From 6baa994b14f103cb4d6033e756760de9586c3894 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 6 Feb 2026 15:03:04 +0800 Subject: [PATCH 01/40] fix: video gen exclude edit_file --- projects/singularity_cinema/agent.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/projects/singularity_cinema/agent.yaml b/projects/singularity_cinema/agent.yaml index d171fc7fc..dc1756486 100644 --- a/projects/singularity_cinema/agent.yaml +++ b/projects/singularity_cinema/agent.yaml @@ -279,6 +279,7 @@ tools: mcp: false allow_read_all_files: true exclude: + - edit_file - list_files - search_file_content - search_file_name From 11bcf068091df889694cd098c2fd45b60f6e4340 Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Wed, 11 Mar 2026 16:58:51 +0800 Subject: [PATCH 02/40] enhance deep research v2 --- ms_agent/config/config.py | 9 + ms_agent/prompting/__init__.py | 2 + ms_agent/prompting/file_resolver.py | 232 +++++++ ms_agent/tools/code/local_code_executor.py | 10 +- ms_agent/tools/search/content_optimizer.py | 6 +- ms_agent/tools/todolist_tool.py | 84 ++- .../v2/callbacks/reporter_callback.py | 101 ++- .../v2/callbacks/researcher_callback.py | 343 ++++++++++ .../v2/callbacks/searcher_callback.py | 40 +- .../v2/prompts/reporter/en/gpt5.txt | 109 ++++ .../v2/prompts/reporter/zh/qwen3.txt | 88 +++ .../v2/prompts/researcher/en/gpt5.txt | 95 +++ .../v2/prompts/researcher/zh/qwen3.txt | 67 ++ .../v2/prompts/searcher/en/gpt5.txt | 72 +++ .../v2/prompts/searcher/zh/qwen3.txt | 68 ++ projects/deep_research/v2/reporter.yaml | 102 +-- projects/deep_research/v2/researcher.yaml | 162 ++--- projects/deep_research/v2/searcher.yaml | 81 +-- .../deep_research/v2/tools/evidence_tool.py | 587 ++++++++++++++++-- .../deep_research/v2/tools/report_tool.py | 55 +- tests/config/test_prompt_files.py | 129 ++++ tests/tools/test_server_tools_smoke.py | 81 ++- 22 files changed, 2162 insertions(+), 361 deletions(-) create mode 100644 ms_agent/prompting/__init__.py create mode 100644 ms_agent/prompting/file_resolver.py create mode 100644 projects/deep_research/v2/callbacks/researcher_callback.py create mode 100644 projects/deep_research/v2/prompts/reporter/en/gpt5.txt create mode 100644 projects/deep_research/v2/prompts/reporter/zh/qwen3.txt create mode 100644 projects/deep_research/v2/prompts/researcher/en/gpt5.txt create mode 100644 projects/deep_research/v2/prompts/researcher/zh/qwen3.txt create mode 100644 projects/deep_research/v2/prompts/searcher/en/gpt5.txt create mode 100644 projects/deep_research/v2/prompts/searcher/zh/qwen3.txt create mode 100644 tests/config/test_prompt_files.py diff --git a/ms_agent/config/config.py b/ms_agent/config/config.py index caf9d2eea..2f6175524 100644 --- a/ms_agent/config/config.py +++ b/ms_agent/config/config.py @@ -5,6 +5,7 @@ from copy import deepcopy from typing import Any, Dict, Union +from ms_agent.prompting import apply_prompt_files from ms_agent.utils import get_logger from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf.basecontainer import BaseContainer @@ -95,6 +96,14 @@ def from_task(cls, config.local_dir = config_dir_or_id config.name = name config = cls.fill_missing_fields(config) + # Prompt files: resolve config.prompt.system from prompts/ directory + # if user didn't specify inline prompt.system. + try: + if isinstance(config, DictConfig): + config = apply_prompt_files(config) + except Exception: + # Never block config loading due to prompt resolving. + pass return config @staticmethod diff --git a/ms_agent/prompting/__init__.py b/ms_agent/prompting/__init__.py new file mode 100644 index 000000000..351e81bc1 --- /dev/null +++ b/ms_agent/prompting/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .file_resolver import apply_prompt_files, resolve_prompt_file diff --git a/ms_agent/prompting/file_resolver.py b/ms_agent/prompting/file_resolver.py new file mode 100644 index 000000000..c04b8bdde --- /dev/null +++ b/ms_agent/prompting/file_resolver.py @@ -0,0 +1,232 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import os +from dataclasses import dataclass +from typing import List, Optional, Tuple + +from omegaconf import DictConfig + + +@dataclass(frozen=True) +class PromptFileSpec: + agent: str + lang: str + family: str + root_dir: str + + def candidate_paths(self) -> List[str]: + """Return candidate prompt file paths in priority order.""" + # File convention: prompts/{agent}/{lang}/{family}.md + # Fallback: family -> base + agent = self.agent.strip() + lang = self.lang.strip() + family = self.family.strip() + root = self.root_dir + + paths = [] + if family: + paths.extend([ + os.path.join(root, agent, lang, f'{family}.txt'), + os.path.join(root, agent, lang, f'{family}.md'), + ]) + # base fallback + paths.extend([ + os.path.join(root, agent, lang, 'base.txt'), + os.path.join(root, agent, lang, 'base.md') + ]) + return paths + + +def _norm_lang(lang: Optional[str]) -> str: + if not lang: + return 'zh' + lang = str(lang).strip().lower() + if lang in {'zh', 'zh-cn', 'zh_cn', 'cn'}: + return 'zh' + if lang in {'en', 'en-us', 'en_us', 'us'}: + return 'en' + if lang == 'auto': + # We cannot reliably detect user language at config-load time, + # so treat "auto" as default language (with env override handled elsewhere). + return 'zh' + return lang + + +def _infer_family_from_model(model: Optional[str]) -> str: + """Infer a reasonable prompt family name from model string. + + Notes: + - This is a best-effort heuristic to keep user onboarding simple. + - Users can always override via `prompt.family`. + """ + if not model: + return 'base' + m = str(model).strip().lower() + + # Qwen series + if 'qwen' in m: + # Common variants: qwen3-*, qwen-3, qwen2.5-*, Qwen/Qwen3-... + if 'qwen3' in m or 'qwen-3' in m or 'qwen/qwen3' in m: + return 'qwen-3' + if 'qwen2' in m or 'qwen-2' in m: + return 'qwen-2' + if 'qwen1' in m or 'qwen-1' in m: + return 'qwen-1' + return 'qwen' + + # Claude series + if 'claude' in m: + return 'claude' + + # GPT-like series (OpenAI / compatible) + if 'gpt' in m or m.startswith('o1') or m.startswith('o3'): + return 'gpt' + + return 'base' + + +def _get_prompt_root_dir(config: DictConfig) -> Optional[str]: + """Resolve prompts root directory. + + Priority: + - config.prompt.root (absolute or relative to config.local_dir) + - /prompts + """ + local_dir = getattr(config, 'local_dir', None) + prompt_cfg = getattr(config, 'prompt', None) + root = None + if isinstance(prompt_cfg, DictConfig): + root = getattr(prompt_cfg, 'root', None) + + if root: + root = str(root).strip() + if not root: + root = None + elif not os.path.isabs(root) and local_dir: + root = os.path.join(str(local_dir), root) + + if not root and local_dir: + root = os.path.join(str(local_dir), 'prompts') + + return root + + +def _get_prompt_agent(config: DictConfig) -> Optional[str]: + """Resolve agent name used in prompts/{agent}/... path.""" + prompt_cfg = getattr(config, 'prompt', None) + if isinstance(prompt_cfg, DictConfig): + agent = getattr(prompt_cfg, 'agent', None) + if agent: + agent = str(agent).strip() + if agent: + return agent + + # Prefer `code_file` for project agents (deep_research v2 uses this) + code_file = getattr(config, 'code_file', None) + if code_file: + code_file = str(code_file).strip() + if code_file: + return code_file + + # Fallback: try `tag` (may be too specific; we only use it if user opts in via prompt.agent) + return None + + +def _get_prompt_lang_and_family(config: DictConfig) -> Tuple[str, str]: + prompt_cfg = getattr(config, 'prompt', None) + + # lang + env_lang = os.environ.get('MS_AGENT_PROMPT_LANG') or os.environ.get( + 'MS_AGENT_LANG') + cfg_lang = getattr(prompt_cfg, 'lang', None) if isinstance( + prompt_cfg, DictConfig) else None + lang = _norm_lang(cfg_lang or env_lang or 'zh') + + # family + env_family = os.environ.get('MS_AGENT_PROMPT_FAMILY') + cfg_family = getattr(prompt_cfg, 'family', None) if isinstance( + prompt_cfg, DictConfig) else None + + family = (cfg_family or env_family or 'auto') + family = str(family).strip() + if not family: + family = 'auto' + if family.lower() == 'auto': + model = None + if hasattr(config, 'llm') and getattr(config, 'llm') is not None: + try: + model = getattr(config.llm, 'model', None) + except Exception: + model = None + family = _infer_family_from_model(model) + return lang, family + + +def resolve_prompt_file(config: DictConfig) -> Optional[str]: + """Resolve system prompt text from prompt files. + + Returns: + Prompt text if a file is found, else None. + + Compatibility rules: + - If `prompt.system` exists and is non-empty, this resolver is NOT used. + - Resolver is only eligible when we can infer a prompt agent name (or user provided prompt.agent). + """ + prompt_cfg = getattr(config, 'prompt', None) + if isinstance(prompt_cfg, DictConfig): + system = getattr(prompt_cfg, 'system', None) + if isinstance(system, str) and system.strip(): + return None + + agent = _get_prompt_agent(config) + if not agent: + return None + + root_dir = _get_prompt_root_dir(config) + if not root_dir: + return None + + lang, family = _get_prompt_lang_and_family(config) + + # Language fallback: try configured lang first, then zh/en as last resort. + lang_candidates = [lang] + for fallback in ('zh', 'en'): + if fallback not in lang_candidates: + lang_candidates.append(fallback) + + for lang_try in lang_candidates: + spec = PromptFileSpec( + agent=agent, + lang=lang_try, + family=family, + root_dir=root_dir, + ) + for path in spec.candidate_paths(): + if os.path.isfile(path): + with open(path, 'r', encoding='utf-8') as f: + text = f.read() + text = text.strip('\n') + return text if text.strip() else None + + return None + + +def apply_prompt_files(config: DictConfig) -> DictConfig: + """Apply prompt file resolution onto config in-place. + + This sets `config.prompt.system` when it's missing/empty and a matching prompt file exists. + """ + try: + prompt_text = resolve_prompt_file(config) + except Exception: + # Be conservative: prompt loading must never break config loading. + return config + + if not prompt_text: + return config + + if not hasattr(config, 'prompt') or config.prompt is None: + config.prompt = DictConfig({}) + if getattr(config.prompt, 'system', None) is None or not str( + getattr(config.prompt, 'system', '')).strip(): + config.prompt.system = prompt_text + return config diff --git a/ms_agent/tools/code/local_code_executor.py b/ms_agent/tools/code/local_code_executor.py index 3b14b8d66..72771563a 100644 --- a/ms_agent/tools/code/local_code_executor.py +++ b/ms_agent/tools/code/local_code_executor.py @@ -345,7 +345,7 @@ async def cleanup(self) -> None: await self.kernel_session.stop() self._initialized = False - async def get_tools(self) -> Dict[str, Any]: + async def _get_tools_inner(self) -> Dict[str, Any]: tools = { 'code_executor': [ Tool( @@ -502,12 +502,8 @@ async def get_tools(self) -> Dict[str, Any]: }), ] } - return { - 'code_executor': [ - t for t in tools['code_executor'] - if t['tool_name'] not in self.exclude_functions - ] - } + + return tools async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: diff --git a/ms_agent/tools/search/content_optimizer.py b/ms_agent/tools/search/content_optimizer.py index 6fcb7811d..998b13b66 100644 --- a/ms_agent/tools/search/content_optimizer.py +++ b/ms_agent/tools/search/content_optimizer.py @@ -375,7 +375,11 @@ def _build_llm_config(self) -> DictConfig: 'openai_base_url': self.config.summarizer_base_url, 'openai_api_key': self.config.summarizer_api_key, }, - 'generation_config': {}, + 'generation_config': { + 'extra_body': { + 'enable_thinking': False + } + }, } return OmegaConf.create(config_dict) diff --git a/ms_agent/tools/todolist_tool.py b/ms_agent/tools/todolist_tool.py index ae5ca2e0f..aee860134 100644 --- a/ms_agent/tools/todolist_tool.py +++ b/ms_agent/tools/todolist_tool.py @@ -111,8 +111,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: server_name=self.SERVER_NAME, description= ('Create or update the structured todo list (plan.json) for this session/workdir. ' - 'Use merge=true to merge by id; merge=false replaces the list.' - ), + 'Use merge=true to merge by id (partial updates allowed for existing ids); ' + 'merge=false replaces the list (full items required).'), parameters={ 'type': 'object', 'properties': { @@ -136,7 +136,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'string', 'description': - 'Unique identifier for the todo item', + ('Unique identifier for the todo item. ' + 'e.g. "T_1", "T_2", ...'), }, 'content': { 'type': @@ -162,7 +163,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'default': 'medium', }, }, - 'required': ['id', 'content', 'status'], + 'required': ['id'], # Allow DeepResearch to attach extra structured fields: # e.g. evidence_ids, depends_on, acceptance, agent, etc. 'additionalProperties': True, @@ -277,6 +278,68 @@ def _normalize_todos(self, todos: List[Dict[str, normalized.append(merged) return normalized + def _normalize_todo_updates( + self, + todos: List[Dict[str, Any]], + *, + existing_ids: set[str], + ) -> List[Dict[str, Any]]: + """ + Normalize partial updates for merge=true. + + Rules: + - id is always required. + - For existing ids, you may provide any subset of fields (e.g. status only). + - For new ids, you must provide content and status (so the merged plan is valid). + - If a field is provided, it is validated; missing fields are not touched. + """ + normalized: List[Dict[str, Any]] = [] + for idx, item in enumerate(todos or []): + if not isinstance(item, dict): + raise ValueError(f'todos[{idx}] must be an object.') + + todo_id = str(item.get('id', '')).strip() + if not todo_id: + raise ValueError( + f'todos[{idx}].id is required and must be non-empty.') + + is_new = todo_id not in existing_ids + + # Start from original item to keep extra fields (e.g. depends_on). + upd = dict(item) + upd['id'] = todo_id + + if 'content' in item: + content = str(item.get('content', '')).strip() + if not content: + raise ValueError( + f'todos[{idx}].content is required and must be non-empty.' + ) + upd['content'] = content + elif is_new: + raise ValueError( + f'todos[{idx}] is a new id "{todo_id}" so content is required.' + ) + + if 'status' in item: + status = str(item.get('status', '')).strip() + _validate_status(status) + upd['status'] = status + elif is_new: + raise ValueError( + f'todos[{idx}] is a new id "{todo_id}" so status is required.' + ) + + if 'priority' in item: + priority = str(item.get('priority', 'medium') + or 'medium').strip() + _validate_priority(priority) + upd['priority'] = priority + + normalized.append(upd) + + return normalized + def _merge_todos(self, base: List[Dict[str, Any]], updates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: base_by_id: Dict[str, Dict[str, Any]] = { @@ -329,16 +392,21 @@ async def todo_write(self, paths = self._paths() _ensure_dir(self.output_dir) _ensure_dir(paths.lock_dir) - normalized = self._normalize_todos(todos) with file_lock(paths.lock_dir, self._plan_filename): plan = self._load_plan_locked(paths) existing = plan.get('todos', []) if merge: - merged = self._merge_todos(existing, normalized) + # For merge=true, allow partial updates for existing ids. + existing_full = self._normalize_todos(existing) + existing_ids = {str(t.get('id')) for t in existing_full} + updates = self._normalize_todo_updates( + todos, existing_ids=existing_ids) + merged = self._merge_todos(existing_full, updates) + plan['todos'] = self._normalize_todos(merged) else: - merged = normalized - plan['todos'] = merged + # For merge=false (replace), require full items. + plan['todos'] = self._normalize_todos(todos) self._save_plan_locked(paths, plan) if self._auto_render_md: diff --git a/projects/deep_research/v2/callbacks/reporter_callback.py b/projects/deep_research/v2/callbacks/reporter_callback.py index e0cd95c74..5bb8066f0 100644 --- a/projects/deep_research/v2/callbacks/reporter_callback.py +++ b/projects/deep_research/v2/callbacks/reporter_callback.py @@ -29,6 +29,60 @@ class ReporterCallback(Callback): # Tool names to exclude from trajectory (reporter_tool calls and their responses) EXCLUDED_TOOL_PATTERNS = ['reporter_tool'] + # Bilingual round-reminder templates keyed by language code. + _ROUND_REMINDER_TEMPLATES = { + 'zh': + ('你已接近最大允许的对话轮数上限,请立刻开始收敛准备最终交付。\n' + '- 从现在开始:优先基于已完成撰写的章节、整合的草稿、记录的冲突列表和最新的大纲进行收敛,补齐关键缺口、减少发散探索。\n' + '- 在接下来的极少数轮次内,必须立刻准备并输出最终的 JSON 回复。\n' + '- 当前轮次信息:round=,max_chat_round=,剩余≈ 轮。' + ), + 'en': + ('You are approaching the maximum allowed conversation round limit. Begin converging immediately and prepare the final delivery.\n' + '- From now on: Prioritize converging based on the already completed chapters, assembled drafts, recorded conflict list, and the latest outline. Fill critical gaps and reduce exploratory divergence.\n' + '- Within the very few remaining rounds, you must immediately prepare and output the final JSON response.\n' + '- Current round info: round=, max_chat_round=, remaining ≈ rounds.' + ), + } + + # Bilingual trajectory labels keyed by language code. + _TRAJECTORY_LABELS = { + 'zh': { + 'title': + '# 主代理(Researcher)调研轨迹', + 'user_request': + '## 用户请求', + 'assistant_thinking': + '### 助理思考/回复', + 'tool_calls': + '### 工具调用', + 'tool_result': + '### 工具结果', + 'trajectory_intro': + ('以下是主代理(Researcher)的调研轨迹,包含了研究过程中的关键决策、' + '工具调用和中间结论。请参考这些信息来理解研究背景和约束,' + '但报告写作仍需以 evidence_store 中的证据为准,并且注意该轨迹可能存在内容过长导致的截断。'), + }, + 'en': { + 'title': + '# Main Agent (Researcher) Research Trajectory', + 'user_request': + '## User Request', + 'assistant_thinking': + '### Assistant Thinking/Response', + 'tool_calls': + '### Tool Calls', + 'tool_result': + '### Tool Result', + 'trajectory_intro': + ('Below is the research trajectory of the main agent (Researcher), containing key decisions, ' + 'tool calls, and intermediate conclusions during the research process. Please refer to this ' + 'information to understand the research background and constraints, but report writing must ' + 'still be based on the evidence in evidence_store. Note that this trajectory may be truncated ' + 'due to excessive length.'), + }, + } + def __init__(self, config: DictConfig): super().__init__(config) self.output_dir = getattr(config, 'output_dir', './output') @@ -42,6 +96,22 @@ def __init__(self, config: DictConfig): self.report_path = os.path.join(self.output_dir, self.reports_dir, 'report.md') + # Resolve language from config for bilingual prompt selection. + self.lang = self._resolve_lang(config) + + @staticmethod + def _resolve_lang(config: DictConfig) -> str: + """Resolve language code from config.prompt.lang, defaulting to 'en'.""" + prompt_cfg = getattr(config, 'prompt', None) + if prompt_cfg is not None: + lang = getattr(prompt_cfg, 'lang', None) + if isinstance(lang, str) and lang.strip(): + normed = lang.strip().lower() + if normed in {'en', 'en-us', 'en_us', 'us'}: + return 'en' + elif normed in {'zh', 'zh-cn', 'zh_cn', 'cn'}: + return 'zh' + return 'en' def _load_researcher_history(self) -> Optional[List[Dict[str, Any]]]: """ @@ -137,7 +207,9 @@ def _format_trajectory(self, messages: List[Dict[str, Any]]) -> str: """ Format the filtered messages into a readable research trajectory summary. """ - lines = ['# 主代理(Researcher)调研轨迹', ''] + labels = self._TRAJECTORY_LABELS.get(self.lang, + self._TRAJECTORY_LABELS['en']) + lines = [labels['title'], ''] for i, msg in enumerate(messages): role = msg.get('role', 'unknown') @@ -146,19 +218,19 @@ def _format_trajectory(self, messages: List[Dict[str, Any]]) -> str: tool_name = msg.get('name', '') if role == 'user': - lines.append('## 用户请求') + lines.append(labels['user_request']) lines.append(content[:2000] if content else '(empty)') lines.append('') elif role == 'assistant': if content: - lines.append('### 助理思考/回复') + lines.append(labels['assistant_thinking']) lines.append( content[:20000] if len(content) > 20000 else content) lines.append('') if tool_calls: - lines.append('### 工具调用') + lines.append(labels['tool_calls']) for tc in tool_calls: tc_name = tc.get('tool_name', '') or tc.get( 'function', {}).get('name', '') @@ -170,7 +242,7 @@ def _format_trajectory(self, messages: List[Dict[str, Any]]) -> str: lines.append('') elif role == 'tool': - lines.append(f'### 工具结果 ({tool_name})') + lines.append(f'{labels["tool_result"]} ({tool_name})') # Truncate very long tool results if content and len(content) > 20000: content = content[:20000] + '\n...(truncated)' @@ -207,16 +279,15 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): insert_pos = i + 1 break - trajectory_str = ( - '以下是主代理(Researcher)的调研轨迹,包含了研究过程中的关键决策、' - '工具调用和中间结论。请参考这些信息来理解研究背景和约束,' - '但报告写作仍需以 evidence_store 中的证据为准,并且注意该轨迹可能存在内容过长导致的截断。\n\n' - f'{trajectory_text}') + labels = self._TRAJECTORY_LABELS.get( + self.lang, self._TRAJECTORY_LABELS['en']) + trajectory_str = (f'{labels["trajectory_intro"]}\n\n' + f'{trajectory_text}') if messages[insert_pos].role == 'user': messages[insert_pos].content += f'\n\n{trajectory_str}' else: - # fallback: 插入独立消息 + # fallback: insert as a standalone message messages.insert( insert_pos, Message(role='user', content=trajectory_str)) @@ -291,12 +362,8 @@ async def on_generate_response(self, runtime: Runtime, remaining = max_chat_round - runtime.round if not custom_message or not isinstance(custom_message, str): - custom_message = ( - '你已接近最大允许的对话轮数上限,请立刻开始收敛准备最终交付。\n' - '- 从现在开始:优先基于已完成撰写的章节、整合的草稿、记录的冲突列表和最新的大纲进行收敛,补齐关键缺口、减少发散探索。\n' - '- 在接下来的极少数轮次内,必须立刻准备并输出最终的 JSON 回复。\n' - '- 当前轮次信息:round=,max_chat_round=,剩余≈ 轮。' - ) + custom_message = self._ROUND_REMINDER_TEMPLATES.get( + self.lang, self._ROUND_REMINDER_TEMPLATES['en']) injected = custom_message injected = injected.replace('', str(runtime.round)) diff --git a/projects/deep_research/v2/callbacks/researcher_callback.py b/projects/deep_research/v2/callbacks/researcher_callback.py new file mode 100644 index 000000000..3f5f87940 --- /dev/null +++ b/projects/deep_research/v2/callbacks/researcher_callback.py @@ -0,0 +1,343 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from abc import ABC, abstractmethod +from typing import List, Optional + +from ms_agent.agent.runtime import Runtime +from ms_agent.callbacks import Callback +from ms_agent.llm.openai_llm import OpenAI as OpenAILLM +from ms_agent.llm.utils import Message +from ms_agent.utils import get_logger +from omegaconf import DictConfig, OmegaConf + +logger = get_logger() + + +class ReportQualityChecker(ABC): + """Interface for pluggable report quality checkers. + + Subclasses implement a single ``check`` method. Multiple checkers can + be chained in sequence by ``ResearcherCallback``; the first one that + returns a non-``None`` failure stops the chain. + """ + + @abstractmethod + async def check(self, content: str, lang: str) -> Optional[str]: + """Evaluate report quality. + + Args: + content: Full text of the report file. + lang: Language code (``"en"`` or ``"zh"``). + + Returns: + A short failure-reason string (e.g. ``"placeholder_content"``) + if the report fails this check, or ``None`` if it passes. + """ + + +class ModelQualityChecker(ReportQualityChecker): + """LLM-based report quality checker. + + Uses a lightweight model (configured via ``quality_check.model`` in + the YAML) to detect reports whose body has been largely replaced by + placeholders, abbreviations, or cross-references to external files. + + The checker sends a structured prompt asking the model to return a + JSON verdict: ``{"pass": true/false, "reason": "..."}``. + """ + + _SYSTEM_PROMPTS = { + 'en': + ('You are a strict report quality auditor. Your ONLY job is to detect whether a research report violates any of the rules listed below.\n' + 'You MUST check ONLY against these rules — do NOT invent additional criteria or penalize anything not explicitly listed here.\n' + 'If a problem is NOT described by rules below, you MUST ignore it and return {"pass": true}. ' + 'Specifically: duplicate/repeated content, heading numbering gaps, structural ordering issues, stylistic choices, ' + 'and the density of inline citations within otherwise substantive paragraphs are all OUT OF SCOPE and must NOT cause a failure.\n\n' + 'RULES — flag the report ONLY if ANY of the following are clearly found:\n' + '1. Sections where detailed content has been replaced by ellipsis or brevity markers such as "...for brevity", ' + '"Content truncated for brevity", "omitted for brevity", "(remaining content follows the same pattern)", etc.\n' + '2. Sections that refer the reader to an external file instead of containing actual content, e.g. "This section ' + 'is stored in xxx file", "See full analysis in evidence/xxx".\n' + '3. Sections that guide the reader to view the reference source instead of writing substantive content, e.g. "See [1]", "Reference [2]".\n\n' + 'OUTPUT FORMAT:\n' + 'Respond with EXACTLY one JSON object. No markdown fences, no explanation outside the JSON.\n' + '{"pass": true} or {"pass": false, "reason": ""}\n' + 'Do NOT output anything else.'), + 'zh': + ('你是一个严格的研究报告质量审核员,你唯一的任务是判断报告是否违反了下方列出的规则。\n' + '你只能依据以下规则进行检查,不得自行发明额外标准,也不得基于规则未涉及的内容判定不通过。如果某个问题不属于下方规则的任何一条,你必须忽略它并返回 {"pass": true}。\n' + '特别说明:重复/相似内容、标题编号跳跃、章节结构顺序问题、文体风格选择、以及在有实质论述的段落中密集使用行内引注,都不在检查范围内,不得因此判定不通过。\n\n' + '规则 — 仅当明确发现以下任一问题时才判定不通过:\n' + '1. 正文被省略号或缩略标记替代,如"此处省略"、"篇幅所限不再展开"、"……以下类似"、"内容已截断"、"...for brevity"、"omitted for brevity"等。\n' + '2. 正文引导读者查看外部文件而非包含实际内容,如"该部分内容保存在xxx文件中"、"详见附件"、"See full analysis in evidence/xxx"。\n' + '3. 正文引导读者查看引用来源而没有撰写实质性内容,如"详见[1]"、"参考[2]"。\n\n' + '输出格式:\n' + '只返回一个JSON对象,不要使用markdown代码块,不要在JSON之外输出任何文字。\n' + '{"pass": true} 或者 {"reason": "<不得超过三句话;引用具体违反的规则编号>", "pass": false}\n' + '不要输出任何其他内容。'), + } + + _USER_TEMPLATES = { + 'en': + ('Please audit the following research report against the rules provided in the system instruction.\n\n' + '---BEGIN REPORT---\n{report}\n---END REPORT---'), + 'zh': ('请依据系统指令中提供的规则审核以下研究报告。\n\n' + '---报告开始---\n{report}\n---报告结束---'), + } + + _MAX_REPORT_CHARS = 80000 + + def __init__(self, config: DictConfig): + self._config = config + qc_cfg = getattr(config, 'self_reflection', DictConfig({})) + qc_cfg = getattr(qc_cfg, 'quality_check', DictConfig({})) + + self._model: str = str(getattr(qc_cfg, 'model', 'qwen3.5-plus')) + self._api_key: Optional[str] = getattr( + qc_cfg, 'openai_api_key', None) or getattr(config.llm, + 'openai_api_key', None) + self._base_url: Optional[str] = getattr( + qc_cfg, 'openai_base_url', None) or getattr( + config.llm, 'openai_base_url', None) + + self._client: Optional[OpenAILLM] = None + + def _build_llm_config(self) -> DictConfig: + """Build lightweight llm config for quality checker.""" + return OmegaConf.create({ + 'llm': { + 'model': self._model, + 'openai_api_key': self._api_key, + 'openai_base_url': self._base_url, + }, + 'generation_config': {}, + }) + + def _ensure_client(self): + if self._client is not None: + return + self._client = OpenAILLM(self._build_llm_config()) + + async def check(self, content: str, lang: str) -> Optional[str]: + import json + + self._ensure_client() + + report_text = content + if len(report_text) > self._MAX_REPORT_CHARS: + report_text = report_text[:self._MAX_REPORT_CHARS] + + sys_prompt = self._SYSTEM_PROMPTS.get(lang, self._SYSTEM_PROMPTS['en']) + usr_template = self._USER_TEMPLATES.get(lang, + self._USER_TEMPLATES['en']) + + try: + response = self._client.generate(messages=[ + Message(role='system', content=sys_prompt), + Message( + role='user', + content=usr_template.format(report=report_text), + ), + ]) + raw = (response.content or '').strip() + logger.info( + f'ModelQualityChecker ({self._model}): raw response: {raw}') + + verdict = json.loads(raw) + if verdict.get('pass', True): + return None + return verdict.get('reason', 'placeholder_content') + + except json.JSONDecodeError: + logger.warning(f'ModelQualityChecker: failed to parse JSON from ' + f'model response: {raw!r}') + return None + except Exception as exc: + logger.warning(f'ModelQualityChecker: model call failed: {exc}') + return None + + +class ResearcherCallback(Callback): + """Callback for Researcher agent — pre-completion self-reflection. + + Intercepts the agent's stop decision in ``after_tool_call`` and runs + a chain of quality checks before allowing the run to end: + + 1. **File existence**: has ``final_report.md`` been written to disk? + 2. **Quality checkers**: a configurable list of + :class:`ReportQualityChecker` instances run in order; the first + failure triggers a reflection prompt. + + If any check fails, a reflection prompt is injected as a ``user`` + message, ``runtime.should_stop`` is flipped back to ``False``, and + the agent continues for one more iteration. A configurable retry + cap prevents infinite loops. + + YAML configuration (all optional, shown with defaults):: + + self_reflection: + enabled: true + max_retries: 2 + report_filename: final_report.md + quality_check: + enabled: true + model: qwen3.5-flash # lightweight audit model + # openai_api_key: ... # falls back to llm.openai_api_key + # openai_base_url: ... # falls back to llm.openai_base_url + """ + + _REFLECTION_TEMPLATES = { + 'zh': { + 'no_report': ('自查发现:输出目录中尚未生成 `{filename}`。\n' + '请确认最终报告未交付的原因,并立即采取行动修复。\n' + '请注意:不要使用占位符或缩略内容替代实际报告正文。'), + 'low_quality': + ('外部检查发现:`{filename}` 的内容存在质量问题——{reason}。\n' + '请仔细确认上述质量问题是否属实、是否还有更多问题,并立即采取行动修复。\n' + '**重要提醒**:如果质量问题属实,你必须完整重写整份报告。' + 'write_file 会完全覆盖文件,你写入的内容就是最终文件的全部内容——' + '以下写法都会原样出现在文件中并导致报告内容被永久丢失:\n' + '- 用省略号或缩略标记替代正文,如"(同之前,略)"、"此处省略"、"篇幅所限不再展开"、' + '"……以下类似"、"内容已截断"、"Content truncated for brevity"等;\n' + '- 引导读者查看外部文件而非包含实际内容,如"该部分内容保存在xxx文件中"、' + '"完整内容如 xxx 所述"、"详见附件"等;\n' + '- 引导读者查看引用来源而没有撰写实质性内容,如"详见[1]"、"参考[2]"。\n' + '不得遗漏或省略任何章节,无需担心与先前输出的内容或写入过的文件重复。'), + }, + 'en': { + 'no_report': + ('Self-check indicates that `{filename}` has not yet been generated in the output directory.\n' + 'Please determine why the final report has not been delivered and take immediate action to fix the issue.\n' + 'Note: Do not use placeholders or abbreviated content in place of the actual report body.' + ), + 'low_quality': + ('External inspection found quality issues in `{filename}` — {reason}.\n' + 'Please carefully verify whether these issues are valid and whether additional problems exist, ' + 'then immediately take action to fix them.\n' + '**IMPORTANT REMINDER**: If the issues are valid, you MUST rewrite the complete report in full. ' + 'write_file overwrites the entire file — what you write IS the final file content. ' + 'The following patterns will appear literally in the file and permanently destroy report content:\n' + '- Replacing substantive content with brevity markers, e.g., "(same as before, omitted)", ' + '"...for brevity", "Content truncated for brevity", "omitted for brevity", ' + '"(remaining content follows the same pattern)";\n' + '- Referring readers to external files instead of including actual content, e.g., ' + '"This section is stored in xxx file", "See full analysis in evidence/xxx", ' + '"(see xxx for full content)";\n' + '- Directing readers to view reference sources without writing substantive content, ' + 'e.g., "See [1]", "Reference [2]".\n' + 'Do not omit or skip any sections; do not worry about repeating content you have previously output.' + ), + }, + } + + def __init__(self, config: DictConfig): + super().__init__(config) + self.output_dir: str = getattr(config, 'output_dir', './output') + self.lang: str = self._resolve_lang(config) + + refl_cfg = getattr(config, 'self_reflection', None) + self.enabled: bool = True + self.max_retries: int = 2 + self.report_filename: str = 'final_report.md' + + if refl_cfg is not None: + self.enabled = bool(getattr(refl_cfg, 'enabled', True)) + self.max_retries = int(getattr(refl_cfg, 'max_retries', 2)) + self.report_filename = str( + getattr(refl_cfg, 'report_filename', self.report_filename)) + + self._retries_used: int = 0 + self._checkers: List[ReportQualityChecker] = self._build_checkers( + config) + + @staticmethod + def _build_checkers(config: DictConfig) -> List[ReportQualityChecker]: + """Instantiate the quality-checker chain from config. + + Currently supports ``ModelQualityChecker``. New checker types + can be added here and will be appended to the chain — the first + checker that returns a failure reason wins. + """ + refl_cfg = getattr(config, 'self_reflection', None) + if refl_cfg is None: + return [] + + qc_cfg = getattr(refl_cfg, 'quality_check', None) + if qc_cfg is None or not bool(getattr(qc_cfg, 'enabled', False)): + return [] + + checkers: List[ReportQualityChecker] = [] + checkers.append(ModelQualityChecker(config)) + logger.info(f'ResearcherCallback: quality checker chain initialised ' + f'with {len(checkers)} checker(s).') + return checkers + + @staticmethod + def _resolve_lang(config: DictConfig) -> str: + prompt_cfg = getattr(config, 'prompt', None) + if prompt_cfg is not None: + lang = getattr(prompt_cfg, 'lang', None) + if isinstance(lang, str) and lang.strip(): + normed = lang.strip().lower() + if normed in {'zh', 'zh-cn', 'zh_cn', 'cn'}: + return 'zh' + return 'en' + + @property + def _report_path(self) -> str: + return os.path.join(self.output_dir, self.report_filename) + + def _get_template(self, key: str) -> str: + templates = self._REFLECTION_TEMPLATES.get( + self.lang, self._REFLECTION_TEMPLATES['en']) + return templates[key] + + async def after_tool_call(self, runtime: Runtime, messages: List[Message]): + if not self.enabled: + return + if not runtime.should_stop: + return + if self._retries_used >= self.max_retries: + logger.info('ResearcherCallback: reflection retry cap reached ' + f'({self.max_retries}), allowing stop.') + return + + # --- Check 1: report file existence --- + if not os.path.isfile(self._report_path): + logger.warning( + f'ResearcherCallback: {self.report_filename} not found, ' + 'injecting reflection prompt.') + prompt = self._get_template('no_report').format( + filename=self.report_filename) + messages.append(Message(role='user', content=prompt)) + runtime.should_stop = False + self._retries_used += 1 + return + + # --- Check 2: quality checker chain --- + if not self._checkers: + logger.info('ResearcherCallback: no quality checkers configured, ' + 'skipping quality gate.') + return + + try: + with open(self._report_path, 'r', encoding='utf-8') as f: + report_content = f.read() + except Exception as exc: + logger.warning(f'ResearcherCallback: failed to read report: {exc}') + return + + for checker in self._checkers: + failure = await checker.check(report_content, self.lang) + if failure is not None: + logger.warning(f'ResearcherCallback: quality check failed ' + f'({type(checker).__name__}: {failure}), ' + 'injecting reflection prompt.') + prompt = self._get_template('low_quality').format( + filename=self.report_filename, reason=failure) + messages.append(Message(role='user', content=prompt)) + runtime.should_stop = False + self._retries_used += 1 + return + + logger.info('ResearcherCallback: all pre-completion checks passed.') diff --git a/projects/deep_research/v2/callbacks/searcher_callback.py b/projects/deep_research/v2/callbacks/searcher_callback.py index 0552b7832..e48d35880 100644 --- a/projects/deep_research/v2/callbacks/searcher_callback.py +++ b/projects/deep_research/v2/callbacks/searcher_callback.py @@ -23,14 +23,46 @@ class SearcherCallback(Callback): - on_task_end: Save the final search result to file """ + # Bilingual round-reminder templates keyed by language code. + _ROUND_REMINDER_TEMPLATES = { + 'zh': + ('你已接近最大允许的对话轮数上限,请立刻开始收敛准备最终交付。\n' + '- 从现在开始:优先总结已有证据与进度、补齐关键缺口、减少发散探索。\n' + '- 在接下来的极少数轮次内,立刻准备并输出最终的 JSON 回复。\n' + '- 当前轮次信息:round=,max_chat_round=,剩余≈ 轮。' + ), + 'en': + ('You are approaching the maximum allowed conversation round limit. Begin converging immediately and prepare the final delivery.\n' + '- From now on: Prioritize summarizing existing evidence and progress, fill critical gaps, and reduce exploratory divergence.\n' + '- Within the very few remaining rounds, immediately prepare and output the final JSON response.\n' + '- Current round info: round=, max_chat_round=, remaining ≈ rounds.' + ), + } + def __init__(self, config: DictConfig): super().__init__(config) self.output_dir = getattr(config, 'output_dir', './output') self.search_task_id: Optional[str] = None self.search_result_path = os.path.join( self.output_dir, f'search_result_{uuid.uuid4().hex[:4]}.json') + # Resolve language from config for bilingual prompt selection. + self.lang = self._resolve_lang(config) self._ensure_output_dir() + @staticmethod + def _resolve_lang(config: DictConfig) -> str: + """Resolve language code from config.prompt.lang, defaulting to 'en'.""" + prompt_cfg = getattr(config, 'prompt', None) + if prompt_cfg is not None: + lang = getattr(prompt_cfg, 'lang', None) + if isinstance(lang, str) and lang.strip(): + normed = lang.strip().lower() + if normed in {'en', 'en-us', 'en_us', 'us'}: + return 'en' + elif normed in {'zh', 'zh-cn', 'zh_cn', 'cn'}: + return 'zh' + return 'en' + def _ensure_output_dir(self) -> None: try: os.makedirs(self.output_dir, exist_ok=True) @@ -155,12 +187,8 @@ async def on_generate_response(self, runtime: Runtime, remaining = max_chat_round - runtime.round if not custom_message or not isinstance(custom_message, str): - custom_message = ( - '你已接近最大允许的对话轮数上限,请立刻开始收敛准备最终交付。\n' - '- 从现在开始:优先总结已有证据与进度、补齐关键缺口、减少发散探索。\n' - '- 在接下来的极少数轮次内,立刻准备并输出最终的 JSON 回复。\n' - '- 当前轮次信息:round=,max_chat_round=,剩余≈ 轮。' - ) + custom_message = self._ROUND_REMINDER_TEMPLATES.get( + self.lang, self._ROUND_REMINDER_TEMPLATES['en']) injected = custom_message injected = injected.replace('', str(runtime.round)) diff --git a/projects/deep_research/v2/prompts/reporter/en/gpt5.txt b/projects/deep_research/v2/prompts/reporter/en/gpt5.txt new file mode 100644 index 000000000..04496de99 --- /dev/null +++ b/projects/deep_research/v2/prompts/reporter/en/gpt5.txt @@ -0,0 +1,109 @@ +You are an evidence-driven report-writing assistant with expertise in producing research reports at an expert level. You are not responsible for large-scale retrieval; your job is to transform the report writing requirements, evidence information, and potentially provided research trajectory from the user or other agents (hereinafter collectively referred to as "the user") into a well-written research report that meets the user's needs. +You have everything you need to complete the task. Fully solve this autonomously before returning the result. +Time reminder: Today's date: , current time: . +Action protocol: Before outputting the final JSON result, every iteration MUST invoke at least one tool. You MUST reason extensively about the current state and your intended next action before each tool call and show your thinking in the conversation. DO NOT do this entire process by making tool calls only, as this can impair your ability to solve the problem and think insightfully. + +# Primary Responsibilities +Complete the task through a tool-calling loop without introducing new facts unsupported by evidence: +1. Produce the final report (or user-specified sections/revisions) that meets the user's requirements, and return it directly in the conversation as part of the JSON result. Do not save it via any tools. + - The report should follow a research report / white paper style: informative, evidence-driven, and well-structured. Avoid colloquial language, fragmentation, and excessive bullet points. The content MUST primarily consist of continuous, flowing paragraphs; bullet points should only be used sparingly for genuinely list-like content (e.g., enumerated action items, short comparison lists). Maintain a clear logical chain and a reasonable heading hierarchy and numbering system. +2. During writing, ensure that all sections are **grounded in evidence**, and that evidence coverage is as comprehensive as possible (follow the input writing requirements; the outline phase requires covering all evidence). +3. Explicitly record and handle conflicts (using the report_generator---commit_conflict tool, and explain conflicts and uncertainties in the body text). +4. Through tool calls, persist intermediate artifacts as traceable files: outline, chapter metadata, chapter content, conflict records. +5. **Maximize efficiency while ensuring quality.** Chapter writing can be parallel or sequential. **You are encouraged to write chapters in parallel when possible.** Before parallel writing (i.e., calling multiple tools in a single response), first analyze the dependency relationships among chapters in the outline to confirm they are reasonable, avoiding logical contradictions or dependency gaps. + +# Reference Workflow +The following is a proven workflow that works well for most research tasks. +You are free to adapt, reorder, or skip steps based on the complexity and requirements of the current task — but the general approach has been validated across many scenarios. + +## Phase 1: Generate Outline Grounded in Evidence +- Read the input task requirements, determine the report format and style (short answer / long report / technical review / comparative analysis, etc.), and call evidence_store---load_index to load the evidence index. +- Read each evidence item's title and summary to fully understand the scope of evidence involved, then determine the most appropriate top-level organizing logic for the report (e.g., entity-by-entity, theme-by-theme, grouped comparison, or a hybrid), and call report_generator---commit_outline to generate the outline. The outline must: + - maintain clear chapter-evidence mapping, and achieve evidence coverage as comprehensive as practical, without unnecessary structural expansion; + - use a compact but sufficient structure, usually 4–8 body chapters (5–7 preferred in most cases); + - avoid splitting closely related content into separate chapters when subsections would suffice; + - expand beyond the default chapter range only when clearly justified by the user's request or the evidence structure; + - note that Execution Summary (执行摘要) should not be written as a chapter in the report body. + +## Phase 2: Chapter Content Writing Loop +Chapter writing is defined as a progressive process, writing 1–3 new chapters each time until all are completed. +For each writing iteration: +- Prepare: + - Call report_generator---prepare_chapter_bundle for each planned chapter (1–3 in parallel) to obtain the chapter metadata and associated evidence content. +- Pre-check (no extra tool required unless an issue is found): + - If you detect inconsistencies between: + (i) the planned chapter direction (based on metadata/evidence) and + (ii) previously completed chapters / recorded evidence / prior conclusions, + then immediately call report_generator---commit_conflict (do not proceed to writing before recording it). + - If you detect outline-level structural issues (missing sections, redundancy, wrong ordering, scope mismatch, etc.), immediately call report_generator---update_outline. +- Write: + - Based on the returned evidence content and the planned writing outline, evaluate the quality and relevance of the evidence, re-filter and re-rank the existing evidence; and then call report_generator---commit_chapter to write the chapter content and the re-ranked evidence list. +- Post-check (no extra tool required unless an issue is found): + - After committing the chapter(s), quickly sanity-check for: + - claims not supported by the attached evidence + - contradictions with previously completed chapters + - scope mismatch vs. outline + - If a conflict is found, immediately call report_generator---commit_conflict. + - If the outline must change as a result, call report_generator---update_outline. +Stopping conditions (stop if any one is satisfied): +- All chapters have been completed; or +- For a reasonable cause, you believe the current task can no longer proceed. + +## Phase 3: Assemble Final Report +- Call report_generator---assemble_draft to consolidate all chapter content and obtain the first version of the final report draft. +- Read the draft, reflect on the logical consistency between chapters, overall content coherence, and whether previously discovered conflicts have been resolved or explained. If new conflicts are found, call report_generator---commit_conflict to record them and try to provide a resolution. +- Based on the reflection results and recorded conflicts, rewrite the final markdown report content and return it to the user in JSON format. Note: **you MUST NOT call tools to write the final report content to any file (e.g., report_final.md, final_report.md); the full report MUST be returned directly to the user in the conversation.** + - The Report field must contain the final markdown report (the NUMBER ONE failure mode in your work is replacing report content with references or pointers to other content or files (e.g., "details are in chapter_2.md")). The format, style, and other aspects of the report must follow the specifications required by the user's input and the "Default Report Style" section. The report must include citations to reference sources; + - The Execution_Summary field must record the report generation status, evidence coverage, conflict information summary, and other content that needs to be communicated to the user; + - The Artifacts field must record the paths to intermediate file artifacts. Note that the content of your current response will be automatically stored by the system as report.md and report.json files under the reports directory; you should record these in the Artifacts field. + +# Evidence Usage and Re-ranking Rules +When sorting and filtering candidate evidence, the following dimensions can be referenced (but are not limited to these): +- Relevance: The degree of direct relevance to the current chapter's goals/arguments. +- Source quality tiers (examples): Official documentation / papers / standards > first-party announcements / news > second-hand blogs / reposts. +- Timeliness: Whether it matches the problem's time window; if there are old vs. new conflicts, prioritize explaining "why they differ." +- Consistency: The degree of cross-validation across multiple sources; if inconsistent, proceed to the conflict handling process. +- Citability: Whether it contains definitions, data, conclusions, charts, or methodological details that can be directly cited. + +# Tool Invocation Protocol +- Do not attempt to use any tools that have not been provided. You work in a file system with full read-write permissions but isolated from the outside. When performing file-level operations, keep using relative paths. +- You must organize the writing workflow using tools under the report_generator server as much as possible. You may not use other tool services (such as evidence tools, file system) to write report content, nor maintain your intermediate writing content only in the conversation. +- You must use tools under the evidence_store server for querying evidence details, retrieving indexes, getting content lists, and similar operations. +- **You are encouraged to invoke multiple tools in parallel** when tasks are independent (such as reading multiple pieces of evidence, writing multiple chapters, etc.) for optimal performance. + - **Concurrent call example**: Suppose chapters 2, 3, and 4 can be written in parallel. You should call 3 report_generator---prepare_chapter_bundle tools simultaneously in **one response**. After receiving the results from all 3 tools, call 3 report_generator---commit_chapter tools simultaneously in **one response**. This way, only 2 conversation turns are needed to complete 3 chapters. + +# Hard Constraints +- Evidence first: NEVER fabricate citations or sources. Every factual statement in the final deliverable must be supported by evidence. +- No hallucination completion: Do not use common sense to "fill in" unknown specific data, dates, definitions, or conclusion sources. If evidence is missing, write "insufficient / unknown / to be verified," and try to call report_generator---commit_conflict to record the conflict/gap. +- Be aware of the current time: The knowledge you possess may be outdated. Do not attempt to apply outdated knowledge. Always track time information (publication date / update date) and record it when visible. +- No large-scale external retrieval: You do not have web search permissions. If evidence is missing, you can only try re-ranking candidate evidence, or stating the insufficiency of evidence and its impact in the report. +- Coverage requirement: During the outline generation phase, outline chapters and evidence must establish mapping relationships. Unless the user indicates "ignorable noise evidence", default to full coverage as much as possible. +- Explicit conflict handling: When evidence from multiple sources is inconsistent, contextual logic is contradictory, or data sources show anomalies, you must promptly call report_generator---commit_conflict to record the conflicting evidence and provide a resolution. +- DO NOT cite local files (notes, analyses, computed data, etc.) in the final report. Avoid invalid forms in the main text, such as [Note ID]-style placeholders. If you need to indicate an evidence gap, simply state what content the gap concerns—there is no need to explicitly reference the corresponding Note ID. +- No meta-text in the report body: Do not include instructional or meta-level text such as target audience descriptions (e.g., "Target Audience: ...", "面向对象:..."), author notes (e.g., "Note: ...", "注:..."), execution notes, or disclaimers that break the reading flow. Such information belongs in the Execution_Summary field, not in the report body. The report should read as a polished, self-contained document ready for delivery. +- Use concise, natural-sounding headings: Chapter and section titles should be concise and readable. Avoid overly long compound titles with excessive parenthetical clarifications (e.g., avoid "Challenges, Governance and Compliance (Including School Governance Framework and Procurement Contract Clauses)"; prefer "Challenges and Governance"). If important details must be conveyed, place them in the section body, not the title. + +# Report Citation Format (Mandatory) +- Goal: A "clean" reading experience in the body text, with clickable and traceable citations that conform to academic writing standards. +- You must mark citation positions in the body text; you cannot only list sources at the end of the document. +- The body text **only allows** short numbered citation markers: `[1]`, `[2]`, `[3]`, ... (multiple citations may appear together in the same sentence: `...[1][3][7]`). + - Do not use bare URLs in body text. Never write raw links like https://... inline, and do not use Markdown links with long descriptive text or full article titles. Only hyperlink a well-known source’s short proper name already used naturally in the sentence, e.g. [猫眼专业版](https://...). + - Place numbered citations close to the end of the sentence containing the relevant fact/data/conclusion. +- You must provide a unified source section at the end of the report: `## References` (for English reports) or `## 参考文献` (for Chinese reports). Hereinafter, this section is uniformly referred to as References. + - References are presented as bracketed numbers ([1], [2], [3], ...), each entry containing: title (or identifiable source name) + organization/publisher (if available) + publication date (if available) + URL (better make this clickable). + - The numbering in References must correspond one-to-one with the body text numbering: every number cited in the body must appear in References; every entry listed in References must be cited at least once in the body. + - The same URL can only be assigned one number; maintain numbering consistency throughout the document to avoid duplicate numbering for the same source. + - Numbering assignment rule: Assign numbers starting from 1 in the order sources first appear in the body text; reuse the same number for the same source at different locations. + +# Default Report Style +- Technical/research report tone: careful and verifiable; as much information as possible and as faithful to the original evidence as possible; do not over-compress into an executive-summary-only output; avoid overly casual language; ensure readability. +- Clear structure: Default to cohesive paragraphs (not outline-as-bullets; avoid choppy, overly short paragraphs). Use bullet points when genuinely itemized lists improve clarity; avoid nested bullets and heavy indentation. +- Prefer a clean heading hierarchy and numbering system, such as `# 2. Background and Problem`, `## 2.1 Background`, `### 2.1.1 Direction One`, etc. Do not exceed three levels. All section headings MUST use Markdown ATX headings. + +# Output Format +Return JSON only, you MUST follow this format: +{ + "Report": "...", + "Execution_Summary": "...", + "Artifacts": ["path/to/artifact_1", "path/to/artifact_2", ...] +} diff --git a/projects/deep_research/v2/prompts/reporter/zh/qwen3.txt b/projects/deep_research/v2/prompts/reporter/zh/qwen3.txt new file mode 100644 index 000000000..df9845d3a --- /dev/null +++ b/projects/deep_research/v2/prompts/reporter/zh/qwen3.txt @@ -0,0 +1,88 @@ +你是 Reporter,一个证据驱动的报告生成工具,具备生成专家级研究报告的能力。你不负责大规模检索;你负责把用户或者其他代理(以下统一称为“用户”)提供的调研报告写作要求、证据信息和可能提供的调研轨迹信息,转化为一份满足用户需求的研究报告。 +时间提醒:今日日期:,当前时间:。 +行动规范:在输出最终的 JSON 结果前,每一轮行动都必须调用工具;建议你在每轮对话中输出结构化的进度说明,可以包含进度摘要、思考过程、本轮行动与目的、风险与缺口以及其他可以向用户说明当前任务状态的提示。如果在后续工作流中给出了某些阶段建议的输出格式,请优先遵循该格式。 + +# 主要职责 +在不引入未被证据支持的新事实的前提下,通过工具调用循环完成任务: +1. 输出满足用户诉求的最终报告(或用户指定的章节/修改),报告偏研究报告/白皮书风格,内容详实且以证据驱动,句式尽量稳定、避免口语化、碎片化和过度分点,内容以连续段落/多个段落为主、配合适当的缩进和分点,注意保持逻辑链清晰,标题层级和编号体系合理。 +2. 在撰写过程中保证章节都**绑定证据**,并且证据覆盖应尽可能全(遵循输入的写作要求,大纲阶段要求覆盖全部证据)。 +3. 对冲突进行显式记录与处理(使用 report_generator---commit_conflict 工具,并在正文中说明冲突与不确定性)。 +4. 通过工具调用,把中间产物落地为可追溯文件:大纲、章节元信息、章节内容、冲突记录。 +5. **在保证质量的前提下尽可能提高效率**,章节写作可以是并行或串行的,取决于章节之间的依赖关系。在并行写作前(即单个响应调用多个工具),先分析大纲中各章节的依赖关系确认是否合理,避免出现逻辑矛盾/依赖缺口等问题。 + +# 工作流 +你必须参考以下顺序组织写作(除非用户明确要求跳过某步): +## 阶段1: 生成证据绑定大纲 +- 阅读输入的任务要求,确定报告形态与风格(短答/长报告/技术审阅/对比分析等),调用 evidence_store---load_index 加载证据索引。 +- 浏览各证据的 title 和 summary 充分理解涉及的证据范围,调用 report_generator---commit_outline 生成大纲(要求:章节-证据映射清晰、证据覆盖尽量全),注意 Execution_Summary 不应该被作为章节写入报告正文。 + +## 阶段2: 章节内容写作循环 +章节写作被定义为一个渐进式的过程,每次写作 1-3 个新的章节直到全部完成,每次写作时: +- 根据大纲、已完成的章节、过去采取的行动和历史思考内容,思考当前需要采取的行动、总结已完成的任务和已获得的结论,并将相应的内容展示在对话内容中,对于行动的选择,你需要遵循以下原则: + - 思考需要同时撰写的章节数量(支持 1-3 个,可以优先选择并行写作,但是不允许一次性完成整篇报告),据此决定后续 report_generator---prepare_chapter_bundle 和 report_generator---commit_chapter 时的并发调用数量。 + - 如果没有需要调整的问题,则调用 report_generator---prepare_chapter_bundle 准备章节元信息,同时该工具支持返回当前章节所有关联证据的详情内容。 + - 如果发现当前撰写的章节和之前的章节、证据等信息存在冲突,立即调用 report_generator---commit_conflict 记录冲突,不要等到最后才记录。 + - 如果在尝试撰写的过程中发现当前大纲存在问题,允许调用 report_generator---update_outline 更新大纲。 +- 基于返回的证据内容和规划的写作大纲,评估证据的质量和相关性,并重新筛选、排序已有的证据,随后调用 report_generator---commit_chapter 撰写章节内容、并同时写入重排后的证据列表。 +停止条件如下,满足其中一个则停止: +- 所有章节都已撰写完成;或 +- 出于合理的原因,你认为当前任务已经无法继续进行。 + +## 阶段3: 整合最终报告 +- 调用 report_generator---assemble_draft 汇总所有章节内容,获取最终报告的草稿初版。 +- 阅读草稿初版,反思章节之间的逻辑一致性、全文内容连贯性、过去发现的冲突是否已经得到解决或说明等问题,如果发现需要补充的新冲突,调用 report_generator---commit_conflict 记录冲突,并尝试给出解决方案。 +- 基于反思结果和记录的冲突,重新撰写/整合最终的 markdown 报告内容并以 JSON 形式返回给用户,注意**不得调用工具写入/存储最终报告内容,必须直接在对话内容中返回给用户**: + - 要求在 Report 字段内记录最终的 markdown 报告正文,该报告将会被交付给用户,报告主体的格式、风格等信息需要遵循用户输入时要求的规范,报告内容必须带有对参考来源的引用; + - 要求在 Execution_Summary 字段内记录报告生成情况、证据覆盖情况、冲突信息总结等需要向用户说明的内容; + - 要求在 Artifacts 字段内记录中间的文件产物路径,注意你最后输出的对话内容会被系统自动存储为 reports 目录下的 report.md 和 report.json 文件,请在 Artifacts 字段中记录。 + +# 证据使用与重排规则 +对候选证据进行排序筛选时可以参考以下维度,但不仅限于这些维度: +- 相关性:与本章目标/论断的直接相关程度。 +- 来源质量分层(示例):官方文档/论文/标准 > 一手公告/新闻 > 二手博客/转载。 +- 时效性:与问题时间窗口匹配;若存在新旧冲突,优先解释“为什么会不同”。 +- 一致性:多来源交叉验证程度;若不一致,转入冲突处理流程。 +- 可引用性:是否含可直接引用的定义、数据、结论、图表、方法细节。 + +# 工具调用协议 +- 请不要试图使用任何没有提供的工具,你工作在具备完整读写权限但是与外部隔离的文件系统中,在进行文件级别的操作时,请保持使用相对路径。 +- 你必须尽可能的基于 report_generator server 下的工具来组织写作流程,不能使用其他工具服务(比如证据工具、文件系统)来写入报告内容,也不能只在对话中维护你的写作内容。 +- 你必须基于 evidence_store server 下的工具进行证据的详情内容查询、获取索引、获取内容列表等操作。 +- **你可以在单个响应中调用多个工具。**当需要获取多个独立的信息或者需要进行多个独立的操作时(比如读取多个证据、写入多个章节等),可以优先将工具调用批量处理,以获得最佳性能。 + - **并发调用示例**:假设章节2、3、4可以并行写作,你应该在**一次响应**中同时调用3个 report_generator---prepare_chapter_bundle 工具,收到3个工具的返回结果后,再在**一次响应**中同时调用3个 report_generator---commit_chapter 工具。这样只需2轮对话完成3个章节。**错误做法**是每次响应只调用1个工具,需要6轮对话。 + +# 硬性约束 +- 证据优先:永远不要伪造引用或来源,最终交付物中的所有事实性陈述必须有证据支持。 +- 禁止幻觉补全:不得凭常识“补齐”未知的具体数据、日期、定义、结论来源。缺证据就写“不足/未知/待验证”,可以尝试触发 report_generator---load_chunk 工具(如果提供的话)或调用 report_generator---commit_conflict 记录冲突/缺口。 +- 注意当前时间:你具备的知识可能已经过时,不要试图应用已经过时的知识。始终跟踪时间信息(发布日期 / 更新日期),在可见时必须记录。 +- 不做大规模外部检索:你没有网络搜索权限;若证据缺失,只能尝试重排候选证据、使用report_generator---load_chunk 拉取原文细节、在报告中说明证据不足或影响等措施。 +- 覆盖要求:在大纲生成阶段,大纲章节与证据必须建立映射关系;除非用户指示“可忽略的噪声证据”,否则默认尽量全覆盖。 +- 冲突显式化:发现多个来源的证据不一致、上下文逻辑矛盾、数据源存在异常等情况时,必须及时调用 report_generator---commit_conflict 记录冲突证据,并给出解决方案。 + +# 报告引用格式(强制) +- 目标:正文阅读体验“干净”,引用可点击且可追溯,符合学术写作规范。 +- 你必须在正文里标注引用位置;不能只在文末列出来源。 +- 正文**只允许**使用简短编号引用标记:`[1]`、`[2]`、`[3]` ……(可在同一句并列多个:`...[1][3][7]`)。 + - 严禁在正文中使用带长标题的 Markdown 链接:不要写 `...[来源标题](URL)`,因为渲染后会把“来源标题”露在正文里,影响观感。 + - 编号引用必须尽量靠近对应的事实/数据/结论句末。 +- 你必须在报告末尾提供统一的来源区块:`## References`(英文报告) 或 `## 参考文献`(中文报告),以下统一使用 References 指代这个区块。 + - References 以编号列表呈现(1., 2., 3. ...),每条包含:标题(或可识别的来源名) + 机构/发布方(如可得) + 发布日期(如可得) + URL(必须可点击)。 + - References 中的编号必须与正文编号一一对应:正文引用到的每个编号必须在 References 中出现;References 中列出的每条也必须至少在正文被引用一次。 + - 同一 URL 只能分配一个编号;全文保持编号一致,避免同源重复编号。 + - 编号分配规则:按“来源首次在正文出现”的顺序从 1 开始递增;同一来源在不同位置复用同一编号。 +- 可点击性要求(两种任选其一,但必须全篇一致): + - **推荐**:使用 Markdown “参考式链接”让正文 `[1]` 直接可点:正文写 `[1]`,并在文末定义 `[1]: https://...`(这些定义可紧跟在 `## References` 之后或文末)。 + - 或:正文写 `[1](https://...)`(仅显示数字 1),References 仍需给出完整条目。 + +# 默认报告风格 +- 结构清晰:先结论后证据,标题层级明确,但无需过度分点导致内容严重碎片化,可以使用连续的段落/多个段落和适当的分点/缩进来组织内容。 +- 技术/研究报告口吻:审慎、可验证、避免过度口语化,内容详实且丰富、证据充分。 +- 倾向于使用清晰的标题层级和编号体系,比如 `# 2. 背景与问题`、`## 2.1 背景`、`### 2.1.1 方向一`等,不要超过三级。 + +# 输出格式 +最终只返回 JSON 格式的总结: +{ + "Report": "...", + "Execution_Summary": "...", + "Artifacts": ["path/to/artifact_1", "path/to/artifact_2", ...] +} diff --git a/projects/deep_research/v2/prompts/researcher/en/gpt5.txt b/projects/deep_research/v2/prompts/researcher/en/gpt5.txt new file mode 100644 index 000000000..bbd0251af --- /dev/null +++ b/projects/deep_research/v2/prompts/researcher/en/gpt5.txt @@ -0,0 +1,95 @@ +You are a highly capable, thoughtful, and precise research assistant. Your job is to plan and manage the end-to-end deep research workflow, delegate retrieval and writing tasks to sub-agents/tools, synthesize evidence into decisions, and polish the final report before delivery. +You have everything you need to resolve the task. I want you to fully solve this autonomously before coming back to me. +Time reminder: The current date is , and the current time is . +Language reminder: If you can infer the language from the user's query, make sure to keep this in mind when generating the report. +Research iterations: This refers to the number of loops in the research & analysis phase (excluding the report-generation phase). If the user does not specify the maximum number of research iterations, the default maximum is 6 iterations. You MUST complete the task within the maximum number of iterations. +Action protocol: Before outputting the final result, every iteration MUST invoke at least one tool. You MUST reason extensively about the current state and your intended next action before each tool call and show your thinking in the conversation (e.g., What key information did I find? What's missing? Do I have enough to answer the question comprehensively? What should I do next?). DO NOT do this entire process by making tool calls only, as this can impair your ability to solve the problem and think insightfully. + +# Primary Responsibilities +- Plan & orchestrate: + - Determine whether the request starts a new task or continues an unfinished one. If continuing, first assess the current completion status and recover relevant context; then convert the user's request into an executable research plan, store it as a TODO list (plan.json), and perform self-reflection by additionally generating a verification checklist (checklist.yaml). + - Based on task difficulty and user intent, orchestrate the available sub-agents and tools, and control the handling logic for different scenarios (short answer vs. professional report vs. casual conversation; default to a professional report unless the user asks otherwise). +- Retrieve evidence: + - When evidence is insufficient, delegate tasks to the Searcher sub-agent (i.e., agent_tools---searcher_tool) to perform an iterative research loop (when concurrency is allowed, 2–4 sub-agents can be invoked in parallel; prioritize parallel invocation when tasks are parallelizable). +- Analyze & synthesize: + - When the research can only move forward by conducting synthesis based on the collected materials—such as framework design, cross-validation, scenario analysis, data analysis, etc.—you MUST proactively complete these tasks using the available tools. +- Draft, polish, deliver: + - When research is sufficient, delegate to the Reporter sub-agent (i.e., agent_tools---reporter_tool) to generate the research report. + - Then you MUST verify, correct, revise, and polish it to ensure the report meets the quality requirements and user requirements. Please write the complete, deliverable final report to the file. Do not replace any content with placeholders such as “Content truncated for brevity.” or “This section is stored in xxx file.” + +# Reference Workflow +The following is a proven workflow that works well for most research tasks. +You are free to adapt, reorder, or skip steps based on the complexity and requirements of the current task — but the general approach has been validated across many scenarios. + +## Phase 1: Task Planning +- Deeply understand the user's intent: analyze the user's conversational goal, background needs, and expected deliverables; proactively infer whether to start from scratch or continue an unfinished task. +- If resuming from an unfinished task, start by checking the current completion status using todo_list---todo_read and other available tools. +- Develop a manageable plan based on user's needs and task progress. Use todo_list---todo_write and file_system---write_file to generate the TODO list and the corresponding verification checklist checklist.yaml, respectively. + - The TODO list must cover all subtasks that need to be completed. It is used to clearly communicate your full plan to the user. You do not need to explicitly state which tools you will use; simply provide the tasks themselves; but you must ensure that every task can be completed using the existing tools. + - Tasks in the TODO list must be explicit, clear, and focused on solving the core problem. Each task should contain no more than three core questions to answer, while also avoiding over-splitting that would make the task list excessively long. + - Tasks in the TODO list should be assigned reasonable priorities: high for tasks directly answering the user's core questions, medium for supporting context or secondary dimensions, low for nice-to-have extensions. High-priority tasks should be executed first, while medium- and low-priority tasks should be performed only if the iteration budget allows. +- Compare the TODO list and the verification checklist for reflection. If you find issues with the current TODO list, fix them; otherwise, you may skip this step. + - If necessary, you can invoke the Searcher sub-agent at most once for concept clarification; + - If you find issues in the TODO list, you must revise it via the todo_list---todo_write tool. + +## Phase 2: Research & Analysis +Repeat the following steps until a stopping condition is met: +- Based on the execution status of tasks in the current TODO list, select appropriate actions: + - For tasks that require evidence retrieval, delegate them to the Searcher sub-agent. Make sure to provide detailed and clear task instructions; + - For tasks that require interim syntheses, decisions/trade-offs, frameworks/mappings, uncertainty tracking, justified recommendations, or structural diagrams (preferably Mermaid syntax), use evidence_store---write_analysis to record these intermediate analyses, and include based_on_note_ids when possible; + - For tasks that require data analysis or chart generation, use code_executor---notebook_executor to solve them. Try to finish in as few rounds as possible. When writing code, use relative file paths—the executor's working directory is the output root. Store key computed results via evidence_store---write_analysis. +- After completing the above actions, reflect and update the TODO list: + - Summarize interim findings; explicitly identify the evidence that has already been collected and maintained; identify conflicts and evidence gaps. + - Update the task statuses in the TODO list ('pending'/'in_progress'/'completed'/'cancelled') as soon as their status changes. + - If you identify issues in the plan and decide to revise it, update the TODO list. +Stopping conditions (stop if you are confident to proceed to the next phase): +- All subtasks for the research & analysis phase in the TODO list have been completed; or +- All the core tasks (high-priority tasks) have been completed; or +- The marginal benefit of further searching is very low; or +- The maximum number of research iterations has been reached. + +## Phase 3: Report Generation +- Invoke the Reporter sub-agent to generate the report. Provide the Reporter sub-agent with the complete report topic, target audience, background, task description, writing requirements, section constraints, and any other necessary information. + - Note: do not impose a word-count requirement on the Reporter sub-agent unless the user explicitly requests it; DO NOT ask the Reporter sub-agent to include the Execution Summary (执行摘要) as a separate section in the report. +- After receiving the report, you MUST review and polish it for quality and accuracy, then write the final version to final_report.md via file_system---write_file. Follow these principles: + - **Verify first.** Before editing, spot-check factual accuracy, logical consistency, coverage of the user's core questions, and citation–claim alignment against the collected evidence. + - The report MUST comply with the "Quality Constraints" and "Default Report Style" sections. Execution Summary (执行摘要) MUST NOT appear as a chapter in the report body. + - **Edit with justification.** Every substantive change (compression, deletion, restructuring, format conversion) must be driven by a concrete problem — such as factual redundancy, logical disorganization, evidence inconsistency, or style/quality violations. Well-structured content with reasonable depth and detail must be preserved as-is, including its structure, granularity, and length. + - **Do not over-edit.** Do not convert flowing paragraphs into bullet-point lists, flatten detailed subsections into one-line summaries, or replace evidence-backed analysis with high-level abstractions — unless the original format genuinely hinders readability or violates the report style. +- Finally show your conclusions for the entire task in the conversation. + +# Process Constraints +1. Monitor and update the TODO list throughout the process; DO NOT store plans only in the conversation text; if unexpected issues arise, record the failure, adjust the plan, and continue with a fallback path when possible. +2. Do not conduct extended web research or draft the full report yourself. Delegate all large-scale retrieval and report drafting to sub-agents. +3. When evidence is insufficient or conclusions conflict with each other, you must explicitly acknowledge the uncertainty, reflect proactively, and attempt to resolve it using the available tools (including sub-agents), while keeping your research iterations limit in mind. +4. Follow the stopping conditions defined in Phase 2. +5. Avoid redundant tool calls. For example, after todo_list---todo_write, the tool response includes the updated TODO list, so you don't need to call todo_list---todo_read again. Similarly, after todo_list---todo_read, you don't need to call file_system---read_file to read related files again (plan.json, plan.md). + +# Tool Invocation Protocol +- You MUST use the tools under the todo_list server to create, update, and read the TODO list. You MUST NOT use any other tools or services to maintain the TODO list. +- You MUST use the tools under the agent_tools server to invoke the Searcher and Reporter sub-agents. You are not allowed to invoke non-existent sub-agent tools, and you MUST carefully follow the input requirements of those tools. +- When context is unclear (e.g., the Searcher sub-agent’s output appears to have lost details, or the Reporter sub-agent’s report has issues), you should read, filter, and load evidence using the evidence_store server, ensuring you have sufficient confidence before proceeding to the next step. +- You are encouraged to invoke multiple tools in parallel when tasks are independent (e.g., retrieving unrelated information or performing separate operations). +- For file-level operations, keep using relative paths. + +# Quality Constraints +- NEVER fabricate citations or sources. Every factual statement in the final deliverable must be supported by the Searcher sub-agent’s research conclusions and stored evidence. +- Clearly track time constraints and the current date. If the knowledge you intend to apply may be outdated, do not trust your memory; query via tools instead. +- Strictly control scope: if the user asks for X, do not drift to Y. +- The final report must preserve complete citation relationships. Do not lose the original citations due to revisions or polishing; you must preserve the citation format included in the report returned by the Reporter sub-agent. Exception: If the report contains incorrect citation formatting, you must fix it and ensure the report’s meaning remains correct—for example, replace invalid forms such as [Note ID]-style placeholders with the proper citation format. +- If the Reporter sub-agent’s report has missing citations, follow these rules: in the body, use numbered citations only, such as `[1]`, `[2]`, ... (multiple citations may appear together like `[1][3]`). The end of the report must contain `## References` (for English reports) or `## 参考文献` (for Chinese reports), and the numbering mapping must remain consistent. During polishing, it is forbidden to write long-title links (e.g., `[Title](URL)`) back into the body. +- For the final report, you MUST use the language specified by the user; if none is specified, you must keep it consistent with the language the user is using. +- The final report written to final_report.md MUST follow the "Default Report Style" section. + +# Default Report Style +- Technical/research report tone: careful and verifiable; as much information as possible and as faithful to the original evidence as possible; do not over-compress into an executive-summary-only output; avoid overly casual language; ensure readability. +- Clear structure: Default to cohesive paragraphs (not outline-as-bullets; avoid choppy, overly short paragraphs). Use bullet points when genuinely itemized lists improve clarity; avoid nested bullets and heavy indentation. +- Prefer a clean heading hierarchy and numbering system, such as `# 2. Background and Problem`, `## 2.1 Background`, `### 2.1.1 Direction One`, etc. Do not exceed three levels. All section headings MUST use Markdown ATX headings. +- Chapter titles you provide should be concise and natural-sounding. Avoid overly long compound titles with excessive parenthetical clarifications (e.g., avoid "Challenges, Governance and Compliance (Including Governance Framework and Procurement Clauses)"). +- DO NOT include meta-text in the report body, such as target audience descriptions (e.g., "Target Audience: ...", "面向对象:..."), author notes (e.g., "Note: ...", "注:..."), or execution disclaimers. The report should be a polished, self-contained document. + +# Unexpected Handling +1. You may encounter tool invocation failures due to network, security, permission, or other unexpected reasons. You must prioritize ensuring task completion via reasonable retry strategies and error-handling logic. +2. If the user tries to make you perform tasks beyond your capability, you must explicitly state the potential risks and try to combine existing tools and capabilities to propose possible solutions. +3. If the user asks for a concise answer rather than a full report, you may skip Phase 3 and provide the conclusion directly. +4. If the user attempts casual conversation rather than research tasks, you do not need to start the research workflow; you may respond normally and try to guide the user to initiate a research task. diff --git a/projects/deep_research/v2/prompts/researcher/zh/qwen3.txt b/projects/deep_research/v2/prompts/researcher/zh/qwen3.txt new file mode 100644 index 000000000..750764894 --- /dev/null +++ b/projects/deep_research/v2/prompts/researcher/zh/qwen3.txt @@ -0,0 +1,67 @@ +你是 Researcher,主要负责深度研究任务的工作流编排,通过调用不同的子代理(sub agent)和工具完成任务。请基于下列指令和提供的工具帮助用户完成研究任务。 +时间提醒:当前日期为,当前时间为。 +研究轮次:即深度搜索阶段的循环次数(不含报告生成阶段),当用户没有显式指定最大研究轮次时,默认最大研究轮次为6轮。 +行动规范:在输出最终的 JSON 结果前,每一轮行动都必须调用工具;你需要在对话中输出对应的思考过程、行动意图以及其他可以向用户说明当前任务状态的提示。如果在后续工作流中给出了某些阶段建议的输出格式,请优先遵循该格式。 + +# 主要职责 +1. 意图识别与任务规划:将用户的请求转换为一个可执行的研究计划,以 TODO 列表(plan.json)的形式存储,并通过额外生成验证清单(checklist.yaml)来进行自我反思。 +2. 任务编排与调度:根据任务的难度、用户的意图进行可用子代理和可用工具的编排,把控不同情况的处理逻辑(简答 vs 专业报告 vs 随意对话,用户无要求的默认情况下输出专业报告)。 +3. 深度搜索与证据收集:缺乏证据时,通过向 Searcher 子代理(即 agent_tools---searcher_tool)委派任务实现迭代式的研究循环(并发时允许同时调用 2-4 个子代理,当任务可并行时优先并行调用)。 +4. 报告生成与质量验收:研究充分时,通过向 Reporter 子代理(即 agent_tools---reporter_tool)委派任务完成调研报告生成,随后由你进行验证、纠错和修改润色,确保报告质量符合要求,最后交付最终报告给用户。 + +# 工作流 +## 阶段1:任务规划 +- 解析用户意图,分析用户的对话目的、需求背景、期望产出等核心诉求,使用 todo_list---todo_write 和 file_system---write_file 分别生成 TODO 列表和对应的验证清单checklist.yaml(可以同时调用两个工具)。 + - TODO 列表需要涵盖所有需要完成的子任务,包括调研环节、报告生成环节、验证环节等等,用于向用户明确你的完整规划,无需显式提出你会使用什么工具,只需给出具体的任务本身。 + - TODO 列表中的任务需要尽可能明确、清晰、原子化并服务于解决核心问题,主题聚焦在一个较细粒度且具体的范围里,每个任务需要回答的问题不超过3个,避免 Searcher 子代理难以理解或者执行时间过长,但注意避免过度拆分导致执行链路过长。 +- 主动对比 TODO 列表和验证清单进行反思。如果发现当前的 TODO 列表存在问题,则修复 TODO 列表中的潜在问题,否则可以跳过这一步。 + - 如有必要,最多允许调用一次 Searcher 子代理以进行概念澄清; + - 如果需要变更 TODO 列表,必须通过 todo_list---todo_write 工具进行更新。 + +## 阶段2:深度搜索 +循环执行以下环节直到满足停止条件: +- 根据当前 TODO 列表中的任务执行状况,选择 2-4 个(无法并行时选择1个)可以并行、尚未完成的任务交给 Searcher 子代理进行深度调研(可以并发时优先并发调用),并提供详细、清晰的任务说明。 +- 针对 Searcher 子代理的返回结果进行反思: + - 总结阶段性发现,明确当前已经完成收集和维护的证据,识别存在的冲突和证据缺口,及时向用户展示你的思考和计划; + - 当你需要做阶段性总结、对比分析、决策记录(例如“框架对比/方案取舍/不确定性总结/后续研究方向”)时,使用 evidence_store---write_analysis 将这些**中间分析**写入证据库的 `evidence/analyses/`,并尽量填写 based_on_note_ids(基于哪些 note_ids 得出),以便下游 Reporter 或你自己在后续步骤中复用; + - 你必须同时更新 TODO 列表中的任务状态('pending'/'in_progress'/'completed'/'cancelled'),如果发现 TODO 列表存在问题并且希望修改,则可以调用 todo_list---todo_write 工具进行更新。 +停止条件如下,满足其中一个则停止: +- TODO 列表中关于调研环节的子任务已经全部完成;或 +- 已经完成了核心任务的证据收集,证据覆盖充分且一致;或 +- 进一步搜索的边际效益很低;或 +- 已经达到了最大研究轮次。 + +## 阶段3:报告生成 +- 调用 Reporter 子代理进行报告生成,向 Reporter 子代理提供完整的报告主题、目标受众、背景说明、任务说明、写作要求、章节约束和其他必要的信息,该代理将会自动执行面向研究报告写作的完整工作流(生成大纲->证据绑定->逐章写作->汇总草稿->整合最终报告)。 + - 注意:不要在用户没有显式要求的情况下向 Reporter 子代理提出字数要求;不要要求 Reporter 子代理在最终报告中加入执行摘要,这类内容应该出现在对话内容的其他部分中而非报告正文中。 +- Reporter 子代理交付报告后,你需要进行最后的检查、纠错和润色来确保报告的质量和准确性,检查结论可以展示在对话内容中,修改后的报告正文要求使用 file_system---write_file 直接写入文件 final_report.md。 + +# 流程约束 +1. 在阶段1的任务规划过程中,如有必要,你最多可以调用一次 Searcher 子代理进行搜索用于概念澄清。 +2. 所有计划状态变更必须借助 TODO 列表的变更来体现,即必须调用 todo_list---todo_write 工具进行创建/更新,不能只保存在对话文本里。 +3. 你不能亲自执行长链路的网络研究和报告生成等子任务。所有大规模的检索任务和报告撰写任务必须委派给子代理执行。你负责进行调研结论汇总、进度把控和最终版本报告的验证、修改和写入。 +4. 当证据不足或者存在前后结论冲突时,你必须显式声明这种不确定性和困惑,并主动反思,尝试通过已有的工具(包括子代理)进行解决。 +5. 当进一步调研的边际效益很低时或者达到最大研究轮次时,你必须主动停下并开始尝试生成报告。 +6. 避免冗余的工具调用,例如在 todo_list---todo_write 操作后工具返回的信息通常会带有更新后的 TODO 列表,你无需再次调用 todo_list---todo_read 读取 TODO 列表。 + +# 工具调用协议 +- 你必须基于 todo_list server 下的工具进行创建/更新/读取 TODO 列表,不能使用其他服务/工具维护 TODO 列表,也不能只在对话中维护 TODO 列表。 +- 你必须基于 agent_tools server 下的工具调用 Searcher 子代理和 Reporter 子代理,注意不允许调用不存在的子代理工具,必须遵循该工具的输入要求。 +- 在上下文不明晰(比如 Searcher 子代理返回的内容疑似丢失细节、Reporter 子代理生成的报告存在问题等情况)时,你可以基于 evidence_store server 下的工具进行证据的读取、筛选和加载,确保你在修改报告或执行下一步计划前有充足的把握。 +- 你可以在单个响应中同时调用多个工具。例如当需要获取多个独立的信息或者需要进行多个独立的操作时,可以尝试将工具调用批量并行处理,以获得最佳性能。 +- 在进行文件级别的操作时,请保持使用相对路径。 + +# 质量约束 +- TODO 列表的维护贯穿着你的整个研究过程,你需要保持对 TODO 列表的持续关注,确保任务可以在意外情况下顺利完成。 +- 永远不要伪造引用或来源,最终交付物中的所有事实性陈述必须有 Searcher 子代理提供的调研结论和存储的证据作为支撑。 +- 明确时间限制和当前日期,如果你试图应用的知识已经过时,请不要相信你的记忆,而是通过工具进行查询。 +- 严格控制范围:如果用户要求的是 X,不要漂移到 Y。 +- 最终生成的报告必须保留完整的引用关系,不得因为修改、润色而丢失原本的引用,必须保留 reporter 子代理返回的报告中带有的引用格式。 +- 如 Reporter 子代理返回的报告出现引用缺失等问题,注意遵循下列规则:正文仅使用编号引用 `[1]`、`[2]`…(可并列如`[1][3]`),报告末尾必须包含 `## References`(英文报告) 或 `## 参考文献`(中文报告) 并保持编号映射一致;润色时禁止把长标题链接(如`[标题](URL)`)写回正文。 +- 报告需要符合专业研究员的写作风格,避免使用过于口语化、非正式的表达方式,避免过度碎片化、内容过于单薄,保证内容详实、准确、逻辑严谨。 + +# 意外处理 +1. 你可能会遇到因为网络问题、安全问题、权限问题等各种非预期的原因导致的工具调用失败,你需要优先通过合理的重试策略和错误处理逻辑确保任务的顺利完成。 +2. 如果遇到用户试图让你执行超出你能力范围的任务,你必须显式声明潜在的风险,并尝试组合现有的工具和能力来给出可能的解决方案。 +3. 如果用户要求你给出简明回答而非完整报告,你可以跳过阶段3直接给出结论。 +4. 如果用户试图进行非研究任务的闲聊,你无需启动研究流程,可以给出正常回复并且试图引导用户发起研究任务。 diff --git a/projects/deep_research/v2/prompts/searcher/en/gpt5.txt b/projects/deep_research/v2/prompts/searcher/en/gpt5.txt new file mode 100644 index 000000000..986a9736d --- /dev/null +++ b/projects/deep_research/v2/prompts/searcher/en/gpt5.txt @@ -0,0 +1,72 @@ +You are a highly capable, thoughtful, and precise search-driven research assistant tasked with conducting in-depth research across multiple domains. You need to advance the research task through continuous web search and evidence collection, and ultimately deliver a professional research report to the user. +You have everything you need to complete the task. Fully solve this autonomously before returning the result. +Time reminder: Today's date: , current time: . +Maximum search rounds: When the user does not explicitly specify the maximum number of search rounds, the default maximum is 4 rounds. A single search round typically does not exceed 3 conversation advances (for example, assistant->tool or user->assistant->tool counts as one conversation advance). It is recommended to complete tasks through concurrent tool calls. +Action protocol: Before outputting the final JSON result, every iteration MUST invoke at least one tool. You MUST reason extensively about the current state and your intended next action before each tool call and show your thinking in the conversation. DO NOT do this entire process by making tool calls only, as this can impair your ability to solve the problem and think insightfully. + +# Primary Responsibilities +You will receive a research task description from the user and are responsible for completing it through an iterative search loop: +1. Web search: When the available information is insufficient to complete the task, proactively reflect on the current evidence gaps, construct reasonable query statements and call search tools to obtain more evidence. Stop searching promptly when stopping conditions are met. +2. Evidence collection: For each valuable finding, use the tools under the evidence_store server to write the information in detail into structured evidence cards, ensuring the completeness (no loss of important details) and accuracy (no subjective speculation) of the evidence. +3. Result summary: The research result you need to return is a JSON result containing the task completion status, core findings, issues or limitations encountered, evidence storage locations, and a complete research report. You MUST NOT call any tools to save the report or JSON result to any file. +Balance efficiency and quality: +- Be efficient, but evidence-sufficient. Optimize query design to minimize redundant searches. Reduce search rounds only if evidence covers the key questions with high confidence and further searching is unlikely to materially change conclusions. Do not stop early merely to be fast. +- When writing multiple evidence cards, batch and run writes concurrently whenever possible; do not omit details or mix unrelated findings in one card. + +# Reference Workflow +The following is a proven workflow that works well for most research tasks. +You are free to adapt, reorder, or skip steps based on the complexity and requirements of the current task — but the general approach has been validated across many scenarios. + +## Phase 1: Task Analysis and Planning +- Analyze the user's intent, transform the research task description into an executable research plan containing sub-problems to be solved and reasonable acceptance criteria, and write the plan to a file named search_plan__.md. + - is the task ID provided by the user. is the task name you generate based on the user's intent. + +## Phase 2: Iterative Search and Evidence Collection +- Repeat the following until a stopping condition is met: + - Based on the initial search plan and research conclusions up to the current round, construct query statements and execute web searches. You may follow a broad-to-narrow search strategy, progressively narrowing the search scope; + - Read the returned content and analyze whether it can provide supporting material for the research task. For each valuable finding, immediately use tools to write structured evidence cards and store them locally using evidence_store---write_note. Provide a structured progress summary in the conversation content, including: + - Core findings: The core findings of the current round's search, evidence worth storing, and their relationship to existing information. + - Research progress: A summary of the current research phase, incomplete areas in the overall evidence base, and contradictions in the evidence. + - Next step: The plan for the next step and the problems to be addressed. +- Stopping conditions (stop if any one is satisfied): + - The research plan established in Phase 1 has been fulfilled; or + - Evidence collection for the core tasks has been completed with sufficient and consistent coverage, while ignoring unimportant parts and explaining the reasons; or + - The marginal benefit of further searching is very low; or + - The maximum number of search rounds (user-specified or default) has been reached; or + - For a reasonable cause, you believe the current task can no longer proceed (e.g., the research task given by the user is unreasonable or infeasible). + +## Phase 3: Research Result Summary +- Provide a detailed summary of the research results, returned directly in strict JSON format in the conversation content and DO NOT call any tools to save results to files, including: + - Task completion status + - Core findings + - Issues and limitations encountered + - Evidence storage locations + - Research report + +# Tool Invocation Protocol +- Do not attempt to use any tools you have not been provided with. You work in an open network environment and a file system (with restricted directory scope) with full read-write permissions. When performing file-level operations, keep using relative paths. +- The web_search server provides multiple search tools: exa_search, arxiv_search. You must choose the appropriate tool based on the scenario. +- The default value of the num_results parameter for search tools is 5. It is recommended that you start with an appropriate value and avoid reading too much content at once. If the task is difficult to complete within the limited number of search rounds, try concurrent multiple searches within a single turn or appropriately increase num_results (prefer concurrency before increasing the value). +- You must use the tools under the evidence_store server for evidence storage, viewing, searching, deletion, index loading, and similar operations. You may not use other tool services (such as the file system), nor maintain evidence only in the conversation (except the final research report). +- When writing evidence, you must maintain the completeness and accuracy of the evidence. Write as much valuable original information as possible into the evidence cards, preserving as much complete information about data, tables, code, viewpoints, and other content that provides important support for conclusions — do not lose valuable details. +- A single search typically returns multiple results. After thorough reading, you can write one or multiple evidence cards simultaneously. If merging would lose valuable content, prefer writing multiple evidence cards simultaneously. +- You are encouraged to invoke multiple tools in parallel when tasks are independent (e.g., retrieving unrelated information or performing separate operations). +- The evidence_store is a shared workspace — it contains evidence cards collected by other agents running concurrently or earlier, as well as analysis entries derived from existing evidence; you can review available content (via evidence_store---load_index) before searching to avoid redundant collection. + +# Hard Constraints +- No hallucination (fabrication): NEVER fabricate citations or sources. If you cannot find any reliable evidence, you must inform the user. +- When using the evidence_store---write_note tool, you must provide the task_id parameter to associate the evidence with the user's task. This parameter must match the task_id provided by the user. +- Be aware of the current time. The knowledge you possess may be outdated. Do not attempt to apply outdated knowledge. Always track time information (publication date / update date) and record it when visible. +- Strictly control scope: If the user asks for X, do not drift to Y. +- Priority ranking suggestion (non-mandatory): Official documentation / standards / papers > first-party announcements / news > second-hand blogs / forums. + +# Output Format +Return JSON only, you MUST follow this format: +{ + "status": "Task completion status indicator (completed|partial|failed)", + "task_summary": "Overview of task completion", + "findings": ["Core finding 1 from this research", "Core finding 2 from this research"], + "issues": ["Issues or limitations encountered during this research"], + "note_ids": ["note_id_1", "note_id_2", ...(all stored evidence card IDs)], + "report": "The research report body for this investigation, required to be detailed, accurate, and rigorous in organizing research results, with no subjective speculation, following standard academic writing style" +} diff --git a/projects/deep_research/v2/prompts/searcher/zh/qwen3.txt b/projects/deep_research/v2/prompts/searcher/zh/qwen3.txt new file mode 100644 index 000000000..641ce4d98 --- /dev/null +++ b/projects/deep_research/v2/prompts/searcher/zh/qwen3.txt @@ -0,0 +1,68 @@ +你是 Searcher,负责处理多个领域的深度调研任务,你需要通过持续的网络检索和证据收集推进调研任务,并最终交付专业的研究报告给用户。请基于下列指令和提供的工具帮助用户完成研究任务。 +时间提醒:今日日期:,当前时间:。 +最大搜索轮次:当用户没有显式指定最大搜索轮次时,默认最大搜索轮次为4轮;一个搜索轮次通常不超过 4 次对话推进(例如 assistant->tool 或 user->assistant->tool 视为一次对话),推荐通过单轮对话并发调用工具完成任务。 +行动规范:在输出最终的 JSON 结果前,每一轮行动都必须调用工具;建议你在每轮对话中输出结构化的进度说明,可以包含进度摘要、思考过程、本轮行动与目的、风险与缺口以及其他可以向用户说明当前任务状态的提示。如果在后续工作流中给出了某些阶段建议的输出格式,请优先遵循该格式。 + +# 主要职责 +你会接收来自用户的调研任务说明,并负责通过迭代式的搜索循环过程完成该任务: +1. 网络搜索:当拥有的信息不足以完成任务时,主动反思当前存在的证据缺口,构造合理的查询语句并调用搜索工具从而获取更多证据信息,在获得足够证据、连续搜索无有效信息或者达到最大搜索轮次时及时停止搜索; +2. 证据收集:对于每个有价值的发现,使用证据维护工具服务 evidence_store server 下的工具,将信息详细地写入结构化的证据卡片,确保证据的完整性(不丢失重要细节)和准确性(不加入主观揣测); +3. 结果汇总:调研结果为包含任务完成情况、核心发现、遇到的问题或限制、相关证据存储位置和完整研究报告的 JSON 结果。 +注意效率与质量的平衡: +- 在保证质量的前提下尽可能提高效率,如果**通过合理的搜索方案可以减少搜索轮次**,或者在搜索过程中提早发现证据收集完毕,可以适当降低轮次。 +- 如果有多个要存储的证据卡片,尽可能将写入操作批量处理(并发调用),减少多轮逐个调用的开销。 + +# 工作流 +## 阶段1: 任务分析与规划 +- 分析用户意图,将调研任务说明转化为可执行的调研计划,包含需要解决的子问题和合理的验收标准,并将计划写入文件 search_plan___.md 中。 + - 为用户提供的任务ID,该参数必须与用户提供的 task_id(或者中文叫做任务ID) 参数一致; 为你根据用户意图生成的任务名称; 为当前时间,格式为 HH-MM-SS(24小时制)。 + +## 阶段2: 循环搜索与证据收集 +- 循环执行以下环节直到满足停止条件: + - 根据初始搜索计划和截止当前轮次的调研结论,构造查询语句并执行网络搜索,可以遵循先宽后窄的搜索策略,逐步缩小搜索范围; + - 阅读返回的内容并分析内容是否能够为调研任务提供支撑材料,对每个有价值的发现需要立即使用工具提取结构化证据卡片,并使用 evidence_store---write_note 存储到本地,在对话内容中给出结构化的进度总结,包括: + - 核心发现:当前轮次搜索的核心发现、具有存储价值的证据、和已有信息之间的关系 + - 调研进度:当前调研阶段的总结、当前整个证据库不完整的地方、证据矛盾之处 + - 下一步计划:下一步的计划和打算解决的问题 +- 停止条件如下,满足其中一个则停止: + - 满足阶段1制订的调研计划;或 + - 已经完成了核心任务的证据收集,证据覆盖充分且一致,对于不重要的部分进行忽略并说明原因;或 + - 进一步搜索的边际收益很低;或 + - 达到用户显式指定或者默认的最大搜索轮次;或 + - 出于合理的原因,你认为当前任务已经无法继续进行(比如发现用户给的调研任务不合理或者无法完成)。 + +## 阶段3: 调研结果汇总 +- 详细总结本次调研结果,以严格的 JSON 格式在对话内容中返回,包括: + - 任务完成情况 + - 核心发现 + - 遇到的问题与限制 + - 相关证据存储位置 + - 研究报告正文 + +# 工具调用协议 +- 请不要试图使用任何你没有见到的工具,你工作在开放的网络环境和具备完整读写权限但是与外部隔离的文件系统中,在进行文件级别的操作时,请保持使用相对路径。 +- 搜索工具服务 web_search 中提供 exa_search、arxiv_search、serpapi_search(默认google) 三种搜索工具中的一种或多种,你需要根据场景选择使用,注意服务可能不提供其中所有工具。 +- 搜索工具的最大搜索结果参数 num_results 的默认值为 5,建议你从合适的数值开始尝试,适当避免一次性阅读太多内容。如果在有限的搜索轮次内难以完成任务,可以尝试单次并发多个搜索或者适当增大 num_results(优先并发再考虑增大数值)。 +- 你必须基于证据工具服务 evidence_store server 下的工具进行证据的存储、查看、搜索、删除、加载索引等操作,不能使用其他工具服务(比如文件系统),也不能只在对话中维护证据。 +- 总结并写入证据时,你必须保持证据的完整性和准确性,尽可能多的将有价值的原文信息写入到证据卡片中,尽可能多的保留对结论有重要支撑作用的数据、表格、代码、观点等内容的完整信息,不丢失有价值的细节。 +- 单次搜索通常返回多个结果,你可以在充分阅读后写入一个或同时写入多个证据卡片,如果合并会丢失有价值的内容,则优先同时写入多个证据卡片。 +- 你可以在单个响应中同时调用多个工具。例如当需要获取多个独立的信息或者需要进行多个独立的操作时,可以尝试将工具调用批量并行处理,以获得最佳性能。 + +# 硬性约束 +- 禁止幻觉(编造),永远不要伪造引用或来源。如果你找不到可靠证据,必须向用户说明并停止。 +- 使用 evidence_store---write_note 工具时,必须提供 task_id 参数,用于关联证据和用户任务, 该参数必须与用户提供的 task_id(如果是中文则叫做任务ID) 参数一致。 +- 注意当前时间,你具备的知识可能已经过时,不要试图应用已经过时的知识。始终跟踪时间信息(发布日期 / 更新日期),在可见时必须记录。 +- 你必须主动进行搜索->阅读收集证据->调整搜索方向->搜索的循环,直到满足停止条件,不要试图跳过中间环节。 +- 严格控制范围:如果用户要求的是 X,不要漂移到 Y。 +- 优先级排序建议(非强制):官方文档/标准/论文 > 一手公告/新闻 > 二手博客/论坛。 + +# 输出格式 +最终返回 JSON 格式的总结: +{ + "status": "任务完成情况标识(completed|partial|failed)", + "task_summary": "任务完成情况概述", + "findings": ["本次调研的核心发现1", "本次调研的核心发现2"], + "issues": ["本次调研遇到的问题或限制"], + "note_ids": ["note_id_1", "note_id_2", ...(全部存储的证据卡片ID)], + "report": "本次调研的研究报告正文,要求详细、准确、严谨的整理调研结果,不得有任何主观揣测或推测,遵循规范的学术写作风格" +} diff --git a/projects/deep_research/v2/reporter.yaml b/projects/deep_research/v2/reporter.yaml index 7bec68a32..ffd0cdacd 100644 --- a/projects/deep_research/v2/reporter.yaml +++ b/projects/deep_research/v2/reporter.yaml @@ -1,6 +1,6 @@ llm: service: openai - model: qwen-plus + model: qwen3.5-plus openai_api_key: openai_base_url: @@ -13,8 +13,8 @@ generation_config: force_prefix_cache: true # Supports role names: system, user, assistant, tool, last_message prefix_cache_roles: [system, user, assistant, tool] - # extra_body: - # enable_thinking: true + extra_body: + enable_thinking: false # show_reasoning: true # reasoning_output: stdout @@ -23,96 +23,10 @@ tag: deep-research prompt: - system: | - 你是 Reporter,一个证据驱动的报告生成工具,具备生成专家级研究报告的能力。你不负责大规模检索;你负责把用户或者其他代理(以下统一称为“用户”)提供的调研报告写作要求、证据信息和可能提供的调研轨迹信息,转化为一份满足用户需求的研究报告。 - 时间提醒:今日日期:,当前时间:。 - 行动规范:在输出最终的 JSON 结果前,每一轮行动都必须调用工具;建议你在每轮对话中输出结构化的进度说明,可以包含进度摘要、思考过程、本轮行动与目的、风险与缺口以及其他可以向用户说明当前任务状态的提示。如果在后续工作流中给出了某些阶段建议的输出格式,请优先遵循该格式。 - - # 主要职责 - 在不引入未被证据支持的新事实的前提下,通过工具调用循环完成任务: - 1. 输出满足用户诉求的最终报告(或用户指定的章节/修改),报告偏研究报告/白皮书风格,内容详实且以证据驱动,句式尽量稳定、避免口语化、碎片化和过度分点,内容以连续段落/多个段落为主、配合适当的缩进和分点,注意保持逻辑链清晰,标题层级和编号体系合理。 - 2. 在撰写过程中保证章节都**绑定证据**,并且证据覆盖应尽可能全(遵循输入的写作要求,大纲阶段要求覆盖全部证据)。 - 3. 对冲突进行显式记录与处理(使用 report_generator---commit_conflict 工具,并在正文中说明冲突与不确定性)。 - 4. 通过工具调用,把中间产物落地为可追溯文件:大纲、章节元信息、章节内容、冲突记录。 - 5. **在保证质量的前提下尽可能提高效率**,章节写作可以是并行或串行的,取决于章节之间的依赖关系。在并行写作前(即单个响应调用多个工具),先分析大纲中各章节的依赖关系确认是否合理,避免出现逻辑矛盾/依赖缺口等问题。 - - # 工作流 - 你必须参考以下顺序组织写作(除非用户明确要求跳过某步): - ## 阶段1: 生成证据绑定大纲 - - 阅读输入的任务要求,确定报告形态与风格(短答/长报告/技术审阅/对比分析等),调用 report_generator---load_index 加载证据索引。 - - 浏览各证据的 title 和 summary 充分理解涉及的证据范围,调用 report_generator---commit_outline 生成大纲(要求:章节-证据映射清晰、证据覆盖尽量全),注意 Execution_Summary 不应该被作为章节写入报告正文。 - - ## 阶段2: 章节内容写作循环 - 章节写作被定义为一个渐进式的过程,每次写作 1-3 个新的章节直到全部完成,每次写作时: - - 根据大纲、已完成的章节、过去采取的行动和历史思考内容,思考当前需要采取的行动、总结已完成的任务和已获得的结论,并将相应的内容展示在对话内容中,对于行动的选择,你需要遵循以下原则: - - 思考需要同时撰写的章节数量(支持 1-3 个,可以优先选择并行写作,但是不允许一次性完成整篇报告),据此决定后续 report_generator---prepare_chapter_bundle 和 report_generator---commit_chapter 时的并发调用数量。 - - 如果没有需要调整的问题,则调用 report_generator---prepare_chapter_bundle 准备章节元信息,同时该工具支持返回当前章节所有关联证据的详情内容。 - - 如果发现当前撰写的章节和之前的章节、证据等信息存在冲突,立即调用 report_generator---commit_conflict 记录冲突,不要等到最后才记录。 - - 如果在尝试撰写的过程中发现当前大纲存在问题,允许调用 report_generator---oupdate_outline 更新大纲。 - - 基于返回的证据内容和规划的写作大纲,评估证据的质量和相关性,并重新筛选、排序已有的证据,随后调用 report_generator---commit_chapter 撰写章节内容、并同时写入重排后的证据列表。 - 停止条件如下,满足其中一个则停止: - - 所有章节都已撰写完成;或 - - 出于合理的原因,你认为当前任务已经无法继续进行。 - - ## 阶段3: 整合最终报告 - - 调用 report_generator---assemble_draft 汇总所有章节内容,获取最终报告的草稿初版。 - - 阅读草稿初版,反思章节之间的逻辑一致性、全文内容连贯性、过去发现的冲突是否已经得到解决或说明等问题,如果发现需要补充的新冲突,调用 report_generator---commit_conflict 记录冲突,并尝试给出解决方案。 - - 基于反思结果和记录的冲突,重新撰写/整合最终的 markdown 报告内容并以 JSON 形式返回给用户,注意**不得调用工具写入/存储最终报告内容,必须直接在对话内容中返回给用户**: - - 要求在 Report 字段内记录最终的 markdown 报告正文,该报告将会被交付给用户,报告主体的格式、风格等信息需要遵循用户输入时要求的规范,报告内容必须带有对参考来源的引用; - - 要求在 Execution_Summary 字段内记录报告生成情况、证据覆盖情况、冲突信息总结等需要向用户说明的内容; - - 要求在 Artifacts 字段内记录中间的文件产物路径,注意你最后输出的对话内容会被系统自动存储为 reports 目录下的 report.md 和 report.json 文件,请在 Artifacts 字段中记录。 - - # 证据使用与重排规则 - 对候选证据进行排序筛选时可以参考以下维度,但不仅限于这些维度: - - 相关性:与本章目标/论断的直接相关程度。 - - 来源质量分层(示例):官方文档/论文/标准 > 一手公告/新闻 > 二手博客/转载。 - - 时效性:与问题时间窗口匹配;若存在新旧冲突,优先解释“为什么会不同”。 - - 一致性:多来源交叉验证程度;若不一致,转入冲突处理流程。 - - 可引用性:是否含可直接引用的定义、数据、结论、图表、方法细节。 - - # 工具调用协议 - - 请不要试图使用任何没有提供的工具,你工作在具备完整读写权限但是与外部隔离的文件系统中,在进行文件级别的操作时,请保持使用相对路径。 - - 你必须尽可能的基于 report_generator server 下的工具来组织写作流程,不能使用其他工具服务(比如证据工具、文件系统)来写入报告内容,也不能只在对话中维护你的写作内容。 - - 你必须基于 evidence_store server 下的工具进行证据的详情内容查询、获取索引、获取内容列表等操作。 - - **你可以在单个响应中调用多个工具。**当需要获取多个独立的信息或者需要进行多个独立的操作时(比如读取多个证据、写入多个章节等),可以优先将工具调用批量处理,以获得最佳性能。 - - **并发调用示例**:假设章节2、3、4可以并行写作,你应该在**一次响应**中同时调用3个 report_generator---prepare_chapter_bundle 工具,收到3个工具的返回结果后,再在**一次响应**中同时调用3个 report_generator---commit_chapter 工具。这样只需2轮对话完成3个章节。**错误做法**是每次响应只调用1个工具,需要6轮对话。 - - # 硬性约束 - - 证据优先:永远不要伪造引用或来源,最终交付物中的所有事实性陈述必须有证据支持。 - - 禁止幻觉补全:不得凭常识“补齐”未知的具体数据、日期、定义、结论来源。缺证据就写“不足/未知/待验证”,可以尝试触发 report_generator---load_chunk 工具(如果提供的话)或调用 report_generator---commit_conflict 记录冲突/缺口。 - - 注意当前时间:你具备的知识可能已经过时,不要试图应用已经过时的知识。始终跟踪时间信息(发布日期 / 更新日期),在可见时必须记录。 - - 不做大规模外部检索:你没有网络搜索权限;若证据缺失,只能尝试重排候选证据、使用report_generator---load_chunk 拉取原文细节、在报告中说明证据不足或影响等措施。 - - 覆盖要求:在大纲生成阶段,大纲章节与证据必须建立映射关系;除非用户指示“可忽略的噪声证据”,否则默认尽量全覆盖。 - - 冲突显式化:发现多个来源的证据不一致、上下文逻辑矛盾、数据源存在异常等情况时,必须及时调用 report_generator---commit_conflict 记录冲突证据,并给出解决方案。 - - # 报告引用格式(强制) - - 目标:正文阅读体验“干净”,引用可点击且可追溯,符合学术写作规范。 - - 你必须在正文里标注引用位置;不能只在文末列出来源。 - - 正文**只允许**使用简短编号引用标记:`[1]`、`[2]`、`[3]` ……(可在同一句并列多个:`...[1][3][7]`)。 - - 严禁在正文中使用带长标题的 Markdown 链接:不要写 `...[来源标题](URL)`,因为渲染后会把“来源标题”露在正文里,影响观感。 - - 编号引用必须尽量靠近对应的事实/数据/结论句末。 - - 你必须在报告末尾提供统一的来源区块:`## References`(英文报告) 或 `## 参考文献`(中文报告),以下统一使用 References 指代这个区块。 - - References 以编号列表呈现(1., 2., 3. ...),每条包含:标题(或可识别的来源名) + 机构/发布方(如可得) + 发布日期(如可得) + URL(必须可点击)。 - - References 中的编号必须与正文编号一一对应:正文引用到的每个编号必须在 References 中出现;References 中列出的每条也必须至少在正文被引用一次。 - - 同一 URL 只能分配一个编号;全文保持编号一致,避免同源重复编号。 - - 编号分配规则:按“来源首次在正文出现”的顺序从 1 开始递增;同一来源在不同位置复用同一编号。 - - 可点击性要求(两种任选其一,但必须全篇一致): - - **推荐**:使用 Markdown “参考式链接”让正文 `[1]` 直接可点:正文写 `[1]`,并在文末定义 `[1]: https://...`(这些定义可紧跟在 `## References` 之后或文末)。 - - 或:正文写 `[1](https://...)`(仅显示数字 1),References 仍需给出完整条目。 - - # 默认报告风格 - - 结构清晰:先结论后证据,标题层级明确,但无需过度分点导致内容严重碎片化,可以使用连续的段落/多个段落和适当的分点/缩进来组织内容。 - - 技术/研究报告口吻:审慎、可验证、避免过度口语化,内容详实且丰富、证据充分。 - - 倾向于使用清晰的标题层级和编号体系,比如 `# 2. 背景与问题`、`## 2.1 背景`、`### 2.1.1 方向一`等,不要超过三级。 - - # 输出格式 - 最终返回 JSON 格式的总结: - { - "Report": "...", - "Execution_Summary": "...", - "Artifacts": ["path/to/artifact_1", "path/to/artifact_2", ...], - } - + root: prompts/ + agent: reporter + lang: en + family: gpt5 tools: file_system: @@ -128,6 +42,8 @@ tools: - load_index - get_note - list_notes + - get_analysis + - list_analyses report_generator: mcp: false reports_dir: reports diff --git a/projects/deep_research/v2/researcher.yaml b/projects/deep_research/v2/researcher.yaml index a5e428ab0..218b9204f 100644 --- a/projects/deep_research/v2/researcher.yaml +++ b/projects/deep_research/v2/researcher.yaml @@ -1,6 +1,6 @@ llm: service: openai - model: qwen3-max + model: gpt-5-2025-08-07 openai_api_key: openai_base_url: @@ -10,7 +10,7 @@ generation_config: stream_options: include_usage: true # Enable explicit prefix caching (auto-detects provider from openai_base_url) - force_prefix_cache: true + force_prefix_cache: false # Supports role names: system, user, assistant, tool, last_message prefix_cache_roles: [system, user, assistant, tool] # extra_body: @@ -21,73 +21,10 @@ tag: deep-research-researcher prompt: - system: | - 你是 Researcher,主要负责深度研究任务的工作流编排,通过调用不同的子代理(sub agent)和工具完成任务。请基于下列指令和提供的工具帮助用户完成研究任务。 - 时间提醒:当前日期为,当前时间为。 - 研究轮次:即深度搜索阶段的循环次数(不含报告生成阶段),当用户没有显式指定最大研究轮次时,默认最大研究轮次为6轮。 - 行动规范:在输出最终的 JSON 结果前,每一轮行动都必须调用工具;你需要在对话中输出对应的思考过程、行动意图以及其他可以向用户说明当前任务状态的提示。如果在后续工作流中给出了某些阶段建议的输出格式,请优先遵循该格式。 - - # 主要职责 - 1. 意图识别与任务规划:将用户的请求转换为一个可执行的研究计划,以 TODO 列表(plan.json)的形式存储,并通过额外生成验证清单(checklist.yaml)来进行自我反思。 - 2. 任务编排与调度:根据任务的难度、用户的意图进行可用子代理和可用工具的编排,把控不同情况的处理逻辑(简答 vs 专业报告 vs 随意对话,用户无要求的默认情况下输出专业报告)。 - 3. 深度搜索与证据收集:缺乏证据时,通过向 Searcher 子代理(即 agent_tools---searcher_tool)委派任务实现迭代式的研究循环(并发时允许同时调用 2-4 个子代理,当任务可并行时优先并行调用)。 - 4. 报告生成与质量验收:研究充分时,通过向 Reporter 子代理(即 agent_tools---reporter_tool)委派任务完成调研报告生成,随后由你进行验证、纠错和修改润色,确保报告质量符合要求,最后交付最终报告给用户。 - - # 工作流 - ## 阶段1:任务规划 - - 解析用户意图,分析用户的对话目的、需求背景、期望产出等核心诉求,使用 todo_list---todo_write 和 file_system---write_file 分别生成 TDDO 列表和对应的验证清单checklist.yaml(可以同时调用两个工具)。 - - TODO 列表需要涵盖所有需要完成的子任务,包括调研环节、报告生成环节、验证环节等等,用于向用户明确你的完整规划,无需显式提出你会使用什么工具,只需给出具体的任务本身。 - - TODO 列表中的任务需要尽可能明确、清晰、原子化并服务于解决核心问题,主题聚焦在一个较细粒度且具体的范围里,每个任务需要回答的问题不超过3个,避免 Searcher 子代理难以理解或者执行时间过长,但注意避免过度拆分导致执行链路过长。 - - 主动对比 TODO 列表和验证清单进行反思。如果发现当前的 TODO 列表存在问题,则修复 TODO 列表中的潜在问题,否则可以跳过这一步。 - - 如有必要,最多允许调用一次 Searcher 子代理以进行概念澄清; - - 如果需要变更 TODO 列表,必须通过 todo_list---todo_write 工具进行更新。 - - ## 阶段2:深度搜索 - 循环执行以下环节直到满足停止条件: - - 根据当前 TODO 列表中的任务执行状况,选择 2-4 个(无法并行时选择1个)可以并行、尚未完成的任务交给 Searcher 子代理进行深度调研(可以并发时优先并发调用),并提供详细、清晰的任务说明。 - - 针对 Searcher 子代理的返回结果进行反思: - - 总结阶段性发现,明确当前已经完成收集和维护的证据,识别存在的冲突和证据缺口,及时向用户展示你的思考和计划; - - 你必须同时更新 TODO 列表中的任务状态('pending'/'in_progress'/'completed'/'cancelled'),如果发现 TODO 列表存在问题并且希望修改,则可以调用 todo_list---todo_write 工具进行更新。 - 停止条件如下,满足其中一个则停止: - - TODO 列表中关于调研环节的子任务已经全部完成;或 - - 已经完成了核心任务的证据收集,证据覆盖充分且一致;或 - - 进一步搜索的边际效益很低;或 - - 已经达到了最大研究轮次。 - - ## 阶段3:报告生成 - - 调用 Reporter 子代理进行报告生成,向 Reporter 子代理提供完整的报告主题、目标受众、背景说明、任务说明、写作要求、章节约束和其他必要的信息,该代理将会自动执行面向研究报告写作的完整工作流(生成大纲->证据绑定->逐章写作->汇总草稿->整合最终报告)。 - - 注意:不要在用户没有显式要求的情况下向 Reporter 子代理提出字数要求;不要要求 Reporter 子代理在最终报告中加入执行摘要,这类内容应该出现在对话内容的其他部分中而非报告正文中。 - - Reporter 子代理交付报告后,你需要进行最后的检查、纠错和润色来确保报告的质量和准确性,检查结论可以展示在对话内容中,修改后的报告正文要求使用 file_system---write_file 直接写入文件 final_report.md。 - - # 流程约束 - 1. 在阶段1的任务规划过程中,如有必要,你最多可以调用一次 Searcher 子代理进行搜索用于概念澄清。 - 2. 所有计划状态变更必须借助 TODO 列表的变更来体现,即必须调用 todo_list---todo_write 工具进行创建/更新,不能只保存在对话文本里。 - 3. 你不能亲自执行长链路的网络研究和报告生成等子任务。所有大规模的检索任务和报告撰写任务必须委派给子代理执行。你负责进行调研结论汇总、进度把控和最终版本报告的验证、修改和写入。 - 4. 当证据不足或者存在前后结论冲突时,你必须显式声明这种不确定性和困惑,并主动反思,尝试通过已有的工具(包括子代理)进行解决。 - 5. 当进一步调研的边际效益很低时或者达到最大研究轮次时,你必须主动停下并开始尝试生成报告。 - 6. 避免冗余的工具调用,例如在 todo_list---todo_write 操作后工具返回的信息通常会带有更新后的 TODO 列表,你无需再次调用 todo_list---todo_read 读取 TODO 列表。 - - # 工具调用协议 - - 你必须基于 todo_list server 下的工具进行创建/更新/读取 TODO 列表,不能使用其他服务/工具维护 TODO 列表,也不能只在对话中维护 TODO 列表。 - - 你必须基于 agent_tools server 下的工具调用 Searcher 子代理和 Reporter 子代理,注意不允许调用不存在的子代理工具,必须遵循该工具的输入要求。 - - 在上下文不明晰(比如 Searcher 子代理返回的内容疑似丢失细节、Reporter 子代理生成的报告存在问题等情况)时,你可以基于 evidence_store server 下的工具进行证据的读取、筛选和加载,确保你在修改报告或执行下一步计划前有充足的把握。 - - 你可以在单个响应中同时调用多个工具。例如当需要获取多个独立的信息或者需要进行多个独立的操作时,可以尝试将工具调用批量并行处理,以获得最佳性能。 - - 在进行文件级别的操作时,请保持使用相对路径。 - - # 质量约束 - - TODO 列表的维护贯穿着你的整个研究过程,你需要保持对 TODO 列表的持续关注,确保任务可以在意外情况下顺利完成。 - - 永远不要伪造引用或来源,最终交付物中的所有事实性陈述必须有 Searcher 子代理提供的调研结论和存储的证据作为支撑。 - - 明确时间限制和当前日期,如果你试图应用的知识已经过时,请不要相信你的记忆,而是通过工具进行查询。 - - 严格控制范围:如果用户要求的是 X,不要漂移到 Y。 - - 最终生成的报告必须保留完整的引用关系,不得因为修改、润色而丢失原本的引用,必须保留 reporter 子代理返回的报告中带有的引用格式。 - - 如 Reporter 子代理返回的报告出现引用缺失等问题,注意遵循下列规则:正文仅使用编号引用 `[1]`、`[2]`…(可并列如`[1][3]`),报告末尾必须包含 `## References`(英文报告) 或 `## 参考文献`(中文报告) 并保持编号映射一致;润色时禁止把长标题链接(如`[标题](URL)`)写回正文。 - - 报告需要符合专业研究员的写作风格,避免使用过于口语化、非正式的表达方式,避免过度碎片化、内容过于单薄,保证内容详实、准确、逻辑严谨。 - - # 意外处理 - 1. 你可能会遇到因为网络问题、安全问题、权限问题等各种非预期的原因导致的工具调用失败,你需要优先通过合理的重试策略和错误处理逻辑确保任务的顺利完成。 - 2. 如果遇到用户试图让你执行超出你能力范围的任务,你必须显式声明潜在的风险,并尝试组合现有的工具和能力来给出可能的解决方案。 - 2. 如果用户要求你给出简明回答而非完整报告,你可以跳过阶段3直接给出结论。 - 3. 如果用户试图进行非研究任务的闲聊,你无需启动研究流程,可以给出正常回复并且试图引导用户发起研究任务。 + root: prompts/ + agent: researcher + lang: en + family: gpt5 tools: @@ -97,6 +34,12 @@ tools: - write_file - read_file - list_files + code_executor: + mcp: false + implementation: python_env + notebook_timeout: 120 + include: + - notebook_executor todo_list: mcp: false auto_render_md: true @@ -110,6 +53,9 @@ tools: - load_index - get_note - list_notes + - write_analysis + - get_analysis + - list_analyses agent_tools: mcp: false enable_stats: true @@ -118,9 +64,9 @@ tools: definitions: - tool_name: searcher_tool description: > - 调用 Searcher 子代理执行特定主题下的深度调研任务。 - Searcher 具备自主执行研究循环直到收集到充足证据、给出调研报告的能力(搜索->解析->证据发现与存储->递进搜索->...)。 - 返回内容为包含任务完成情况、核心发现、遇到的问题或限制、研究报告正文、相关证据存储位置等信息的 JSON 结果。 + Invoke the Searcher sub-agent to perform an in-depth research task on a specific topic. + Searcher is capable of autonomously executing a research loop until sufficient evidence is collected and a research report is produced (search -> parse -> evidence discovery & storage -> progressive search -> ...). + Returns a JSON result containing: task completion status, core findings, issues or limitations encountered, research report body, and evidence storage locations. config_path: searcher.yaml parameters: type: object @@ -128,21 +74,21 @@ tools: request: type: string description: > - JSON 格式的调研任务描述,应包含: - - TODO 列表中对应的任务ID(必填) - - 具体的调研目标 - - 需要回答的问题 - - 约束条件(时间范围、来源偏好等,可选) - - 停止条件(可选) - - 其他要求(可选) - 建议的格式为: + A JSON-formatted research task description that should include: + - The corresponding task ID from the TODO list (required) + - Specific research objectives + - Questions to be answered + - Constraints (time range, source preferences, etc., optional) + - Stopping conditions (optional) + - Other requirements (optional) + Recommended format: { "task_id": "...", - "调研目标": "...", - "需要回答的问题": "...", - "约束条件": "...", - "停止条件": "...", - "其他要求": "...", + "research_objectives": "...", + "questions_to_answer": "...", + "constraints": "...", + "stopping_conditions": "...", + "other_requirements": "...", } required: [request] additionalProperties: false @@ -151,9 +97,9 @@ tools: max_output_chars: 200000 - tool_name: reporter_tool description: > - 调用 Reporter 子代理基于已收集的证据生成报告。 - Reporter 会读取已经存储的证据卡片,执行面向研究报告写作的复杂工作流。 - 返回内容为包含报告正文、执行总结、中间文件产物路径等信息的 JSON 结果。 + Invoke the Reporter sub-agent to generate a report based on collected evidence. + Reporter reads the stored evidence cards and executes a complex workflow for research report writing. + Returns a JSON result containing: report body, execution summary, and intermediate artifact file paths. config_path: reporter.yaml parameters: type: object @@ -161,19 +107,19 @@ tools: request: type: string description: > - JSON 格式的报告生成指令,应包含: - - 报告主题和目标受众 - - 完整的背景说明和任务说明 - - 需要覆盖的核心问题 - - 写作要求(风格、结构、长度等) - - 任何其他要求 - 建议的格式为: + A JSON-formatted report generation instruction that should include: + - Report topic and target audience + - Complete background description and task description + - Core questions to be covered + - Writing requirements (style, structure, length, language, etc.) + - Any other requirements + Recommended format: { - "报告主题和目标受众": "...", - "背景说明": "...", - "任务说明": "...", - "写作要求": "...", - "其他要求": "...", + "report_topic_and_audience": "...", + "background": "...", + "task_description": "...", + "writing_requirements": "...", + "other_requirements": "...", } required: [request] additionalProperties: false @@ -184,12 +130,24 @@ tools: - tools/evidence_tool.py +callbacks: + - callbacks/researcher_callback + +# Self-reflection checks before allowing the researcher to stop. +# Runs inside ResearcherCallback.after_tool_call. +self_reflection: + enabled: true + max_retries: 2 + quality_check: + enabled: true + model: qwen3.5-flash + handler: time_handler code_file: researcher -max_chat_round: 30 +max_chat_round: 40 -tool_call_timeout: 1200 +tool_call_timeout: 1800 output_dir: ./output diff --git a/projects/deep_research/v2/searcher.yaml b/projects/deep_research/v2/searcher.yaml index c37db5ebf..b9b19f08b 100644 --- a/projects/deep_research/v2/searcher.yaml +++ b/projects/deep_research/v2/searcher.yaml @@ -1,6 +1,6 @@ llm: service: openai - model: qwen-plus + model: qwen3.5-plus openai_api_key: openai_base_url: @@ -13,83 +13,18 @@ generation_config: force_prefix_cache: true # Supports role names: system, user, assistant, tool, last_message prefix_cache_roles: [system, user, assistant, tool] - # extra_body: - # enable_thinking: false + extra_body: + enable_thinking: false tag: deep-research prompt: - system: | - 你是 Searcher,负责处理多个领域的深度调研任务,你需要通过持续的网络检索和证据收集推进调研任务,并最终交付专业的研究报告给用户。请基于下列指令和提供的工具帮助用户完成研究任务。 - 时间提醒:今日日期:,当前时间:。 - 最大搜索轮次:当用户没有显式指定最大搜索轮次时,默认最大搜索轮次为4轮;一个搜索轮次通常不超过 4 次对话推进(例如 assistant->tool 或 user->assistant->tool 视为一次对话),推荐通过单轮对话并发调用工具完成任务。 - 行动规范:在输出最终的 JSON 结果前,每一轮行动都必须调用工具;建议你在每轮对话中输出结构化的进度说明,可以包含进度摘要、思考过程、本轮行动与目的、风险与缺口以及其他可以向用户说明当前任务状态的提示。如果在后续工作流中给出了某些阶段建议的输出格式,请优先遵循该格式。 - - # 主要职责 - 你会接收来自用户的调研任务说明,并负责通过迭代式的搜索循环过程完成该任务: - 1. 网络搜索:当拥有的信息不足以完成任务时,主动反思当前存在的证据缺口,构造合理的查询语句并调用搜索工具从而获取更多证据信息,在获得足够证据、连续搜索无有效信息或者达到最大搜索轮次时及时停止搜索; - 2. 证据收集:对于每个有价值的发现,使用证据维护工具服务 evidence_store server 下的工具,将信息详细地写入结构化的证据卡片,确保证据的完整性(不丢失重要细节)和准确性(不加入主观揣测); - 3. 结果汇总:调研结果为包含任务完成情况、核心发现、遇到的问题或限制、相关证据存储位置和完整研究报告的 JSON 结果。 - 注意效率与质量的平衡: - - 在保证质量的前提下尽可能提高效率,如果**通过合理的搜索方案可以减少搜索轮次**,或者在搜索过程中提早发现证据收集完毕,可以适当降低轮次。 - - 如果有多个要存储的证据卡片,尽可能将写入操作批量处理(并发调用),减少多轮逐个调用的开销。 - - # 工作流 - ## 阶段1: 任务分析与规划 - - 分析用户意图,将调研任务说明转化为可执行的调研计划,包含需要解决的子问题和合理的验收标准,并将计划写入文件 search_plan___.md 中。 - - 为用户提供的任务ID,该参数必须与用户提供的 task_id(或者中文叫做任务ID) 参数一致; 为你根据用户意图生成的任务名称; 为当前时间,格式为 HH-MM-SS(24小时制)。 - - ## 阶段2: 循环搜索与证据收集 - - 循环执行以下环节直到满足停止条件: - - 根据初始搜索计划和截止当前轮次的调研结论,构造查询语句并执行网络搜索,可以遵循先宽后窄的搜索策略,逐步缩小搜索范围; - - 阅读返回的内容并分析内容是否能够为调研任务提供支撑材料,对每个有价值的发现需要立即使用工具提取结构化证据卡片,并使用 evidence_store---write_note 存储到本地,在对话内容中给出结构化的进度总结,包括: - - 核心发现:当前轮次搜索的核心发现、具有存储价值的证据、和已有信息之间的关系 - - 调研进度:当前调研阶段的总结、当前整个证据库不完整的地方、证据矛盾之处 - - 下一步计划:下一步的计划和打算解决的问题 - - 停止条件如下,满足其中一个则停止: - - 满足阶段1制订的调研计划;或 - - 已经完成了核心任务的证据收集,证据覆盖充分且一致,对于不重要的部分进行忽略并说明原因;或 - - 进一步搜索的边际收益很低;或 - - 达到用户显式指定或者默认的最大搜索轮次;或 - - 出于合理的原因,你认为当前任务已经无法继续进行(比如发现用户给的调研任务不合理或者无法完成)。 - - ## 阶段3: 调研结果汇总 - - 详细总结本次调研结果,以严格的 JSON 格式在对话内容中返回,包括: - - 任务完成情况 - - 核心发现 - - 遇到的问题与限制 - - 相关证据存储位置 - - 研究报告正文 - - # 工具调用协议 - - 请不要试图使用任何你没有见到的工具,你工作在开放的网络环境和具备完整读写权限但是与外部隔离的文件系统中,在进行文件级别的操作时,请保持使用相对路径。 - - 搜索工具服务 web_search 中提供 exa_search、arxiv_search、serpapi_search(默认google) 三种搜索工具中的一种或多种,你需要根据场景选择使用,注意服务可能不提供其中所有工具。 - - 搜索工具的最大搜索结果参数 num_results 的默认值为 5,建议你从合适的数值开始尝试,适当避免一次性阅读太多内容。如果在有限的搜索轮次内难以完成任务,可以尝试单次并发多个搜索或者适当增大 num_results(优先并发再考虑增大数值)。 - - 你必须基于证据工具服务 evidence_store server 下的工具进行证据的存储、查看、搜索、删除、加载索引等操作,不能使用其他工具服务(比如文件系统),也不能只在对话中维护证据。 - - 总结并写入证据时,你必须保持证据的完整性和准确性,尽可能多的将有价值的原文信息写入到证据卡片中,尽可能多的保留对结论有重要支撑作用的数据、表格、代码、观点等内容的完整信息,不丢失有价值的细节。 - - 单次搜索通常返回多个结果,你可以在充分阅读后写入一个或同时写入多个证据卡片,如果合并会丢失有价值的内容,则优先同时写入多个证据卡片。 - - 你可以在单个响应中同时调用多个工具。例如当需要获取多个独立的信息或者需要进行多个独立的操作时,可以尝试将工具调用批量并行处理,以获得最佳性能。 - - # 硬性约束 - - 禁止幻觉(编造),永远不要伪造引用或来源。如果你找不到可靠证据,必须向用户说明并停止。 - - 使用 evidence_store---write_note 工具时,必须提供 task_id 参数,用于关联证据和用户任务, 该参数必须与用户提供的 task_id(如果是中文则叫做任务ID) 参数一致。 - - 注意当前时间,你具备的知识可能已经过时,不要试图应用已经过时的知识。始终跟踪时间信息(发布日期 / 更新日期),在可见时必须记录。 - - 你必须主动进行搜索->阅读收集证据->调整搜索方向->搜索的循环,直到满足停止条件,不要试图跳过中间环节。 - - 严格控制范围:如果用户要求的是 X,不要漂移到 Y。 - - 优先级排序建议(非强制):官方文档/标准/论文 > 一手公告/新闻 > 二手博客/论坛。 - - # 输出格式 - 最终返回 JSON 格式的总结: - { - "status": "任务完成情况标识(completed|partial|failed)", - "task_summary": "任务完成情况概述", - "findings": ["本次调研的核心发现1", "本次调研的核心发现2"], - "issues": ["本次调研遇到的问题或限制"], - "note_ids": ["note_id_1", "note_id_2", ...(全部存储的证据卡片ID)], - "report": "本次调研的研究报告正文,要求详细、准确、严谨的整理调研结果,不得有任何主观揣测或推测,遵循规范的学术写作风格" - } + root: prompts/ + agent: searcher + lang: en + family: gpt5 tools: @@ -112,7 +47,7 @@ tools: _max_concurrent_fetch: 5 enable_chunking: false enable_summarization: true - summarizer_model: qwen-flash + summarizer_model: qwen3.5-flash summarizer_base_url: summarizer_api_key: max_content_chars: 200000 diff --git a/projects/deep_research/v2/tools/evidence_tool.py b/projects/deep_research/v2/tools/evidence_tool.py index 892c51a0e..064ab6a87 100644 --- a/projects/deep_research/v2/tools/evidence_tool.py +++ b/projects/deep_research/v2/tools/evidence_tool.py @@ -43,6 +43,11 @@ def _generate_note_id() -> str: return uuid.uuid4().hex[:6] +def _generate_analysis_id() -> str: + """Generate a short unique ID for an analysis.""" + return uuid.uuid4().hex[:6] + + def _sanitize_filename(name: str) -> str: """Sanitize a string for use as a filename.""" return re.sub(r'[^\w\-]', '_', name)[:64] @@ -64,9 +69,8 @@ def _render_note_card(note: Dict[str, Any]) -> str: "note_id": "abc123", "task_id": "task_1", # optional, links to plan task "title": "Key finding about X", - "claim": "The main claim or observation", - "supports": "Evidence text supporting the claim...", - "contradicts": "Evidence text contradicting the claim...", # optional + "content": "Detailed evidence text including findings, data, quotes...", + "contradicts": "Evidence text contradicting the finding...", # optional "sources": [ {"url": "...", "published_at": "...", "source_tier": "primary"} ], @@ -95,16 +99,10 @@ def _render_note_card(note: Dict[str, Any]) -> str: lines.append(f"- **Created**: {note.get('created_at', '')}") lines.append('') - # Claim - if note.get('claim'): - lines.append('## Claim') - lines.append(note['claim']) - lines.append('') - - # Supporting evidence - if note.get('supports'): - lines.append('## Supporting Evidence') - lines.append(note['supports']) + # Content (evidence body) + if note.get('content'): + lines.append('## Content') + lines.append(note['content']) lines.append('') # Contradicting evidence @@ -134,6 +132,108 @@ def _render_note_card(note: Dict[str, Any]) -> str: return '\n'.join(lines) +def _render_analysis_card(analysis: Dict[str, Any]) -> str: + """ + Render an analysis card as Markdown. + + Analysis structure: + { + "analysis_id": "abc123", + "task_id": "task_1", # optional + "title": "Interim analysis: ...", + "summary": "One-sentence summary", # optional + "content": "Markdown content", # required + "based_on_note_ids": ["cd9818", "1c108f"], # optional + "tags": ["tag1", "tag2"], + "quality_score": 85 (0-100), # optional + "created_at": "2025-01-19T10:00:00" + } + """ + lines: List[str] = [] + + # Header + lines.append(f"# {analysis.get('title', 'Untitled Analysis')}") + lines.append('') + + # Metadata + lines.append('## Metadata') + lines.append(f"- **Analysis ID**: `{analysis.get('analysis_id', '')}`") + if analysis.get('task_id'): + lines.append(f"- **Task ID**: `{analysis['task_id']}`") + if analysis.get('based_on_note_ids'): + ids_str = ', '.join(f'`{nid}`' + for nid in analysis.get('based_on_note_ids', [])) + lines.append(f'- **Based on Notes**: {ids_str}') + if analysis.get('tags'): + tags_str = ', '.join(f'`{t}`' for t in analysis['tags']) + lines.append(f'- **Tags**: {tags_str}') + if analysis.get('quality_score') is not None: + lines.append(f"- **Quality Score**: {analysis['quality_score']}/100") + lines.append(f"- **Created**: {analysis.get('created_at', '')}") + lines.append('') + + # Summary + if analysis.get('summary'): + lines.append('## Summary') + lines.append(analysis['summary']) + lines.append('') + + # Content + if analysis.get('content'): + lines.append('## Content') + lines.append(analysis['content']) + lines.append('') + + return '\n'.join(lines) + + +def _parse_analysis_from_md(content: str, analysis_id: str) -> Dict[str, Any]: + """ + Parse an analysis card from Markdown back to dict. + Best-effort parser for re-reading stored analyses. + """ + analysis: Dict[str, Any] = {'analysis_id': analysis_id} + + title_match = re.search(r'^# (.+)$', content, re.MULTILINE) + if title_match: + analysis['title'] = title_match.group(1).strip() + + sections = re.split(r'^## ', content, flags=re.MULTILINE) + for section in sections[1:]: + lines = section.strip().split('\n', 1) + if not lines: + continue + header = lines[0].strip() + body = lines[1].strip() if len(lines) > 1 else '' + + if header == 'Content': + analysis['content'] = body + elif header == 'Summary': + analysis['summary'] = body + elif header == 'Metadata': + for line in body.split('\n'): + if '**Task ID**' in line: + match = re.search(r'`([^`]+)`', line) + if match: + analysis['task_id'] = match.group(1) + elif '**Tags**' in line: + tags = re.findall(r'`([^`]+)`', line) + analysis['tags'] = tags + elif '**Based on Notes**' in line: + ids = re.findall(r'`([^`]+)`', line) + analysis['based_on_note_ids'] = ids + elif '**Quality Score**' in line: + match = re.search(r'(\d+)/100', line) + if match: + analysis['quality_score'] = int(match.group(1)) + elif '**Created**' in line: + match = re.search(r'\*\*Created\*\*: (.+)$', line) + if match: + analysis['created_at'] = match.group(1).strip() + + return analysis + + def _parse_note_from_md(content: str, note_id: str) -> Dict[str, Any]: """ Parse a note card from Markdown back to dict. @@ -155,10 +255,11 @@ def _parse_note_from_md(content: str, note_id: str) -> Dict[str, Any]: header = lines[0].strip() body = lines[1].strip() if len(lines) > 1 else '' - if header == 'Claim': - note['claim'] = body - elif header == 'Supporting Evidence': - note['supports'] = body + if header == 'Content': + note['content'] = body + elif header in ('Claim', 'Supporting Evidence'): + # Backward compat: merge legacy Claim/Supporting Evidence into content + note['content'] = (note.get('content', '') + '\n\n' + body).strip() elif header == 'Contradicting Evidence': note['contradicts'] = body elif header == 'Summary': @@ -208,6 +309,7 @@ class EvidenceTool(ToolBase): Storage: - evidence/index.json: Global index for fast lookups - evidence/notes/note_{id}.md: Individual evidence cards + - evidence/analyses/analysis_{id}.md: Interim analysis / synthesis / comparison / decision records - chunks/: Reserved for future chunk storage """ @@ -235,6 +337,11 @@ async def connect(self) -> None: """Initialize directory structure.""" _ensure_dir(self.output_dir) _ensure_dir(os.path.join(self.output_dir, self._evidence_dir, 'notes')) + _ensure_dir( + os.path.join(self.output_dir, self._evidence_dir, 'analyses')) + # Backward-compat: older runs may have used evidence/conclusions/ + _ensure_dir( + os.path.join(self.output_dir, self._evidence_dir, 'conclusions')) _ensure_dir(os.path.join(self.output_dir, self._chunks_dir)) _ensure_dir(os.path.join(self.output_dir, self._lock_subdir)) @@ -244,6 +351,10 @@ def _paths(self) -> Dict[str, str]: os.path.join(self.output_dir, self._evidence_dir, 'index.json'), 'notes_dir': os.path.join(self.output_dir, self._evidence_dir, 'notes'), + 'analyses_dir': + os.path.join(self.output_dir, self._evidence_dir, 'analyses'), + 'legacy_conclusions_dir': + os.path.join(self.output_dir, self._evidence_dir, 'conclusions'), 'chunks_dir': os.path.join(self.output_dir, self._chunks_dir), 'lock_dir': @@ -270,28 +381,22 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'description': 'Brief title describing this evidence (e.g., "Tesla Q3 revenue growth").', }, - 'claim': { + 'content': { 'type': 'string', 'description': - 'The main claim or observation this note captures. ' - 'It should be as detailed and comprehensive as possible.', - }, - 'supports': { - 'type': - 'string', - 'description': - ('Evidence text that supports this claim. ' - 'This should be a detailed and comprehensive description ' - 'of the evidence that supports the claim.' - 'Can include quotes, data, or reasoning. Multi-paragraph allowed.' - ), + ('The full evidence text for this note. ' + 'State the core finding or observation, then provide all ' + 'supporting details: specific data points, statistics, quotes, ' + 'case studies, reasoning, and any other substantive information. ' + 'Be thorough — preserve all valuable details from the source material. ' + 'Multi-paragraph allowed.'), }, 'contradicts': { 'type': 'string', 'description': - ('Optional: Evidence text that contradicts this claim. ' + ('Optional: Evidence text that contradicts this finding. ' 'Include if there are conflicting sources or caveats.' ), }, @@ -360,7 +465,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: }, }, 'required': [ - 'title', 'claim', 'supports', 'sources', 'summary', + 'title', 'content', 'sources', 'summary', 'task_id', 'tags' ], 'additionalProperties': @@ -448,6 +553,171 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'additionalProperties': False, }, ), + Tool( + tool_name='write_analysis', + server_name=self.SERVER_NAME, + description= + ('Write an interim **analysis** record to the evidence store. ' + 'Use this tool whenever you need to turn multiple evidence notes into reusable reasoning artifacts, e.g.: ' + '(1) synthesis / interim summaries; ' + '(2) comparisons and trade-off decisions (A vs B, pros/cons, why choose X); ' + '(3) framework building (typologies, evaluation rubrics, scoring criteria, checklists); ' + '(4) mapping & reconciliation (align competing definitions/metrics, resolve conflicts, record assumptions); ' + '(5) scenario framing and uncertainty tracking (what-if branches, key sensitivities/risks, open questions); ' + '(6) rankings/recommendations that require rationale (e.g., pick top 2–3 options and justify). ' + '(7) Structured / visual intermediate artifacts (e.g., mind-map-style hierarchical outlines, and ' + 'text-based flow/relationship diagrams—prefer Mermaid syntax when possible).' + '(8) other intermediate analysis that requires reasoning, justification and recording.' + 'This is **not** the final report; it is an intermediate analysis that should cite supporting evidence via ' + 'based_on_note_ids when possible so downstream writing can reuse it. ' + 'Returns the generated analysis_id.'), + parameters={ + 'type': 'object', + 'properties': { + 'title': { + 'type': + 'string', + 'description': + 'Brief title describing this analysis (e.g., "Interim comparison: Framework A vs B").', + }, + 'content': { + 'type': + 'string', + 'description': + ('The analysis content in Markdown. ' + 'This should capture synthesis/comparison, constraints, assumptions, and reasoning. ' + 'Multi-paragraph allowed.'), + }, + 'summary': { + 'type': + 'string', + 'description': + 'Optional: One-sentence summary of this analysis.', + }, + 'task_id': { + 'type': + 'string', + 'description': + 'Optional: The plan task this analysis relates to.', + }, + 'based_on_note_ids': { + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Optional: List of note_ids this analysis is based on.', + }, + 'tags': { + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Optional: Tags for categorization.', + }, + 'quality_score': { + 'type': + 'integer', + 'minimum': + 0, + 'maximum': + 100, + 'description': + 'Optional: Confidence/quality score (0-100).', + }, + }, + 'required': ['title', 'content', 'summary', 'tags'], + 'additionalProperties': False, + }, + ), + Tool( + tool_name='get_analysis', + server_name=self.SERVER_NAME, + description='Retrieve a specific analysis by its ID.', + parameters={ + 'type': 'object', + 'properties': { + 'analysis_id': { + 'type': + 'string', + 'description': + 'The ID of the analysis to retrieve.', + }, + 'parse_analysis': { + 'type': + 'boolean', + 'description': + 'Optional: Whether to parse stored markdown back to structured dict.', + }, + }, + 'required': ['analysis_id'], + 'additionalProperties': False, + }, + ), + Tool( + tool_name='list_analyses', + server_name=self.SERVER_NAME, + description= + ('List all analyses, optionally filtered by task_id or tags. ' + 'Returns a summary list (not full content).'), + parameters={ + 'type': 'object', + 'properties': { + 'task_id': { + 'type': 'string', + 'description': 'Optional: Filter by task ID.', + }, + 'tags': { + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Optional: Filter by tags (analyses must have ALL specified tags).', + }, + }, + 'required': [], + 'additionalProperties': False, + }, + ), + Tool( + tool_name='search_analyses', + server_name=self.SERVER_NAME, + description= + 'Search analyses by keyword in title, summary, or tags.', + parameters={ + 'type': 'object', + 'properties': { + 'keyword': { + 'type': 'string', + 'description': 'Keyword to search for.', + }, + }, + 'required': ['keyword'], + 'additionalProperties': False, + }, + ), + Tool( + tool_name='delete_analysis', + server_name=self.SERVER_NAME, + description='Delete an analysis by its ID.', + parameters={ + 'type': 'object', + 'properties': { + 'analysis_id': { + 'type': 'string', + 'description': + 'The ID of the analysis to delete.', + }, + }, + 'required': ['analysis_id'], + 'additionalProperties': False, + }, + ), Tool( tool_name='load_index', server_name=self.SERVER_NAME, @@ -472,11 +742,24 @@ def _load_index_locked(self, paths: Dict[str, str]) -> Dict[str, Any]: data = _safe_read_json(paths['index']) if data is None or not isinstance(data, dict): return { - 'schema_version': 1, + 'schema_version': 2, 'updated_at': _now_iso(), 'notes': {}, # note_id -> {title, task_id, summary, sources, tags, quality_score, created_at} + 'analyses': + {}, # analysis_id -> {title, task_id, summary, based_on_note_ids, tags, quality_score, created_at, path} } + # Backward/forward compatible defaults + if 'notes' not in data or not isinstance(data.get('notes'), dict): + data['notes'] = {} + if 'analyses' not in data or not isinstance( + data.get('analyses'), dict): + data['analyses'] = {} + + # Backward-compat: older schema used "conclusions" key. + legacy = data.get('conclusions') + if isinstance(legacy, dict) and legacy and not data.get('analyses'): + data['analyses'] = legacy return data def _save_index_locked(self, paths: Dict[str, str], @@ -499,6 +782,22 @@ def _add_to_index(self, index: Dict[str, Any], note: Dict[str, 'created_at': note.get('created_at', ''), } + def _add_analysis_to_index(self, index: Dict[str, Any], + analysis: Dict[str, Any], + analysis_path: str) -> None: + """Add an analysis' metadata to the index.""" + aid = analysis['analysis_id'] + index['analyses'][aid] = { + 'title': analysis.get('title', ''), + 'task_id': analysis.get('task_id', ''), + 'summary': analysis.get('summary', ''), + 'based_on_note_ids': analysis.get('based_on_note_ids', []), + 'tags': analysis.get('tags', []), + 'quality_score': analysis.get('quality_score'), + 'created_at': analysis.get('created_at', ''), + 'path': os.path.relpath(analysis_path, self.output_dir), + } + def _remove_from_index(self, index: Dict[str, Any], note_id: str) -> bool: """Remove a note from the index. Returns True if found and removed.""" if note_id in index.get('notes', {}): @@ -506,6 +805,14 @@ def _remove_from_index(self, index: Dict[str, Any], note_id: str) -> bool: return True return False + def _remove_analysis_from_index(self, index: Dict[str, Any], + analysis_id: str) -> bool: + """Remove an analysis from the index. Returns True if found and removed.""" + if analysis_id in index.get('analyses', {}): + del index['analyses'][analysis_id] + return True + return False + def _store_chunk(self, chunk_id: str, content: str, metadata: Dict[str, Any]) -> str: """ @@ -543,8 +850,7 @@ def _load_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: async def write_note( self, title: str, - claim: str, - supports: str, + content: str, contradicts: Optional[str] = None, sources: Optional[List[Dict[str, Any]]] = None, summary: Optional[str] = None, @@ -562,8 +868,7 @@ async def write_note( note: Dict[str, Any] = { 'note_id': note_id, 'title': title.strip(), - 'claim': claim.strip(), - 'supports': supports.strip(), + 'content': content.strip(), 'created_at': _now_iso(), } @@ -602,6 +907,211 @@ async def write_note( 'path': os.path.relpath(note_path, self.output_dir), }) + async def write_analysis( + self, + title: str, + content: str, + summary: Optional[str] = None, + task_id: Optional[str] = None, + based_on_note_ids: Optional[List[str]] = None, + tags: Optional[List[str]] = None, + quality_score: Optional[int] = None, + ) -> str: + """Write a new interim analysis.""" + paths = self._paths() + _ensure_dir(paths['analyses_dir']) + _ensure_dir(paths['lock_dir']) + + analysis_id = _generate_analysis_id() + analysis: Dict[str, Any] = { + 'analysis_id': analysis_id, + 'title': title.strip(), + 'content': content.strip(), + 'created_at': _now_iso(), + } + if summary: + analysis['summary'] = summary.strip() + if task_id: + analysis['task_id'] = task_id.strip() + if based_on_note_ids: + analysis['based_on_note_ids'] = [ + nid.strip() for nid in based_on_note_ids if nid.strip() + ] + if tags: + analysis['tags'] = [t.strip() for t in tags if t.strip()] + if quality_score is not None: + analysis['quality_score'] = max(0, min(100, quality_score)) + + analysis_path = os.path.join(paths['analyses_dir'], + f'analysis_{analysis_id}.md') + analysis_content = _render_analysis_card(analysis) + _write_text(analysis_path, analysis_content) + + with file_lock(paths['lock_dir'], 'evidence_index'): + index = self._load_index_locked(paths) + self._add_analysis_to_index(index, analysis, analysis_path) + self._save_index_locked(paths, index) + + return _json_dumps({ + 'status': + 'ok', + 'analysis_id': + analysis_id, + 'path': + os.path.relpath(analysis_path, self.output_dir), + }) + + async def get_analysis(self, + analysis_id: str, + parse_analysis: Optional[bool] = False) -> str: + """Retrieve an analysis by ID.""" + paths = self._paths() + analysis_path = os.path.join(paths['analyses_dir'], + f'analysis_{analysis_id}.md') + legacy_path = os.path.join(paths['legacy_conclusions_dir'], + f'conclusion_{analysis_id}.md') + + if not os.path.exists(analysis_path) and os.path.exists(legacy_path): + analysis_path = legacy_path + + if not os.path.exists(analysis_path): + return _json_dumps({ + 'status': 'error', + 'message': f'Analysis {analysis_id} not found.' + }) + + with open(analysis_path, 'r', encoding='utf-8') as f: + content = f.read() + + if not parse_analysis: + return _json_dumps({'status': 'ok', 'raw_content': content}) + analysis = _parse_analysis_from_md(content, analysis_id) + return _json_dumps({ + 'status': 'ok', + 'analysis_id': analysis_id, + 'analysis': analysis, + 'raw_content': content, + }) + + async def list_analyses(self, + task_id: Optional[str] = None, + tags: Optional[List[str]] = None) -> str: + """List analyses with optional filters.""" + paths = self._paths() + _ensure_dir(paths['lock_dir']) + + with file_lock(paths['lock_dir'], 'evidence_index'): + index = self._load_index_locked(paths) + + analyses_meta = index.get('analyses', {}) + results = [] + for aid, meta in analyses_meta.items(): + if task_id and meta.get('task_id') != task_id: + continue + if tags: + a_tags = set(meta.get('tags', [])) + if not all(t in a_tags for t in tags): + continue + results.append({ + 'analysis_id': + aid, + 'title': + meta.get('title', ''), + 'task_id': + meta.get('task_id', ''), + 'summary': + meta.get('summary', ''), + 'based_on_note_ids': + meta.get('based_on_note_ids', []), + 'tags': + meta.get('tags', []), + 'quality_score': + meta.get('quality_score'), + 'created_at': + meta.get('created_at', ''), + 'path': + meta.get('path', ''), + }) + + results.sort(key=lambda x: x.get('created_at', ''), reverse=True) + return _json_dumps({ + 'status': 'ok', + 'count': len(results), + 'analyses': results, + }) + + async def search_analyses(self, keyword: str) -> str: + """Search analyses by keyword.""" + paths = self._paths() + _ensure_dir(paths['lock_dir']) + + keyword_lower = keyword.lower().strip() + if not keyword_lower: + return _json_dumps({ + 'status': 'error', + 'message': 'Keyword is required.' + }) + + with file_lock(paths['lock_dir'], 'evidence_index'): + index = self._load_index_locked(paths) + + analyses_meta = index.get('analyses', {}) + results = [] + for aid, meta in analyses_meta.items(): + searchable = ' '.join([ + meta.get('title', ''), + meta.get('summary', ''), + ]).lower() + a_tags = meta.get('tags', []) + searchable += ' ' + ' '.join(a_tags).lower() + if keyword_lower in searchable: + results.append({ + 'analysis_id': aid, + 'title': meta.get('title', ''), + 'summary': meta.get('summary', ''), + 'task_id': meta.get('task_id', ''), + 'quality_score': meta.get('quality_score'), + }) + + return _json_dumps({ + 'status': 'ok', + 'keyword': keyword, + 'count': len(results), + 'analyses': results, + }) + + async def delete_analysis(self, analysis_id: str) -> str: + """Delete an analysis by ID.""" + paths = self._paths() + _ensure_dir(paths['lock_dir']) + + analysis_path = os.path.join(paths['analyses_dir'], + f'analysis_{analysis_id}.md') + legacy_path = os.path.join(paths['legacy_conclusions_dir'], + f'conclusion_{analysis_id}.md') + + with file_lock(paths['lock_dir'], 'evidence_index'): + index = self._load_index_locked(paths) + removed = self._remove_analysis_from_index(index, analysis_id) + + if not removed and not os.path.exists( + analysis_path) and not os.path.exists(legacy_path): + return _json_dumps({ + 'status': + 'error', + 'message': + f'Analysis {analysis_id} not found.' + }) + + self._save_index_locked(paths, index) + + if os.path.exists(analysis_path): + os.remove(analysis_path) + if os.path.exists(legacy_path): + os.remove(legacy_path) + + return _json_dumps({'status': 'ok', 'deleted': analysis_id}) + async def get_note(self, note_id: str, parse_note: Optional[bool] = False) -> str: @@ -765,9 +1275,12 @@ async def load_index(self) -> str: index = self._load_index_locked(paths) notes = index.get('notes', {}) + analyses = index.get('analyses', {}) return _json_dumps({ 'status': 'ok', 'updated_at': index.get('updated_at', ''), 'total_notes': len(notes), + 'total_analyses': len(analyses), 'notes': notes, + 'analyses': analyses, }) diff --git a/projects/deep_research/v2/tools/report_tool.py b/projects/deep_research/v2/tools/report_tool.py index 004e60c91..96fd51994 100644 --- a/projects/deep_research/v2/tools/report_tool.py +++ b/projects/deep_research/v2/tools/report_tool.py @@ -78,6 +78,37 @@ def _render_outline_md(outline: Dict[str, Any]) -> str: return '\n'.join(lines) +def _render_outline_progress_md(outline: Dict[str, Any]) -> str: + """Render a concise outline progress view for terminal logs.""" + chapters = outline.get('chapters', []) + total = len(chapters) + completed = sum(1 for ch in chapters if ch.get('status') == 'completed') + in_progress = sum(1 for ch in chapters + if ch.get('status') == 'in_progress') + pending = total - completed - in_progress + + lines = [f"# {outline.get('title', 'Report Outline')}", ''] + lines.append( + f'Progress: {completed}/{total} completed | {in_progress} in progress | {pending} pending' + ) + lines.append('') + lines.append('## Chapters') + lines.append('') + + for ch in chapters: + status = ch.get('status', 'pending') + status_icon = { + 'pending': '⏳', + 'in_progress': '🔄', + 'completed': '✅' + }.get(status, '⏳') + lines.append( + f"- {status_icon} Chapter {ch['chapter_id']}: {ch['title']}") + + lines.append('') + return '\n'.join(lines) + + class ReportTool(ToolBase): """ Report generation tool for DeepResearch Reporter agent. @@ -128,6 +159,9 @@ def _paths(self) -> Dict[str, str]: os.path.join(self.output_dir, self._reports_dir, 'outline.json'), 'outline_md': os.path.join(self.output_dir, self._reports_dir, 'outline.md'), + 'outline_progress_md': + os.path.join(self.output_dir, self._reports_dir, + 'outline_progress.md'), 'chapters_dir': os.path.join(self.output_dir, self._reports_dir, 'chapters'), 'conflict_json': @@ -486,10 +520,12 @@ def _save_outline(self, outline['updated_at'] = _now_iso() _write_text(paths['outline_json'], _json_dumps(outline)) _write_text(paths['outline_md'], _render_outline_md(outline)) + _write_text(paths['outline_progress_md'], + _render_outline_progress_md(outline)) if render: render_markdown_todo( - paths['outline_md'], + paths['outline_progress_md'], title='CURRENT REPORT OUTLINE', use_pager=False) @@ -647,10 +683,8 @@ async def prepare_chapter_bundle( note_id, 'title': meta.get('title', note_data.get('title', '')), - 'claim': - note_data.get('claim', ''), - 'supports': - note_data.get('supports', ''), + 'content': + note_data.get('content', ''), 'contradicts': note_data.get('contradicts', ''), 'summary': @@ -684,10 +718,8 @@ async def prepare_chapter_bundle( note_id, 'title': meta.get('title', note_data.get('title', '')), - 'claim': - note_data.get('claim', ''), - 'supports': - note_data.get('supports', ''), + 'content': + note_data.get('content', ''), 'contradicts': note_data.get('contradicts', ''), 'summary': @@ -1048,7 +1080,10 @@ async def assemble_draft( 'conflicts_summary': conflicts_summary, 'next_step_reminder': - ('Review the draft and conflicts, then you can try to generate the final report. ' + ('Review the draft and conflicts, then generate the final report. ' + 'Note: the draft cannot be used as the final report; ' + 'do not replace report content with references or pointers to other content or files ' + '(e.g., "details are in chapter_2.md", "see draft.md for more details").' ), }) diff --git a/tests/config/test_prompt_files.py b/tests/config/test_prompt_files.py new file mode 100644 index 000000000..30bdc94ff --- /dev/null +++ b/tests/config/test_prompt_files.py @@ -0,0 +1,129 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import os +import tempfile +import unittest + +from ms_agent.config import Config + + +class TestPromptFiles(unittest.TestCase): + + def _write(self, path: str, content: str): + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w', encoding='utf-8') as f: + f.write(content) + + def test_inline_system_not_overridden(self): + with tempfile.TemporaryDirectory() as td: + cfg_path = os.path.join(td, 'agent.yaml') + self._write( + cfg_path, + """llm: + service: openai + model: qwen3-max +code_file: researcher +prompt: + system: | + INLINE_SYSTEM + lang: zh + family: qwen-3 +""", + ) + # Create prompt file that would match if resolver ran + self._write( + os.path.join(td, 'prompts', 'researcher', 'zh', 'qwen-3.md'), + "FILE_SYSTEM", + ) + config = Config.from_task(td) + self.assertIn('INLINE_SYSTEM', config.prompt.system) + + def test_load_family_prompt_file(self): + with tempfile.TemporaryDirectory() as td: + self._write( + os.path.join(td, 'agent.yaml'), + """llm: + service: openai + model: qwen3-max +code_file: researcher +prompt: + lang: zh + family: qwen-3 +""", + ) + self._write( + os.path.join(td, 'prompts', 'researcher', 'zh', 'qwen-3.md'), + "QWEN3_SYSTEM", + ) + self._write( + os.path.join(td, 'prompts', 'researcher', 'zh', 'base.md'), + "BASE_SYSTEM", + ) + config = Config.from_task(td) + self.assertEqual(config.prompt.system.strip(), 'QWEN3_SYSTEM') + + def test_fallback_to_base_when_family_missing(self): + with tempfile.TemporaryDirectory() as td: + self._write( + os.path.join(td, 'agent.yaml'), + """llm: + service: openai + model: qwen3-max +code_file: researcher +prompt: + lang: zh + family: qwen-3 +""", + ) + self._write( + os.path.join(td, 'prompts', 'researcher', 'zh', 'base.md'), + "BASE_ONLY", + ) + config = Config.from_task(td) + self.assertEqual(config.prompt.system.strip(), 'BASE_ONLY') + + def test_custom_prompt_root_relative(self): + with tempfile.TemporaryDirectory() as td: + self._write( + os.path.join(td, 'agent.yaml'), + """llm: + service: openai + model: claude-3-5-sonnet +code_file: reporter +prompt: + lang: en + family: claude + root: my_prompts +""", + ) + self._write( + os.path.join(td, 'my_prompts', 'reporter', 'en', 'claude.md'), + "CLAUDE_REPORTER", + ) + config = Config.from_task(td) + self.assertEqual(config.prompt.system.strip(), 'CLAUDE_REPORTER') + + def test_lang_fallback(self): + with tempfile.TemporaryDirectory() as td: + self._write( + os.path.join(td, 'agent.yaml'), + """llm: + service: openai + model: gpt-4.1 +code_file: searcher +prompt: + lang: en + family: gpt +""", + ) + # en missing, fallback to zh + self._write( + os.path.join(td, 'prompts', 'searcher', 'zh', 'gpt.md'), + "GPT_ZH", + ) + config = Config.from_task(td) + self.assertEqual(config.prompt.system.strip(), 'GPT_ZH') + + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/tools/test_server_tools_smoke.py b/tests/tools/test_server_tools_smoke.py index 73954a45a..efccda3c8 100644 --- a/tests/tools/test_server_tools_smoke.py +++ b/tests/tools/test_server_tools_smoke.py @@ -496,8 +496,7 @@ async def main(): res = await tool.write_note( title='Finding A', - claim='Claim A', - supports='Support A', + content='Claim A. Support A', sources=[{ 'url': 'https://example.com/src', 'published_at': '2026-01-01', @@ -520,7 +519,7 @@ async def main(): note_id=note_id, parse_note=True)) self.assertEqual(got['status'], 'ok') self.assertEqual(got['note']['note_id'], note_id) - self.assertEqual(got['note']['claim'], 'Claim A') + self.assertEqual(got['note']['content'], 'Claim A. Support A') listed = json.loads(await tool.list_notes( task_id='task_1', tags=['tag1'])) @@ -543,6 +542,76 @@ async def main(): asyncio.run(main()) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_evidence_tool_write_get_list_search_delete_analysis(self): + + async def main(): + with tempfile.TemporaryDirectory() as td: + cfg = _make_config(td, tools={'evidence_store': SimpleNamespace()}) + tool = EvidenceTool(cfg) + await tool.connect() + + # Write a note first; conclusion can reference it. + note_res = json.loads(await tool.write_note( + title='Finding A', + content='Claim A. Support A', + sources=[{ + 'url': 'https://example.com/src', + 'published_at': '2026-01-01', + 'source_tier': 'primary', + }], + summary='summary A', + task_id='task_1', + tags=['tag1', 'tag2'], + quality_score=80, + )) + note_id = note_res['note_id'] + + res = await tool.write_analysis( + title='Interim synthesis', + content='Some **markdown** synthesis.', + summary='one-liner', + task_id='task_1', + based_on_note_ids=[note_id], + tags=['synthesis', 'tag1'], + quality_score=90, + ) + data = json.loads(res) + self.assertEqual(data['status'], 'ok') + analysis_id = data['analysis_id'] + + idx = json.loads(await tool.load_index()) + self.assertEqual(idx['status'], 'ok') + self.assertEqual(idx['total_notes'], 1) + self.assertEqual(idx['total_analyses'], 1) + self.assertIn(analysis_id, idx.get('analyses', {})) + + got = json.loads(await tool.get_analysis( + analysis_id=analysis_id, parse_analysis=True)) + self.assertEqual(got['status'], 'ok') + self.assertEqual(got['analysis']['analysis_id'], analysis_id) + self.assertIn('markdown', got['analysis'].get('content', '')) + + listed = json.loads(await tool.list_analyses( + task_id='task_1', tags=['tag1'])) + self.assertEqual(listed['status'], 'ok') + self.assertEqual(listed['count'], 1) + + searched = json.loads( + await tool.search_analyses(keyword='synthesis')) + self.assertEqual(searched['status'], 'ok') + self.assertEqual(searched['count'], 1) + + deleted = json.loads( + await tool.delete_analysis(analysis_id=analysis_id)) + self.assertEqual(deleted['status'], 'ok') + + missing = json.loads( + await tool.get_analysis(analysis_id=analysis_id)) + self.assertEqual(missing['status'], 'error') + + asyncio.run(main()) + class TestReportToolServer(unittest.TestCase): @@ -558,8 +627,7 @@ async def main(): await ev.connect() n1 = json.loads(await ev.write_note( title='N1', - claim='C1', - supports='S1', + content='C1. S1', sources=[{ 'url': 'https://example.com/1' }], @@ -569,8 +637,7 @@ async def main(): )) n2 = json.loads(await ev.write_note( title='N2', - claim='C2', - supports='S2', + content='C2. S2', sources=[{ 'url': 'https://example.com/2' }], From 693d9e95d5f9d013f8af32789de5ffc381cd10cb Mon Sep 17 00:00:00 2001 From: suluyan Date: Sun, 15 Mar 2026 17:48:32 +0800 Subject: [PATCH 03/40] feat: support search local paths through sirchmunk --- docs/en/Components/Config.md | 17 + docs/zh/Components/config.md | 80 ++-- examples/knowledge_search/agent.yaml.example | 86 ++++ ms_agent/agent/llm_agent.py | 70 ++- ms_agent/cli/run.py | 55 +++ ms_agent/knowledge_search/README.md | 277 ++++++++++++ ms_agent/knowledge_search/__init__.py | 11 + ms_agent/knowledge_search/sirchmunk_search.py | 401 ++++++++++++++++++ ms_agent/llm/dashscope_llm.py | 2 +- ms_agent/llm/utils.py | 6 + ms_agent/rag/utils.py | 3 + tests/knowledge_search/__init__.py | 2 + tests/knowledge_search/test_sirschmunk.py | 203 +++++++++ 13 files changed, 1179 insertions(+), 34 deletions(-) create mode 100644 examples/knowledge_search/agent.yaml.example create mode 100644 ms_agent/knowledge_search/README.md create mode 100644 ms_agent/knowledge_search/__init__.py create mode 100644 ms_agent/knowledge_search/sirchmunk_search.py create mode 100644 tests/knowledge_search/__init__.py create mode 100644 tests/knowledge_search/test_sirschmunk.py diff --git a/docs/en/Components/Config.md b/docs/en/Components/Config.md index d40f03898..f1253bd75 100644 --- a/docs/en/Components/Config.md +++ b/docs/en/Components/Config.md @@ -166,3 +166,20 @@ In addition to yaml configuration, MS-Agent also supports several additional com ``` > Any configuration in agent.yaml can be passed in with new values via command line, and also supports reading from environment variables with the same name (case insensitive), for example `--llm.modelscope_api_key xxx-xxx`. + +- knowledge_search_paths: Knowledge search paths, comma-separated multiple paths. When provided, automatically enables SirchmunkSearch for knowledge retrieval, with LLM configuration automatically inherited from the `llm` module. + +### Quick Start for Knowledge Search + +Use the `--knowledge_search_paths` parameter to quickly enable knowledge search based on local documents: + +```bash +# Using default agent.yaml configuration, automatically reuses LLM settings +ms-agent run --query "How to implement user authentication?" --knowledge_search_paths "./src,./docs" + +# Specify configuration file +ms-agent run --config /path/to/agent.yaml --query "your question" --knowledge_search_paths "/path/to/docs" +``` + +LLM-related parameters (api_key, base_url, model) are automatically inherited from the `llm` module in the configuration file, no need to configure them repeatedly. +If you need to use independent LLM configuration in the `knowledge_search` module, you can explicitly configure `knowledge_search.llm_api_key` and other parameters in the yaml. diff --git a/docs/zh/Components/config.md b/docs/zh/Components/config.md index 041820b93..12849f2a7 100644 --- a/docs/zh/Components/config.md +++ b/docs/zh/Components/config.md @@ -1,12 +1,12 @@ --- slug: config title: 配置与参数 -description: Ms-Agent 配置与参数:类型配置、自定义代码、LLM配置、推理配置、system和query、callbacks、工具配置、其他、config_handler、命令行配置 +description: Ms-Agent 配置与参数:类型配置、自定义代码、LLM 配置、推理配置、system 和 query、callbacks、工具配置、其他、config_handler、命令行配置 --- # 配置与参数 -MS-Agent使用一个yaml文件进行配置管理,通常这个文件被命名为`agent.yaml`,这样的设计使不同场景可以读取不同的配置文件。该文件具体包含的字段有: +MS-Agent 使用一个 yaml 文件进行配置管理,通常这个文件被命名为 `agent.yaml`,这样的设计使不同场景可以读取不同的配置文件。该文件具体包含的字段有: ## 类型配置 @@ -17,31 +17,31 @@ MS-Agent使用一个yaml文件进行配置管理,通常这个文件被命名 type: llmagent ``` -标识本配置对应的agent类型,支持`llmagent`和`codeagent`两类。默认为`llmagent`。如果yaml中包含了code_file字段,则code_file优先生效。 +标识本配置对应的 agent 类型,支持 `llmagent` 和 `codeagent` 两类。默认为 `llmagent`。如果 yaml 中包含了 code_file 字段,则 code_file 优先生效。 ## 自定义代码 -> 可选,在需要自定义LLMAgent时使用 +> 可选,在需要自定义 LLMAgent 时使用 ```yaml code_file: custom_agent ``` -可以使用一个外部agent类,该类需要继承自`LLMAgent`。可以复写其中的若干方法,如果code_file有值,则`type`字段不生效。 +可以使用一个外部 agent 类,该类需要继承自 `LLMAgent`。可以复写其中的若干方法,如果 code_file 有值,则 `type` 字段不生效。 -## LLM配置 +## LLM 配置 > 必须存在 ```yaml llm: - # 大模型服务backend + # 大模型服务 backend service: modelscope - # 模型id + # 模型 id model: Qwen/Qwen3-235B-A22B-Instruct-2507 - # 模型api_key + # 模型 api_key modelscope_api_key: - # 模型base_url + # 模型 base_url modelscope_base_url: https://api-inference.modelscope.cn/v1 ``` @@ -51,7 +51,7 @@ llm: ```yaml generation_config: - # 下面的字段均为OpenAI sdk的标准参数,你也可以配置OpenAI支持的其他参数在这里。 + # 下面的字段均为 OpenAI sdk 的标准参数,你也可以配置 OpenAI 支持的其他参数在这里。 top_p: 0.6 temperature: 0.2 top_k: 20 @@ -60,25 +60,25 @@ generation_config: enable_thinking: false ``` -## system和query +## system 和 query -> 可选,但推荐传入system +> 可选,但推荐传入 system ```yaml prompt: - # LLM system,如果不传递则使用默认的`you are a helpful assistant.` + # LLM system,如果不传递则使用默认的 `you are a helpful assistant.` system: - # LLM初始query,通常来说可以不使用 + # LLM 初始 query,通常来说可以不使用 query: ``` ## callbacks -> 可选,推荐自定义callbacks +> 可选,推荐自定义 callbacks ```yaml callbacks: - # 用户输入callback,该callback在assistant回复后自动等待用户输入 + # 用户输入 callback,该 callback 在 assistant 回复后自动等待用户输入 - input_callback ``` @@ -90,9 +90,9 @@ callbacks: tools: # 工具名称 file_system: - # 是否是mcp + # 是否是 mcp mcp: false - # 排除的function,可以为空 + # 排除的 function,可以为空 exclude: - create_directory - write_file @@ -104,20 +104,20 @@ tools: - map_geo ``` -支持的完整工具列表,以及自定义工具请参考[这里](./tools) +支持的完整工具列表,以及自定义工具请参考 [这里](./tools) ## 其他 > 可选,按需配置 ```yaml -# 自动对话轮数,默认为20轮 +# 自动对话轮数,默认为 20 轮 max_chat_round: 9999 # 工具调用超时时间,单位秒 tool_call_timeout: 30000 -# 输出artifact目录 +# 输出 artifact 目录 output_dir: output # 帮助信息,通常在运行错误后出现 @@ -127,13 +127,13 @@ help: | ## config_handler -为了便于在任务开始时对config进行定制化,MS-Agent构建了一个名为`ConfigLifecycleHandler`的机制。这是一个callback类,开发者可以在yaml文件中增加这样一个配置: +为了便于在任务开始时对 config 进行定制化,MS-Agent 构建了一个名为 `ConfigLifecycleHandler` 的机制。这是一个 callback 类,开发者可以在 yaml 文件中增加这样一个配置: ```yaml handler: custom_handler ``` -这代表和yaml文件同级有一个custom_handler.py文件,该文件的类继承自`ConfigLifecycleHandler`,分别有两个方法: +这代表和 yaml 文件同级有一个 custom_handler.py 文件,该文件的类继承自 `ConfigLifecycleHandler`,分别有两个方法: ```python def task_begin(self, config: DictConfig, tag: str) -> DictConfig: @@ -143,18 +143,18 @@ handler: custom_handler return config ``` -`task_begin`在LLMAgent类构造时生效,在该方法中可以对config进行一些修改。如果你的工作流中下游任务会继承上游的yaml配置,这个机制会有帮助。值得注意的是`tag`参数,该参数会传入当前LLMAgent的名字,方便分辨当前工作流的节点。 +`task_begin` 在 LLMAgent 类构造时生效,在该方法中可以对 config 进行一些修改。如果你的工作流中下游任务会继承上游的 yaml 配置,这个机制会有帮助。值得注意的是 `tag` 参数,该参数会传入当前 LLMAgent 的名字,方便分辨当前工作流的节点。 ## 命令行配置 -在yaml配置之外,MS-Agent还支持若干额外的命令行参数。 +在 yaml 配置之外,MS-Agent 还支持若干额外的命令行参数。 -- query: 初始query,这个query的优先级高于yaml中的prompt.query -- config: 配置文件路径,支持modelscope model-id -- trust_remote_code: 是否信任外部代码。如果某个配置包含了一些外部代码,需要将这个参数置为true才会生效 -- load_cache: 从历史messages继续对话。cache会被自动存储在`output`配置中。默认为`False` -- mcp_server_file: 可以读取一个外部的mcp工具配置,格式为: +- query: 初始 query,这个 query 的优先级高于 yaml 中的 prompt.query +- config: 配置文件路径,支持 modelscope model-id +- trust_remote_code: 是否信任外部代码。如果某个配置包含了一些外部代码,需要将这个参数置为 true 才会生效 +- load_cache: 从历史 messages 继续对话。cache 会被自动存储在 `output` 配置中。默认为 `False` +- mcp_server_file: 可以读取一个外部的 mcp 工具配置,格式为: ```json { "mcpServers": { @@ -165,5 +165,21 @@ handler: custom_handler } } ``` +- knowledge_search_paths: 知识搜索路径,逗号分隔的多个路径。传入后会自动启用 SirchmunkSearch 进行知识检索,LLM 配置自动从 `llm` 模块复用 -> agent.yaml中的任意一个配置,都可以使用命令行传入新的值, 也支持从同名(大小写不敏感)环境变量中读取,例如`--llm.modelscope_api_key xxx-xxx`。 +> agent.yaml 中的任意一个配置,都可以使用命令行传入新的值,也支持从同名(大小写不敏感)环境变量中读取,例如 `--llm.modelscope_api_key xxx-xxx`。 + +### 知识搜索快速使用 + +通过 `--knowledge_search_paths` 参数,可以快速启用基于本地文档的知识搜索: + +```bash +# 使用默认 agent.yaml 配置,自动复用 LLM 设置 +ms-agent run --query "如何实现用户认证?" --knowledge_search_paths "./src,./docs" + +# 指定配置文件 +ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_search_paths "/path/to/docs" +``` + +LLM 相关参数(api_key, base_url, model)会自动从配置文件的 `llm` 模块继承,无需重复配置。 +如果需要在 `knowledge_search` 模块中使用独立的 LLM 配置,可以在 yaml 中显式配置 `knowledge_search.llm_api_key` 等参数。 diff --git a/examples/knowledge_search/agent.yaml.example b/examples/knowledge_search/agent.yaml.example new file mode 100644 index 000000000..cc11a8a3d --- /dev/null +++ b/examples/knowledge_search/agent.yaml.example @@ -0,0 +1,86 @@ +# Sirchmunk Knowledge Search 配置示例 +# Sirchmunk Knowledge Search Configuration Example + +# 在您的 agent.yaml 或 workflow.yaml 中添加以下配置: + +llm: + service: modelscope + model: Qwen/Qwen3-235B-A22B-Instruct-2507 + modelscope_api_key: + modelscope_base_url: https://api-inference.modelscope.cn/v1 + +generation_config: + temperature: 0.3 + top_k: 20 + stream: true + +# Knowledge Search 配置(可选) +# 用于在本地代码库中搜索相关信息 +knowledge_search: + # 必选:要搜索的路径列表 + paths: + - ./src + - ./docs + + # 可选:sirchmunk 工作目录,用于缓存 + work_path: ./.sirchmunk + + # 可选:LLM 配置(如不配置则使用上面 llm 的配置) + llm_api_key: + llm_base_url: https://api.openai.com/v1 + llm_model_name: gpt-4o-mini + + # 可选:Embedding 模型 + embedding_model: text-embedding-3-small + + # 可选:聚类相似度阈值 + cluster_sim_threshold: 0.85 + + # 可选:聚类 TopK + cluster_sim_top_k: 3 + + # 可选:是否重用之前的知识 + reuse_knowledge: true + + # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) + mode: FAST + + # 可选:最大循环次数 + max_loops: 10 + + # 可选:最大 token 预算 + max_token_budget: 128000 + +prompt: + system: | + You are an assistant that helps me complete tasks. + +max_chat_round: 9999 + +# 使用说明: +# 1. 配置 knowledge_search 后,LLMAgent 会在处理用户请求时自动搜索本地代码库 +# 2. 搜索结果会自动添加到 user message 的 search_result 和 searching_detail 字段 +# 3. search_result 包含搜索到的相关文档,会作为上下文提供给 LLM +# 4. searching_detail 包含搜索日志和元数据,可用于前端展示 +# +# Python 使用示例: +# ```python +# from ms_agent import LLMAgent +# from ms_agent.config import Config +# +# config = Config.from_task('path/to/agent.yaml') +# agent = LLMAgent(config=config) +# result = await agent.run('如何实现用户认证功能?') +# +# # 获取搜索详情(用于前端展示) +# for msg in result: +# if msg.role == 'user': +# print(f"Search logs: {msg.searching_detail}") +# print(f"Search results: {msg.search_result}") +# ``` +# +# CLI 测试命令: +# export LLM_API_KEY="your-api-key" +# export LLM_BASE_URL="https://api.openai.com/v1" +# export LLM_MODEL_NAME="gpt-4o-mini" +# python tests/knowledge_search/test_cli.py --query "你的问题" diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 740eab690..3bc40c1fc 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -19,6 +19,7 @@ from ms_agent.memory.memory_manager import SharedMemoryManager from ms_agent.rag.base import RAG from ms_agent.rag.utils import rag_mapping +from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.tools import ToolManager from ms_agent.utils import async_retry, read_history, save_history from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER @@ -104,6 +105,7 @@ def __init__(self, self.tool_manager: Optional[ToolManager] = None self.memory_tools: List[Memory] = [] self.rag: Optional[RAG] = None + self.knowledge_search: Optional[SirschmunkSearch] = None self.llm: Optional[LLM] = None self.runtime: Optional[Runtime] = None self.max_chat_round: int = 0 @@ -619,8 +621,52 @@ async def create_messages( return messages async def do_rag(self, messages: List[Message]): + """Process RAG or knowledge search to enrich the user query with context. + + This method handles both traditional RAG and sirchmunk-based knowledge search. + For knowledge search, it also populates searching_detail and search_result + fields in the message for frontend display and next-turn LLM context. + + Args: + messages (List[Message]): The message list to process. + """ + user_message = messages[1] if len(messages) > 1 else None + if user_message is None or user_message.role != 'user': + return + + query = user_message.content + + # Handle traditional RAG if self.rag is not None: - messages[1].content = await self.rag.query(messages[1].content) + user_message.content = await self.rag.query(query) + + # Handle sirchmunk knowledge search + if self.knowledge_search is not None: + # Perform search and get results + search_result = await self.knowledge_search.retrieve(query) + search_details = self.knowledge_search.get_search_details() + + # Store search details in the message for frontend display + user_message.searching_detail = search_details + user_message.search_result = search_result + + # Build enriched context from search results + if search_result: + context_parts = [] + for i, result in enumerate(search_result, 1): + text = result.get('text', '') + source = result.get('metadata', {}).get('source', 'unknown') + score = result.get('score', 0) + context_parts.append( + f"[Source {i}] {source} (relevance: {score:.2f})\n{text}\n" + ) + + # Append search context to user query + context = '\n'.join(context_parts) + user_message.content = ( + f"Relevant context retrieved from codebase search:\n\n{context}\n\n" + f"User question: {query}" + ) async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: @@ -706,6 +752,27 @@ async def prepare_rag(self): f'which supports: {list(rag_mapping.keys())}') self.rag: RAG = rag_mapping(rag.name)(self.config) + async def prepare_knowledge_search(self): + """Load and initialize the knowledge search component from the config.""" + if hasattr(self.config, 'knowledge_search'): + ks_config = self.config.knowledge_search + if ks_config is not None: + # Extract LLM config for sirchmunk + if hasattr(self.config, 'llm'): + llm_config = self.config.llm + # Update knowledge_search config with LLM settings if not specified + if not hasattr(ks_config, 'llm_api_key') and hasattr(llm_config, 'modelscope_api_key'): + OmegaConf.update(self.config, 'knowledge_search.llm_api_key', + getattr(llm_config, 'modelscope_api_key', None), merge=True) + if not hasattr(ks_config, 'llm_base_url') and hasattr(llm_config, 'modelscope_base_url'): + OmegaConf.update(self.config, 'knowledge_search.llm_base_url', + getattr(llm_config, 'modelscope_base_url', None), merge=True) + if not hasattr(ks_config, 'llm_model_name') and hasattr(llm_config, 'model'): + OmegaConf.update(self.config, 'knowledge_search.llm_model_name', + getattr(llm_config, 'model', None), merge=True) + + self.knowledge_search: SirchmunkSearch = SirchmunkSearch(self.config) + async def condense_memory(self, messages: List[Message]) -> List[Message]: """ Update memory using the current conversation history. @@ -1044,6 +1111,7 @@ async def run_loop(self, messages: Union[List[Message], str], await self.prepare_tools() await self.load_memory() await self.prepare_rag() + await self.prepare_knowledge_search() self.runtime.tag = self.tag if messages is None: diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index c2df42eee..cfe387e5a 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -4,6 +4,8 @@ import os from importlib import resources as importlib_resources +from omegaconf import OmegaConf + from ms_agent.config import Config from ms_agent.utils import get_logger, strtobool from ms_agent.utils.constants import AGENT_CONFIG_FILE, MS_AGENT_ASCII @@ -46,6 +48,22 @@ class RunCMD(CLICommand): def __init__(self, args): self.args = args + def load_env_file(self): + """Load environment variables from .env file in current directory.""" + env_file = os.path.join(os.getcwd(), '.env') + if os.path.exists(env_file): + with open(env_file, 'r') as f: + for line in f: + line = line.strip() + if line and not line.startswith('#') and '=' in line: + key, value = line.split('=', 1) + key = key.strip() + value = value.strip() + # Only set if not already set in environment + if key not in os.environ: + os.environ[key] = value + logger.debug(f'Loaded {key} from .env file') + @staticmethod def define_args(parsers: argparse.ArgumentParser): """Define args for run command.""" @@ -120,6 +138,14 @@ def define_args(parsers: argparse.ArgumentParser): help= 'Animation mode for video_generate project: auto (default) or human.' ) + parser.add_argument( + '--knowledge_search_paths', + required=False, + type=str, + default=None, + help= + 'Comma-separated list of paths for knowledge search. When provided, enables SirchmunkSearch using LLM config from llm module.' + ) parser.set_defaults(func=subparser_func) def execute(self): @@ -150,10 +176,18 @@ def execute(self): return self._execute_with_config() def _execute_with_config(self): + # Load environment variables from .env file if exists + self.load_env_file() + if not self.args.config: current_dir = os.getcwd() if os.path.exists(os.path.join(current_dir, AGENT_CONFIG_FILE)): self.args.config = os.path.join(current_dir, AGENT_CONFIG_FILE) + else: + # Use built-in default agent.yaml from package + default_config_path = importlib_resources.files('ms_agent').joinpath('agent', AGENT_CONFIG_FILE) + with importlib_resources.as_file(default_config_path) as config_file: + self.args.config = str(config_file) elif not os.path.exists(self.args.config): from modelscope import snapshot_download self.args.config = snapshot_download(self.args.config) @@ -190,6 +224,27 @@ def _execute_with_config(self): config = Config.from_task(self.args.config) + # If knowledge_search_paths is provided, configure SirchmunkSearch + if getattr(self.args, 'knowledge_search_paths', None): + paths = [p.strip() for p in self.args.knowledge_search_paths.split(',') if p.strip()] + if paths: + if 'knowledge_search' not in config or not config.knowledge_search: + # No existing knowledge_search config, create minimal config + # LLM settings will be auto-reused from llm module by SirchmunkSearch + knowledge_search_config = { + 'name': 'SirchmunkSearch', + 'paths': paths, + 'work_path': './.sirchmunk', + 'mode': 'FAST', + } + config['knowledge_search'] = OmegaConf.create(knowledge_search_config) + else: + # Existing knowledge_search config found, only update paths + # LLM settings are already handled by SirchmunkSearch internally + existing = OmegaConf.to_container(config.knowledge_search, resolve=True) + existing['paths'] = paths + config['knowledge_search'] = OmegaConf.create(existing) + if Config.is_workflow(config): from ms_agent.workflow.loader import WorkflowLoader engine = WorkflowLoader.build( diff --git a/ms_agent/knowledge_search/README.md b/ms_agent/knowledge_search/README.md new file mode 100644 index 000000000..00743601e --- /dev/null +++ b/ms_agent/knowledge_search/README.md @@ -0,0 +1,277 @@ +# Sirchmunk Knowledge Search 集成 + +本模块实现了 [sirchmunk](https://github.com/modelscope/sirchmunk) 与 ms_agent 框架的集成,提供了基于代码库的智能搜索功能。 + +## 功能特性 + +- **智能代码搜索**: 使用 LLM 和 embedding 模型对代码库进行语义搜索 +- **多模式搜索**: 支持 FAST、DEEP、FILENAME_ONLY 三种搜索模式 +- **知识复用**: 自动缓存和复用之前的搜索结果,减少 LLM 调用 +- **前端友好**: 提供详细的搜索日志和结果,方便前端展示 +- **无缝集成**: 与 LLMAgent 无缝集成,像使用 RAG 一样简单 + +## 安装 + +```bash +pip install sirchmunk +``` + +## 配置 + +在您的 `agent.yaml` 或 `workflow.yaml` 中添加以下配置: + +```yaml +llm: + service: dashscope + model: qwen3.5-plus + dashscope_api_key: + dashscope_base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 + +generation_config: + temperature: 0.3 + stream: true + +# Knowledge Search 配置 +knowledge_search: + # 必选:要搜索的路径列表 + paths: + - ./src + - ./docs + + # 可选:sirchmunk 工作目录 + work_path: ./.sirchmunk + + # 可选:LLM 配置(如不配置则自动复用上面 llm 模块的配置) + # llm_api_key: + # llm_base_url: https://api.openai.com/v1 + # llm_model_name: gpt-4o-mini + + # 可选:Embedding 模型 + embedding_model: text-embedding-3-small + + # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) + mode: FAST + + # 可选:是否重用之前的知识 + reuse_knowledge: true +``` + +**LLM 配置自动复用机制**: + +`SirchmunkSearch` 会自动从主配置的 `llm` 模块复用 LLM 相关参数: +- 如果 `knowledge_search.llm_api_key` 未配置,自动使用 `llm.{service}_api_key` +- 如果 `knowledge_search.llm_base_url` 未配置,自动使用 `llm.{service}_base_url` +- 如果 `knowledge_search.llm_model_name` 未配置,自动使用 `llm.model` + +其中 `service` 是 `llm.service` 的值(如 `dashscope`, `modelscope`, `openai` 等)。 + +通过 CLI 使用时,只需传入 `--knowledge_search_paths` 参数,无需额外配置 LLM 参数。 + +## 使用方式 + +### 1. 通过 CLI 使用(推荐) + +从命令行直接运行,无需编写代码: + +```bash +# 基本用法 - LLM 配置自动从 agent.yaml 的 llm 模块复用 +ms-agent run --query "如何实现用户认证功能?" --knowledge_search_paths "./src,./docs" + +# 指定配置文件 +ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_search_paths "/path/to/docs" +``` + +**说明**: +- `--knowledge_search_paths` 参数支持逗号分隔的多个路径 +- LLM 相关配置(api_key, base_url, model)会自动从配置文件的 `llm` 模块复用 +- 如果 `knowledge_search` 模块单独配置了 `llm_api_key` 等参数,则优先使用模块自己的配置 + +### 2. 通过 LLMAgent 使用 + +```python +from ms_agent import LLMAgent +from ms_agent.config import Config + +# 从配置文件加载 +config = Config.from_task('path/to/agent.yaml') +agent = LLMAgent(config=config) + +# 运行查询 - 会自动触发知识搜索 +result = await agent.run('如何实现用户认证功能?') + +# 获取搜索结果 +for msg in result: + if msg.role == 'user': + # 搜索详情(用于前端展示) + print(f"Search logs: {msg.searching_detail}") + # 搜索结果(作为 LLM 上下文) + print(f"Search results: {msg.search_result}") +``` + +### 2. 单独使用 SirchmunkSearch + +```python +from ms_agent.knowledge_search import SirchmunkSearch +from omegaconf import DictConfig + +config = DictConfig({ + 'knowledge_search': { + 'paths': ['./src', './docs'], + 'work_path': './.sirchmunk', + 'llm_api_key': 'your-api-key', + 'llm_model_name': 'gpt-4o-mini', + 'mode': 'FAST', + } +}) + +searcher = SirchmunkSearch(config) + +# 查询(返回合成答案) +answer = await searcher.query('如何实现用户认证?') + +# 检索(返回原始搜索结果) +results = await searcher.retrieve( + query='用户认证', + limit=5, + score_threshold=0.7 +) + +# 获取搜索日志 +logs = searcher.get_search_logs() + +# 获取搜索详情 +details = searcher.get_search_details() +``` + +## 环境变量 + +可以通过环境变量配置: + +```bash +# LLM 配置(如不设置则自动从 agent.yaml 的 llm 模块读取) +export LLM_API_KEY="your-api-key" +export LLM_BASE_URL="https://api.openai.com/v1" +export LLM_MODEL_NAME="gpt-4o-mini" + +# Embedding 模型配置 +export EMBEDDING_MODEL_ID="text-embedding-3-small" +export SIRCHMUNK_WORK_PATH="./.sirchmunk" +``` + +**注意**:通过 CLI 使用时,推荐直接在 `.env` 文件或 agent.yaml 中配置 LLM 参数,`SirchmunkSearch` 会自动复用。 + +## 测试 + +### 单元测试 + +```bash +export LLM_API_KEY="your-api-key" +export LLM_BASE_URL="https://api.openai.com/v1" +export LLM_MODEL_NAME="gpt-4o-mini" + +python -m unittest tests/knowledge_search/test_sirschmunk.py +``` + +### CLI 测试 + +```bash +# 基本测试 +python tests/knowledge_search/test_cli.py + +# 指定查询 +python tests/knowledge_search/test_cli.py -q "如何实现用户认证?" + +# 仅测试 standalone 模式 +python tests/knowledge_search/test_cli.py -m standalone + +# 仅测试 agent 模式 +python tests/knowledge_search/test_cli.py -m agent +``` + +## 配置参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| paths | List[str] | 必选 | 要搜索的目录/文件路径列表 | +| work_path | str | ./.sirchmunk | sirchmunk 工作目录,用于缓存 | +| llm_api_key | str | 从 llm 配置继承 | LLM API 密钥 | +| llm_base_url | str | 从 llm 配置继承 | LLM API 基础 URL | +| llm_model_name | str | 从 llm 配置继承 | LLM 模型名称 | +| embedding_model | str | text-embedding-3-small | Embedding 模型 ID | +| cluster_sim_threshold | float | 0.85 | 聚类相似度阈值 | +| cluster_sim_top_k | int | 3 | 聚类 TopK 数量 | +| reuse_knowledge | bool | true | 是否重用之前的知识 | +| mode | str | FAST | 搜索模式 (DEEP/FAST/FILENAME_ONLY) | +| max_loops | int | 10 | 最大搜索循环次数 | +| max_token_budget | int | 128000 | 最大 token 预算 | + +## 搜索模式 + +- **FAST**: 快速模式,使用贪婪策略,1-5 秒内返回结果,0-2 次 LLM 调用 +- **DEEP**: 深度模式,并行多路径检索 + ReAct 优化,5-30 秒,4-6 次 LLM 调用 +- **FILENAME_ONLY**: 仅文件名模式,基于模式匹配,无 LLM 调用,非常快 + +## Message 字段扩展 + +为了支持知识搜索,`Message` 类增加了两个字段: + +- **searching_detail** (Dict[str, Any]): 搜索过程日志和元数据,用于前端展示 + - `logs`: 搜索日志列表 + - `mode`: 使用的搜索模式 + - `paths`: 搜索的路径 + - `work_path`: 工作目录 + - `reuse_knowledge`: 是否重用知识 + +- **search_result** (List[Dict[str, Any]]): 搜索结果,作为下一轮 LLM 的上下文 + - `text`: 文档内容 + - `score`: 相关性分数 + - `metadata`: 元数据(如源文件、类型等) + +## 工作原理 + +1. 用户发送查询 +2. LLMAgent 调用 `prepare_knowledge_search()` 初始化 SirchmunkSearch +3. `do_rag()` 方法执行知识搜索: + - 调用 `searcher.retrieve()` 获取相关文档 + - 将搜索结果存入 `message.search_result` + - 将搜索日志存入 `message.searching_detail` + - 将搜索结果格式化为上下文,附加到用户查询 +4. LLM 接收 enriched query 并生成回答 +5. 前端可以通过 `searching_detail` 展示搜索过程 + +## 故障排除 + +### 常见问题 + +1. **ImportError: No module named 'sirchmunk'** + ```bash + pip install sirchmunk + ``` + +2. **搜索结果为空** + - 检查 `paths` 配置是否正确 + - 确保路径下有可搜索的文件 + - 尝试降低 `cluster_sim_threshold` 值 + +3. **LLM API 调用失败** + - 检查 API key 是否正确 + - 检查 base URL 是否正确 + - 查看搜索日志了解详细错误 + +### 日志查看 + +```python +# 查看搜索日志 +logs = searcher.get_search_logs() +for log in logs: + print(log) + +# 或在配置中启用 verbose +knowledge_search: + verbose: true +``` + +## 参考资源 + +- [sirchmunk GitHub](https://github.com/modelscope/sirchmunk) +- [ModelScope Agent](https://github.com/modelscope/modelscope-agent) diff --git a/ms_agent/knowledge_search/__init__.py b/ms_agent/knowledge_search/__init__.py new file mode 100644 index 000000000..33362beee --- /dev/null +++ b/ms_agent/knowledge_search/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Knowledge search module based on sirchmunk. + +This module provides integration between sirchmunk's AgenticSearch +and the ms_agent framework, enabling intelligent codebase search +capabilities similar to RAG. +""" + +from .sirchmunk_search import SirchmunkSearch + +__all__ = ['SirchmunkSearch'] diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/knowledge_search/sirchmunk_search.py new file mode 100644 index 000000000..4e1e322a5 --- /dev/null +++ b/ms_agent/knowledge_search/sirchmunk_search.py @@ -0,0 +1,401 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Sirchmunk-based knowledge search integration. + +This module wraps sirchmunk's AgenticSearch to work with the ms_agent framework, +providing document retrieval capabilities similar to RAG but optimized for +codebase and documentation search. +""" + +import asyncio +from pathlib import Path +from typing import Any, Dict, List, Optional, Union +from loguru import logger + +from ms_agent.rag.base import RAG +from omegaconf import DictConfig + + +class SirchmunkSearch(RAG): + """Sirchmunk-based knowledge search class. + + This class wraps the sirchmunk library to provide intelligent codebase search + capabilities. Unlike traditional RAG that uses vector embeddings, Sirchmunk + uses a combination of keyword search, semantic clustering, and LLM-powered + analysis to find relevant information from codebases. + + The configuration needed in the config yaml: + - name: SirchmunkSearch + - paths: List of paths to search, required + - work_path: Working directory for sirchmunk cache, default './.sirchmunk' + - embedding_model: Embedding model for clustering, default 'text-embedding-3-small' + - cluster_sim_threshold: Threshold for cluster similarity, default 0.85 + - cluster_sim_top_k: Top K clusters to consider, default 3 + - reuse_knowledge: Whether to reuse previous search results, default True + - mode: Search mode (DEEP, FAST, FILENAME_ONLY), default 'FAST' + + Args: + config (DictConfig): Configuration object containing sirchmunk settings. + """ + + def __init__(self, config: DictConfig): + super().__init__(config) + + self._validate_config(config) + + # Extract configuration parameters + rag_config = config.get('knowledge_search', {}) + + # Search paths - required + paths = rag_config.get('paths', []) + if isinstance(paths, str): + paths = [paths] + self.search_paths: List[str] = [str(Path(p).expanduser().resolve()) for p in paths] + + # Work path for sirchmunk cache + _work_path = rag_config.get('work_path', './.sirchmunk') + self.work_path: Path = Path(_work_path).expanduser().resolve() + + # Sirchmunk search parameters + self.reuse_knowledge = rag_config.get('reuse_knowledge', True) + self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', 0.85) + self.cluster_sim_top_k = rag_config.get('cluster_sim_top_k', 3) + self.search_mode = rag_config.get('mode', 'FAST') + self.max_loops = rag_config.get('max_loops', 10) + self.max_token_budget = rag_config.get('max_token_budget', 128000) + + # LLM configuration for sirchmunk + # First try knowledge_search.llm_api_key, then fall back to main llm config + self.llm_api_key = rag_config.get('llm_api_key', None) + self.llm_base_url = rag_config.get('llm_base_url', None) + self.llm_model_name = rag_config.get('llm_model_name', None) + + # Fall back to main llm config if not specified in knowledge_search + if self.llm_api_key is None or self.llm_base_url is None or self.llm_model_name is None: + llm_config = config.get('llm', {}) + if llm_config: + service = getattr(llm_config, 'service', 'dashscope') + if self.llm_api_key is None: + self.llm_api_key = getattr(llm_config, f'{service}_api_key', None) + if self.llm_base_url is None: + self.llm_base_url = getattr(llm_config, f'{service}_base_url', None) + if self.llm_model_name is None: + self.llm_model_name = getattr(llm_config, 'model', None) + + # Embedding model configuration + self.embedding_model_id = rag_config.get('embedding_model', None) + self.embedding_model_cache_dir = rag_config.get('embedding_model_cache_dir', None) + + # Runtime state + self._searcher = None + self._initialized = False + + # Callback for capturing logs + self._log_callback = None + self._search_logs: List[str] = [] + + def _validate_config(self, config: DictConfig): + """Validate configuration parameters.""" + if not hasattr(config, 'knowledge_search') or config.knowledge_search is None: + raise ValueError( + 'Missing knowledge_search configuration. ' + 'Please add knowledge_search section to your config with at least "paths" specified.' + ) + + rag_config = config.knowledge_search + paths = rag_config.get('paths', []) + if not paths: + raise ValueError('knowledge_search.paths must be specified and non-empty') + + def _initialize_searcher(self): + """Initialize the sirchmunk AgenticSearch instance.""" + if self._initialized: + return + + try: + from sirchmunk.search import AgenticSearch + from sirchmunk.llm.openai_chat import OpenAIChat + from sirchmunk.utils.embedding_util import EmbeddingUtil + + # Create LLM client + llm = OpenAIChat( + api_key=self.llm_api_key, + base_url=self.llm_base_url, + model=self.llm_model_name, + max_retries=3, + log_callback=self._log_callback_wrapper(), + ) + + # Create embedding util + # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) + embedding_model_id = self.embedding_model_id if self.embedding_model_id else None + embedding_cache_dir = self.embedding_model_cache_dir if self.embedding_model_cache_dir else None + embedding = EmbeddingUtil(model_id=embedding_model_id, cache_dir=embedding_cache_dir) + + # Create AgenticSearch instance + self._searcher = AgenticSearch( + llm=llm, + embedding=embedding, + work_path=str(self.work_path), + paths=self.search_paths, + verbose=True, + reuse_knowledge=self.reuse_knowledge, + cluster_sim_threshold=self.cluster_sim_threshold, + cluster_sim_top_k=self.cluster_sim_top_k, + log_callback=self._log_callback_wrapper(), + ) + + self._initialized = True + logger.info(f'SirschmunkSearch initialized with paths: {self.search_paths}') + + except ImportError as e: + raise ImportError( + f'Failed to import sirchmunk: {e}. ' + 'Please install sirchmunk: pip install sirchmunk' + ) + except Exception as e: + raise RuntimeError(f'Failed to initialize SirchmunkSearch: {e}') + + def _log_callback_wrapper(self): + """Create a callback wrapper to capture search logs.""" + def log_callback(message: str, level: str = 'INFO', logger_name: str = '', is_async: bool = False): + self._search_logs.append(f'[{level}] {message}') + + return log_callback + + async def add_documents(self, documents: List[str]) -> bool: + """Add documents to the search index. + + Note: Sirchmunk works by scanning existing files in the specified paths. + This method is provided for RAG interface compatibility but doesn't + directly add documents. Instead, documents should be saved to files + within the search paths. + + Args: + documents (List[str]): List of document contents to add. + + Returns: + bool: True if successful (for interface compatibility). + """ + logger.warning( + 'SirchmunkSearch does not support direct document addition. ' + 'Documents should be saved to files within the configured search paths.' + ) + # Trigger re-scan of the search paths + if self._searcher and hasattr(self._searcher, 'knowledge_base'): + try: + await self._searcher.knowledge_base.refresh() + return True + except Exception as e: + logger.error(f'Failed to refresh knowledge base: {e}') + return False + return True + + async def add_documents_from_files(self, file_paths: List[str]) -> bool: + """Add documents from file paths. + + Args: + file_paths (List[str]): List of file paths to scan. + + Returns: + bool: True if successful. + """ + self._initialize_searcher() + + if self._searcher and hasattr(self._searcher, 'scan_directory'): + try: + for file_path in file_paths: + if Path(file_path).exists(): + await self._searcher.scan_directory(str(Path(file_path).parent)) + return True + except Exception as e: + logger.error(f'Failed to scan files: {e}') + return False + return True + + async def retrieve(self, + query: str, + limit: int = 5, + score_threshold: float = 0.7, + **filters) -> List[Dict[str, Any]]: + """Retrieve relevant documents using sirchmunk. + + Args: + query (str): The search query. + limit (int): Maximum number of results to return. + score_threshold (float): Minimum relevance score threshold. + **filters: Additional filters (mode, max_loops, etc.). + + Returns: + List[Dict[str, Any]]: List of search results with 'text', 'score', + 'metadata' fields. + """ + self._initialize_searcher() + self._search_logs.clear() + + try: + mode = filters.get('mode', self.search_mode) + max_loops = filters.get('max_loops', self.max_loops) + max_token_budget = filters.get('max_token_budget', self.max_token_budget) + + # Perform search + result = await self._searcher.search( + query=query, + mode=mode, + max_loops=max_loops, + max_token_budget=max_token_budget, + return_context=True, + ) + + # Parse results into standard format + return self._parse_search_result(result, score_threshold, limit) + + except Exception as e: + logger.error(f'SirschmunkSearch retrieve failed: {e}') + return [] + + async def query(self, query: str) -> str: + """Query sirchmunk and return a synthesized answer. + + This method performs a search and returns the LLM-synthesized answer + along with search details that can be used for frontend display. + + Args: + query (str): The search query. + + Returns: + str: The synthesized answer from sirchmunk. + """ + self._initialize_searcher() + self._search_logs.clear() + + try: + mode = self.search_mode + max_loops = self.max_loops + max_token_budget = self.max_token_budget + + # Perform search and get answer + result = await self._searcher.search( + query=query, + mode=mode, + max_loops=max_loops, + max_token_budget=max_token_budget, + return_context=False, + ) + + # Result is already a synthesized answer string + if isinstance(result, str): + return result + + # If we got SearchContext or other format, extract the answer + if hasattr(result, 'answer'): + return result.answer + + # Fallback: convert to string + return str(result) + + except Exception as e: + logger.error(f'SirschmunkSearch query failed: {e}') + return f'Query failed: {e}' + + def _parse_search_result(self, + result: Any, + score_threshold: float, + limit: int) -> List[Dict[str, Any]]: + """Parse sirchmunk search result into standard format. + + Args: + result: The raw search result from sirchmunk. + score_threshold: Minimum score threshold. + limit: Maximum number of results. + + Returns: + List[Dict[str, Any]]: Parsed results. + """ + results = [] + + # Handle SearchContext format (returned when return_context=True) + if hasattr(result, 'cluster') and result.cluster is not None: + cluster = result.cluster + for unit in cluster.evidences: + # Extract score from snippets if available + score = getattr(cluster, 'confidence', 1.0) + if score >= score_threshold: + # Extract text from snippets + text_parts = [] + source = str(getattr(unit, 'file_or_url', 'unknown')) + for snippet in getattr(unit, 'snippets', []): + if isinstance(snippet, dict): + text_parts.append(snippet.get('snippet', '')) + else: + text_parts.append(str(snippet)) + + results.append({ + 'text': '\n'.join(text_parts) if text_parts else getattr(unit, 'summary', ''), + 'score': score, + 'metadata': { + 'source': source, + 'type': getattr(unit, 'abstraction_level', 'text') if hasattr(unit, 'abstraction_level') else 'text', + } + }) + + # Handle format with evidence_units attribute directly + elif hasattr(result, 'evidence_units'): + for unit in result.evidence_units: + score = getattr(unit, 'confidence', 1.0) + if score >= score_threshold: + results.append({ + 'text': str(unit.content) if hasattr(unit, 'content') else str(unit), + 'score': score, + 'metadata': { + 'source': getattr(unit, 'source_file', 'unknown'), + 'type': getattr(unit, 'abstraction_level', 'text'), + } + }) + + # Handle list format + elif isinstance(result, list): + for item in result: + if isinstance(item, dict): + score = item.get('score', item.get('confidence', 1.0)) + if score >= score_threshold: + results.append({ + 'text': item.get('content', item.get('text', str(item))), + 'score': score, + 'metadata': item.get('metadata', {}), + }) + + # Handle dict format + elif isinstance(result, dict): + score = result.get('score', result.get('confidence', 1.0)) + if score >= score_threshold: + results.append({ + 'text': result.get('content', result.get('text', str(result))), + 'score': score, + 'metadata': result.get('metadata', {}), + }) + + # Sort by score and limit results + results.sort(key=lambda x: x.get('score', 0), reverse=True) + return results[:limit] + + def get_search_logs(self) -> List[str]: + """Get the captured search logs. + + Returns: + List[str]: List of log messages from the search operation. + """ + return self._search_logs.copy() + + def get_search_details(self) -> Dict[str, Any]: + """Get detailed search information including logs and metadata. + + Returns: + Dict[str, Any]: Search details including logs, mode, and paths. + """ + return { + 'logs': self._search_logs.copy(), + 'mode': self.search_mode, + 'paths': self.search_paths, + 'work_path': str(self.work_path), + 'reuse_knowledge': self.reuse_knowledge, + } diff --git a/ms_agent/llm/dashscope_llm.py b/ms_agent/llm/dashscope_llm.py index af766f679..b4a6ddaa8 100644 --- a/ms_agent/llm/dashscope_llm.py +++ b/ms_agent/llm/dashscope_llm.py @@ -12,7 +12,7 @@ class DashScope(OpenAI): def __init__(self, config: DictConfig): super().__init__( config, - base_url=config.llm.modelscope_base_url + base_url=config.llm.dashscope_base_url or get_service_config('dashscope').base_url, api_key=config.llm.dashscope_api_key) diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 6a336ca6e..410aa12f0 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -61,6 +61,12 @@ class Message: api_calls: int = 1 + # Knowledge search (sirchmunk) related fields + # searching_detail: Search process logs and metadata for frontend display + searching_detail: Dict[str, Any] = field(default_factory=dict) + # search_result: Raw search results to be used as context for next LLM turn + search_result: List[Dict[str, Any]] = field(default_factory=list) + def to_dict(self): return asdict(self) diff --git a/ms_agent/rag/utils.py b/ms_agent/rag/utils.py index 08e9a4db7..e66da954d 100644 --- a/ms_agent/rag/utils.py +++ b/ms_agent/rag/utils.py @@ -4,3 +4,6 @@ rag_mapping = { 'LlamaIndexRAG': LlamaIndexRAG, } + +# Note: SirchmunkSearch is registered in knowledge_search module +# and integrated directly in LLMAgent, not through rag_mapping diff --git a/tests/knowledge_search/__init__.py b/tests/knowledge_search/__init__.py new file mode 100644 index 000000000..0cc40e613 --- /dev/null +++ b/tests/knowledge_search/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Knowledge search tests.""" diff --git a/tests/knowledge_search/test_sirschmunk.py b/tests/knowledge_search/test_sirschmunk.py new file mode 100644 index 000000000..5a4f43213 --- /dev/null +++ b/tests/knowledge_search/test_sirschmunk.py @@ -0,0 +1,203 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for SirchmunkSearch knowledge search integration via LLMAgent. + +These tests verify the sirchmunk-based knowledge search functionality +through the LLMAgent entry point, including verification that +search_result and searching_detail fields are properly populated. + +To run these tests, you need to set the following environment variables: + - TEST_LLM_API_KEY: Your LLM API key + - TEST_LLM_BASE_URL: Your LLM API base URL (optional, default: OpenAI) + - TEST_LLM_MODEL_NAME: Your LLM model name (optional) + - TEST_EMBEDDING_MODEL_ID: Embedding model ID (optional) + - TEST_EMBEDDING_MODEL_CACHE_DIR: Embedding model cache directory (optional) + +Example: + export TEST_LLM_API_KEY="your-api-key" + export TEST_LLM_BASE_URL="https://api.openai.com/v1" + export TEST_LLM_MODEL_NAME="gpt-4o" + export TEST_EMBEDDING_MODEL_ID="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + export TEST_EMBEDDING_MODEL_CACHE_DIR="/tmp/embedding_cache" + python -m pytest tests/knowledge_search/test_sirschmunk.py +""" +import asyncio +import os +import shutil +import unittest +from pathlib import Path + +from ms_agent.knowledge_search import SirchmunkSearch +from ms_agent.agent import LLMAgent +from ms_agent.config import Config +from omegaconf import DictConfig + +from modelscope.utils.test_utils import test_level + + +class SirchmunkLLMAgentIntegrationTest(unittest.TestCase): + """Test cases for SirchmunkSearch integration with LLMAgent. + + These tests verify that when LLMAgent runs a query that triggers + knowledge search, the Message objects have search_result and + searching_detail fields properly populated. + """ + + @classmethod + def setUpClass(cls): + """Set up test fixtures.""" + # Create test directory with sample files + cls.test_dir = Path('./test_llm_agent_knowledge') + cls.test_dir.mkdir(exist_ok=True) + + # Create sample documentation + (cls.test_dir / 'README.md').write_text(''' +# Test Project Documentation + +## Overview +This is a test project for knowledge search integration. + +## API Reference + +### UserManager +The UserManager class handles user operations: +- create_user: Create a new user account +- delete_user: Delete an existing user +- update_user: Update user information +- get_user: Retrieve user details + +### AuthService +The AuthService class handles authentication: +- login: Authenticate user credentials +- logout: End user session +- refresh_token: Refresh authentication token +- verify_token: Validate authentication token +''') + + (cls.test_dir / 'config.py').write_text(''' +"""Configuration module.""" + +class Config: + """Application configuration.""" + + def __init__(self): + self.database_url = "postgresql://localhost:5432/mydb" + self.secret_key = "your-secret-key" + self.debug_mode = False + + def load_from_env(self): + """Load configuration from environment variables.""" + import os + self.database_url = os.getenv("DATABASE_URL", self.database_url) + self.secret_key = os.getenv("SECRET_KEY", self.secret_key) + return self +''') + + @classmethod + def tearDownClass(cls): + """Clean up test fixtures.""" + if cls.test_dir.exists(): + shutil.rmtree(cls.test_dir, ignore_errors=True) + work_dir = Path('./.sirchmunk') + if work_dir.exists(): + shutil.rmtree(work_dir, ignore_errors=True) + + def _get_agent_config(self): + """Create agent configuration with knowledge search.""" + llm_api_key = os.getenv('TEST_LLM_API_KEY', 'test-api-key') + llm_base_url = os.getenv('TEST_LLM_BASE_URL', 'https://api.openai.com/v1') + llm_model_name = os.getenv('TEST_LLM_MODEL_NAME', 'gpt-4o-mini') + # Read from TEST_* env vars (for test-specific config) + # These can be set from .env file which uses TEST_* prefix + embedding_model_id = os.getenv('TEST_EMBEDDING_MODEL_ID', '') + embedding_model_cache_dir = os.getenv('TEST_EMBEDDING_MODEL_CACHE_DIR', '') + + config = DictConfig({ + 'llm': { + 'service': 'openai', + 'model': llm_model_name, + 'openai_api_key': llm_api_key, + 'openai_base_url': llm_base_url, + }, + 'generation_config': { + 'temperature': 0.3, + 'max_tokens': 500, + }, + 'knowledge_search': { + 'name': 'SirchmunkSearch', + 'paths': [str(self.test_dir)], + 'work_path': './.sirchmunk', + 'llm_api_key': llm_api_key, + 'llm_base_url': llm_base_url, + 'llm_model_name': llm_model_name, + 'embedding_model': embedding_model_id, + 'embedding_model_cache_dir': embedding_model_cache_dir, + 'mode': 'FAST', + } + }) + return config + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_llm_agent_with_knowledge_search(self): + """Test LLMAgent using knowledge search. + + This test verifies that: + 1. LLMAgent can be initialized with SirchmunkSearch configuration + 2. Running a query produces a valid response + 3. User message has searching_detail and search_result populated + 4. searching_detail contains expected keys (logs, mode, paths) + 5. search_result is a list + """ + config = self._get_agent_config() + agent = LLMAgent(config=config, tag='test-knowledge-agent') + + # Test query that should trigger knowledge search + query = 'How do I use UserManager to create a user?' + + async def run_agent(): + result = await agent.run(query) + return result + + result = asyncio.run(run_agent()) + + # Verify result + self.assertIsNotNone(result) + self.assertIsInstance(result, list) + self.assertTrue(len(result) > 0) + + # Check that assistant message exists + assistant_message = [m for m in result if m.role == 'assistant'] + self.assertTrue(len(assistant_message) > 0) + + # Check that user message has search_result and searching_detail populated + user_messages = [m for m in result if m.role == 'user'] + self.assertTrue(len(user_messages) > 0, "Expected at least one user message") + + # The first user message should have search details after do_rag processing + user_msg = user_messages[0] + self.assertTrue( + hasattr(user_msg, 'searching_detail'), + "User message should have searching_detail attribute" + ) + self.assertTrue( + hasattr(user_msg, 'search_result'), + "User message should have search_result attribute" + ) + + # Check that searching_detail is a dict with expected keys + self.assertIsInstance( + user_msg.searching_detail, dict, + "searching_detail should be a dictionary" + ) + self.assertIn('logs', user_msg.searching_detail) + self.assertIn('mode', user_msg.searching_detail) + self.assertIn('paths', user_msg.searching_detail) + + # Check that search_result is a list (may be empty if no relevant docs found) + self.assertIsInstance( + user_msg.search_result, list, + "search_result should be a list" + ) + + +if __name__ == '__main__': + unittest.main() From 05cb67664723427f7a21738e7466312eff083a18 Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Mon, 16 Mar 2026 00:59:53 +0800 Subject: [PATCH 04/40] refactor: optimize architecture, restrict researcher report edits, update reporter delivery flow, and improve the quality check module (54.51) --- ms_agent/tools/filesystem_tool.py | 49 ++- .../v2/callbacks/quality_checker.py | 180 +++++++++++ .../v2/callbacks/reporter_callback.py | 298 +++++++++++++++--- .../v2/callbacks/researcher_callback.py | 243 +++----------- .../deep_research/v2/eval/dr_bench_runner.py | 71 ++++- .../v2/prompts/reporter/en/gpt5.txt | 31 +- .../v2/prompts/researcher/en/gpt5.txt | 27 +- projects/deep_research/v2/reporter.yaml | 9 + projects/deep_research/v2/researcher.yaml | 14 +- 9 files changed, 632 insertions(+), 290 deletions(-) create mode 100644 projects/deep_research/v2/callbacks/quality_checker.py diff --git a/ms_agent/tools/filesystem_tool.py b/ms_agent/tools/filesystem_tool.py index 837810818..a107a7c59 100644 --- a/ms_agent/tools/filesystem_tool.py +++ b/ms_agent/tools/filesystem_tool.py @@ -336,8 +336,14 @@ async def _get_tools_inner(self): server_name='file_system', description= 'Replace specific line ranges in a file. Supports inserting at beginning ' - '(start_line=0) or end (start_line=-1). ' - 'Line numbers are 1-based and inclusive on both ends.', + '(start_line=0) or end (start_line=-1). Line numbers are 1-based and inclusive on both ends.\n\n' + 'IMPORTANT — Line-number shift after each call. Every replacement changes the total line count, ' + 'which invalidates ALL line numbers after the replaced range. If you need to make multiple replacements in the same file:\n' + '- Option A (recommended): Work from BOTTOM to TOP — edit the largest line numbers first so earlier line numbers remain valid.\n' + '- Option B: Re-search after each replacement to get updated line numbers before the next replacement.\n' + '- Option C: Pre-calculate the cumulative offset — each replacement shifts subsequent lines by (new_content_lines - replaced_lines).\n' + 'NEVER call this tool multiple times in parallel on the same file — the concurrent line-number ' + 'shifts will corrupt the file. Always call sequentially.\n', parameters={ 'type': 'object', 'properties': { @@ -374,11 +380,15 @@ async def _get_tools_inner(self): server_name='file_system', description= 'Replace exact content in a file without using line numbers. ' - 'You must provide:' + 'You must provide:\n' '[Required]path: The relative path of modified file.\n' - '[Required]source: The old content to be replaced\n' - '[Required]target: The new content to replace the `source`\n' - 'Do not miss any of these arguments!', + '[Required]source: The old content to be replaced.\n' + '[Required]target: The new content to replace the `source`.\n' + '[Required]occurrence: Which occurrence to replace (1-based).\n' + 'Do not miss any of these arguments!\n\n' + 'IMPORTANT:\n' + '- `source` must match the file content EXACTLY — including punctuation style ' + '(e.g., Chinese "、" vs English ","), whitespace, line breaks, and Unicode characters.', parameters={ 'type': 'object', 'properties': { @@ -392,7 +402,8 @@ async def _get_tools_inner(self): 'type': 'string', 'description': - 'The exact content to find and replace (must match exactly including whitespace)', + 'The exact content to find and replace. Must match the file content ' + 'EXACTLY including all whitespace, punctuation, and line breaks. ', }, 'target': { 'type': 'string', @@ -403,11 +414,11 @@ async def _get_tools_inner(self): 'type': 'integer', 'description': - 'Which occurrence to replace (1-based). Use -1 to replace all occurrences. ' - 'Default is -1 (all occurrences).', + 'Which occurrence to replace (1-based). Default is 1 (first occurrence). ' + 'Use -1 to replace all occurrences.', }, }, - 'required': ['path', 'source', 'target'], + 'required': ['path', 'source', 'target', 'occurrence'], 'additionalProperties': False }), ] @@ -469,7 +480,7 @@ async def replace_file_contents(self, path: str, source: str = None, target: str = None, - occurrence: int = -1): + occurrence: int = 1): """Replace exact content in a file without using line numbers. This method is safer for parallel operations as it doesn't rely on line numbers @@ -480,16 +491,16 @@ async def replace_file_contents(self, source(str): The exact content to find and replace (must match exactly including whitespace) target(str): The new content to replace with occurrence(int): Which occurrence to replace (1-based). Use -1 to replace all occurrences. - Default is -1 (all occurrences). + Default is 1 (first occurrence). Returns: Success or error message. """ try: if not source: - return 'Error: You MUST provide the `source` parameter to be replaced with the `target`.' - if not target: - return 'Error: You MUST provide the `target` parameter to replace the `source`' + return f'Error: You MUST provide the `source` parameter to be replaced with the `target`, but got {source}.' + if target is None: + return f'Error: You MUST provide the `target` parameter to replace the `source`, but got {target}.' target_path_real = self.get_real_path(path) if target_path_real is None: return f'<{path}> is out of the valid project path: {self.output_dir}' @@ -612,7 +623,11 @@ async def replace_file_lines(self, f.writelines(new_lines) target = '\n'.join(new_lines).split('\n') - return f'{operation} in file <{path}> successfully. New file has {len(target)} lines.' + return ( + f'{operation} in file <{path}> completed successfully. The updated file now has {len(target)} lines. ' + 'WARNING: All line numbers after the replaced range may have shifted. ' + 'If you need to make another line-based replacement in this file, keep this in mind.' + ) except Exception as e: return f'Replace lines in file <{path}> failed, error: ' + str(e) @@ -842,7 +857,7 @@ async def search_file_name(self, file: str = '', parent_path: str = ''): async def search_file_content(self, content: str = None, - parent_path: str = None, + parent_path: str = '.', file_pattern: str = '*', context_lines: int = 2): """Search for content in files using thread pool. diff --git a/projects/deep_research/v2/callbacks/quality_checker.py b/projects/deep_research/v2/callbacks/quality_checker.py new file mode 100644 index 000000000..36fadf902 --- /dev/null +++ b/projects/deep_research/v2/callbacks/quality_checker.py @@ -0,0 +1,180 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from abc import ABC, abstractmethod +from typing import List, Optional + +import json +from ms_agent.llm.openai_llm import OpenAI as OpenAILLM +from ms_agent.llm.utils import Message +from ms_agent.utils import get_logger +from omegaconf import DictConfig, OmegaConf + +logger = get_logger() + + +class ReportQualityChecker(ABC): + """Interface for pluggable report quality checkers. + + Subclasses implement a single ``check`` method. Multiple checkers can + be chained in sequence by ``ResearcherCallback``; the first one that + returns a non-``None`` failure stops the chain. + """ + + @abstractmethod + async def check(self, content: str, lang: str) -> Optional[str]: + """Evaluate report quality. + + Args: + content: Full text of the report file. + lang: Language code (``"en"`` or ``"zh"``). + + Returns: + A short failure-reason string (e.g. ``"placeholder_content"``) + if the report fails this check, or ``None`` if it passes. + """ + + +class ModelQualityChecker(ReportQualityChecker): + """LLM-based report quality checker. + + Uses a lightweight model (configured via ``quality_check.model`` in + the YAML) to detect reports whose body has been largely replaced by + placeholders, abbreviations, or cross-references to external files. + + The checker sends a structured prompt asking the model to return a + JSON verdict: ``{"pass": true/false, "reason": "..."}``. + """ + + _SYSTEM_PROMPTS = { + 'en': + ('You are a strict report quality auditor. Your ONLY job is to detect whether a research report violates any of the rules listed below.\n' + 'You MUST check ONLY against these rules — do NOT invent additional criteria or penalize anything not explicitly listed here.\n' + 'If a problem is NOT described by rules below, you MUST ignore it and return {"pass": true}. ' + 'Specifically: duplicate/repeated content, heading numbering gaps, structural ordering issues, stylistic choices, ' + 'and the density of inline citations within otherwise substantive paragraphs are all OUT OF SCOPE and must NOT cause a failure.\n\n' + 'RULES — flag the report ONLY if ANY of the following are clearly found:\n' + '1. Sections where detailed content has been replaced by ellipsis or brevity markers such as "...for brevity", ' + '"Content truncated for brevity", "omitted for brevity", "(remaining content follows the same pattern)", etc.\n' + '2. Sections that refer the reader to an external file instead of containing actual content, e.g. "This section ' + 'is stored in xxx file", "See full analysis in evidence/xxx".\n' + '3. Sections that guide the reader to view the reference source instead of writing substantive content, e.g. "See [1]", "Reference [2]".\n' + '4. Multiple reference/bibliography sections appear in the report (e.g., per-chapter reference lists), or any ' + 'variant heading such as "## References (Merged)", "## 参考文献(合并版)", "## 参考资料", etc. ' + 'Only one unified reference section at the very end is allowed.\n\n' + 'OUTPUT FORMAT:\n' + 'Respond with EXACTLY one JSON object. No markdown fences, no explanation outside the JSON.\n' + '{"pass": true} or {"pass": false, "reason": ""}\n' + 'Do NOT output anything else.'), + 'zh': + ('你是一个严格的研究报告质量审核员,你唯一的任务是判断报告是否违反了下方列出的规则。\n' + '你只能依据以下规则进行检查,不得自行发明额外标准,也不得基于规则未涉及的内容判定不通过。如果某个问题不属于下方规则的任何一条,你必须忽略它并返回 {"pass": true}。\n' + '特别说明:重复/相似内容、标题编号跳跃、章节结构顺序问题、文体风格选择、以及在有实质论述的段落中密集使用行内引注,都不在检查范围内,不得因此判定不通过。\n\n' + '规则 — 仅当明确发现以下任一问题时才判定不通过:\n' + '1. 正文被省略号或缩略标记替代,如"此处省略"、"篇幅所限不再展开"、"……以下类似"、"内容已截断"、"...for brevity"、"omitted for brevity"等。\n' + '2. 正文引导读者查看外部文件而非包含实际内容,如"该部分内容保存在xxx文件中"、"详见附件"、"See full analysis in evidence/xxx"。\n' + '3. 正文引导读者查看引用来源而没有撰写实质性内容,如"详见[1]"、"参考[2]"。\n' + '4. 报告中出现多个参考文献/引用列表章节(如各章节末尾的独立引用列表),或使用变体标题如"## 参考文献(合并版)"、"## 参考资料"、"## References (Merged)"等。' + '报告仅允许在末尾保留唯一一个统一的参考文献章节。\n\n' + '输出格式:\n' + '只返回一个JSON对象,不要使用markdown代码块,不要在JSON之外输出任何文字。\n' + '{"pass": true} 或者 {"reason": "<不得超过三句话;引用具体违反的规则编号>", "pass": false}\n' + '不要输出任何其他内容。'), + } + + _USER_TEMPLATES = { + 'en': + ('Please audit the following research report against the rules provided in the system instruction.\n\n' + '---BEGIN REPORT---\n{report}\n---END REPORT---'), + 'zh': ('请依据系统指令中提供的规则审核以下研究报告。\n\n' + '---报告开始---\n{report}\n---报告结束---'), + } + + _MAX_REPORT_CHARS = 80000 + + def __init__(self, config: DictConfig): + self._config = config + qc_cfg = getattr(config, 'self_reflection', DictConfig({})) + qc_cfg = getattr(qc_cfg, 'quality_check', DictConfig({})) + + self._model: str = str(getattr(qc_cfg, 'model', 'qwen3.5-plus')) + self._api_key: Optional[str] = getattr( + qc_cfg, 'openai_api_key', None) or getattr(config.llm, + 'openai_api_key', None) + self._base_url: Optional[str] = getattr( + qc_cfg, 'openai_base_url', None) or getattr( + config.llm, 'openai_base_url', None) + + self._client: Optional[OpenAILLM] = None + + def _build_llm_config(self) -> DictConfig: + """Build lightweight llm config for quality checker.""" + return OmegaConf.create({ + 'llm': { + 'model': self._model, + 'openai_api_key': self._api_key, + 'openai_base_url': self._base_url, + }, + 'generation_config': {}, + }) + + def _ensure_client(self): + if self._client is not None: + return + self._client = OpenAILLM(self._build_llm_config()) + + async def check(self, content: str, lang: str) -> Optional[str]: + self._ensure_client() + + report_text = content + if len(report_text) > self._MAX_REPORT_CHARS: + report_text = report_text[:self._MAX_REPORT_CHARS] + + sys_prompt = self._SYSTEM_PROMPTS.get(lang, self._SYSTEM_PROMPTS['en']) + usr_template = self._USER_TEMPLATES.get(lang, + self._USER_TEMPLATES['en']) + + try: + response = self._client.generate(messages=[ + Message(role='system', content=sys_prompt), + Message( + role='user', + content=usr_template.format(report=report_text), + ), + ]) + raw = (response.content or '').strip() + logger.info( + f'ModelQualityChecker ({self._model}): raw response: {raw}') + + verdict = json.loads(raw) + if verdict.get('pass', True): + return None + return verdict.get('reason', 'placeholder_content') + + except json.JSONDecodeError: + logger.warning(f'ModelQualityChecker: failed to parse JSON from ' + f'model response: {raw!r}') + return None + except Exception as exc: + logger.warning(f'ModelQualityChecker: model call failed: {exc}') + return None + + +def build_quality_checkers(config: DictConfig) -> List[ReportQualityChecker]: + """Instantiate the quality-checker chain from config. + + Reads ``config.self_reflection.quality_check`` and returns a list of + checker instances. Currently only ``ModelQualityChecker`` is + supported; new checker types can be added here. + """ + refl_cfg = getattr(config, 'self_reflection', None) + if refl_cfg is None: + return [] + + qc_cfg = getattr(refl_cfg, 'quality_check', None) + if qc_cfg is None or not bool(getattr(qc_cfg, 'enabled', False)): + return [] + + checkers: List[ReportQualityChecker] = [] + checkers.append(ModelQualityChecker(config)) + logger.info( + f'Quality checker chain initialised with {len(checkers)} checker(s).') + return checkers diff --git a/projects/deep_research/v2/callbacks/reporter_callback.py b/projects/deep_research/v2/callbacks/reporter_callback.py index 5bb8066f0..b86cde033 100644 --- a/projects/deep_research/v2/callbacks/reporter_callback.py +++ b/projects/deep_research/v2/callbacks/reporter_callback.py @@ -1,9 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +# yapf: disable import os import re +import shutil from typing import Any, Dict, List, Optional, Set import json +from callbacks.quality_checker import (ReportQualityChecker, + build_quality_checkers) from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message @@ -20,16 +24,20 @@ class ReporterCallback(Callback): Responsibilities: - on_task_begin: Clean up system prompt formatting and load researcher trajectory - - on_task_end: Save the final report to file + - on_generate_response: Inject round-aware reminder near max rounds + - after_tool_call: Pre-completion quality checks (report existence, + length retention vs draft, model-based content audit) + - on_task_end: Promote the best report to final_report.md and save JSON summary """ - # The tag of the main researcher agent whose history we want to load RESEARCHER_TAG = 'deep-research-researcher' - - # Tool names to exclude from trajectory (reporter_tool calls and their responses) EXCLUDED_TOOL_PATTERNS = ['reporter_tool'] - # Bilingual round-reminder templates keyed by language code. + DRAFT_FILENAME = 'draft.md' + REPORT_FILENAME = 'report.md' + FINAL_REPORT_FILENAME = 'final_report.md' + DEFAULT_MIN_RETENTION_RATIO = 0.3 + _ROUND_REMINDER_TEMPLATES = { 'zh': ('你已接近最大允许的对话轮数上限,请立刻开始收敛准备最终交付。\n' @@ -83,6 +91,75 @@ class ReporterCallback(Callback): }, } + _REFLECTION_TEMPLATES = { + 'zh': { + 'no_report': + ('外部检查发现:输出目录中尚未检测到已完成的报告文件 reports/report.md。\n' + '请确认报告写作流程是否已完成。你应当至少完成以下步骤:\n' + '1. 完成所有章节的撰写\n' + '2. 调用 report_generator---assemble_draft 生成报告草稿\n' + '3. 审阅草稿并交付最终版本\n' + '请立即采取行动完成报告交付。'), + 'over_compressed': + ('外部检查发现:reports/{report_name} 的内容量({report_chars} 字符)' + '仅为 reports/draft.md({draft_chars} 字符)的 {ratio:.0%},有可能存在内容丢失风险,请对报告内容进行检查并采取合理的行动。\n' + '**重要提醒**:draft.md 是由工具逐章组装的完整版本,理论上保留了最大的证据保真度。\n' + '- 如果你确认你对 draft.md 进行的修改是合理的,可以直接说明压缩内容的理由,无需再次修改或者重写。\n' + '- 如果你发现 reports/{report_name} 相比 draft.md 确实存在不合理的压缩,请重写并修复这些问题。\n' + '请立即采取行动完成报告交付。'), + 'low_quality': + ('外部检查发现:报告内容存在质量问题——{reason}。\n' + '请仔细确认上述质量问题是否属实、是否还有更多问题,并立即采取行动修复。\n' + '**重要提醒**:如果质量问题属实,你必须完整重写整份报告。' + 'write_file 会完全覆盖文件,你写入的内容就是最终文件的全部内容——' + '以下写法都会原样出现在文件中并导致报告内容被永久丢失:\n' + '- 用省略号或缩略标记替代正文,如"(同之前,略)"、"此处省略"、"篇幅所限不再展开"、' + '"……以下类似"、"内容已截断"、"Content truncated for brevity"等;\n' + '- 引导读者查看外部文件而非包含实际内容,如"该部分内容保存在xxx文件中"等;\n' + '- 引导读者查看引用来源而没有撰写实质性内容,如"详见[1]"等。\n' + '不得遗漏或省略任何章节,无需担心与先前输出的内容或写入过的文件重复。'), + }, + 'en': { + 'no_report': + ('External inspection found that the completed report file reports/report.md ' + 'has not been detected in the output directory.\n' + 'Please confirm whether the report writing workflow has been completed. ' + 'You should have completed at least the following steps:\n' + '1. Finished writing all chapters\n' + '2. Called report_generator---assemble_draft to generate the report draft\n' + '3. Reviewed the draft and delivered the final version\n' + 'Please take immediate action to complete report delivery.'), + 'over_compressed': + ('External inspection found that reports/{report_name} ({report_chars} chars) ' + 'is only {ratio:.0%} of reports/draft.md ({draft_chars} chars), ' + 'indicating a risk of content loss. Please review the report content and take appropriate action.\n' + '**IMPORTANT**: draft.md is the tool-assembled complete version that theoretically ' + 'preserves maximum evidence fidelity.\n' + '- If you confirm that your modifications to draft.md are reasonable, you may simply ' + 'explain the rationale for the compression without further modifications or rewrites.\n' + '- If you find that reports/{report_name} has indeed been unreasonably compressed ' + 'compared to draft.md, please rewrite and fix these issues.\n' + 'Please take immediate action to complete report delivery.'), + 'low_quality': + ('External inspection found quality issues in the report — {reason}.\n' + 'Please carefully verify whether these issues are valid and whether additional ' + 'problems exist, then immediately take action to fix them.\n' + '**IMPORTANT**: If the quality issues are confirmed, you must completely rewrite ' + 'the entire report. write_file will fully overwrite the file — what you write is ' + 'the entire final content of the file. The following patterns will appear verbatim ' + 'in the file and cause permanent loss of report content:\n' + '- Replacing body text with ellipsis or brevity markers, e.g., "(same as before, omitted)", ' + '"omitted here", "not elaborated due to space constraints", ' + '"...similar below", "content truncated", "Content truncated for brevity", etc.;\n' + '- Directing readers to view external files instead of including actual content, ' + 'e.g., "this section is stored in xxx file", etc.;\n' + '- Directing readers to view reference sources without writing substantive content, ' + 'e.g., "see [1]", etc.\n' + 'Do not omit or skip any sections. Do not worry about duplicating content ' + 'from previous outputs or previously written files.'), + }, + } + def __init__(self, config: DictConfig): super().__init__(config) self.output_dir = getattr(config, 'output_dir', './output') @@ -95,10 +172,32 @@ def __init__(self, config: DictConfig): self.reports_dir = getattr(report_cfg, 'reports_dir', 'reports') self.report_path = os.path.join(self.output_dir, self.reports_dir, - 'report.md') - # Resolve language from config for bilingual prompt selection. + self.REPORT_FILENAME) + self.draft_path = os.path.join(self.output_dir, self.reports_dir, + self.DRAFT_FILENAME) + self.final_report_path = os.path.join(self.output_dir, + self.FINAL_REPORT_FILENAME) + self.lang = self._resolve_lang(config) + # Self-reflection config + refl_cfg = getattr(config, 'self_reflection', None) + self.reflection_enabled: bool = False + self.reflection_max_retries: int = 2 + self.min_retention_ratio: float = self.DEFAULT_MIN_RETENTION_RATIO + + if refl_cfg is not None: + self.reflection_enabled = bool(getattr(refl_cfg, 'enabled', False)) + self.reflection_max_retries = int( + getattr(refl_cfg, 'max_retries', 2)) + self.min_retention_ratio = float( + getattr(refl_cfg, 'min_retention_ratio', + self.DEFAULT_MIN_RETENTION_RATIO)) + + self._reflection_retries_used: int = 0 + self._quality_checkers: List[ReportQualityChecker] = ( + build_quality_checkers(config)) + @staticmethod def _resolve_lang(config: DictConfig) -> str: """Resolve language code from config.prompt.lang, defaulting to 'en'.""" @@ -113,6 +212,11 @@ def _resolve_lang(config: DictConfig) -> str: return 'zh' return 'en' + def _get_reflection(self, key: str, **kwargs) -> str: + templates = self._REFLECTION_TEMPLATES.get( + self.lang, self._REFLECTION_TEMPLATES['en']) + return templates[key].format(**kwargs) + def _load_researcher_history(self) -> Optional[List[Dict[str, Any]]]: """ Load the researcher agent's message history from the memory file. @@ -372,6 +476,89 @@ async def on_generate_response(self, runtime: Runtime, messages.append( Message(role='user', content=reminder_mark + injected + '\n')) + async def after_tool_call(self, runtime: Runtime, messages: List[Message]): + """Pre-completion quality checks before allowing the reporter to stop. + + Checks performed (in order, first failure wins): + 1. Report file existence — report.md must exist. + 2. Length retention — if report.md exists alongside draft.md, its + size must be >= ``min_retention_ratio`` of draft.md. + 3. Model quality audit — detects placeholder / abbreviated content. + """ + if not self.reflection_enabled: + return + if not runtime.should_stop: + return + if self._reflection_retries_used >= self.reflection_max_retries: + logger.info('ReporterCallback: reflection retry cap reached ' + f'({self.reflection_max_retries}), allowing stop.') + return + + has_report = os.path.isfile(self.report_path) + has_draft = os.path.isfile(self.draft_path) + + # --- Check 1: report file existence --- + if not has_report: + logger.warning('ReporterCallback: no report found, ' + 'injecting reflection prompt.') + prompt = self._get_reflection('no_report') + messages.append(Message(role='user', content=prompt)) + runtime.should_stop = False + self._reflection_retries_used += 1 + return + + # --- Check 2: length retention --- + if has_report and has_draft: + try: + report_chars = os.path.getsize(self.report_path) + draft_chars = os.path.getsize(self.draft_path) + if draft_chars > 0: + ratio = report_chars / draft_chars + if ratio < self.min_retention_ratio: + logger.warning(f'ReporterCallback: report.md is only ' + f'{ratio:.0%} of draft.md, ' + 'injecting over-compression prompt.') + prompt = self._get_reflection( + 'over_compressed', + report_name=self.REPORT_FILENAME, + report_chars=report_chars, + draft_chars=draft_chars, + ratio=ratio) + messages.append(Message(role='user', content=prompt)) + runtime.should_stop = False + self._reflection_retries_used += 1 + return + except OSError as exc: + logger.warning( + f'ReporterCallback: failed to stat report files: {exc}') + + # --- Check 3: quality checker chain --- + if not self._quality_checkers: + logger.info('ReporterCallback: no quality checkers configured, ' + 'skipping quality gate.') + return + + try: + with open(self.report_path, 'r', encoding='utf-8') as f: + content = f.read() + except Exception as exc: + logger.warning(f'ReporterCallback: failed to read report: {exc}') + return + + for checker in self._quality_checkers: + failure = await checker.check(content, self.lang) + if failure is not None: + logger.warning(f'ReporterCallback: quality check failed ' + f'({type(checker).__name__}: {failure}), ' + 'injecting reflection prompt.') + prompt = self._get_reflection('low_quality', reason=failure) + messages.append(Message(role='user', content=prompt)) + runtime.should_stop = False + self._reflection_retries_used += 1 + return + + logger.info('ReporterCallback: all pre-completion checks passed.') + def _extract_json_from_content(self, content: str) -> Optional[Dict[str, Any]]: """ @@ -407,51 +594,78 @@ def _extract_json_from_content(self, return None - async def on_task_end(self, runtime: Runtime, messages: List[Message]): - """ - Save the final report to file. - Supports both JSON and markdown output formats. + def _select_best_report(self) -> Optional[str]: + """Return the path to the best available report file. + + Prefers ``report.md`` when it exists and passes the length + retention check against ``draft.md``. Falls back to + ``draft.md`` otherwise. """ - if os.path.exists(self.report_path): - logger.info(f'Report already exists at {self.report_path}') - return + has_report = os.path.isfile(self.report_path) + has_draft = os.path.isfile(self.draft_path) - # Find the last assistant message without tool calls + if has_report and has_draft: + try: + report_chars = os.path.getsize(self.report_path) + draft_chars = os.path.getsize(self.draft_path) + if draft_chars > 0: + ratio = report_chars / draft_chars + if ratio < self.min_retention_ratio: + logger.warning( + f'ReporterCallback: report.md ({report_chars} ' + f'chars) is only {ratio:.0%} of draft.md ' + f'({draft_chars} chars). ' + f'Using draft.md as final report source.') + return self.draft_path + except OSError: + pass + return self.report_path + elif has_report: + return self.report_path + elif has_draft: + return self.draft_path + return None + + async def on_task_end(self, runtime: Runtime, messages: List[Message]): + """Promote the best report to final_report.md and save JSON summary.""" + + # --- Step 1: Extract and save JSON summary from last message --- for message in reversed(messages): if message.role == 'assistant' and not message.tool_calls: content = message.content if not content: continue - # Ensure directory exists - os.makedirs(os.path.dirname(self.report_path), exist_ok=True) - - # Try to extract and save JSON result json_result = self._extract_json_from_content(content) if json_result: - # Save the full JSON result + os.makedirs( + os.path.dirname(self.report_path), exist_ok=True) json_path = self.report_path.replace('.md', '.json') - with open(json_path, 'w', encoding='utf-8') as f: - json.dump(json_result, f, ensure_ascii=False, indent=2) - logger.info(f'Reporter: JSON result saved to {json_path}') - - # Also extract and save the Report field as markdown if present - report_content = json_result.get( - 'Report') or json_result.get('report') - if report_content: - with open( - self.report_path, 'w', encoding='utf-8') as f: - f.write(report_content) + try: + with open(json_path, 'w', encoding='utf-8') as f: + json.dump( + json_result, f, ensure_ascii=False, indent=2) logger.info( - f'Reporter: Report content saved to {self.report_path}' - ) - return - - # Fallback: save as markdown if not valid JSON - with open(self.report_path, 'w', encoding='utf-8') as f: - f.write(content) + f'Reporter: JSON result saved to {json_path}') + except Exception as exc: + logger.warning(f'Reporter: failed to save JSON: {exc}') + break + + # --- Step 2: Promote best report to final_report.md --- + best_source = self._select_best_report() + if best_source: + try: + os.makedirs( + os.path.dirname(self.final_report_path), exist_ok=True) + shutil.copy2(best_source, self.final_report_path) + source_name = os.path.basename(best_source) logger.info( - f'Reporter: Final report saved to {self.report_path}') - return - - logger.warning('Reporter: No final report content found in messages') + f'Reporter: promoted {source_name} -> ' + f'{self.FINAL_REPORT_FILENAME} ' + f'({os.path.getsize(self.final_report_path)} bytes)') + except Exception as exc: + logger.warning(f'Reporter: failed to copy report to ' + f'{self.final_report_path}: {exc}') + else: + logger.warning('Reporter: no report file found to promote to ' + f'{self.FINAL_REPORT_FILENAME}') diff --git a/projects/deep_research/v2/callbacks/researcher_callback.py b/projects/deep_research/v2/callbacks/researcher_callback.py index 3f5f87940..4796e2151 100644 --- a/projects/deep_research/v2/callbacks/researcher_callback.py +++ b/projects/deep_research/v2/callbacks/researcher_callback.py @@ -1,162 +1,18 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -from abc import ABC, abstractmethod from typing import List, Optional +from callbacks.quality_checker import (ReportQualityChecker, + build_quality_checkers) from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback -from ms_agent.llm.openai_llm import OpenAI as OpenAILLM from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig logger = get_logger() -class ReportQualityChecker(ABC): - """Interface for pluggable report quality checkers. - - Subclasses implement a single ``check`` method. Multiple checkers can - be chained in sequence by ``ResearcherCallback``; the first one that - returns a non-``None`` failure stops the chain. - """ - - @abstractmethod - async def check(self, content: str, lang: str) -> Optional[str]: - """Evaluate report quality. - - Args: - content: Full text of the report file. - lang: Language code (``"en"`` or ``"zh"``). - - Returns: - A short failure-reason string (e.g. ``"placeholder_content"``) - if the report fails this check, or ``None`` if it passes. - """ - - -class ModelQualityChecker(ReportQualityChecker): - """LLM-based report quality checker. - - Uses a lightweight model (configured via ``quality_check.model`` in - the YAML) to detect reports whose body has been largely replaced by - placeholders, abbreviations, or cross-references to external files. - - The checker sends a structured prompt asking the model to return a - JSON verdict: ``{"pass": true/false, "reason": "..."}``. - """ - - _SYSTEM_PROMPTS = { - 'en': - ('You are a strict report quality auditor. Your ONLY job is to detect whether a research report violates any of the rules listed below.\n' - 'You MUST check ONLY against these rules — do NOT invent additional criteria or penalize anything not explicitly listed here.\n' - 'If a problem is NOT described by rules below, you MUST ignore it and return {"pass": true}. ' - 'Specifically: duplicate/repeated content, heading numbering gaps, structural ordering issues, stylistic choices, ' - 'and the density of inline citations within otherwise substantive paragraphs are all OUT OF SCOPE and must NOT cause a failure.\n\n' - 'RULES — flag the report ONLY if ANY of the following are clearly found:\n' - '1. Sections where detailed content has been replaced by ellipsis or brevity markers such as "...for brevity", ' - '"Content truncated for brevity", "omitted for brevity", "(remaining content follows the same pattern)", etc.\n' - '2. Sections that refer the reader to an external file instead of containing actual content, e.g. "This section ' - 'is stored in xxx file", "See full analysis in evidence/xxx".\n' - '3. Sections that guide the reader to view the reference source instead of writing substantive content, e.g. "See [1]", "Reference [2]".\n\n' - 'OUTPUT FORMAT:\n' - 'Respond with EXACTLY one JSON object. No markdown fences, no explanation outside the JSON.\n' - '{"pass": true} or {"pass": false, "reason": ""}\n' - 'Do NOT output anything else.'), - 'zh': - ('你是一个严格的研究报告质量审核员,你唯一的任务是判断报告是否违反了下方列出的规则。\n' - '你只能依据以下规则进行检查,不得自行发明额外标准,也不得基于规则未涉及的内容判定不通过。如果某个问题不属于下方规则的任何一条,你必须忽略它并返回 {"pass": true}。\n' - '特别说明:重复/相似内容、标题编号跳跃、章节结构顺序问题、文体风格选择、以及在有实质论述的段落中密集使用行内引注,都不在检查范围内,不得因此判定不通过。\n\n' - '规则 — 仅当明确发现以下任一问题时才判定不通过:\n' - '1. 正文被省略号或缩略标记替代,如"此处省略"、"篇幅所限不再展开"、"……以下类似"、"内容已截断"、"...for brevity"、"omitted for brevity"等。\n' - '2. 正文引导读者查看外部文件而非包含实际内容,如"该部分内容保存在xxx文件中"、"详见附件"、"See full analysis in evidence/xxx"。\n' - '3. 正文引导读者查看引用来源而没有撰写实质性内容,如"详见[1]"、"参考[2]"。\n\n' - '输出格式:\n' - '只返回一个JSON对象,不要使用markdown代码块,不要在JSON之外输出任何文字。\n' - '{"pass": true} 或者 {"reason": "<不得超过三句话;引用具体违反的规则编号>", "pass": false}\n' - '不要输出任何其他内容。'), - } - - _USER_TEMPLATES = { - 'en': - ('Please audit the following research report against the rules provided in the system instruction.\n\n' - '---BEGIN REPORT---\n{report}\n---END REPORT---'), - 'zh': ('请依据系统指令中提供的规则审核以下研究报告。\n\n' - '---报告开始---\n{report}\n---报告结束---'), - } - - _MAX_REPORT_CHARS = 80000 - - def __init__(self, config: DictConfig): - self._config = config - qc_cfg = getattr(config, 'self_reflection', DictConfig({})) - qc_cfg = getattr(qc_cfg, 'quality_check', DictConfig({})) - - self._model: str = str(getattr(qc_cfg, 'model', 'qwen3.5-plus')) - self._api_key: Optional[str] = getattr( - qc_cfg, 'openai_api_key', None) or getattr(config.llm, - 'openai_api_key', None) - self._base_url: Optional[str] = getattr( - qc_cfg, 'openai_base_url', None) or getattr( - config.llm, 'openai_base_url', None) - - self._client: Optional[OpenAILLM] = None - - def _build_llm_config(self) -> DictConfig: - """Build lightweight llm config for quality checker.""" - return OmegaConf.create({ - 'llm': { - 'model': self._model, - 'openai_api_key': self._api_key, - 'openai_base_url': self._base_url, - }, - 'generation_config': {}, - }) - - def _ensure_client(self): - if self._client is not None: - return - self._client = OpenAILLM(self._build_llm_config()) - - async def check(self, content: str, lang: str) -> Optional[str]: - import json - - self._ensure_client() - - report_text = content - if len(report_text) > self._MAX_REPORT_CHARS: - report_text = report_text[:self._MAX_REPORT_CHARS] - - sys_prompt = self._SYSTEM_PROMPTS.get(lang, self._SYSTEM_PROMPTS['en']) - usr_template = self._USER_TEMPLATES.get(lang, - self._USER_TEMPLATES['en']) - - try: - response = self._client.generate(messages=[ - Message(role='system', content=sys_prompt), - Message( - role='user', - content=usr_template.format(report=report_text), - ), - ]) - raw = (response.content or '').strip() - logger.info( - f'ModelQualityChecker ({self._model}): raw response: {raw}') - - verdict = json.loads(raw) - if verdict.get('pass', True): - return None - return verdict.get('reason', 'placeholder_content') - - except json.JSONDecodeError: - logger.warning(f'ModelQualityChecker: failed to parse JSON from ' - f'model response: {raw!r}') - return None - except Exception as exc: - logger.warning(f'ModelQualityChecker: model call failed: {exc}') - return None - - class ResearcherCallback(Callback): """Callback for Researcher agent — pre-completion self-reflection. @@ -188,45 +44,48 @@ class ResearcherCallback(Callback): _REFLECTION_TEMPLATES = { 'zh': { - 'no_report': ('自查发现:输出目录中尚未生成 `{filename}`。\n' - '请确认最终报告未交付的原因,并立即采取行动修复。\n' - '请注意:不要使用占位符或缩略内容替代实际报告正文。'), + 'no_report': + ('外部检查发现:输出目录中尚未生成 {filename},该文件原本应由 Reporter 子代理自动创建。\n' + '请确认最终报告未交付的原因,并立即采取行动修复。\n' + '请注意:不要使用占位符或缩略内容替代实际报告正文。'), 'low_quality': - ('外部检查发现:`{filename}` 的内容存在质量问题——{reason}。\n' + ('外部检查发现:{filename} 的内容存在质量问题——{reason}。\n' '请仔细确认上述质量问题是否属实、是否还有更多问题,并立即采取行动修复。\n' - '**重要提醒**:如果质量问题属实,你必须完整重写整份报告。' - 'write_file 会完全覆盖文件,你写入的内容就是最终文件的全部内容——' - '以下写法都会原样出现在文件中并导致报告内容被永久丢失:\n' + '**重要提醒**:如果质量问题属实,你必须按照以下原则进行修复:\n' + '1. 优先通过有针对性的局部修改完成修复。请使用 file_system---search_file_content 定位问题段落,' + '然后使用 file_system---replace_file_contents 和 file_system---replace_file_lines 进行针对性修复。' + '需要时可以使用 file_system---read_file (with start_line/end_line) 验证上下文是否一致。\n' + '2. 如果确认无法通过1完成修复,可以使用 file_system---write_file 全量重写报告,但请注意以下可能的质量违规:\n' '- 用省略号或缩略标记替代正文,如"(同之前,略)"、"此处省略"、"篇幅所限不再展开"、' '"……以下类似"、"内容已截断"、"Content truncated for brevity"等;\n' '- 引导读者查看外部文件而非包含实际内容,如"该部分内容保存在xxx文件中"、' '"完整内容如 xxx 所述"、"详见附件"等;\n' - '- 引导读者查看引用来源而没有撰写实质性内容,如"详见[1]"、"参考[2]"。\n' - '不得遗漏或省略任何章节,无需担心与先前输出的内容或写入过的文件重复。'), + '- 引导读者查看引用来源而没有撰写实质性内容,如"详见[1]"、"参考[2]"。\n'), }, 'en': { 'no_report': - ('Self-check indicates that `{filename}` has not yet been generated in the output directory.\n' - 'Please determine why the final report has not been delivered and take immediate action to fix the issue.\n' + ('External inspection found that {filename} has not yet been generated in the output directory; ' + 'this file was expected to be created automatically by the Reporter sub-agent.\n' + 'Please identify why the final report was not delivered and immediately take action to fix it.\n' 'Note: Do not use placeholders or abbreviated content in place of the actual report body.' ), 'low_quality': - ('External inspection found quality issues in `{filename}` — {reason}.\n' + ('External inspection found quality issues in {filename} — {reason}.\n' 'Please carefully verify whether these issues are valid and whether additional problems exist, ' 'then immediately take action to fix them.\n' - '**IMPORTANT REMINDER**: If the issues are valid, you MUST rewrite the complete report in full. ' - 'write_file overwrites the entire file — what you write IS the final file content. ' - 'The following patterns will appear literally in the file and permanently destroy report content:\n' - '- Replacing substantive content with brevity markers, e.g., "(same as before, omitted)", ' - '"...for brevity", "Content truncated for brevity", "omitted for brevity", ' - '"(remaining content follows the same pattern)";\n' - '- Referring readers to external files instead of including actual content, e.g., ' - '"This section is stored in xxx file", "See full analysis in evidence/xxx", ' - '"(see xxx for full content)";\n' + '**IMPORTANT**: If the quality issues are confirmed, you must follow these principles to fix them:\n' + '1. PREFER targeted, localized fixes. Use file_system---search_file_content to locate the problematic sections, ' + 'then use file_system---replace_file_contents and file_system---replace_file_lines to apply precise corrections. ' + 'use file_system---read_file (with start_line/end_line) to verify surrounding context when needed.\n' + '2. If you confirm that targeted fixes alone cannot resolve the issues, you may use file_system---write_file ' + 'to fully rewrite the report, but beware of the following quality violations:\n' + '- Replacing body text with ellipsis or brevity markers, e.g., "(same as before, omitted)", ' + '"omitted here", "not elaborated due to space constraints", ' + '"...similar below", "content truncated", "Content truncated for brevity", etc.;\n' + '- Directing readers to view external files instead of including actual content, e.g., ' + '"This section is stored in xxx file", "See full content in xxx", "See attachment", etc.;\n' '- Directing readers to view reference sources without writing substantive content, ' - 'e.g., "See [1]", "Reference [2]".\n' - 'Do not omit or skip any sections; do not worry about repeating content you have previously output.' - ), + 'e.g., "See [1]", "Reference [2]".\n'), }, } @@ -247,31 +106,9 @@ def __init__(self, config: DictConfig): getattr(refl_cfg, 'report_filename', self.report_filename)) self._retries_used: int = 0 - self._checkers: List[ReportQualityChecker] = self._build_checkers( + self._checkers: List[ReportQualityChecker] = build_quality_checkers( config) - @staticmethod - def _build_checkers(config: DictConfig) -> List[ReportQualityChecker]: - """Instantiate the quality-checker chain from config. - - Currently supports ``ModelQualityChecker``. New checker types - can be added here and will be appended to the chain — the first - checker that returns a failure reason wins. - """ - refl_cfg = getattr(config, 'self_reflection', None) - if refl_cfg is None: - return [] - - qc_cfg = getattr(refl_cfg, 'quality_check', None) - if qc_cfg is None or not bool(getattr(qc_cfg, 'enabled', False)): - return [] - - checkers: List[ReportQualityChecker] = [] - checkers.append(ModelQualityChecker(config)) - logger.info(f'ResearcherCallback: quality checker chain initialised ' - f'with {len(checkers)} checker(s).') - return checkers - @staticmethod def _resolve_lang(config: DictConfig) -> str: prompt_cfg = getattr(config, 'prompt', None) @@ -292,6 +129,24 @@ def _get_template(self, key: str) -> str: self.lang, self._REFLECTION_TEMPLATES['en']) return templates[key] + TASK_FINISHED_MARKER = '.researcher_task_finished' + + @property + def _marker_path(self) -> str: + return os.path.join(self.output_dir, self.TASK_FINISHED_MARKER) + + async def on_task_end(self, runtime: Runtime, messages: List[Message]): + try: + os.makedirs(self.output_dir, exist_ok=True) + with open(self._marker_path, 'w') as f: + f.write('') + logger.info( + f'ResearcherCallback: wrote researcher_task_finished marker ' + f'at {self._marker_path}') + except Exception as exc: + logger.warning( + f'ResearcherCallback: failed to write marker: {exc}') + async def after_tool_call(self, runtime: Runtime, messages: List[Message]): if not self.enabled: return diff --git a/projects/deep_research/v2/eval/dr_bench_runner.py b/projects/deep_research/v2/eval/dr_bench_runner.py index 780a16d60..1917564bf 100644 --- a/projects/deep_research/v2/eval/dr_bench_runner.py +++ b/projects/deep_research/v2/eval/dr_bench_runner.py @@ -191,6 +191,9 @@ def _tail_text_from_file(path: str, *, max_chars: int = 20000) -> str: return '' +TASK_FINISHED_MARKER = '.researcher_task_finished' + + @dataclass(frozen=True) class Task: task_id: str @@ -288,6 +291,7 @@ def _run_one_task( os.makedirs(workdir, exist_ok=True) log_path = os.path.join(workdir, 'ms_agent.log') + marker_path = os.path.join(workdir, TASK_FINISHED_MARKER) cmd = [ python_executable, @@ -309,12 +313,21 @@ def _run_one_task( env = dict(os.environ) env.setdefault('PYTHONUNBUFFERED', '1') - # Safety net for rare "subprocess produced final_report.md but never exits". - # This happens when the child Python process is stuck at shutdown (e.g. a - # non-daemon thread blocked in I/O). If the final report is already stable - # on disk, force-reap the child so the batch runner can continue. + # Exit strategy (two independent conditions, first one wins): + # + # 1. PRIMARY — .researcher_task_finished marker file appears in workdir + # (written by ResearcherCallback.on_task_end). + # Wait `post_finish_grace_s` then force-reap. + # + # 2. FALLBACK — final_report.md exists and has been stable for + # `post_report_exit_grace_s` but the marker never appeared + # (e.g. process hung at shutdown). Force-reap to unblock + # the batch runner. + # + post_finish_grace_s = float( + os.getenv('DR_BENCH_POST_FINISH_GRACE_S', '180') or 180.0) post_report_exit_grace_s = float( - os.getenv('DR_BENCH_POST_REPORT_EXIT_GRACE_S', '15') or 15.0) + os.getenv('DR_BENCH_POST_REPORT_EXIT_GRACE_S', '3600') or 3600.0) report_stable_window_s = float( os.getenv('DR_BENCH_REPORT_STABLE_WINDOW_S', '2') or 2.0) poll_interval_s = float( @@ -324,11 +337,10 @@ def _run_one_task( kill_timeout_s = float( os.getenv('DR_BENCH_SUBPROCESS_KILL_TIMEOUT_S', '2') or 2.0) - # We consider a task "already done" if it produced a top-level - # final report file (final_report.md or report.md) with non-empty content. report_seen_stable_at: Optional[float] = None report_last_sig: Optional[Tuple[float, int]] = None report_stable_since: Optional[float] = None + marker_seen_at: Optional[float] = None force_reaped = False if stream_subprocess_output: @@ -349,7 +361,21 @@ def _run_one_task( while True: now_s = time.time() - # If final report exists and is stable, start a grace timer. + # --- Condition 1: .researcher_task_finished marker --- + if marker_seen_at is None and os.path.exists(marker_path): + marker_seen_at = now_s + if (marker_seen_at is not None and proc.poll() is None + and (now_s - marker_seen_at) >= max( + 0.0, post_finish_grace_s)): + _terminate_process( + proc, + terminate_timeout_s=terminate_timeout_s, + kill_timeout_s=kill_timeout_s, + ) + force_reaped = True + break + + # --- Condition 2: report stable for a long time (fallback) --- report_path_hint = _find_report_md(workdir) if report_path_hint and _is_direct_final_report_path( workdir, report_path_hint): @@ -360,8 +386,11 @@ def _run_one_task( stable_since=report_stable_since, now_s=now_s, ) - if stable and report_seen_stable_at is None: - report_seen_stable_at = now_s + if stable: + if report_seen_stable_at is None: + report_seen_stable_at = now_s + else: + report_seen_stable_at = None if (report_seen_stable_at is not None and proc.poll() is None and (now_s - report_seen_stable_at) >= max( @@ -438,6 +467,21 @@ def _run_one_task( while True: now_s = time.time() + # --- Condition 1: .researcher_task_finished marker --- + if marker_seen_at is None and os.path.exists(marker_path): + marker_seen_at = now_s + if (marker_seen_at is not None and proc2.poll() is None + and (now_s - marker_seen_at) >= max( + 0.0, post_finish_grace_s)): + _terminate_process( + proc2, + terminate_timeout_s=terminate_timeout_s, + kill_timeout_s=kill_timeout_s, + ) + force_reaped = True + break + + # --- Condition 2: report stable for a long time (fallback) --- report_path_hint = _find_report_md(workdir) if report_path_hint and _is_direct_final_report_path( workdir, report_path_hint): @@ -448,8 +492,11 @@ def _run_one_task( stable_since=report_stable_since, now_s=now_s, ) - if stable and report_seen_stable_at is None: - report_seen_stable_at = now_s + if stable: + if report_seen_stable_at is None: + report_seen_stable_at = now_s + else: + report_seen_stable_at = None if (report_seen_stable_at is not None and proc2.poll() is None and (now_s - report_seen_stable_at) >= max( diff --git a/projects/deep_research/v2/prompts/reporter/en/gpt5.txt b/projects/deep_research/v2/prompts/reporter/en/gpt5.txt index 04496de99..34060d889 100644 --- a/projects/deep_research/v2/prompts/reporter/en/gpt5.txt +++ b/projects/deep_research/v2/prompts/reporter/en/gpt5.txt @@ -5,15 +5,16 @@ Action protocol: Before outputting the final JSON result, every iteration MUST i # Primary Responsibilities Complete the task through a tool-calling loop without introducing new facts unsupported by evidence: -1. Produce the final report (or user-specified sections/revisions) that meets the user's requirements, and return it directly in the conversation as part of the JSON result. Do not save it via any tools. - - The report should follow a research report / white paper style: informative, evidence-driven, and well-structured. Avoid colloquial language, fragmentation, and excessive bullet points. The content MUST primarily consist of continuous, flowing paragraphs; bullet points should only be used sparingly for genuinely list-like content (e.g., enumerated action items, short comparison lists). Maintain a clear logical chain and a reasonable heading hierarchy and numbering system. +1. Produce the final report (or user-specified sections/revisions) that meets the user's requirements and save it to reports/report.md. + - The report should follow a research report / white paper style: informative, evidence-driven, and well-structured. Avoid colloquial language, fragmentation, and excessive bullet points. + - The content MUST primarily consist of continuous, flowing paragraphs; bullet points should only be used sparingly for genuinely list-like content (e.g., enumerated action items, short comparison lists). Maintain a clear logical chain and a reasonable heading hierarchy and numbering system. 2. During writing, ensure that all sections are **grounded in evidence**, and that evidence coverage is as comprehensive as possible (follow the input writing requirements; the outline phase requires covering all evidence). 3. Explicitly record and handle conflicts (using the report_generator---commit_conflict tool, and explain conflicts and uncertainties in the body text). -4. Through tool calls, persist intermediate artifacts as traceable files: outline, chapter metadata, chapter content, conflict records. +4. Through tool calls, persist the working artifacts as traceable files: outline, chapter metadata, chapter content, conflict records, final report. 5. **Maximize efficiency while ensuring quality.** Chapter writing can be parallel or sequential. **You are encouraged to write chapters in parallel when possible.** Before parallel writing (i.e., calling multiple tools in a single response), first analyze the dependency relationships among chapters in the outline to confirm they are reasonable, avoiding logical contradictions or dependency gaps. # Reference Workflow -The following is a proven workflow that works well for most research tasks. +The following is a proven workflow that works well for most research report writing tasks. You are free to adapt, reorder, or skip steps based on the complexity and requirements of the current task — but the general approach has been validated across many scenarios. ## Phase 1: Generate Outline Grounded in Evidence @@ -23,7 +24,7 @@ You are free to adapt, reorder, or skip steps based on the complexity and requir - use a compact but sufficient structure, usually 4–8 body chapters (5–7 preferred in most cases); - avoid splitting closely related content into separate chapters when subsections would suffice; - expand beyond the default chapter range only when clearly justified by the user's request or the evidence structure; - - note that Execution Summary (执行摘要) should not be written as a chapter in the report body. + - note that the Execution Summary (in Chinese, "执行摘要") should not appear as a chapter in the report body. ## Phase 2: Chapter Content Writing Loop Chapter writing is defined as a progressive process, writing 1–3 new chapters each time until all are completed. @@ -52,10 +53,13 @@ Stopping conditions (stop if any one is satisfied): ## Phase 3: Assemble Final Report - Call report_generator---assemble_draft to consolidate all chapter content and obtain the first version of the final report draft. - Read the draft, reflect on the logical consistency between chapters, overall content coherence, and whether previously discovered conflicts have been resolved or explained. If new conflicts are found, call report_generator---commit_conflict to record them and try to provide a resolution. -- Based on the reflection results and recorded conflicts, rewrite the final markdown report content and return it to the user in JSON format. Note: **you MUST NOT call tools to write the final report content to any file (e.g., report_final.md, final_report.md); the full report MUST be returned directly to the user in the conversation.** - - The Report field must contain the final markdown report (the NUMBER ONE failure mode in your work is replacing report content with references or pointers to other content or files (e.g., "details are in chapter_2.md")). The format, style, and other aspects of the report must follow the specifications required by the user's input and the "Default Report Style" section. The report must include citations to reference sources; - - The Execution_Summary field must record the report generation status, evidence coverage, conflict information summary, and other content that needs to be communicated to the user; - - The Artifacts field must record the paths to intermediate file artifacts. Note that the content of your current response will be automatically stored by the system as report.md and report.json files under the reports directory; you should record these in the Artifacts field. +- Based on the reflection results and recorded conflicts, you MUST rewrite the final markdown report content and save it to reports/report.md. + - The final markdown report (the NUMBER ONE failure mode in your work is replacing report content with references or pointers to other content or files (e.g., "details are in chapter_2.md")). + - The format, style, and other aspects of the report must **follow the specifications required by the user's input and the "Default Report Style" section**. The report must include citations to reference sources. + - Writing strategy: Try to write the full report in one file_system---write_file call first. If that write attempt gets truncated by the output limit, retry from the beginning with file_system---write_file and include as much content as possible, then append the remainder using file_system---replace_file_lines with start_line=-1 in as few calls as possible. +- After delivering the final report, return a work summary in JSON format in the conversation: + - The Execution_Summary field must include the report generation status, evidence coverage, a summary of conflicts, and any other information that should be communicated to the user. + - The Artifacts field must include the paths to intermediate file artifacts. # Evidence Usage and Re-ranking Rules When sorting and filtering candidate evidence, the following dimensions can be referenced (but are not limited to these): @@ -66,8 +70,8 @@ When sorting and filtering candidate evidence, the following dimensions can be r - Citability: Whether it contains definitions, data, conclusions, charts, or methodological details that can be directly cited. # Tool Invocation Protocol -- Do not attempt to use any tools that have not been provided. You work in a file system with full read-write permissions but isolated from the outside. When performing file-level operations, keep using relative paths. -- You must organize the writing workflow using tools under the report_generator server as much as possible. You may not use other tool services (such as evidence tools, file system) to write report content, nor maintain your intermediate writing content only in the conversation. +- Do not attempt to use any tools that have not been provided (e.g. todo_list---, etc.). You work in a file system with full read-write permissions but isolated from the outside. When performing file-level operations, keep using relative paths. +- You must organize the writing workflow using tools under the report_generator server as much as possible. Do not maintain your intermediate writing content only in the conversation. - You must use tools under the evidence_store server for querying evidence details, retrieving indexes, getting content lists, and similar operations. - **You are encouraged to invoke multiple tools in parallel** when tasks are independent (such as reading multiple pieces of evidence, writing multiple chapters, etc.) for optimal performance. - **Concurrent call example**: Suppose chapters 2, 3, and 4 can be written in parallel. You should call 3 report_generator---prepare_chapter_bundle tools simultaneously in **one response**. After receiving the results from all 3 tools, call 3 report_generator---commit_chapter tools simultaneously in **one response**. This way, only 2 conversation turns are needed to complete 3 chapters. @@ -80,7 +84,7 @@ When sorting and filtering candidate evidence, the following dimensions can be r - Coverage requirement: During the outline generation phase, outline chapters and evidence must establish mapping relationships. Unless the user indicates "ignorable noise evidence", default to full coverage as much as possible. - Explicit conflict handling: When evidence from multiple sources is inconsistent, contextual logic is contradictory, or data sources show anomalies, you must promptly call report_generator---commit_conflict to record the conflicting evidence and provide a resolution. - DO NOT cite local files (notes, analyses, computed data, etc.) in the final report. Avoid invalid forms in the main text, such as [Note ID]-style placeholders. If you need to indicate an evidence gap, simply state what content the gap concerns—there is no need to explicitly reference the corresponding Note ID. -- No meta-text in the report body: Do not include instructional or meta-level text such as target audience descriptions (e.g., "Target Audience: ...", "面向对象:..."), author notes (e.g., "Note: ...", "注:..."), execution notes, or disclaimers that break the reading flow. Such information belongs in the Execution_Summary field, not in the report body. The report should read as a polished, self-contained document ready for delivery. +- No meta-text in the report body: Do not include instructional or meta-level text such as target audience descriptions (e.g., "Target Audience: ...", "面向对象:..."), author notes (e.g., "Note: ...", "注:..."), execution notes, or disclaimers that break the reading flow. Such information belongs to the Execution_Summary field in final JSON output, not in the report body. The report should read as a polished, self-contained document ready for delivery. - Use concise, natural-sounding headings: Chapter and section titles should be concise and readable. Avoid overly long compound titles with excessive parenthetical clarifications (e.g., avoid "Challenges, Governance and Compliance (Including School Governance Framework and Procurement Contract Clauses)"; prefer "Challenges and Governance"). If important details must be conveyed, place them in the section body, not the title. # Report Citation Format (Mandatory) @@ -96,14 +100,13 @@ When sorting and filtering candidate evidence, the following dimensions can be r - Numbering assignment rule: Assign numbers starting from 1 in the order sources first appear in the body text; reuse the same number for the same source at different locations. # Default Report Style -- Technical/research report tone: careful and verifiable; as much information as possible and as faithful to the original evidence as possible; do not over-compress into an executive-summary-only output; avoid overly casual language; ensure readability. +- Technical/research report tone: careful and verifiable; include as much information as possible while remaining as faithful to the original evidence as possible; do not over-compress into an executive-summary-only output; avoid overly casual language; ensure readability. - Clear structure: Default to cohesive paragraphs (not outline-as-bullets; avoid choppy, overly short paragraphs). Use bullet points when genuinely itemized lists improve clarity; avoid nested bullets and heavy indentation. - Prefer a clean heading hierarchy and numbering system, such as `# 2. Background and Problem`, `## 2.1 Background`, `### 2.1.1 Direction One`, etc. Do not exceed three levels. All section headings MUST use Markdown ATX headings. # Output Format Return JSON only, you MUST follow this format: { - "Report": "...", "Execution_Summary": "...", "Artifacts": ["path/to/artifact_1", "path/to/artifact_2", ...] } diff --git a/projects/deep_research/v2/prompts/researcher/en/gpt5.txt b/projects/deep_research/v2/prompts/researcher/en/gpt5.txt index bbd0251af..609d868e1 100644 --- a/projects/deep_research/v2/prompts/researcher/en/gpt5.txt +++ b/projects/deep_research/v2/prompts/researcher/en/gpt5.txt @@ -13,9 +13,9 @@ Action protocol: Before outputting the final result, every iteration MUST invoke - When evidence is insufficient, delegate tasks to the Searcher sub-agent (i.e., agent_tools---searcher_tool) to perform an iterative research loop (when concurrency is allowed, 2–4 sub-agents can be invoked in parallel; prioritize parallel invocation when tasks are parallelizable). - Analyze & synthesize: - When the research can only move forward by conducting synthesis based on the collected materials—such as framework design, cross-validation, scenario analysis, data analysis, etc.—you MUST proactively complete these tasks using the available tools. -- Draft, polish, deliver: - - When research is sufficient, delegate to the Reporter sub-agent (i.e., agent_tools---reporter_tool) to generate the research report. - - Then you MUST verify, correct, revise, and polish it to ensure the report meets the quality requirements and user requirements. Please write the complete, deliverable final report to the file. Do not replace any content with placeholders such as “Content truncated for brevity.” or “This section is stored in xxx file.” +- Draft, review, deliver: + - When research is sufficient, delegate to the Reporter sub-agent (i.e., agent_tools---reporter_tool) to generate the research report. The Reporter will automatically deliver the complete report as final_report.md. + - Then you MUST review the report for quality and accuracy. If issues are found, apply **targeted corrections** using file_system---search_file_content to locate problems and file_system---replace_file_contents to fix them. Do NOT rewrite the entire report unless you are strongly sure it is necessary — the Reporter’s output preserves maximum evidence fidelity. # Reference Workflow The following is a proven workflow that works well for most research tasks. @@ -51,11 +51,20 @@ Stopping conditions (stop if you are confident to proceed to the next phase): ## Phase 3: Report Generation - Invoke the Reporter sub-agent to generate the report. Provide the Reporter sub-agent with the complete report topic, target audience, background, task description, writing requirements, section constraints, and any other necessary information. - Note: do not impose a word-count requirement on the Reporter sub-agent unless the user explicitly requests it; DO NOT ask the Reporter sub-agent to include the Execution Summary (执行摘要) as a separate section in the report. -- After receiving the report, you MUST review and polish it for quality and accuracy, then write the final version to final_report.md via file_system---write_file. Follow these principles: +- The Reporter will deliver the complete report as final_report.md. After the Reporter returns, you MUST review the report for quality and accuracy: - **Verify first.** Before editing, spot-check factual accuracy, logical consistency, coverage of the user's core questions, and citation–claim alignment against the collected evidence. - The report MUST comply with the "Quality Constraints" and "Default Report Style" sections. Execution Summary (执行摘要) MUST NOT appear as a chapter in the report body. - **Edit with justification.** Every substantive change (compression, deletion, restructuring, format conversion) must be driven by a concrete problem — such as factual redundancy, logical disorganization, evidence inconsistency, or style/quality violations. Well-structured content with reasonable depth and detail must be preserved as-is, including its structure, granularity, and length. - **Do not over-edit.** Do not convert flowing paragraphs into bullet-point lists, flatten detailed subsections into one-line summaries, or replace evidence-backed analysis with high-level abstractions — unless the original format genuinely hinders readability or violates the report style. +- If the report passes your review without issues: proceed directly to your conclusion. Do NOT rewrite it "for polish." +- If issues are found, **strongly prefer targeted corrections** over full rewrites: + - **Standard workflow**: use file_system---search_file_content to locate the problem, then file_system---replace_file_contents to fix it. This is the safest and most precise approach. + - Precision reminder: Punctuation mismatches (e.g., Chinese `、` vs English `,`; full-width vs half-width characters), whitespace differences, or line-break variations usually cause the replacement to fail. + - Parallel editing: for multiple independent fixes in the same file, use file_system---search_file_content and file_system---replace_file_contents in parallel when `source` spans do not overlap. However, NEVER call file_system---replace_file_lines in parallel on the same file — line numbers shift after each call. + - **Deleting or replacing line ranges**: use file_system---replace_file_lines with start_line/end_line to delete or replace a block of lines (e.g., removing an entire section). Use file_system---search_file_content first to locate the line numbers (start line and end line). + - **Inspect before editing**: use file_system---read_file (with start_line/end_line) to verify surrounding context when needed. + - **Last resort only**: file_system---write_file overwrites the entire file — use it only when targeted tools cannot address the issue (e.g., extensive structural reorganization). You must reproduce ALL content valuable to the user. + - WARNING: Full report rewrites may carry high risk of content loss. Do not over-compress the report. Do not replace any content with placeholders such as "Content truncated for brevity." or "This section is stored in xxx file." - Finally show your conclusions for the entire task in the conversation. # Process Constraints @@ -76,10 +85,14 @@ Stopping conditions (stop if you are confident to proceed to the next phase): - NEVER fabricate citations or sources. Every factual statement in the final deliverable must be supported by the Searcher sub-agent’s research conclusions and stored evidence. - Clearly track time constraints and the current date. If the knowledge you intend to apply may be outdated, do not trust your memory; query via tools instead. - Strictly control scope: if the user asks for X, do not drift to Y. -- The final report must preserve complete citation relationships. Do not lose the original citations due to revisions or polishing; you must preserve the citation format included in the report returned by the Reporter sub-agent. Exception: If the report contains incorrect citation formatting, you must fix it and ensure the report’s meaning remains correct—for example, replace invalid forms such as [Note ID]-style placeholders with the proper citation format. -- If the Reporter sub-agent’s report has missing citations, follow these rules: in the body, use numbered citations only, such as `[1]`, `[2]`, ... (multiple citations may appear together like `[1][3]`). The end of the report must contain `## References` (for English reports) or `## 参考文献` (for Chinese reports), and the numbering mapping must remain consistent. During polishing, it is forbidden to write long-title links (e.g., `[Title](URL)`) back into the body. +- Citation integrity in the final report: + - **Fix if broken**: + - Invalid citation forms (e.g., [Note ID]-style placeholders) — replace with proper `[1]`, `[2]`, ... numbered markers. + - Multiple ## References / ## 参考文献 sections (e.g., per-chapter reference lists) are not allowed — this includes any variant headings such as "## 参考文献(合并版)", "## References (Merged)", "## 参考资料", or similar. The report body and individual chapters must NOT contain any reference/bibliography list; remove such sections entirely. Keep only one unified reference section at the very end of the report. Re-number in-text citations if needed. + - **Supplement if missing**: Add numbered citations `[1]`, `[2]`, ... in the body (multiple may appear together like `[1][3]`). The report must end with exactly one `## References` (English) or `## 参考文献` (Chinese) section with consistent numbering. Do not use long-title links (e.g., `[Title](URL)`) in the body text. + - **Preserve by default**: Do not alter correct citations delivered by the Reporter sub-agent. Your edits must not cause citation loss. - For the final report, you MUST use the language specified by the user; if none is specified, you must keep it consistent with the language the user is using. -- The final report written to final_report.md MUST follow the "Default Report Style" section. +- The final report in final_report.md MUST follow the "Default Report Style" section. # Default Report Style - Technical/research report tone: careful and verifiable; as much information as possible and as faithful to the original evidence as possible; do not over-compress into an executive-summary-only output; avoid overly casual language; ensure readability. diff --git a/projects/deep_research/v2/reporter.yaml b/projects/deep_research/v2/reporter.yaml index ffd0cdacd..4f4d09c8b 100644 --- a/projects/deep_research/v2/reporter.yaml +++ b/projects/deep_research/v2/reporter.yaml @@ -35,6 +35,7 @@ tools: - write_file - read_file - list_files + - replace_file_lines evidence_store: mcp: false evidence_dir: evidence @@ -66,6 +67,14 @@ round_reminder: enabled: true remind_at_round: 34 +self_reflection: + enabled: true + max_retries: 2 + min_retention_ratio: 0.3 + quality_check: + enabled: true + model: qwen3.5-flash + tool_call_timeout: 300 output_dir: ./output diff --git a/projects/deep_research/v2/researcher.yaml b/projects/deep_research/v2/researcher.yaml index 218b9204f..5966d5353 100644 --- a/projects/deep_research/v2/researcher.yaml +++ b/projects/deep_research/v2/researcher.yaml @@ -14,7 +14,7 @@ generation_config: # Supports role names: system, user, assistant, tool, last_message prefix_cache_roles: [system, user, assistant, tool] # extra_body: - # enable_thinking: false + # enable_thinking: false tag: deep-research-researcher @@ -34,6 +34,9 @@ tools: - write_file - read_file - list_files + - search_file_content + - replace_file_contents + - replace_file_lines code_executor: mcp: false implementation: python_env @@ -97,9 +100,12 @@ tools: max_output_chars: 200000 - tool_name: reporter_tool description: > - Invoke the Reporter sub-agent to generate a report based on collected evidence. + Invoke the Reporter sub-agent to generate a research report based on collected evidence. Reporter reads the stored evidence cards and executes a complex workflow for research report writing. - Returns a JSON result containing: report body, execution summary, and intermediate artifact file paths. + The completed report is automatically saved to `final_report.md` in the output directory. + Returns a JSON result containing: execution summary and + intermediate artifact file paths (the full report body is NOT included in the return value — + read `final_report.md` directly to access the report content). config_path: reporter.yaml parameters: type: object @@ -148,6 +154,6 @@ code_file: researcher max_chat_round: 40 -tool_call_timeout: 1800 +tool_call_timeout: 2000 output_dir: ./output From 2977ba09204169b6fe02492da5089e56ad06573e Mon Sep 17 00:00:00 2001 From: suluyana <110878454+suluyana@users.noreply.github.com> Date: Mon, 16 Mar 2026 09:58:32 +0800 Subject: [PATCH 05/40] Update ms_agent/agent/llm_agent.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- ms_agent/agent/llm_agent.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 3bc40c1fc..515446381 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -757,20 +757,6 @@ async def prepare_knowledge_search(self): if hasattr(self.config, 'knowledge_search'): ks_config = self.config.knowledge_search if ks_config is not None: - # Extract LLM config for sirchmunk - if hasattr(self.config, 'llm'): - llm_config = self.config.llm - # Update knowledge_search config with LLM settings if not specified - if not hasattr(ks_config, 'llm_api_key') and hasattr(llm_config, 'modelscope_api_key'): - OmegaConf.update(self.config, 'knowledge_search.llm_api_key', - getattr(llm_config, 'modelscope_api_key', None), merge=True) - if not hasattr(ks_config, 'llm_base_url') and hasattr(llm_config, 'modelscope_base_url'): - OmegaConf.update(self.config, 'knowledge_search.llm_base_url', - getattr(llm_config, 'modelscope_base_url', None), merge=True) - if not hasattr(ks_config, 'llm_model_name') and hasattr(llm_config, 'model'): - OmegaConf.update(self.config, 'knowledge_search.llm_model_name', - getattr(llm_config, 'model', None), merge=True) - self.knowledge_search: SirchmunkSearch = SirchmunkSearch(self.config) async def condense_memory(self, messages: List[Message]) -> List[Message]: From 879dba4190c819ee7b70ce7d4f511aba21e44715 Mon Sep 17 00:00:00 2001 From: suluyana <110878454+suluyana@users.noreply.github.com> Date: Mon, 16 Mar 2026 09:59:07 +0800 Subject: [PATCH 06/40] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- ms_agent/knowledge_search/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ms_agent/knowledge_search/README.md b/ms_agent/knowledge_search/README.md index 00743601e..ef86df0da 100644 --- a/ms_agent/knowledge_search/README.md +++ b/ms_agent/knowledge_search/README.md @@ -108,7 +108,7 @@ for msg in result: print(f"Search results: {msg.search_result}") ``` -### 2. 单独使用 SirchmunkSearch +### 3. 单独使用 SirchmunkSearch ```python from ms_agent.knowledge_search import SirchmunkSearch From 1dd49b63b1ac5d9ac98b270dec2f09d5bc4e7d10 Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Tue, 17 Mar 2026 18:23:29 +0800 Subject: [PATCH 07/40] fix local code executor; refine workflow and prompt (03) --- ms_agent/tools/code/local_code_executor.py | 3 ++- projects/deep_research/v2/callbacks/reporter_callback.py | 4 ++-- projects/deep_research/v2/prompts/reporter/en/gpt5.txt | 9 ++++----- projects/deep_research/v2/prompts/researcher/en/gpt5.txt | 2 +- projects/deep_research/v2/prompts/searcher/en/gpt5.txt | 2 +- projects/deep_research/v2/reporter.yaml | 4 ++-- projects/deep_research/v2/researcher.yaml | 4 ++-- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/ms_agent/tools/code/local_code_executor.py b/ms_agent/tools/code/local_code_executor.py index 72771563a..65de0556e 100644 --- a/ms_agent/tools/code/local_code_executor.py +++ b/ms_agent/tools/code/local_code_executor.py @@ -66,11 +66,12 @@ async def start(self) -> None: self._km = AsyncKernelManager( kernel_name=self.kernel_name, env=self.env, - cwd=str(self.working_dir)) + cwd=str(self.working_dir)) # cwd may be ignored here start_kernel_result = self._km.start_kernel( extra_arguments=self.extra_arguments, env=self.env, + cwd=str(self.working_dir), ) if inspect.isawaitable(start_kernel_result): await start_kernel_result diff --git a/projects/deep_research/v2/callbacks/reporter_callback.py b/projects/deep_research/v2/callbacks/reporter_callback.py index b86cde033..0200f6cb7 100644 --- a/projects/deep_research/v2/callbacks/reporter_callback.py +++ b/projects/deep_research/v2/callbacks/reporter_callback.py @@ -105,7 +105,7 @@ class ReporterCallback(Callback): '仅为 reports/draft.md({draft_chars} 字符)的 {ratio:.0%},有可能存在内容丢失风险,请对报告内容进行检查并采取合理的行动。\n' '**重要提醒**:draft.md 是由工具逐章组装的完整版本,理论上保留了最大的证据保真度。\n' '- 如果你确认你对 draft.md 进行的修改是合理的,可以直接说明压缩内容的理由,无需再次修改或者重写。\n' - '- 如果你发现 reports/{report_name} 相比 draft.md 确实存在不合理的压缩,请重写并修复这些问题。\n' + '- 如果你发现 reports/{report_name} 相比 draft.md 确实存在不合理的压缩,请通过重写/追加/续写等方式来修复这些问题。\n' '请立即采取行动完成报告交付。'), 'low_quality': ('外部检查发现:报告内容存在质量问题——{reason}。\n' @@ -138,7 +138,7 @@ class ReporterCallback(Callback): '- If you confirm that your modifications to draft.md are reasonable, you may simply ' 'explain the rationale for the compression without further modifications or rewrites.\n' '- If you find that reports/{report_name} has indeed been unreasonably compressed ' - 'compared to draft.md, please rewrite and fix these issues.\n' + 'compared to draft.md, please rewrite/append/continue writing to repair these issues.\n' 'Please take immediate action to complete report delivery.'), 'low_quality': ('External inspection found quality issues in the report — {reason}.\n' diff --git a/projects/deep_research/v2/prompts/reporter/en/gpt5.txt b/projects/deep_research/v2/prompts/reporter/en/gpt5.txt index 34060d889..fc276252a 100644 --- a/projects/deep_research/v2/prompts/reporter/en/gpt5.txt +++ b/projects/deep_research/v2/prompts/reporter/en/gpt5.txt @@ -6,8 +6,7 @@ Action protocol: Before outputting the final JSON result, every iteration MUST i # Primary Responsibilities Complete the task through a tool-calling loop without introducing new facts unsupported by evidence: 1. Produce the final report (or user-specified sections/revisions) that meets the user's requirements and save it to reports/report.md. - - The report should follow a research report / white paper style: informative, evidence-driven, and well-structured. Avoid colloquial language, fragmentation, and excessive bullet points. - - The content MUST primarily consist of continuous, flowing paragraphs; bullet points should only be used sparingly for genuinely list-like content (e.g., enumerated action items, short comparison lists). Maintain a clear logical chain and a reasonable heading hierarchy and numbering system. + - The report should follow a research report / white paper style: informative, evidence-driven, and well-structured. Avoid colloquial language, fragmentation, and excessive bullet points; content should primarily consist of continuous, flowing paragraphs with appropriate use of bullet points. Maintain a clear logical chain and a reasonable heading hierarchy and numbering system. 2. During writing, ensure that all sections are **grounded in evidence**, and that evidence coverage is as comprehensive as possible (follow the input writing requirements; the outline phase requires covering all evidence). 3. Explicitly record and handle conflicts (using the report_generator---commit_conflict tool, and explain conflicts and uncertainties in the body text). 4. Through tool calls, persist the working artifacts as traceable files: outline, chapter metadata, chapter content, conflict records, final report. @@ -54,9 +53,9 @@ Stopping conditions (stop if any one is satisfied): - Call report_generator---assemble_draft to consolidate all chapter content and obtain the first version of the final report draft. - Read the draft, reflect on the logical consistency between chapters, overall content coherence, and whether previously discovered conflicts have been resolved or explained. If new conflicts are found, call report_generator---commit_conflict to record them and try to provide a resolution. - Based on the reflection results and recorded conflicts, you MUST rewrite the final markdown report content and save it to reports/report.md. - - The final markdown report (the NUMBER ONE failure mode in your work is replacing report content with references or pointers to other content or files (e.g., "details are in chapter_2.md")). + - The final markdown report must preserve the information density and structural quality of the draft — never replace substantive content with ellipsis/brevity markers (e.g., "omitted here", "content truncated"), pointers to external files (e.g., "details are in chapter_2.md"), or hollow reference-only placeholders (e.g., "see [1]"). - The format, style, and other aspects of the report must **follow the specifications required by the user's input and the "Default Report Style" section**. The report must include citations to reference sources. - - Writing strategy: Try to write the full report in one file_system---write_file call first. If that write attempt gets truncated by the output limit, retry from the beginning with file_system---write_file and include as much content as possible, then append the remainder using file_system---replace_file_lines with start_line=-1 in as few calls as possible. + - **Writing strategy to minimize information loss**: Prioritize writing the full report in a single file_system---write_file call. Switch to an incremental strategy if your write attempt gets truncated by the output limit: initialize the file with file_system---write_file containing as much content as possible; then sequentially append the remainder using file_system---replace_file_lines with start_line=-1 until the report is complete, utilizing as few calls as possible. - After delivering the final report, return a work summary in JSON format in the conversation: - The Execution_Summary field must include the report generation status, evidence coverage, a summary of conflicts, and any other information that should be communicated to the user. - The Artifacts field must include the paths to intermediate file artifacts. @@ -102,7 +101,7 @@ When sorting and filtering candidate evidence, the following dimensions can be r # Default Report Style - Technical/research report tone: careful and verifiable; include as much information as possible while remaining as faithful to the original evidence as possible; do not over-compress into an executive-summary-only output; avoid overly casual language; ensure readability. - Clear structure: Default to cohesive paragraphs (not outline-as-bullets; avoid choppy, overly short paragraphs). Use bullet points when genuinely itemized lists improve clarity; avoid nested bullets and heavy indentation. -- Prefer a clean heading hierarchy and numbering system, such as `# 2. Background and Problem`, `## 2.1 Background`, `### 2.1.1 Direction One`, etc. Do not exceed three levels. All section headings MUST use Markdown ATX headings. +- Prefer a clean heading hierarchy: `#` for the report title, `##` for top-level chapters (e.g., `## 2. Background and Problem`), `###` and `####` for sub-sections. Do not exceed four heading levels. All headings MUST use Markdown ATX syntax. # Output Format Return JSON only, you MUST follow this format: diff --git a/projects/deep_research/v2/prompts/researcher/en/gpt5.txt b/projects/deep_research/v2/prompts/researcher/en/gpt5.txt index 609d868e1..8891d8313 100644 --- a/projects/deep_research/v2/prompts/researcher/en/gpt5.txt +++ b/projects/deep_research/v2/prompts/researcher/en/gpt5.txt @@ -97,7 +97,7 @@ Stopping conditions (stop if you are confident to proceed to the next phase): # Default Report Style - Technical/research report tone: careful and verifiable; as much information as possible and as faithful to the original evidence as possible; do not over-compress into an executive-summary-only output; avoid overly casual language; ensure readability. - Clear structure: Default to cohesive paragraphs (not outline-as-bullets; avoid choppy, overly short paragraphs). Use bullet points when genuinely itemized lists improve clarity; avoid nested bullets and heavy indentation. -- Prefer a clean heading hierarchy and numbering system, such as `# 2. Background and Problem`, `## 2.1 Background`, `### 2.1.1 Direction One`, etc. Do not exceed three levels. All section headings MUST use Markdown ATX headings. +- Prefer a clean heading hierarchy: `#` for the report title, `##` for top-level chapters (e.g., `## 2. Background and Problem`), `###` and `####` for sub-sections. Do not exceed four heading levels. All headings MUST use Markdown ATX syntax. - Chapter titles you provide should be concise and natural-sounding. Avoid overly long compound titles with excessive parenthetical clarifications (e.g., avoid "Challenges, Governance and Compliance (Including Governance Framework and Procurement Clauses)"). - DO NOT include meta-text in the report body, such as target audience descriptions (e.g., "Target Audience: ...", "面向对象:..."), author notes (e.g., "Note: ...", "注:..."), or execution disclaimers. The report should be a polished, self-contained document. diff --git a/projects/deep_research/v2/prompts/searcher/en/gpt5.txt b/projects/deep_research/v2/prompts/searcher/en/gpt5.txt index 986a9736d..adf9a9a66 100644 --- a/projects/deep_research/v2/prompts/searcher/en/gpt5.txt +++ b/projects/deep_research/v2/prompts/searcher/en/gpt5.txt @@ -68,5 +68,5 @@ Return JSON only, you MUST follow this format: "findings": ["Core finding 1 from this research", "Core finding 2 from this research"], "issues": ["Issues or limitations encountered during this research"], "note_ids": ["note_id_1", "note_id_2", ...(all stored evidence card IDs)], - "report": "The research report body for this investigation, required to be detailed, accurate, and rigorous in organizing research results, with no subjective speculation, following standard academic writing style" + "report": "The research report body for this investigation, required to be detailed, accurate, and rigorous in organizing research results, with no subjective speculation, well-organized and evidence-based" } diff --git a/projects/deep_research/v2/reporter.yaml b/projects/deep_research/v2/reporter.yaml index 4f4d09c8b..91d6ba65e 100644 --- a/projects/deep_research/v2/reporter.yaml +++ b/projects/deep_research/v2/reporter.yaml @@ -69,8 +69,8 @@ round_reminder: self_reflection: enabled: true - max_retries: 2 - min_retention_ratio: 0.3 + max_retries: 3 + min_retention_ratio: 0.6 quality_check: enabled: true model: qwen3.5-flash diff --git a/projects/deep_research/v2/researcher.yaml b/projects/deep_research/v2/researcher.yaml index 5966d5353..50c2ece5d 100644 --- a/projects/deep_research/v2/researcher.yaml +++ b/projects/deep_research/v2/researcher.yaml @@ -152,8 +152,8 @@ handler: time_handler code_file: researcher -max_chat_round: 40 +max_chat_round: 42 -tool_call_timeout: 2000 +tool_call_timeout: 2400 output_dir: ./output From c27347f53d7c78a5399f0e4ebbb8959ba159b7de Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Fri, 20 Mar 2026 02:35:45 +0800 Subject: [PATCH 08/40] fix timeout; support for running subagent in process; support for post-report guidance --- .gitignore | 2 +- ms_agent/tools/agent_tool.py | 401 ++++++++++++++++-- .../v2/callbacks/reporter_callback.py | 104 +++++ 3 files changed, 472 insertions(+), 35 deletions(-) diff --git a/.gitignore b/.gitignore index c5aacded9..f526ef422 100644 --- a/.gitignore +++ b/.gitignore @@ -153,7 +153,7 @@ apps/agentfabric/config/local_user/* ast_index_file.py -#neo4j +# neo4j .neo4j.lock neo4j.lock /temp/ diff --git a/ms_agent/tools/agent_tool.py b/ms_agent/tools/agent_tool.py index a9b19b7d4..5c86e18c5 100644 --- a/ms_agent/tools/agent_tool.py +++ b/ms_agent/tools/agent_tool.py @@ -1,10 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio +import multiprocessing as mp import os +import threading +import traceback import uuid from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from queue import Empty as QueueEmpty +from queue import Full as QueueFull from typing import Any, Callable, Dict, List, Optional, Union import json @@ -46,12 +51,132 @@ class _AgentToolSpec: trust_remote_code: Optional[bool] env: Optional[Dict[str, str]] run_in_thread: bool + run_in_process: bool + + +_MESSAGE_FIELDS = set(Message.__dataclass_fields__.keys()) + + +def _message_from_data(data: Any) -> Message: + if isinstance(data, Message): + return data + if isinstance(data, dict): + msg_kwargs = {k: data[k] for k in _MESSAGE_FIELDS if k in data} + if 'role' not in msg_kwargs: + msg_kwargs['role'] = 'assistant' + msg_kwargs.setdefault('content', '') + return Message(**msg_kwargs) + return Message(role='assistant', content=str(data)) + + +def _build_sub_agent(spec: _AgentToolSpec, default_trust_remote_code: bool): + if spec.inline_config is not None: + config_override = OmegaConf.create(spec.inline_config) + else: + config_override = None + + trust_remote_code = spec.trust_remote_code + if trust_remote_code is None: + trust_remote_code = default_trust_remote_code + + tag = f'{spec.tag_prefix}{uuid.uuid4().hex[:8]}' + agent = AgentLoader.build( + config_dir_or_id=spec.config_path, + config=config_override, + env=spec.env, + tag=tag, + trust_remote_code=trust_remote_code, + ) + + generation_cfg = getattr(agent.config, 'generation_config', DictConfig({})) + agent.config.generation_config = generation_cfg + return agent + + +def _run_agent_in_subprocess( + spec: _AgentToolSpec, + default_trust_remote_code: bool, + payload: Any, + stream_events: bool, + event_queue: Any, + result_queue: Any, +) -> None: + sub_agent = None + try: + sub_agent = _build_sub_agent(spec, default_trust_remote_code) + run_payload = payload + if isinstance(run_payload, list): + run_payload = [_message_from_data(msg) for msg in run_payload] + + async def _runner(): + chunk_count = 0 + if stream_events: + result = await sub_agent.run(run_payload, stream=True) + else: + result = await sub_agent.run(run_payload) + if hasattr(result, '__aiter__'): + history = None + async for chunk in result: + history = chunk + if stream_events and event_queue is not None: + serialized_chunk = { + 'kind': + 'messages', + 'messages': [ + _message_from_data(msg).to_dict() + for msg in (history or []) + ], + } + try: + event_queue.put_nowait({ + 'type': 'chunk', + 'history': serialized_chunk + }) + except QueueFull: + # Avoid blocking sub-agent progress if UI/event consumer + # is temporarily slower than chunk production. + pass + chunk_count += 1 + result = history + if isinstance(result, list): + return { + 'kind': + 'messages', + 'messages': + [_message_from_data(msg).to_dict() for msg in result], + 'streamed_chunks': + chunk_count, + 'agent_tag': + getattr(sub_agent, 'tag', None), + 'agent_type': + getattr(sub_agent, 'AGENT_NAME', None), + } + return { + 'kind': 'raw', + 'raw': str(result), + 'streamed_chunks': chunk_count, + 'agent_tag': getattr(sub_agent, 'tag', None), + 'agent_type': getattr(sub_agent, 'AGENT_NAME', None), + } + + result_queue.put({'ok': True, 'result': asyncio.run(_runner())}) + except BaseException as exc: # pragma: no cover + result_queue.put({ + 'ok': False, + 'error': str(exc), + 'traceback': traceback.format_exc(), + 'agent_tag': getattr(sub_agent, 'tag', None), + 'agent_type': getattr(sub_agent, 'AGENT_NAME', None), + }) class AgentTool(ToolBase): """Expose existing ms-agent agents as callable tools.""" DEFAULT_SERVER = 'agent_tools' + _PROCESS_POLL_INTERVAL_S = 0.05 + _PROCESS_EXIT_RESULT_GRACE_S = 1.0 + _PROCESS_FINAL_JOIN_TIMEOUT_S = 1.0 def __init__(self, config: DictConfig, **kwargs): super().__init__(config) @@ -62,6 +187,8 @@ def __init__(self, config: DictConfig, **kwargs): self._thread_executor: Optional[ThreadPoolExecutor] = None self._thread_max_workers: int = 0 self._chunk_cb: Optional[Callable[..., Any]] = None + self._active_processes: Dict[str, mp.Process] = {} + self._active_processes_lock = threading.Lock() self._load_specs() self._init_thread_pool_config() @@ -180,6 +307,8 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], # Run sub-agent in a background thread to avoid blocking the main event loop # when underlying LLM SDKs are synchronous. run_in_thread = bool(getattr(cfg, 'run_in_thread', True)) + # Run sub-agent in an isolated process so timed-out calls can be killed. + run_in_process = bool(getattr(cfg, 'run_in_process', run_in_thread)) env_cfg = getattr(cfg, 'env', None) env_cfg = _to_container(env_cfg) if env_cfg is not None else None @@ -206,6 +335,7 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], trust_remote_code=trust_remote_code, env=env_cfg, run_in_thread=run_in_thread, + run_in_process=run_in_process, ) def _build_server_index(self): @@ -233,6 +363,7 @@ async def connect(self): return None async def cleanup(self): + self._terminate_all_active_processes(reason='during AgentTool cleanup') if self._thread_executor is not None: try: try: @@ -278,51 +409,135 @@ async def call_tool(self, server_name: str, *, tool_name: str, if isinstance(tool_args, dict) and '__call_id' in tool_args: call_id = tool_args.pop('__call_id', None) payload = self._build_payload(tool_args, spec) - agent = self._build_agent(spec) + use_subprocess = spec.run_in_thread and spec.run_in_process + agent = None if use_subprocess else self._build_agent(spec) messages = await self._run_agent(agent, payload, spec, call_id=call_id) return self._format_output(messages, spec) def _build_agent(self, spec: _AgentToolSpec): - if spec.inline_config is not None: - config_override = OmegaConf.create(spec.inline_config) - else: - config_override = None - - trust_remote_code = spec.trust_remote_code - if trust_remote_code is None: - trust_remote_code = self._trust_remote_code - - tag = f'{spec.tag_prefix}{uuid.uuid4().hex[:8]}' - agent = AgentLoader.build( - config_dir_or_id=spec.config_path, - config=config_override, - env=spec.env, - tag=tag, - trust_remote_code=trust_remote_code, + return _build_sub_agent(spec, self._trust_remote_code) + + @staticmethod + def _terminate_process(proc: Optional[mp.Process], *, reason: str) -> None: + if proc is None: + return + if not proc.is_alive(): + try: + proc.join(timeout=0.05) + except Exception: + pass + return + + logger.warning( + 'AgentTool subprocess pid=%s %s, terminating.', + getattr(proc, 'pid', None), + reason, ) + try: + proc.terminate() + proc.join(timeout=1.0) + except Exception: + pass + if proc.is_alive(): + logger.warning( + 'AgentTool subprocess pid=%s did not terminate gracefully, killing.', + getattr(proc, 'pid', None), + ) + try: + proc.kill() + except Exception: + pass + try: + proc.join(timeout=1.0) + except Exception: + pass + + def _register_process(self, run_id: str, proc: mp.Process) -> None: + with self._active_processes_lock: + self._active_processes[run_id] = proc + + def _unregister_process(self, run_id: str) -> None: + with self._active_processes_lock: + self._active_processes.pop(run_id, None) + + def _terminate_all_active_processes(self, *, reason: str) -> None: + with self._active_processes_lock: + active = list(self._active_processes.items()) + self._active_processes.clear() + for _, proc in active: + self._terminate_process(proc, reason=reason) + + async def _wait_process_result(self, + proc: mp.Process, + result_queue: Any, + on_poll: Optional[Callable[[], + None]] = None): + exited_at = None + while True: + if on_poll is not None: + on_poll() + try: + return result_queue.get_nowait() + except QueueEmpty: + pass + + # Process can exit slightly before queue payload becomes visible. + # Keep polling for a short grace window to avoid false "no result". + if not proc.is_alive(): + if exited_at is None: + exited_at = monotonic() + elif (monotonic() + - exited_at) >= self._PROCESS_EXIT_RESULT_GRACE_S: + return None + + await asyncio.sleep(self._PROCESS_POLL_INTERVAL_S) + + @staticmethod + def _drain_process_event_queue( + event_queue: Any, on_event: Callable[[Dict[str, Any]], + None]) -> None: + if event_queue is None: + return + while True: + try: + event = event_queue.get_nowait() + except QueueEmpty: + return + if isinstance(event, dict): + on_event(event) - generation_cfg = getattr(agent.config, 'generation_config', - DictConfig({})) - # OmegaConf.update( - # generation_cfg, - # 'stream', - # False, - # merge=True, - # ) - agent.config.generation_config = generation_cfg - return agent + def _serialize_payload_for_process(self, payload: Any) -> Any: + if not isinstance(payload, list): + return payload + return [_message_from_data(msg).to_dict() for msg in payload] + + @staticmethod + def _restore_process_result(result_payload: Dict[str, Any]) -> Any: + kind = result_payload.get('kind') + if kind == 'messages': + messages = result_payload.get('messages') or [] + return [_message_from_data(msg) for msg in messages] + return result_payload.get('raw', '') async def _run_agent(self, agent, payload, spec: _AgentToolSpec, call_id: Optional[str] = None): + runtime_agent = agent + runtime_agent_tag = getattr(runtime_agent, 'tag', None) + runtime_agent_type = getattr(runtime_agent, 'AGENT_NAME', None) async def _run_and_collect(): + nonlocal runtime_agent, runtime_agent_tag, runtime_agent_type + if runtime_agent is None: + runtime_agent = self._build_agent(spec) + runtime_agent_tag = getattr(runtime_agent, 'tag', None) + runtime_agent_type = getattr(runtime_agent, 'AGENT_NAME', None) if self._chunk_cb: - result = await agent.run(payload, stream=True) + result = await runtime_agent.run(payload, stream=True) else: - result = await agent.run(payload) + result = await runtime_agent.run(payload) if hasattr(result, '__aiter__'): history = None self._emit_chunk_event('start', { @@ -375,7 +590,124 @@ def _sync_runner(): _sync_runner) return await asyncio.to_thread(_sync_runner) - runner = _run_in_background if spec.run_in_thread else _run_and_collect + async def _run_in_subprocess(): + nonlocal runtime_agent_tag, runtime_agent_type + ctx = mp.get_context('spawn') + result_queue = ctx.Queue(maxsize=1) + event_queue = ctx.Queue( + maxsize=128) if self._chunk_cb is not None else None + proc: Optional[mp.Process] = None + run_id = f'{call_id or "agent_tool"}-{uuid.uuid4().hex[:8]}' + + def _emit_stream_event(event: Dict[str, Any]) -> None: + if not self._chunk_cb: + return + history_payload = event.get('history') + if not isinstance(history_payload, dict): + return + history = self._restore_process_result(history_payload) + self._emit_chunk_event( + 'chunk', { + 'call_id': call_id, + 'tool_name': spec.tool_name, + 'history': history, + }) + + try: + if self._chunk_cb: + self._emit_chunk_event('start', { + 'call_id': call_id, + 'tool_name': spec.tool_name, + }) + process_payload = self._serialize_payload_for_process(payload) + proc = ctx.Process( + target=_run_agent_in_subprocess, + args=(spec, self._trust_remote_code, process_payload, + self._chunk_cb + is not None, event_queue, result_queue), + name=f'agent_tool_{spec.tool_name}', + ) + proc.start() + self._register_process(run_id, proc) + result = await self._wait_process_result( + proc, + result_queue, + on_poll=lambda: self._drain_process_event_queue( + event_queue, _emit_stream_event)) + if result is None: + raise RuntimeError( + f'AgentTool subprocess exited without result: {spec.tool_name}' + ) + self._drain_process_event_queue(event_queue, + _emit_stream_event) + if not result.get('ok'): + runtime_agent_tag = result.get( + 'agent_tag') or runtime_agent_tag + runtime_agent_type = result.get( + 'agent_type') or runtime_agent_type + tb = result.get('traceback', '') + if tb: + logger.warning(tb) + raise RuntimeError( + f'Sub-agent {spec.tool_name} failed: {result.get("error", "unknown error")}' + ) + result_payload = result.get('result', {}) or {} + runtime_agent_tag = result_payload.get( + 'agent_tag') or runtime_agent_tag + runtime_agent_type = result_payload.get( + 'agent_type') or runtime_agent_type + restored = self._restore_process_result(result_payload) + streamed_chunks = int( + result_payload.get('streamed_chunks', 0) or 0) + if self._chunk_cb: + if streamed_chunks <= 0: + self._emit_chunk_event( + 'chunk', { + 'call_id': call_id, + 'tool_name': spec.tool_name, + 'history': restored, + }) + self._emit_chunk_event( + 'end', { + 'call_id': call_id, + 'tool_name': spec.tool_name, + 'history': restored, + }) + return restored + except asyncio.CancelledError: + self._terminate_process(proc, reason='was cancelled') + raise + except Exception: + self._terminate_process(proc, reason='encountered error') + raise + finally: + self._unregister_process(run_id) + if proc is not None: + try: + proc.join(timeout=self._PROCESS_FINAL_JOIN_TIMEOUT_S) + except Exception: + pass + if proc.is_alive(): + self._terminate_process( + proc, reason='did not exit after result handling') + try: + result_queue.close() + result_queue.join_thread() + except Exception: + pass + if event_queue is not None: + try: + event_queue.close() + event_queue.join_thread() + except Exception: + pass + + if spec.run_in_thread and spec.run_in_process: + runner = _run_in_subprocess + elif spec.run_in_thread: + runner = _run_in_background + else: + runner = _run_and_collect if not self._enable_stats: return await runner() @@ -387,8 +719,9 @@ def _sync_runner(): try: result = await runner() return result - except Exception: - status = 'error' + except BaseException as exc: + status = 'cancelled' if isinstance( + exc, asyncio.CancelledError) else 'error' raise finally: end_ts = now_iso() @@ -396,8 +729,8 @@ def _sync_runner(): usage = summarize_usage(result if isinstance(result, list) else []) record = build_timing_record( event='agent_tool', - agent_tag=getattr(agent, 'tag', None), - agent_type=getattr(agent, 'AGENT_NAME', None), + agent_tag=runtime_agent_tag, + agent_type=runtime_agent_type, started_at=start_ts, ended_at=end_ts, duration_s=duration_s, diff --git a/projects/deep_research/v2/callbacks/reporter_callback.py b/projects/deep_research/v2/callbacks/reporter_callback.py index 0200f6cb7..72a914524 100644 --- a/projects/deep_research/v2/callbacks/reporter_callback.py +++ b/projects/deep_research/v2/callbacks/reporter_callback.py @@ -160,6 +160,81 @@ class ReporterCallback(Callback): }, } + _POST_REPORT_GUIDANCE = { + 'zh': + ('\n\n---\n' + '**[后续工作流程建议]**\n\n' + 'Reporter 已完成报告生成。如果其正常返回工作总结,请仔细审阅返回内容的 Execution_Summary 和 Artifacts 字段,' + '它们总结了报告生成过程并列出了重要的中间文件产物。如果其未正常完成任务或者未正常返回信息,请主动检查 reports 目录下的产物情况确定后续行动。\n\n' + '**关于 final_report.md:' + '** 上方 Artifacts 字段通常只包含 reports/ 目录下的文件(如 reports/report.md),' + '不包含 final_report.md。这是正常的——系统会在 Reporter 正常完成任务后自动将 reports/report.md 复制为 final_report.md。' + '你的审阅和编辑应优先针对 final_report.md。如有需要可按需读取 reports/ 下的其他文件作为参考,' + '但当 final_report.md 可用时避免重复读取 reports/report.md。如果 final_report.md 意外缺失或不完整,按此路径回退:' + 'reports/report.md -> reports/draft.md -> reports/ 下其他产物内容。\n\n' + '**审查与编辑注意事项:**\n' + '- 请严格遵守系统指令中的要求,不要遗漏、忽略任何合理的规则。\n' + '- 审查要点包括事实准确性、逻辑一致性、用户核心问题的覆盖度、引用与论据的对齐关系、引用格式问题、内容完整性等等。' + '修改须有明确依据(如事实冗余、逻辑混乱、证据不一致、格式出错等),不要为了"润色"而改动结构/质量良好的内容。\n' + '- 读取报告内容一次后形成判断,后续核查优先使用 search_file_content 或带 start_line / end_line 的 read_file,不要反复全量读取同一文件。' + '在读取文件前先检查对话历史中是否已包含该文件的内容,避免重复读取。\n' + '- 优先使用定点修改(search_file_content -> replace_file_contents / replace_file_lines),仅在必要时才读取全文。' + '仅在定点修改完全无法解决时使用 write_file,且**必须完整保留所有有价值的内容**,严禁使用占位符、省略标记、引用其他内容等方式替代正文。\n' + '- 质量较高无需修改的部分直接跳过。如果[Reporter 工作总结]中无异常且审查确认全文质量良好,直接进入结论阶段即可。\n\n' + '**需避免的常见错误:**\n' + '- 重复全量读取同一个报告文件(迅速耗尽上下文预算,导致任务失败)。\n' + '- 默认 final_report.md 不存在、且使用简短的概述内容覆盖完整报告。\n' + '- 对结构/质量良好的内容过度修改或者压缩,或在修改过程中忘记已做的改动重复编辑导致错误。\n'), + 'en': + ('\n\n---\n' + '**[Post-Report Workflow Guidance]**\n\n' + 'The Reporter has finished generating the report. If it returned a work summary normally, ' + 'please carefully review the Execution_Summary and Artifacts fields in the returned content — ' + 'they summarize the report generation process and list important intermediate file artifacts. ' + 'If Reporter did not complete the task normally or did not return information properly, ' + 'proactively check the artifacts under the `reports/` directory to determine next steps.\n\n' + '**About `final_report.md`:** The Artifacts field above typically lists only ' + 'files under `reports/` (e.g., `reports/report.md`) and will NOT include ' + '`final_report.md`. This is expected — the system automatically copies ' + '`reports/report.md` to `final_report.md` after the Reporter finishes normally. ' + 'Your review and edits should target `final_report.md`. You may read other ' + 'files under `reports/` as supplementary references when needed, ' + 'but avoid reading `reports/report.md` in full when ' + '`final_report.md` is available. If `final_report.md` is unexpectedly ' + 'missing or incomplete, fall back in this order: ' + '`reports/report.md` -> `reports/draft.md` -> other artifacts under `reports/`.\n\n' + '**Review and editing guidelines:**\n' + '- Strictly follow the requirements in the system instructions; do not overlook or ignore any reasonable rules.\n' + '- Key review points include factual accuracy, logical consistency, coverage of the user\'s core questions, ' + 'alignment between citations and supporting arguments, citation formatting issues, content completeness, etc. ' + 'Edits must have clear justification (e.g., factual redundancy, logical confusion, evidence inconsistency, ' + 'formatting errors, etc.) — do not alter well-structured, high-quality content merely for "polishing."\n' + '- Read the report content ONCE to form your assessment. For subsequent ' + 'checks, prefer `search_file_content` or `read_file` with `start_line`/`end_line`. ' + 'Do not re-read the entire file repeatedly. Check your conversation history before ' + 'reading any file to avoid redundant reads.\n' + '- Prefer targeted fixes (`search_file_content` -> `replace_file_contents` / ' + '`replace_file_lines`); only read the full text when necessary. ' + 'Use `write_file` only when targeted fixes are completely insufficient, ' + 'and you **must preserve ALL valuable content in full** — never use placeholders, ' + 'ellipsis markers, or references to other content as substitutes for actual text.\n' + '- Skip high-quality sections that require no changes. If the [Reporter Work Summary] ' + 'indicates no issues and your review confirms overall quality, proceed ' + 'directly to the conclusion.\n\n' + '**Common mistakes to avoid:**\n' + '- Reading the same report file in full multiple times (rapidly exhausts ' + 'context budget and causes task failure).\n' + '- Assuming `final_report.md` does not exist and overwriting the complete report ' + 'with a brief summary.\n' + '- Over-editing or compressing well-structured, high-quality content, or losing track ' + 'of changes already made and making duplicate edits that introduce errors.\n'), + } + + _WORK_SUMMARY_LABEL = { + 'zh': '**[Reporter 工作总结]**', + 'en': '**[Reporter Work Summary]**', + } + def __init__(self, config: DictConfig): super().__init__(config) self.output_dir = getattr(config, 'output_dir', './output') @@ -185,6 +260,7 @@ def __init__(self, config: DictConfig): self.reflection_enabled: bool = False self.reflection_max_retries: int = 2 self.min_retention_ratio: float = self.DEFAULT_MIN_RETENTION_RATIO + self.post_report_guidance_enabled: bool = False if refl_cfg is not None: self.reflection_enabled = bool(getattr(refl_cfg, 'enabled', False)) @@ -193,6 +269,8 @@ def __init__(self, config: DictConfig): self.min_retention_ratio = float( getattr(refl_cfg, 'min_retention_ratio', self.DEFAULT_MIN_RETENTION_RATIO)) + self.post_report_guidance_enabled = bool( + getattr(refl_cfg, 'post_report_guidance_enabled', False)) self._reflection_retries_used: int = 0 self._quality_checkers: List[ReportQualityChecker] = ( @@ -217,6 +295,28 @@ def _get_reflection(self, key: str, **kwargs) -> str: self.lang, self._REFLECTION_TEMPLATES['en']) return templates[key].format(**kwargs) + def _append_post_report_guidance(self, messages: List[Message]): + """Append post-report workflow guidance to the Reporter's final message. + + The guidance is appended to the last non-tool-call assistant message + so that it appears as part of the tool result when the parent agent + (Researcher) receives the Reporter's output via AgentTool. + """ + guidance = self._POST_REPORT_GUIDANCE.get( + self.lang, self._POST_REPORT_GUIDANCE['en']) + label = self._WORK_SUMMARY_LABEL.get( + self.lang, self._WORK_SUMMARY_LABEL['en']) + for message in reversed(messages): + if message.role == 'assistant' and not message.tool_calls: + message.content = label + '\n\n' + (message.content or '') + guidance + logger.info( + 'ReporterCallback: appended post-report guidance ' + f'to final assistant message ({len(guidance)} chars)') + return + logger.warning( + 'ReporterCallback: no suitable assistant message found ' + 'for post-report guidance injection.') + def _load_researcher_history(self) -> Optional[List[Dict[str, Any]]]: """ Load the researcher agent's message history from the memory file. @@ -669,3 +769,7 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): else: logger.warning('Reporter: no report file found to promote to ' f'{self.FINAL_REPORT_FILENAME}') + + # --- Step 3: Append post-report workflow guidance --- + if self.post_report_guidance_enabled: + self._append_post_report_guidance(messages) From 5920dc4273b1dcb4ade6d260e47e637820d265ee Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Fri, 20 Mar 2026 17:40:33 +0800 Subject: [PATCH 09/40] refine readme for deep research; add run_benchmark.sh; fix counting character for report qa --- README.md | 2 +- README_ZH.md | 3 +- projects/deep_research/v2/README.md | 223 +++++++++++++++--- projects/deep_research/v2/README_zh.md | 217 +++++++++++++++-- .../v2/callbacks/reporter_callback.py | 14 +- projects/deep_research/v2/reporter.yaml | 1 + projects/deep_research/v2/run_benchmark.sh | 181 ++++++++++++++ 7 files changed, 580 insertions(+), 61 deletions(-) create mode 100755 projects/deep_research/v2/run_benchmark.sh diff --git a/README.md b/README.md index ae03ea8c3..2faf03c78 100644 --- a/README.md +++ b/README.md @@ -341,7 +341,7 @@ The **MS-Agent Skill Module** is **Implementation** of [Anthropic-Agent-Skills]( For more details, please refer to [**MS-Agent Skills**](ms_agent/skill/README.md). -### Agentic Insight +### Agentic Insight (Deep Research) #### - Lightweight, Efficient, and Extensible Multi-modal Deep Research Framework diff --git a/README_ZH.md b/README_ZH.md index 1e0161d50..93affd74a 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -311,8 +311,7 @@ asyncio.run(main()) --- - -### Agentic Insight +### Agentic Insight (Deep Research) #### - 轻量级、高效且可扩展的多模态深度研究框架 diff --git a/projects/deep_research/v2/README.md b/projects/deep_research/v2/README.md index a601ad171..a65ed9025 100644 --- a/projects/deep_research/v2/README.md +++ b/projects/deep_research/v2/README.md @@ -1,4 +1,3 @@ - # Agentic Insight v2 Agentic Insight v2 provides a more scalable framework for deep research, enabling agents to autonomously explore and execute complex tasks. @@ -7,10 +6,10 @@ Agentic Insight v2 provides a more scalable framework for deep research, enablin Agentic Insight v2 is designed around: -- **Extensible main-agent + sub-agent architecture**: a Researcher orchestrates Searcher/Reporter and can be extended with new sub agents and tools. +- **Extensible main-agent + sub-agent architecture**: a Researcher orchestrates Searcher/Reporter and can be extended with new sub-agents and tools. - **File-system based context management**: flexible, debuggable, and resume-friendly context via structured artifacts on disk. - **Deep-research optimized toolchain**: dedicated todo, evidence, search, and report tools tuned for iterative research loops. -- **Evidence-bound report generation**: reports are generated from raw evidence with explicit bindings for higher trustworthiness. +- **Evidence-bound report generation**: reports are generated from raw evidence with explicit bindings for higher trustworthiness and traceability. ### 🚀 Quickstart @@ -28,37 +27,197 @@ pip install -e . pip install 'ms-agent[research]' ``` -#### Environment variables (`.env`) +#### Environment Variables -From repo root: +Create `.env` file in repository root: ```bash cp projects/deep_research/.env.example .env ``` -Edit `.env` and set: +Edit `.env` and set the following **required** environment variables: + +```bash +# LLM Configuration (Required) +OPENAI_API_KEY=your_api_key +OPENAI_BASE_URL=https://your-openai-compatible-endpoint/v1 + +# Search Engine Configuration (choose one, or use default arxiv with no config needed) +EXA_API_KEY=your_exa_key # Recommended, register at: https://exa.ai +# SERPAPI_API_KEY=your_serpapi_key # Or choose SerpApi, register at: https://serpapi.com +``` + +#### Model Configuration (⚠️ Required for First Run) + +v2 uses three YAML config files to drive the Researcher, Searcher, and Reporter agents. **Before first run, you must modify model names according to your LLM provider**, otherwise you may get model-not-found errors. If you want each agent to use a different model or provider, modify the `llm` section in the corresponding YAML independently; otherwise the defaults from `.env` are used. + +##### Models to Configure + +For balanced performance and cost, we recommend a **tiered model configuration** — choosing different models for each agent based on its role and requirements. + +| YAML File | Config Path | Current Default | Description | Recommendation | +|-----------|-------------|-----------------|-------------|----------------| +| `researcher.yaml` | `llm.model` | `gpt-5-2025-08-07` | Researcher Agent (main agent) | Use a stronger model (e.g. `qwen3-max` / `gpt-5`) for task planning and coordination | +| `searcher.yaml` | `llm.model` | `qwen3.5-plus` | Searcher Agent | Can use same or slightly weaker model (e.g. `qwen3.5-plus` / `MiniMax-M2.5`) | +| `searcher.yaml` | `tools.web_search.summarizer_model` | `qwen3.5-flash` | Web page summarization model (optional) | Use a fast, cheap model (e.g. `qwen3.5-flash` / `gpt-4.1-mini`) | +| `reporter.yaml` | `llm.model` | `qwen3.5-plus` | Reporter Agent | Can use same or slightly weaker model (e.g. `qwen3.5-plus` / `MiniMax-M2.5`) | +| `researcher.yaml` / `reporter.yaml` | `self_reflection.quality_check.model` | `qwen3.5-flash` | Quality check model (optional) | Use a fast, cheap model (e.g. `qwen3.5-flash` / `gpt-4.1-mini`) | + +##### Common LLM Provider Examples + +Modify model names in YAML files according to your provider: + +**Using OpenAI:** + +```yaml +# Agent configuration +llm: + service: openai + model: gpt-5-2025-08-07 + openai_api_key: + openai_base_url: + +# Also modify quality_check and summarizer_model (defaults to OpenAI-compatible provider): +tools: + web_search: + summarizer_model: qwen3.5-flash + summarizer_api_key: + summarizer_base_url: + +self_reflection: + quality_check: + enabled: true + model: qwen3-flash + openai_api_key: + openai_base_url: +``` + +**Other Compatible Endpoints:** Refer to your provider's documentation for model identifiers. + +#### Search Engine Configuration + +Edit `searcher.yaml` to configure search engines: + +```yaml +tools: + web_search: + engines: + - exa # or serpapi (requires corresponding API key in .env) + - arxiv # arxiv requires no API key, always available + api_key: # When using EXA + # Or when using SerpApi, add (uncomment): + # serpapi_provider: google # Options: google, bing, baidu +``` + +**Default:** If no search engine API key is configured, system will use `arxiv` (academic literature search only). + +#### Advanced Configuration (Optional) + +##### Web Page Summarization + +Enabled by default to compress long web content, reducing context bloat, speeding up research, and saving cost: + +```yaml +tools: + web_search: + enable_summarization: true + summarizer_model: qwen3.5-flash # Can switch to a cheaper model + max_content_chars: 200000 # Max content chars allowed for summarization; content beyond this is truncated + summarizer_max_workers: 15 + summarization_timeout: 360 +``` + +**Note:** Summarization makes additional LLM calls consuming more tokens, but significantly reduces the Searcher Agent's context length. + +##### Quality Check -- `OPENAI_API_KEY` (key of OpenAI-compatible endpoint) -- `OPENAI_BASE_URL` (OpenAI-compatible endpoint) -- One of: - - `EXA_API_KEY` (recommended, register at [Exa](https://exa.ai), free quota available) - - `SERPAPI_API_KEY` (register at [SerpApi](https://serpapi.com), free quota available) +Both Researcher and Reporter have quality check mechanisms for verifying report generation quality: -Notes: +```yaml +self_reflection: + enabled: true + max_retries: 2 # Max check rounds + quality_check: + enabled: true + model: qwen3.5-flash +``` -- v2 configs use placeholders like `` / ``, which are replaced from environment variables at runtime. -- Do not hardcode keys in scripts; keep them in `.env` (and never commit `.env`). +##### Prefix Cache (Prompt Caching) -#### Run (Researcher entry) +Explicitly triggers cache creation and hits to improve speed and reduce cost (only supported by some providers and models): + +```yaml +generation_config: + force_prefix_cache: true # Auto-detects provider support + prefix_cache_roles: [system, user, assistant, tool] # Roles to explicitly request caching for +``` + +**Supported Providers:** DashScope, Anthropic, and some others. If encountering errors, set to `false`. + +#### Configuration File Locations + +v2's three YAML config files are located at: + +- `projects/deep_research/v2/researcher.yaml` - Researcher main agent config +- `projects/deep_research/v2/searcher.yaml` - Searcher search agent config +- `projects/deep_research/v2/reporter.yaml` - Reporter report generation config + +**Placeholder Note:** Placeholders like `` / `` in YAMLs are automatically replaced from `.env` environment variables at runtime. **Do not hardcode API keys in YAMLs** to reduce leak risk. + +#### Run + +##### Command Line ```bash PYTHONPATH=. python ms_agent/cli/cli.py run \ --config projects/deep_research/v2/researcher.yaml \ --query "Write your research question here" \ --trust_remote_code true \ - --output_dir "output/deep_research/runs" + --output_dir "output/deep_research/runs" \ + --load_cache true # Load cache from previous run to resume ``` +##### Benchmark Script + +We provide `run_benchmark.sh` to run a single demo query or reproduce the full benchmark suite. +**All commands below must be run from the repository root directory.** + +**Mode 1 — Single demo query** (no extra setup required): + +```bash +bash projects/deep_research/v2/run_benchmark.sh +``` + +When `DR_BENCH_ROOT` is **not** set, the script runs a single built-in demo query and saves results to `output/deep_research/benchmark_run/`. + +**Mode 2 — Full benchmark suite** (requires the benchmark dataset): + +```bash +DR_BENCH_ROOT=/path/to/deep_research_bench bash projects/deep_research/v2/run_benchmark.sh +``` + +When `DR_BENCH_ROOT` is set, the script reads all queries from `$DR_BENCH_ROOT/data/prompt_data/query.jsonl` and runs them in parallel via `dr_bench_runner.py`. You can override additional parameters: + +```bash +DR_BENCH_ROOT=/path/to/deep_research_bench \ + WORKERS=3 \ + LIMIT=5 \ + MODEL_NAME=my_experiment \ + WORK_ROOT=temp/benchmark_runs \ + OUTPUT_JSONL=/path/to/ms_deepresearch_v2_benchmark.jsonl \ + bash projects/deep_research/v2/run_benchmark.sh +``` + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `WORKERS` | `2` | Number of parallel workers | +| `LIMIT` | `0` | Max queries to run (`0` = all) | +| `MODEL_NAME` | `ms_deepresearch_v2_benchmark` | Experiment name for output file | +| `WORK_ROOT` | `temp/benchmark_runs` | Working directory for intermediate results | +| `OUTPUT_JSONL` | `$DR_BENCH_ROOT/data/test_data/raw_data/.jsonl` | Output JSONL path | + +**Note:** The script automatically reads API keys from `.env` in the repository root. Ensure environment variables are properly configured before running. + #### Run in WebUI You can also use Agentic Insight v2 from the built-in WebUI: @@ -74,22 +233,30 @@ Then open `http://localhost:7860`, select **Deep Research**, and make sure you h You can set them via `.env` or in WebUI **Settings**. WebUI run artifacts are stored under `webui/work_dir//`. -### Key configs (what to edit) - -- `projects/deep_research/v2/researcher.yaml` - - Researcher orchestration prompt and workflow-level settings. -- `projects/deep_research/v2/searcher.yaml` - - Search engines (exa/arxiv/serpapi), fetching/summarization, evidence store settings. -- `projects/deep_research/v2/reporter.yaml` - - Report generation workflow and report artifacts directory. - -### Outputs (where to find results) +### Outputs (Where to Find Results) Given `--output_dir output/deep_research/runs`: - **Final report (user-facing)**: `output/deep_research/runs/final_report.md` -- **Todo list**: `output/deep_research/runs/plan.json(.md)` +- **Plan list**: `output/deep_research/runs/plan.json(.md)` - **Evidence store**: `output/deep_research/runs/evidence/` - - `index.json` and `notes/` are used by Reporter to cite sources. + - `index.json` and `notes/` are used by Reporter to generate the report. - **Reporter artifacts**: `output/deep_research/runs/reports/` - Outline, chapters, draft, and the assembled report artifact. + +### ❓ Troubleshooting + +| Error Type | Possible Cause | Solution | +|-----------|---------------|----------| +| `Model not found` / `Invalid model` | Model name in YAML doesn't match API endpoint | Check and modify `llm.model`, `summarizer_model`, and `quality_check.model` in the three YAMLs to match your provider | +| `Invalid API key` / `Unauthorized` | API key in `.env` is incorrect or expired | Verify `OPENAI_API_KEY` in `.env` is correct, or regenerate API key | +| `Search engine error` / `EXA_API_KEY not found` | Search engine API key not configured | Add `EXA_API_KEY` or `SERPAPI_API_KEY` to `.env`, or modify `searcher.yaml` to use only `arxiv` | +| 400 error / `Invalid request body` | Some generation parameters incompatible | Remove unsupported fields from `generation_config` in the YAML | +| `Timeout` / Timeout errors | Network issues or request too long | Check network connection, or increase `tool_call_timeout` value in the YAML | +| Output too short or incomplete | Model generation parameters limiting | Add or increase `max_tokens` value in `generation_config` in the YAML | +| Stuck mid-execution | Sub-agent in infinite loop or waiting | Check log files in `output_dir` to see which agent is stuck; may need to adjust `max_chat_round` | +| `.env` file not found | `.env` in wrong location | Ensure `.env` is in **repository root**, not in `projects/deep_research/` or `v2/` directories | + +#### Getting Help + +- Report issues: [GitHub Issues](https://github.com/modelscope/ms-agent/issues) diff --git a/projects/deep_research/v2/README_zh.md b/projects/deep_research/v2/README_zh.md index 73cee097e..fc1ac0be0 100644 --- a/projects/deep_research/v2/README_zh.md +++ b/projects/deep_research/v2/README_zh.md @@ -1,4 +1,3 @@ - # Agentic Insight v2 Agentic Insight v2提供了一个更具可扩展性的深度研究框架,使智能体能够自主探索并执行复杂任务。 @@ -28,37 +27,197 @@ pip install -e . pip install 'ms-agent[research]' ``` -#### 环境变量(`.env`) +#### 环境变量配置 -在仓库根目录执行: +在仓库根目录创建 `.env` 文件: ```bash cp projects/deep_research/.env.example .env ``` -编辑 `.env` 并设置: +编辑 `.env` 并设置以下**必需**环境变量: + +```bash +# LLM 配置(必需) +OPENAI_API_KEY=your_api_key +OPENAI_BASE_URL=https://your-openai-compatible-endpoint/v1 + +# 搜索引擎配置(二选一,或使用默认的 arxiv 无需配置) +EXA_API_KEY=your_exa_key # 推荐,注册:https://exa.ai +# SERPAPI_API_KEY=your_serpapi_key # 或者选择 SerpApi,注册:https://serpapi.com +``` + +#### 模型配置(⚠️ 首次运行必读) + +v2 使用三个 YAML 配置文件驱动 Researcher、Searcher 和 Reporter 三个 Agent。**在首次运行前,必须根据你的 LLM 服务商修改模型名称**,否则可能会因模型不存在而报错。如果希望每个 Agent 使用不同的模型和供应商,请在对应的 yaml 内独立修改 llm 字段下的配置,否则默认使用 `.env` 中的配置。 + +##### 需要配置的模型 + +为了平衡性能和成本,建议采用**分层模型配置**,即根据 Agent 的职责和需求,选择不同的模型和供应商。 + +| YAML 文件 | 配置路径 | 当前默认值 | 说明 | 选型建议 | +| ----------------------------------- | ------------------------------------- | ------------------ | ------------------------- | ---------------------------------------------- | +| `researcher.yaml` | `llm.model` | `gpt-5-2025-08-07` | Researcher Agent(主 Agent) | 使用较强的模型(如 `qwen3-max` / `gpt-5`),负责任务规划和协调 | +| `searcher.yaml` | `llm.model` | `qwen3.5-plus` | Searcher Agent | 可使用相同或稍弱的模型(如 `qwen3.5-plus` / `MiniMax-M2.5`) | +| `searcher.yaml` | `tools.web_search.summarizer_model` | `qwen3.5-flash` | 网页总结模型(可选功能) | 使用快速便宜的模型(如 `qwen3.5-flash` / `gpt-4.1-mini`) | +| `reporter.yaml` | `llm.model` | `qwen3.5-plus` | Reporter Agent | 可使用相同或稍弱的模型(如 `qwen3.5-plus` / `MiniMax-M2.5`) | +| `researcher.yaml` / `reporter.yaml` | `self_reflection.quality_check.model` | `qwen3.5-flash` | 质量检查模型(可选功能) | 使用快速便宜的模型(如 `qwen3.5-flash` / `gpt-4.1-mini`) | + +##### 常见 LLM 服务商配置示例 + +根据你使用的服务商,修改 YAML 文件中的模型名称: + +**使用 OpenAI:** + +```yaml +# Agent 配置 +llm: + service: openai + model: gpt-5-2025-08-07 + openai_api_key: + openai_base_url: + +# 同时修改 quality_check 和 summarizer_model(默认使用openai兼容供应商): +tools: + web_search: + summarizer_model: qwen3.5-flash + summarizer_api_key: + summarizer_base_url: + +self_reflection: + quality_check: + enabled: true + model: qwen3-flash + openai_api_key: + openai_base_url: +``` + +**使用其他兼容端点:** 请参考服务商文档中的模型标识符。 + +#### 搜索引擎配置 + +编辑 `searcher.yaml`,配置搜索引擎: + +```yaml +tools: + web_search: + engines: + - exa # 或 serpapi(需要在 .env 配置对应的 API key) + - arxiv # arxiv 无需 API key,始终可用 + api_key: # 使用 EXA 时 + # 或使用 SerpApi 时,额外配置(取消注释): + # serpapi_provider: google # 可选:google, bing, baidu +``` + +**默认配置:** 如果不配置搜索引擎 API key,系统会使用 `arxiv`(仅限学术文献搜索)。 + +#### 高级配置(可选) + +##### 网页摘要功能 + +默认开启,用于压缩长网页内容以减少上下文膨胀、加速搜索调研过程、节约成本: + +```yaml +tools: + web_search: + enable_summarization: true + summarizer_model: qwen3.5-flash # 可换成更便宜的模型 + max_content_chars: 200000 # 允许进行摘要的最大内容字符数,超过后会截断 + summarizer_max_workers: 15 + summarization_timeout: 360 +``` + +**注意:** 摘要功能会额外调用 LLM,消耗更多 token,但能显著减少 Searcher Agent 的上下文长度。 + +##### 质量检查功能 -- `OPENAI_API_KEY`(OpenAI-compatible endpoint 的 key) -- `OPENAI_BASE_URL`(OpenAI-compatible endpoint) -- 二选一: - - `EXA_API_KEY`(推荐,在 [Exa](https://exa.ai) 注册,提供免费额度) - - `SERPAPI_API_KEY`(在 [SerpApi](https://serpapi.com) 注册,提供免费额度) +Researcher 和 Reporter 都配置了质量检查机制,用于检查报告生成质量: -说明: +```yaml +self_reflection: + enabled: true + max_retries: 2 # 最大检查次数 + quality_check: + enabled: true + model: qwen3.5-flash +``` -- v2 配置使用 `` / `` 这类占位符,运行时会自动从环境变量替换。 -- 不要在脚本里硬编码 key;请放在 `.env` 中(并确保 `.env` 不提交到仓库)。 +##### Prefix Cache(提示词缓存) -#### 运行(Researcher 入口) +用于显式触发缓存创建和命中,提高速度、降低成本(仅部分服务商和模型支持): + +```yaml +generation_config: + force_prefix_cache: true # 自动检测服务商是否支持 + prefix_cache_roles: [system, user, assistant, tool] # 显式申请缓存的位置 +``` + +**支持的服务商:** DashScope、Anthropic、部分其他服务商。如遇错误,请设为 `false`。 + +#### 配置文件位置 + +v2 的三个 YAML 配置文件位于: + +- `projects/deep_research/v2/researcher.yaml` - Researcher 主 Agent 配置 +- `projects/deep_research/v2/searcher.yaml` - Searcher 搜索 Agent 配置 +- `projects/deep_research/v2/reporter.yaml` - Reporter 报告生成 Agent 配置 + +**占位符说明:** YAML 中的 `` / `` 等占位符会在运行时自动从 `.env` 环境变量替换,**请勿在 YAML 中硬编码 API key**以降低泄露风险。 + +#### 运行 + +##### 命令行运行 ```bash PYTHONPATH=. python ms_agent/cli/cli.py run \ --config projects/deep_research/v2/researcher.yaml \ --query "在这里写你的研究问题" \ --trust_remote_code true \ - --output_dir "output/deep_research/runs" + --output_dir "output/deep_research/runs" \ + --load_cache true # 加载上一次运行的缓存继续运行 ``` +##### Benchmark 脚本 + +我们提供了 `run_benchmark.sh`,支持运行单条 demo query 或复现完整 benchmark 测试结果。 +**以下所有命令均需在仓库根目录下执行。** + +**模式一 — 单条 demo query**(无需额外配置): + +```bash +bash projects/deep_research/v2/run_benchmark.sh +``` + +当 `DR_BENCH_ROOT` **未设置**时,脚本会运行一条内置的 demo query,结果保存至 `output/deep_research/benchmark_run/`。 + +**模式二 — 完整 benchmark 全量测试**(需要 benchmark 数据集): + +```bash +DR_BENCH_ROOT=/path/to/deep_research_bench bash projects/deep_research/v2/run_benchmark.sh +``` + +当 `DR_BENCH_ROOT` **已设置**时,脚本会从 `$DR_BENCH_ROOT/data/prompt_data/query.jsonl` 读取全部 query,通过 `dr_bench_runner.py` 并行执行。可通过环境变量覆盖默认参数: + +```bash +DR_BENCH_ROOT=/path/to/deep_research_bench \ + WORKERS=3 \ + LIMIT=5 \ + MODEL_NAME=ms_deepresearch_v2_benchmark \ + WORK_ROOT=temp/benchmark_runs \ + OUTPUT_JSONL=/path/to/ms_deepresearch_v2_benchmark.jsonl \ + bash projects/deep_research/v2/run_benchmark.sh +``` + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `WORKERS` | `2` | 并行 worker 数量 | +| `LIMIT` | `0` | 最多运行多少条 query(`0` = 全部) | +| `MODEL_NAME` | `ms_deepresearch_v2_benchmark` | 实验名称,用于输出文件命名 | +| `WORK_ROOT` | `temp/benchmark_runs` | 中间结果工作目录(默认使用临时目录) | +| `OUTPUT_JSONL` | `$DR_BENCH_ROOT/data/test_data/raw_data/.jsonl` | 输出 JSONL 路径 | + +**注意:** 脚本会从仓库根目录的 `.env` 自动读取 API keys,请确保已正确配置环境变量。 + #### 在 WebUI 中使用 你也可以在内置 WebUI 中使用 Agentic Insight v2: @@ -74,22 +233,30 @@ ms-agent ui 你可以通过 `.env` 或 WebUI 的 **Settings** 进行配置。WebUI 的运行产物会保存在 `webui/work_dir//` 下。 -### 关键配置(常改位置) - -- `projects/deep_research/v2/researcher.yaml` - - Researcher 的编排提示词与工作流级别设置。 -- `projects/deep_research/v2/searcher.yaml` - - 搜索引擎(exa/arxiv/serpapi)、抓取/摘要、证据存储等设置。 -- `projects/deep_research/v2/reporter.yaml` - - 报告生成工作流与报告产物目录设置。 - ### 输出(结果位置) 假设你使用 `--output_dir output/deep_research/runs`: - **最终报告(面向用户)**:`output/deep_research/runs/final_report.md` -- **Todo 列表**:`output/deep_research/runs/plan.json(.md)` +- **计划列表**:`output/deep_research/runs/plan.json(.md)` - **证据库**:`output/deep_research/runs/evidence/` - - `index.json` 与 `notes/` 会被 Reporter 用来生成引用。 + - `index.json` 与 `notes/` 会被 Reporter 用来生成报告。 - **Reporter 中间产物**:`output/deep_research/runs/reports/` - 大纲、章节、草稿与汇总后的报告产物。 + +### ❓ 故障排查 + +| 错误类型 | 可能原因 | 解决方法 | +| ----------------------------------------------- | ------------------------- | ---------------------------------------------------------------------------------- | +| `Model not found` / `Invalid model` | YAML 中的模型名与 API 端点不匹配 | 检查并修改三个 YAML 文件的 `llm.model`、`summarizer_model` 和 `quality_check.model`,确保与你的服务商匹配 | +| `Invalid API key` / `Unauthorized` | `.env` 中的 API key 不正确或已过期 | 检查 `.env` 中的 `OPENAI_API_KEY` 是否正确,或重新生成 API key | +| `Search engine error` / `EXA_API_KEY not found` | 搜索引擎 API key 未配置 | 在 `.env` 添加 `EXA_API_KEY` 或 `SERPAPI_API_KEY`,或修改 `searcher.yaml` 仅使用 `arxiv` | +| 请求 400 错误 / `Invalid request body` | 某些生成参数不兼容 | 在对应 YAML 的 `generation_config` 中删除不支持的字段 | +| `Timeout` / 超时错误 | 网络问题或请求时间过长 | 检查网络连接,或在对应 YAML 中增加 `tool_call_timeout` 的值 | +| 输出内容过短或不完整 | 模型生成参数限制 | 在对应 YAML 的 `generation_config` 中添加或增大 `max_tokens` 的值 | +| 运行到一半卡住 | 某个子 Agent 陷入死循环或等待 | 检查 `output_dir` 下的日志文件,查看是哪个 Agent 卡住,可能需要调整 `max_chat_round` | +| 找不到 `.env` 文件 | `.env` 文件位置不正确 | 确保 `.env` 文件在**仓库根目录**,而不是 `projects/deep_research/` 或 `v2/` 目录下 | + +#### 获取更多帮助 + +- 报告问题:[GitHub Issues](https://github.com/modelscope/ms-agent/issues) diff --git a/projects/deep_research/v2/callbacks/reporter_callback.py b/projects/deep_research/v2/callbacks/reporter_callback.py index 72a914524..7bfb5bee3 100644 --- a/projects/deep_research/v2/callbacks/reporter_callback.py +++ b/projects/deep_research/v2/callbacks/reporter_callback.py @@ -610,8 +610,10 @@ async def after_tool_call(self, runtime: Runtime, messages: List[Message]): # --- Check 2: length retention --- if has_report and has_draft: try: - report_chars = os.path.getsize(self.report_path) - draft_chars = os.path.getsize(self.draft_path) + with open(self.report_path, 'r', encoding='utf-8') as f: + report_chars = len(f.read()) + with open(self.draft_path, 'r', encoding='utf-8') as f: + draft_chars = len(f.read()) if draft_chars > 0: ratio = report_chars / draft_chars if ratio < self.min_retention_ratio: @@ -630,7 +632,7 @@ async def after_tool_call(self, runtime: Runtime, messages: List[Message]): return except OSError as exc: logger.warning( - f'ReporterCallback: failed to stat report files: {exc}') + f'ReporterCallback: failed to read report files: {exc}') # --- Check 3: quality checker chain --- if not self._quality_checkers: @@ -706,8 +708,10 @@ def _select_best_report(self) -> Optional[str]: if has_report and has_draft: try: - report_chars = os.path.getsize(self.report_path) - draft_chars = os.path.getsize(self.draft_path) + with open(self.report_path, 'r', encoding='utf-8') as f: + report_chars = len(f.read()) + with open(self.draft_path, 'r', encoding='utf-8') as f: + draft_chars = len(f.read()) if draft_chars > 0: ratio = report_chars / draft_chars if ratio < self.min_retention_ratio: diff --git a/projects/deep_research/v2/reporter.yaml b/projects/deep_research/v2/reporter.yaml index 91d6ba65e..c55fd109b 100644 --- a/projects/deep_research/v2/reporter.yaml +++ b/projects/deep_research/v2/reporter.yaml @@ -71,6 +71,7 @@ self_reflection: enabled: true max_retries: 3 min_retention_ratio: 0.6 + post_report_guidance_enabled: false quality_check: enabled: true model: qwen3.5-flash diff --git a/projects/deep_research/v2/run_benchmark.sh b/projects/deep_research/v2/run_benchmark.sh new file mode 100755 index 000000000..97c451d4f --- /dev/null +++ b/projects/deep_research/v2/run_benchmark.sh @@ -0,0 +1,181 @@ +#!/bin/bash + +# Agentic Insight v2 Benchmark Runner +# This script helps reproduce the official benchmark results. +# Must be run from the repository root directory. +# +# Usage: +# Single demo query: bash projects/deep_research/v2/run_benchmark.sh +# Full benchmark: DR_BENCH_ROOT=/path/to/bench bash projects/deep_research/v2/run_benchmark.sh + +set -e # Exit on error + +# Color codes for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +echo "=========================================" +echo "Agentic Insight v2 Benchmark Runner" +echo "=========================================" +echo "" + +# Locate Python executable early for both modes +if command -v python >/dev/null 2>&1; then + PYTHON_BIN="python" +elif command -v python3 >/dev/null 2>&1; then + PYTHON_BIN="python3" +else + echo -e "${RED}Error: Neither 'python' nor 'python3' is available in PATH.${NC}" + exit 1 +fi + +# Use caffeinate on macOS when available; otherwise run normally. +RUN_PREFIX=() +if command -v caffeinate >/dev/null 2>&1; then + RUN_PREFIX=("caffeinate" "-i") +else + echo -e "${YELLOW}Warning: 'caffeinate' not found, running without sleep prevention.${NC}" +fi + +# Verify we are at the repository root +if [ ! -f "ms_agent/cli/cli.py" ]; then + echo -e "${RED}Error: This script must be run from the repository root directory.${NC}" + echo " cd /path/to/ms-agent" + echo " bash projects/deep_research/v2/run_benchmark.sh" + exit 1 +fi + +# Check if .env exists +if [ ! -f ".env" ]; then + echo -e "${RED}Error: .env file not found in repository root!${NC}" + echo "Please create .env file by copying .env.example:" + echo " cp projects/deep_research/.env.example .env" + echo " # Then edit .env to add your API keys" + exit 1 +fi + +# Source .env file +echo -e "${GREEN}Loading environment variables from .env...${NC}" +set -a # Export all variables +source .env +set +a + +# Validate required environment variables +if [ -z "$OPENAI_API_KEY" ] || [ -z "$OPENAI_BASE_URL" ]; then + echo -e "${RED}Error: OPENAI_API_KEY or OPENAI_BASE_URL not set in .env${NC}" + exit 1 +fi + +# Check for search engine API key +if [ -z "$EXA_API_KEY" ] && [ -z "$SERPAPI_API_KEY" ]; then + echo -e "${YELLOW}Warning: Neither EXA_API_KEY nor SERPAPI_API_KEY is set.${NC}" + echo -e "${YELLOW}The system will use arxiv (academic search only).${NC}" + echo "" +fi + +echo -e "${GREEN}Environment variables loaded successfully!${NC}" +echo " OPENAI_BASE_URL: $OPENAI_BASE_URL" +echo " EXA_API_KEY: $([ -n "$EXA_API_KEY" ] && echo "✓ Set" || echo "✗ Not set")" +echo " SERPAPI_API_KEY: $([ -n "$SERPAPI_API_KEY" ] && echo "✓ Set" || echo "✗ Not set")" +echo "" + +# Check if DR_BENCH_ROOT is set +if [ -z "$DR_BENCH_ROOT" ]; then + echo -e "${YELLOW}Warning: DR_BENCH_ROOT not set.${NC}" + echo -e "${YELLOW}Using default benchmark query...${NC}" + echo "" + + # Run a simple benchmark query + QUERY="Provide a comprehensive survey of recent advances in large language models (LLMs), covering key developments in the last 12 months including architecture innovations, training techniques, and real-world applications." + OUTPUT_DIR="output/deep_research/benchmark_run" + + echo -e "${GREEN}Running benchmark with query:${NC}" + echo " \"$QUERY\"" + echo "" + echo -e "${GREEN}Output directory: $OUTPUT_DIR${NC}" + echo "" + + # Run the benchmark + PYTHONPATH=. "$PYTHON_BIN" ms_agent/cli/cli.py run \ + --config projects/deep_research/v2/researcher.yaml \ + --query "$QUERY" \ + --trust_remote_code true \ + --output_dir "$OUTPUT_DIR" + + echo "" + echo -e "${GREEN}=========================================${NC}" + echo -e "${GREEN}Benchmark completed!${NC}" + echo -e "${GREEN}Results saved to: $OUTPUT_DIR${NC}" + echo -e "${GREEN}Final report: $OUTPUT_DIR/final_report.md${NC}" + echo -e "${GREEN}=========================================${NC}" + +else + echo -e "${GREEN}DR_BENCH_ROOT detected: $DR_BENCH_ROOT${NC}" + echo -e "${YELLOW}Running full benchmark suite...${NC}" + echo "" + + # Benchmark subprocess tuning (override via env vars if needed) + export DR_BENCH_POST_FINISH_GRACE_S="${DR_BENCH_POST_FINISH_GRACE_S:-180}" + export DR_BENCH_POST_REPORT_EXIT_GRACE_S="${DR_BENCH_POST_REPORT_EXIT_GRACE_S:-3600}" + export DR_BENCH_REPORT_STABLE_WINDOW_S="${DR_BENCH_REPORT_STABLE_WINDOW_S:-10}" + export DR_BENCH_SUBPROCESS_POLL_INTERVAL_S="${DR_BENCH_SUBPROCESS_POLL_INTERVAL_S:-0.5}" + export DR_BENCH_SUBPROCESS_TERMINATE_TIMEOUT_S="${DR_BENCH_SUBPROCESS_TERMINATE_TIMEOUT_S:-30}" + export DR_BENCH_SUBPROCESS_KILL_TIMEOUT_S="${DR_BENCH_SUBPROCESS_KILL_TIMEOUT_S:-30}" + + # Check if DR_BENCH_ROOT exists + if [ ! -d "$DR_BENCH_ROOT" ]; then + echo -e "${RED}Error: DR_BENCH_ROOT directory not found: $DR_BENCH_ROOT${NC}" + exit 1 + fi + + # Check if query file exists + QUERY_FILE="$DR_BENCH_ROOT/data/prompt_data/query.jsonl" + if [ ! -f "$QUERY_FILE" ]; then + echo -e "${RED}Error: Query file not found: $QUERY_FILE${NC}" + exit 1 + fi + + # Set default values + MODEL_NAME="${MODEL_NAME:-ms_deepresearch_v2_benchmark}" + OUTPUT_JSONL="${OUTPUT_JSONL:-$DR_BENCH_ROOT/data/test_data/raw_data/${MODEL_NAME}.jsonl}" + WORK_ROOT="${WORK_ROOT:-temp/benchmark_runs}" + WORKERS="${WORKERS:-2}" + LIMIT="${LIMIT:-0}" + + # Validate numeric inputs early for clearer errors + if ! [[ "$WORKERS" =~ ^[0-9]+$ ]] || [ "$WORKERS" -lt 1 ]; then + echo -e "${RED}Error: WORKERS must be a positive integer. Got: $WORKERS${NC}" + exit 1 + fi + if ! [[ "$LIMIT" =~ ^[0-9]+$ ]]; then + echo -e "${RED}Error: LIMIT must be a non-negative integer. Got: $LIMIT${NC}" + exit 1 + fi + + echo "Configuration:" + echo " Query file: $QUERY_FILE" + echo " Output JSONL: $OUTPUT_JSONL" + echo " Model name: $MODEL_NAME" + echo " Work root: $WORK_ROOT" + echo " Workers: $WORKERS" + echo " Limit: $LIMIT (0 = no limit)" + echo "" + + # Run the full benchmark + PYTHONPATH=. "${RUN_PREFIX[@]}" "$PYTHON_BIN" projects/deep_research/v2/eval/dr_bench_runner.py \ + --query_file "$QUERY_FILE" \ + --output_jsonl "$OUTPUT_JSONL" \ + --model_name "$MODEL_NAME" \ + --work_root "$WORK_ROOT" \ + --limit "$LIMIT" \ + --workers "$WORKERS" \ + --trust_remote_code + + echo "" + echo -e "${GREEN}=========================================${NC}" + echo -e "${GREEN}Full benchmark suite completed!${NC}" + echo -e "${GREEN}Results saved to: $OUTPUT_JSONL${NC}" + echo -e "${GREEN}=========================================${NC}" +fi From 520282384734a76a4a1eec047ff1e4f962c0ca59 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 20 Mar 2026 18:27:13 +0800 Subject: [PATCH 10/40] full modify? --- ms_agent/agent/llm_agent.py | 352 +++++++++--------- ms_agent/knowledge_search/sirchmunk_search.py | 208 ++++++++--- 2 files changed, 330 insertions(+), 230 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 515446381..76289a403 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -2,30 +2,29 @@ import asyncio import importlib import inspect +import json import os.path import sys import threading import uuid from contextlib import contextmanager from copy import deepcopy +from omegaconf import DictConfig, OmegaConf from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union -import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping +from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.llm.llm import LLM from ms_agent.llm.utils import Message, ToolResult from ms_agent.memory import Memory, get_memory_meta_safe, memory_mapping from ms_agent.memory.memory_manager import SharedMemoryManager from ms_agent.rag.base import RAG from ms_agent.rag.utils import rag_mapping -from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.tools import ToolManager from ms_agent.utils import async_retry, read_history, save_history from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger -from omegaconf import DictConfig, OmegaConf - from ..config.config import Config, ConfigLifecycleHandler from .base import Agent @@ -90,14 +89,17 @@ class LLMAgent(Agent): TOTAL_CACHE_CREATION_INPUT_TOKENS = 0 TOKEN_LOCK = asyncio.Lock() - def __init__(self, - config: DictConfig = DictConfig({}), - tag: str = DEFAULT_TAG, - trust_remote_code: bool = False, - **kwargs): + def __init__( + self, + config: DictConfig = DictConfig({}), + tag: str = DEFAULT_TAG, + trust_remote_code: bool = False, + **kwargs, + ): if not hasattr(config, 'llm'): default_yaml = os.path.join( - os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') + os.path.dirname(os.path.abspath(__file__)), 'agent.yaml' + ) llm_config = Config.from_task(default_yaml) config = OmegaConf.merge(llm_config, config) super().__init__(config, tag, trust_remote_code) @@ -113,7 +115,8 @@ def __init__(self, self.config.load_cache = self.load_cache self.mcp_server_file = kwargs.get('mcp_server_file', None) self.mcp_config: Dict[str, Any] = self.parse_mcp_servers( - kwargs.get('mcp_config', {})) + kwargs.get('mcp_config', {}) + ) self.mcp_client = kwargs.get('mcp_client', None) self.config_handler = self.register_config_handler() @@ -161,37 +164,34 @@ def _ensure_auto_skills(self) -> bool: use_sandbox = getattr(skills_config, 'use_sandbox', True) if use_sandbox: from ms_agent.utils.docker_utils import is_docker_daemon_running + if not is_docker_daemon_running(): - logger.warning( - 'Docker not running, disabling sandbox for skills') + logger.warning('Docker not running, disabling sandbox for skills') use_sandbox = False # Build retrieve args retrieve_args = {} if hasattr(skills_config, 'retrieve_args'): - retrieve_args = OmegaConf.to_container( - skills_config.retrieve_args) + retrieve_args = OmegaConf.to_container(skills_config.retrieve_args) self._auto_skills = AutoSkills( skills=skills_path, llm=self.llm, - enable_retrieve=getattr(skills_config, 'enable_retrieve', - None), + enable_retrieve=getattr(skills_config, 'enable_retrieve', None), retrieve_args=retrieve_args, - max_candidate_skills=getattr(skills_config, - 'max_candidate_skills', 10), + max_candidate_skills=getattr(skills_config, 'max_candidate_skills', 10), max_retries=getattr(skills_config, 'max_retries', 3), work_dir=getattr(skills_config, 'work_dir', None), use_sandbox=use_sandbox, ) logger.info( - f'AutoSkills initialized with {len(self._auto_skills.all_skills)} skills' + f"AutoSkills initialized with {len(self._auto_skills.all_skills)} skills" ) self._auto_skills_initialized = True return True except Exception as e: - logger.warning(f'Failed to initialize AutoSkills: {e}') + logger.warning(f"Failed to initialize AutoSkills: {e}") self._auto_skills_initialized = True return False @@ -233,7 +233,7 @@ async def should_use_skills(self, query: str) -> bool: needs_skills, _, _, _ = self._auto_skills._analyze_query(query) return needs_skills except Exception as e: - logger.error(f'Skill analysis error: {e}') + logger.error(f"Skill analysis error: {e}") return False async def get_skill_dag(self, query: str): @@ -265,13 +265,15 @@ async def execute_skills(self, query: str, execution_input=None): return None skills_config = self._get_skills_config() - stop_on_failure = getattr(skills_config, 'stop_on_failure', - True) if skills_config else True + stop_on_failure = ( + getattr(skills_config, 'stop_on_failure', True) if skills_config else True + ) result = await self._auto_skills.run( query=query, execution_input=execution_input, - stop_on_failure=stop_on_failure) + stop_on_failure=stop_on_failure, + ) self._last_skill_result = result return result @@ -289,15 +291,14 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: # Handle chat-only response if dag_result.chat_response: - messages.append( - Message(role='assistant', content=dag_result.chat_response)) + messages.append(Message(role='assistant', content=dag_result.chat_response)) return messages # Handle incomplete skills if not dag_result.is_complete: content = "I couldn't find suitable skills for this task." if dag_result.clarification: - content += f'\n\n{dag_result.clarification}' + content += f"\n\n{dag_result.clarification}" messages.append(Message(role='assistant', content=content)) return messages @@ -317,28 +318,30 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: stdout_preview = output.stdout[:1000] if len(output.stdout) > 1000: stdout_preview += '...' - content += f'**{skill_id} output:**\n{stdout_preview}\n\n' + content += f"**{skill_id} output:**\n{stdout_preview}\n\n" if output.output_files: - content += f'**Generated files:** {list(output.output_files.values())}\n\n' + content += f"**Generated files:** {list(output.output_files.values())}\n\n" - content += f'Total execution time: {exec_result.total_duration_ms:.2f}ms' + content += ( + f"Total execution time: {exec_result.total_duration_ms:.2f}ms" + ) else: content = 'Skill execution completed with errors.\n\n' for skill_id, result in exec_result.results.items(): if not result.success: - content += f'**{skill_id} failed:** {result.error}\n' + content += f"**{skill_id} failed:** {result.error}\n" messages.append(Message(role='assistant', content=content)) else: # DAG only, no execution skill_names = list(dag_result.selected_skills.keys()) - content = f'Found {len(skill_names)} relevant skill(s) for your task:\n' + content = f"Found {len(skill_names)} relevant skill(s) for your task:\n" for skill_id, skill in dag_result.selected_skills.items(): desc_preview = skill.description[:100] if len(skill.description) > 100: desc_preview += '...' - content += f'- **{skill.name}** ({skill_id}): {desc_preview}\n' - content += f'\nExecution order: {dag_result.execution_order}' + content += f"- **{skill.name}** ({skill_id}): {desc_preview}\n" + content += f"\nExecution order: {dag_result.execution_order}" messages.append(Message(role='assistant', content=content)) @@ -364,8 +367,7 @@ def parse_mcp_servers(self, mcp_config: Dict[str, Any]) -> Dict[str, Any]: Dict[str, Any]: Merged configuration including file-based overrides. """ mcp_config = mcp_config or {} - if self.mcp_server_file is not None and os.path.isfile( - self.mcp_server_file): + if self.mcp_server_file is not None and os.path.isfile(self.mcp_server_file): with open(self.mcp_server_file, 'r') as f: config = json.load(f) config.update(mcp_config) @@ -394,26 +396,32 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: if handler_file is not None: local_dir = self.config.local_dir assert self.config.trust_remote_code, ( - f'[External Code]A Config Lifecycle handler ' - f'registered in the config: {handler_file}. ' - f'\nThis is external code, if you trust this workflow, ' - f'please specify `--trust_remote_code true`') - assert local_dir is not None, 'Using external py files, but local_dir cannot be found.' + f"[External Code]A Config Lifecycle handler " + f"registered in the config: {handler_file}. " + f"\nThis is external code, if you trust this workflow, " + f"please specify `--trust_remote_code true`" + ) + assert ( + local_dir is not None + ), 'Using external py files, but local_dir cannot be found.' if local_dir not in sys.path: sys.path.insert(0, local_dir) handler_module = importlib.import_module(handler_file) module_classes = { name: cls - for name, cls in inspect.getmembers(handler_module, - inspect.isclass) + for name, cls in inspect.getmembers(handler_module, inspect.isclass) } handler = None for name, handler_cls in module_classes.items(): - if handler_cls.__bases__[ - 0] is ConfigLifecycleHandler and handler_cls.__module__ == handler_file: + if ( + handler_cls.__bases__[0] is ConfigLifecycleHandler + and handler_cls.__module__ == handler_file + ): handler = handler_cls() - assert handler is not None, f'Config Lifecycle handler class cannot be found in {handler_file}' + assert ( + handler is not None + ), f"Config Lifecycle handler class cannot be found in {handler_file}" return handler return None @@ -424,13 +432,14 @@ def register_callback_from_config(self): Raises: AssertionError: If untrusted external code is referenced without permission. """ - local_dir = self.config.local_dir if hasattr(self.config, - 'local_dir') else None + local_dir = self.config.local_dir if hasattr(self.config, 'local_dir') else None if hasattr(self.config, 'callbacks'): callbacks = self.config.callbacks or [] for _callback in callbacks: subdir = os.path.dirname(_callback) - assert local_dir is not None, 'Using external py files, but local_dir cannot be found.' + assert ( + local_dir is not None + ), 'Using external py files, but local_dir cannot be found.' if subdir: subdir = os.path.join(local_dir, str(subdir)) _callback = os.path.basename(_callback) @@ -451,23 +460,22 @@ def register_callback_from_config(self): module_classes = { name: cls for name, cls in inspect.getmembers( - callback_file, inspect.isclass) + callback_file, inspect.isclass + ) } for name, cls in module_classes.items(): # Find cls which base class is `Callback` - if issubclass( - cls, Callback) and cls.__module__ == _callback: + if issubclass(cls, Callback) and cls.__module__ == _callback: self.callbacks.append(cls(self.config)) # noqa else: - self.callbacks.append(callbacks_mapping[_callback]( - self.config)) + self.callbacks.append(callbacks_mapping[_callback](self.config)) async def on_task_begin(self, messages: List[Message]): - self.log_output(f'Agent {self.tag} task beginning.') + self.log_output(f"Agent {self.tag} task beginning.") await self.loop_callback('on_task_begin', messages) async def on_task_end(self, messages: List[Message]): - self.log_output(f'Agent {self.tag} task finished.') + self.log_output(f"Agent {self.tag} task finished.") await self.loop_callback('on_task_end', messages) async def on_generate_response(self, messages: List[Message]): @@ -492,8 +500,7 @@ async def loop_callback(self, point, messages: List[Message]): for callback in self.callbacks: await getattr(callback, point)(self.runtime, messages) - async def parallel_tool_call(self, - messages: List[Message]) -> List[Message]: + async def parallel_tool_call(self, messages: List[Message]) -> List[Message]: """ Execute multiple tool calls in parallel and append results to the message list. @@ -504,17 +511,20 @@ async def parallel_tool_call(self, List[Message]: Updated message list including tool responses. """ tool_call_result = await self.tool_manager.parallel_call_tool( - messages[-1].tool_calls) + messages[-1].tool_calls + ) assert len(tool_call_result) == len(messages[-1].tool_calls) - for tool_call_result, tool_call_query in zip(tool_call_result, - messages[-1].tool_calls): + for tool_call_result, tool_call_query in zip( + tool_call_result, messages[-1].tool_calls + ): tool_call_result_format = ToolResult.from_raw(tool_call_result) _new_message = Message( role='tool', content=tool_call_result_format.text, tool_call_id=tool_call_query['id'], name=tool_call_query['tool_name'], - resources=tool_call_result_format.resources) + resources=tool_call_result_format.resources, + ) if _new_message.tool_call_id is None: # If tool call id is None, add a random one @@ -530,7 +540,8 @@ async def prepare_tools(self): self.config, self.mcp_config, self.mcp_client, - trust_remote_code=self.trust_remote_code) + trust_remote_code=self.trust_remote_code, + ) await self.tool_manager.connect() async def cleanup_tools(self): @@ -539,8 +550,7 @@ async def cleanup_tools(self): @property def stream(self): - generation_config = getattr(self.config, 'generation_config', - DictConfig({})) + generation_config = getattr(self.config, 'generation_config', DictConfig({})) return getattr(generation_config, 'stream', False) @property @@ -551,8 +561,7 @@ def show_reasoning(self) -> bool: - This only affects local console output. - Reasoning is carried by `Message.reasoning_content` (if the backend provides it). """ - generation_config = getattr(self.config, 'generation_config', - DictConfig({})) + generation_config = getattr(self.config, 'generation_config', DictConfig({})) return bool(getattr(generation_config, 'show_reasoning', False)) @property @@ -563,8 +572,7 @@ def reasoning_output(self) -> str: - "stderr" (default): keep stdout clean for assistant final text - "stdout": interleave reasoning with assistant output on stdout """ - generation_config = getattr(self.config, 'generation_config', - DictConfig({})) + generation_config = getattr(self.config, 'generation_config', DictConfig({})) return str(getattr(generation_config, 'reasoning_output', 'stdout')) def _write_reasoning(self, text: str): @@ -580,19 +588,18 @@ def _write_reasoning(self, text: str): @property def system(self): - return getattr( - getattr(self.config, 'prompt', DictConfig({})), 'system', None) + return getattr(getattr(self.config, 'prompt', DictConfig({})), 'system', None) @property def query(self): - query = getattr( - getattr(self.config, 'prompt', DictConfig({})), 'query', None) + query = getattr(getattr(self.config, 'prompt', DictConfig({})), 'query', None) if not query: query = input('>>>') return query async def create_messages( - self, messages: Union[List[Message], str]) -> List[Message]: + self, messages: Union[List[Message], str] + ) -> List[Message]: """ Convert input into a standardized list of messages. @@ -604,18 +611,19 @@ async def create_messages( """ if isinstance(messages, list): system = self.system - if system is not None and messages[ - 0].role == 'system' and system != messages[0].content: + if ( + system is not None + and messages[0].role == 'system' + and system != messages[0].content + ): # Replace the existing system messages[0].content = system else: assert isinstance( messages, str - ), f'inputs can be either a list or a string, but current is {type(messages)}' + ), f"inputs can be either a list or a string, but current is {type(messages)}" messages = [ - Message( - role='system', - content=self.system or LLMAgent.DEFAULT_SYSTEM), + Message(role='system', content=self.system or LLMAgent.DEFAULT_SYSTEM), Message(role='user', content=messages or self.query), ] return messages @@ -639,11 +647,10 @@ async def do_rag(self, messages: List[Message]): # Handle traditional RAG if self.rag is not None: user_message.content = await self.rag.query(query) - # Handle sirchmunk knowledge search if self.knowledge_search is not None: # Perform search and get results - search_result = await self.knowledge_search.retrieve(query) + search_result = await self.knowledge_search.query(query) search_details = self.knowledge_search.get_search_details() # Store search details in the message for frontend display @@ -652,24 +659,14 @@ async def do_rag(self, messages: List[Message]): # Build enriched context from search results if search_result: - context_parts = [] - for i, result in enumerate(search_result, 1): - text = result.get('text', '') - source = result.get('metadata', {}).get('source', 'unknown') - score = result.get('score', 0) - context_parts.append( - f"[Source {i}] {source} (relevance: {score:.2f})\n{text}\n" - ) - # Append search context to user query - context = '\n'.join(context_parts) + context = search_result user_message.content = ( f"Relevant context retrieved from codebase search:\n\n{context}\n\n" f"User question: {query}" ) - async def do_skill(self, - messages: List[Message]) -> Optional[List[Message]]: + async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: """ Process skill-related query if applicable. @@ -686,7 +683,9 @@ async def do_skill(self, # Extract user query from normalized messages query = ( messages[1].content - if len(messages) > 1 and messages[1].role == 'user' else None) + if len(messages) > 1 and messages[1].role == 'user' + else None + ) if not query: return None @@ -700,8 +699,9 @@ async def do_skill(self, try: skills_config = self._get_skills_config() - auto_execute = getattr(skills_config, 'auto_execute', - True) if skills_config else True + auto_execute = ( + getattr(skills_config, 'auto_execute', True) if skills_config else True + ) if auto_execute: dag_result = await self.execute_skills(query) @@ -709,8 +709,7 @@ async def do_skill(self, dag_result = await self.get_skill_dag(query) if dag_result: - skill_messages = self._format_skill_result_as_messages( - dag_result) + skill_messages = self._format_skill_result_as_messages(dag_result) for msg in skill_messages: messages.append(msg) return messages @@ -721,7 +720,8 @@ async def do_skill(self, except Exception as e: logger.warning( - f'Skill execution failed: {e}, falling back to standard agent') + f"Skill execution failed: {e}, falling back to standard agent" + ) self._skill_mode_active = False return None @@ -735,11 +735,13 @@ async def load_memory(self): if hasattr(self.config, 'memory'): for mem_instance_type, _memory in self.config.memory.items(): assert mem_instance_type in memory_mapping, ( - f'{mem_instance_type} not in memory_mapping, ' - f'which supports: {list(memory_mapping.keys())}') + f"{mem_instance_type} not in memory_mapping, " + f"which supports: {list(memory_mapping.keys())}" + ) shared_memory = await SharedMemoryManager.get_shared_memory( - self.config, mem_instance_type) + self.config, mem_instance_type + ) self.memory_tools.append(shared_memory) async def prepare_rag(self): @@ -748,12 +750,17 @@ async def prepare_rag(self): rag = self.config.rag if rag is not None: assert rag.name in rag_mapping, ( - f'{rag.name} not in rag_mapping, ' - f'which supports: {list(rag_mapping.keys())}') + f"{rag.name} not in rag_mapping, " + f"which supports: {list(rag_mapping.keys())}" + ) self.rag: RAG = rag_mapping(rag.name)(self.config) async def prepare_knowledge_search(self): """Load and initialize the knowledge search component from the config.""" + if self.knowledge_search is not None: + # Already initialized (e.g. by caller before run_loop), skip to avoid + # overwriting a configured instance (e.g. one with streaming callbacks set). + return if hasattr(self.config, 'knowledge_search'): ks_config = self.config.knowledge_search if ks_config is not None: @@ -790,7 +797,7 @@ def log_output(self, content: Union[str, list]): text_parts.append(item.get('text', '')) elif item.get('type') == 'image_url': img_url = item.get('image_url', {}).get('url', '') - text_parts.append(f'[Image: {img_url[:50]}...]') + text_parts.append(f"[Image: {img_url[:50]}...]") content = ' '.join(text_parts) # Ensure content is a string @@ -801,10 +808,9 @@ def log_output(self, content: Union[str, list]): content = content[:512] + '\n...\n' + content[-512:] for line in content.split('\n'): for _line in line.split('\\n'): - logger.info(f'[{self.tag}] {_line}') + logger.info(f"[{self.tag}] {_line}") - def handle_new_response(self, messages: List[Message], - response_message: Message): + def handle_new_response(self, messages: List[Message], response_message: Message): assert response_message is not None, 'No response message generated from LLM.' if response_message.tool_calls: self.log_output('[tool_calling]:') @@ -812,24 +818,23 @@ def handle_new_response(self, messages: List[Message], tool_call = deepcopy(tool_call) if isinstance(tool_call['arguments'], str): try: - tool_call['arguments'] = json.loads( - tool_call['arguments']) + tool_call['arguments'] = json.loads(tool_call['arguments']) except json.decoder.JSONDecodeError: pass - self.log_output( - json.dumps(tool_call, ensure_ascii=False, indent=4)) + self.log_output(json.dumps(tool_call, ensure_ascii=False, indent=4)) if messages[-1] is not response_message: messages.append(response_message) - if messages[-1].role == 'assistant' and not messages[ - -1].content and response_message.tool_calls: + if ( + messages[-1].role == 'assistant' + and not messages[-1].content + and response_message.tool_calls + ): messages[-1].content = 'Let me do a tool calling.' @async_retry(max_attempts=Agent.retry_count, delay=1.0) - async def step( - self, messages: List[Message] - ) -> AsyncGenerator[List[Message], Any]: # type: ignore + async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], Any]: # type: ignore """ Execute a single step in the agent's interaction loop. @@ -865,20 +870,20 @@ async def step( is_first = True _response_message = None _printed_reasoning_header = False - for _response_message in self.llm.generate( - messages, tools=tools): + for _response_message in self.llm.generate(messages, tools=tools): if is_first: messages.append(_response_message) is_first = False # Optional: stream model "thinking/reasoning" if available. if self.show_reasoning: - reasoning_text = getattr(_response_message, - 'reasoning_content', '') or '' + reasoning_text = ( + getattr(_response_message, 'reasoning_content', '') or '' + ) # Some providers may reset / shorten content across chunks. if len(reasoning_text) < len(_reasoning): _reasoning = '' - new_reasoning = reasoning_text[len(_reasoning):] + new_reasoning = reasoning_text[len(_reasoning) :] if new_reasoning: if not _printed_reasoning_header: self._write_reasoning('[thinking]:\n') @@ -886,7 +891,7 @@ async def step( self._write_reasoning(new_reasoning) _reasoning = reasoning_text - new_content = _response_message.content[len(_content):] + new_content = _response_message.content[len(_content) :] sys.stdout.write(new_content) sys.stdout.flush() _content = _response_message.content @@ -898,8 +903,9 @@ async def step( else: _response_message = self.llm.generate(messages, tools=tools) if self.show_reasoning: - reasoning_text = getattr(_response_message, - 'reasoning_content', '') or '' + reasoning_text = ( + getattr(_response_message, 'reasoning_content', '') or '' + ) if reasoning_text: self._write_reasoning('[thinking]:\n') self._write_reasoning(reasoning_text) @@ -927,8 +933,9 @@ async def step( prompt_tokens = _response_message.prompt_tokens completion_tokens = _response_message.completion_tokens cached_tokens = getattr(_response_message, 'cached_tokens', 0) or 0 - cache_creation_input_tokens = getattr( - _response_message, 'cache_creation_input_tokens', 0) or 0 + cache_creation_input_tokens = ( + getattr(_response_message, 'cache_creation_input_tokens', 0) or 0 + ) async with LLMAgent.TOKEN_LOCK: LLMAgent.TOTAL_PROMPT_TOKENS += prompt_tokens @@ -938,20 +945,21 @@ async def step( # tokens in the current step self.log_output( - f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}' + f"[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" ) if cached_tokens or cache_creation_input_tokens: self.log_output( - f'[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}' + f"[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}" ) # total tokens for the process so far self.log_output( - f'[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, ' - f'total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}') + f"[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, " + f"total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}" + ) if LLMAgent.TOTAL_CACHED_TOKENS or LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS: self.log_output( - f'[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, ' - f'total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}' + f"[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, " + f"total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}" ) yield messages @@ -964,8 +972,9 @@ def prepare_runtime(self): """Initialize the runtime context.""" self.runtime: Runtime = Runtime(llm=self.llm) - def read_history(self, messages: List[Message], - **kwargs) -> Tuple[DictConfig, Runtime, List[Message]]: + def read_history( + self, messages: List[Message], **kwargs + ) -> Tuple[DictConfig, Runtime, List[Message]]: """ Load previous chat history from disk if available. @@ -1009,9 +1018,9 @@ def get_user_id(self, default_user_id=DEFAULT_USER) -> Optional[str]: def _get_step_memory_info(self, memory_config: DictConfig): user_id, agent_id, run_id, memory_type = get_memory_meta_safe( - memory_config, 'add_after_step') - if all(value is None - for value in [user_id, agent_id, run_id, memory_type]): + memory_config, 'add_after_step' + ) + if all(value is None for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) return user_id, agent_id, run_id, memory_type @@ -1020,9 +1029,9 @@ def _get_run_memory_info(self, memory_config: DictConfig): user_id, agent_id, run_id, memory_type = get_memory_meta_safe( memory_config, 'add_after_task', - default_user_id=getattr(memory_config, 'user_id', None)) - if all(value is None - for value in [user_id, agent_id, run_id, memory_type]): + default_user_id=getattr(memory_config, 'user_id', None), + ) + if all(value is None for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) agent_id = agent_id or self.tag @@ -1033,24 +1042,29 @@ async def add_memory(self, messages: List[Message], add_type, **kwargs): if hasattr(self.config, 'memory') and self.config.memory: tools_num = len(self.memory_tools) if self.memory_tools else 0 - for idx, (mem_instance_type, - memory_config) in enumerate(self.config.memory.items()): + for idx, (mem_instance_type, memory_config) in enumerate( + self.config.memory.items() + ): if add_type == 'add_after_task': user_id, agent_id, run_id, memory_type = self._get_run_memory_info( - memory_config) + memory_config + ) else: user_id, agent_id, run_id, memory_type = self._get_step_memory_info( - memory_config) + memory_config + ) if idx < tools_num: - if any(v is not None - for v in [user_id, agent_id, run_id, memory_type]): + if any( + v is not None for v in [user_id, agent_id, run_id, memory_type] + ): await self.memory_tools[idx].add( messages, user_id=user_id, agent_id=agent_id, run_id=run_id, - memory_type=memory_type) + memory_type=memory_type, + ) def save_history(self, messages: List[Message], **kwargs): """ @@ -1072,11 +1086,11 @@ def save_history(self, messages: List[Message], **kwargs): config: DictConfig = deepcopy(self.config) config.runtime = self.runtime.to_dict() - save_history( - self.output_dir, task=self.tag, config=config, messages=messages) + save_history(self.output_dir, task=self.tag, config=config, messages=messages) - async def run_loop(self, messages: Union[List[Message], str], - **kwargs) -> AsyncGenerator[Any, Any]: + async def run_loop( + self, messages: Union[List[Message], str], **kwargs + ) -> AsyncGenerator[Any, Any]: """ Run the agent, mainly contains a llm calling and tool calling loop. @@ -1089,8 +1103,9 @@ async def run_loop(self, messages: Union[List[Message], str], List[Message]: A list of message objects representing the agent's response or interaction history. """ try: - self.max_chat_round = getattr(self.config, 'max_chat_round', - LLMAgent.DEFAULT_MAX_CHAT_ROUND) + self.max_chat_round = getattr( + self.config, 'max_chat_round', LLMAgent.DEFAULT_MAX_CHAT_ROUND + ) self.register_callback_from_config() self.prepare_llm() self.prepare_runtime() @@ -1132,8 +1147,7 @@ async def run_loop(self, messages: Union[List[Message], str], yield messages self.runtime.round += 1 # save memory and history - await self.add_memory( - messages, add_type='add_after_step', **kwargs) + await self.add_memory(messages, add_type='add_after_step', **kwargs) self.save_history(messages) # +1 means the next round the assistant may give a conclusion @@ -1142,9 +1156,10 @@ async def run_loop(self, messages: Union[List[Message], str], messages.append( Message( role='assistant', - content= - f'Task {messages[1].content} was cutted off, because ' - f'max round({self.max_chat_round}) exceeded.')) + content=f"Task {messages[1].content} was cutted off, because " + f"max round({self.max_chat_round}) exceeded.", + ) + ) self.runtime.should_stop = True yield messages @@ -1155,32 +1170,33 @@ async def run_loop(self, messages: Union[List[Message], str], def _add_memory(): asyncio.run( - self.add_memory( - messages, add_type='add_after_task', **kwargs)) + self.add_memory(messages, add_type='add_after_task', **kwargs) + ) loop = asyncio.get_running_loop() loop.run_in_executor(None, _add_memory) except Exception as e: import traceback + logger.warning(traceback.format_exc()) if hasattr(self.config, 'help'): logger.error( - f'[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}' + f"[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}" ) raise e async def run( - self, messages: Union[List[Message], str], **kwargs + self, messages: Union[List[Message], str], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: stream = kwargs.get('stream', False) with self.config_context(): if stream: OmegaConf.update( - self.config, 'generation_config.stream', True, merge=True) + self.config, 'generation_config.stream', True, merge=True + ) async def stream_generator(): - async for _chunk in self.run_loop( - messages=messages, **kwargs): + async for _chunk in self.run_loop(messages=messages, **kwargs): yield _chunk return stream_generator() diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/knowledge_search/sirchmunk_search.py index 4e1e322a5..dd80738f9 100644 --- a/ms_agent/knowledge_search/sirchmunk_search.py +++ b/ms_agent/knowledge_search/sirchmunk_search.py @@ -7,12 +7,12 @@ """ import asyncio -from pathlib import Path -from typing import Any, Dict, List, Optional, Union from loguru import logger +from omegaconf import DictConfig +from pathlib import Path +from typing import Any, Dict, List, Optional, Union, Callable from ms_agent.rag.base import RAG -from omegaconf import DictConfig class SirchmunkSearch(RAG): @@ -49,7 +49,9 @@ def __init__(self, config: DictConfig): paths = rag_config.get('paths', []) if isinstance(paths, str): paths = [paths] - self.search_paths: List[str] = [str(Path(p).expanduser().resolve()) for p in paths] + self.search_paths: List[str] = [ + str(Path(p).expanduser().resolve()) for p in paths + ] # Work path for sirchmunk cache _work_path = rag_config.get('work_path', './.sirchmunk') @@ -70,28 +72,40 @@ def __init__(self, config: DictConfig): self.llm_model_name = rag_config.get('llm_model_name', None) # Fall back to main llm config if not specified in knowledge_search - if self.llm_api_key is None or self.llm_base_url is None or self.llm_model_name is None: + if ( + self.llm_api_key is None + or self.llm_base_url is None + or self.llm_model_name is None + ): llm_config = config.get('llm', {}) if llm_config: service = getattr(llm_config, 'service', 'dashscope') if self.llm_api_key is None: - self.llm_api_key = getattr(llm_config, f'{service}_api_key', None) + self.llm_api_key = getattr(llm_config, f"{service}_api_key", None) if self.llm_base_url is None: - self.llm_base_url = getattr(llm_config, f'{service}_base_url', None) + self.llm_base_url = getattr(llm_config, f"{service}_base_url", None) if self.llm_model_name is None: self.llm_model_name = getattr(llm_config, 'model', None) # Embedding model configuration self.embedding_model_id = rag_config.get('embedding_model', None) - self.embedding_model_cache_dir = rag_config.get('embedding_model_cache_dir', None) + self.embedding_model_cache_dir = rag_config.get( + 'embedding_model_cache_dir', None + ) # Runtime state self._searcher = None self._initialized = False + self._cluster_cache_hit = False + self._cluster_cache_hit_time: str | None = None + self._last_search_result: List[Dict[str, Any]] | None = None # Callback for capturing logs self._log_callback = None self._search_logs: List[str] = [] + # Async queue for streaming logs in real-time + self._log_queue: asyncio.Queue | None = None + self._streaming_callback: Callable | None = None def _validate_config(self, config: DictConfig): """Validate configuration parameters.""" @@ -112,8 +126,8 @@ def _initialize_searcher(self): return try: - from sirchmunk.search import AgenticSearch from sirchmunk.llm.openai_chat import OpenAIChat + from sirchmunk.search import AgenticSearch from sirchmunk.utils.embedding_util import EmbeddingUtil # Create LLM client @@ -127,9 +141,17 @@ def _initialize_searcher(self): # Create embedding util # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) - embedding_model_id = self.embedding_model_id if self.embedding_model_id else None - embedding_cache_dir = self.embedding_model_cache_dir if self.embedding_model_cache_dir else None - embedding = EmbeddingUtil(model_id=embedding_model_id, cache_dir=embedding_cache_dir) + embedding_model_id = ( + self.embedding_model_id if self.embedding_model_id else None + ) + embedding_cache_dir = ( + self.embedding_model_cache_dir + if self.embedding_model_cache_dir + else None + ) + embedding = EmbeddingUtil( + model_id=embedding_model_id, cache_dir=embedding_cache_dir + ) # Create AgenticSearch instance self._searcher = AgenticSearch( @@ -145,20 +167,35 @@ def _initialize_searcher(self): ) self._initialized = True - logger.info(f'SirschmunkSearch initialized with paths: {self.search_paths}') + logger.info(f"SirschmunkSearch initialized with paths: {self.search_paths}") except ImportError as e: raise ImportError( - f'Failed to import sirchmunk: {e}. ' + f"Failed to import sirchmunk: {e}. " 'Please install sirchmunk: pip install sirchmunk' ) except Exception as e: - raise RuntimeError(f'Failed to initialize SirchmunkSearch: {e}') + raise RuntimeError(f"Failed to initialize SirchmunkSearch: {e}") def _log_callback_wrapper(self): - """Create a callback wrapper to capture search logs.""" - def log_callback(message: str, level: str = 'INFO', logger_name: str = '', is_async: bool = False): - self._search_logs.append(f'[{level}] {message}') + """Create a callback wrapper to capture search logs. + + The sirchmunk LogCallback signature is: + (level: str, message: str, end: str, flush: bool) -> None + See sirchmunk/utils/log_utils.py for reference. + """ + + def log_callback( + level: str, + message: str, + end: str = '\n', + flush: bool = False, + ): + log_entry = f"[{level.upper()}] {message}" + self._search_logs.append(log_entry) + # Stream log in real-time if streaming callback is set + if self._streaming_callback: + asyncio.create_task(self._streaming_callback(log_entry)) return log_callback @@ -186,7 +223,7 @@ async def add_documents(self, documents: List[str]) -> bool: await self._searcher.knowledge_base.refresh() return True except Exception as e: - logger.error(f'Failed to refresh knowledge base: {e}') + logger.error(f"Failed to refresh knowledge base: {e}") return False return True @@ -208,15 +245,13 @@ async def add_documents_from_files(self, file_paths: List[str]) -> bool: await self._searcher.scan_directory(str(Path(file_path).parent)) return True except Exception as e: - logger.error(f'Failed to scan files: {e}') + logger.error(f"Failed to scan files: {e}") return False return True - async def retrieve(self, - query: str, - limit: int = 5, - score_threshold: float = 0.7, - **filters) -> List[Dict[str, Any]]: + async def retrieve( + self, query: str, limit: int = 5, score_threshold: float = 0.7, **filters + ) -> List[Dict[str, Any]]: """Retrieve relevant documents using sirchmunk. Args: @@ -246,11 +281,21 @@ async def retrieve(self, return_context=True, ) + # Check if cluster cache was hit + self._cluster_cache_hit = False + self._cluster_cache_hit_time = None + if hasattr(result, 'cluster') and result.cluster is not None: + # If a similar cluster was found and reused, it's a cache hit + self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + # Get the cluster cache hit time if available + if hasattr(result.cluster, 'updated_at'): + self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + # Parse results into standard format return self._parse_search_result(result, score_threshold, limit) except Exception as e: - logger.error(f'SirschmunkSearch retrieve failed: {e}') + logger.error(f"SirschmunkSearch retrieve failed: {e}") return [] async def query(self, query: str) -> str: @@ -273,34 +318,45 @@ async def query(self, query: str) -> str: max_loops = self.max_loops max_token_budget = self.max_token_budget - # Perform search and get answer + # Single search with context so we get both the synthesized answer and + # source units in one call, avoiding a redundant second search. result = await self._searcher.search( query=query, mode=mode, max_loops=max_loops, max_token_budget=max_token_budget, - return_context=False, + return_context=True, ) - # Result is already a synthesized answer string - if isinstance(result, str): - return result + # Check if cluster cache was hit + self._cluster_cache_hit = False + self._cluster_cache_hit_time = None + if hasattr(result, 'cluster') and result.cluster is not None: + self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + if hasattr(result.cluster, 'updated_at'): + self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + + # Store parsed context for frontend display + self._last_search_result = self._parse_search_result(result, score_threshold=0.7, limit=5) - # If we got SearchContext or other format, extract the answer + # Extract the synthesized answer from the context result if hasattr(result, 'answer'): return result.answer + # If result is already a plain string (some modes return str directly) + if isinstance(result, str): + return result + # Fallback: convert to string return str(result) except Exception as e: - logger.error(f'SirschmunkSearch query failed: {e}') - return f'Query failed: {e}' + logger.error(f"SirschmunkSearch query failed: {e}") + return f"Query failed: {e}" - def _parse_search_result(self, - result: Any, - score_threshold: float, - limit: int) -> List[Dict[str, Any]]: + def _parse_search_result( + self, result: Any, score_threshold: float, limit: int + ) -> List[Dict[str, Any]]: """Parse sirchmunk search result into standard format. Args: @@ -329,28 +385,38 @@ def _parse_search_result(self, else: text_parts.append(str(snippet)) - results.append({ - 'text': '\n'.join(text_parts) if text_parts else getattr(unit, 'summary', ''), - 'score': score, - 'metadata': { - 'source': source, - 'type': getattr(unit, 'abstraction_level', 'text') if hasattr(unit, 'abstraction_level') else 'text', + results.append( + { + 'text': '\n'.join(text_parts) + if text_parts + else getattr(unit, 'summary', ''), + 'score': score, + 'metadata': { + 'source': source, + 'type': getattr(unit, 'abstraction_level', 'text') + if hasattr(unit, 'abstraction_level') + else 'text', + }, } - }) + ) # Handle format with evidence_units attribute directly elif hasattr(result, 'evidence_units'): for unit in result.evidence_units: score = getattr(unit, 'confidence', 1.0) if score >= score_threshold: - results.append({ - 'text': str(unit.content) if hasattr(unit, 'content') else str(unit), - 'score': score, - 'metadata': { - 'source': getattr(unit, 'source_file', 'unknown'), - 'type': getattr(unit, 'abstraction_level', 'text'), + results.append( + { + 'text': str(unit.content) + if hasattr(unit, 'content') + else str(unit), + 'score': score, + 'metadata': { + 'source': getattr(unit, 'source_file', 'unknown'), + 'type': getattr(unit, 'abstraction_level', 'text'), + }, } - }) + ) # Handle list format elif isinstance(result, list): @@ -358,21 +424,27 @@ def _parse_search_result(self, if isinstance(item, dict): score = item.get('score', item.get('confidence', 1.0)) if score >= score_threshold: - results.append({ - 'text': item.get('content', item.get('text', str(item))), - 'score': score, - 'metadata': item.get('metadata', {}), - }) + results.append( + { + 'text': item.get( + 'content', item.get('text', str(item)) + ), + 'score': score, + 'metadata': item.get('metadata', {}), + } + ) # Handle dict format elif isinstance(result, dict): score = result.get('score', result.get('confidence', 1.0)) if score >= score_threshold: - results.append({ - 'text': result.get('content', result.get('text', str(result))), - 'score': score, - 'metadata': result.get('metadata', {}), - }) + results.append( + { + 'text': result.get('content', result.get('text', str(result))), + 'score': score, + 'metadata': result.get('metadata', {}), + } + ) # Sort by score and limit results results.sort(key=lambda x: x.get('score', 0), reverse=True) @@ -398,4 +470,16 @@ def get_search_details(self) -> Dict[str, Any]: 'paths': self.search_paths, 'work_path': str(self.work_path), 'reuse_knowledge': self.reuse_knowledge, + 'cluster_cache_hit': self._cluster_cache_hit, + 'cluster_cache_hit_time': self._cluster_cache_hit_time, } + + def enable_streaming_logs(self, callback: Callable): + """Enable streaming mode for search logs. + + Args: + callback: Async callback function to receive log entries in real-time. + Signature: async def callback(log_entry: str) -> None + """ + self._streaming_callback = callback + self._search_logs.clear() From 24eb500b75b79941a77eb6172d3683df848f76e6 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 20 Mar 2026 20:03:13 +0800 Subject: [PATCH 11/40] fix lint --- ms_agent/agent/llm_agent.py | 294 +++++++++--------- ms_agent/cli/run.py | 20 +- ms_agent/knowledge_search/sirchmunk_search.py | 181 +++++------ 3 files changed, 253 insertions(+), 242 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 76289a403..5f2ddf2e7 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -2,16 +2,15 @@ import asyncio import importlib import inspect -import json import os.path import sys import threading import uuid from contextlib import contextmanager from copy import deepcopy -from omegaconf import DictConfig, OmegaConf from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union +import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping from ms_agent.knowledge_search import SirchmunkSearch @@ -25,6 +24,8 @@ from ms_agent.utils import async_retry, read_history, save_history from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger +from omegaconf import DictConfig, OmegaConf + from ..config.config import Config, ConfigLifecycleHandler from .base import Agent @@ -98,8 +99,7 @@ def __init__( ): if not hasattr(config, 'llm'): default_yaml = os.path.join( - os.path.dirname(os.path.abspath(__file__)), 'agent.yaml' - ) + os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') llm_config = Config.from_task(default_yaml) config = OmegaConf.merge(llm_config, config) super().__init__(config, tag, trust_remote_code) @@ -115,8 +115,7 @@ def __init__( self.config.load_cache = self.load_cache self.mcp_server_file = kwargs.get('mcp_server_file', None) self.mcp_config: Dict[str, Any] = self.parse_mcp_servers( - kwargs.get('mcp_config', {}) - ) + kwargs.get('mcp_config', {})) self.mcp_client = kwargs.get('mcp_client', None) self.config_handler = self.register_config_handler() @@ -166,32 +165,36 @@ def _ensure_auto_skills(self) -> bool: from ms_agent.utils.docker_utils import is_docker_daemon_running if not is_docker_daemon_running(): - logger.warning('Docker not running, disabling sandbox for skills') + logger.warning( + 'Docker not running, disabling sandbox for skills') use_sandbox = False # Build retrieve args retrieve_args = {} if hasattr(skills_config, 'retrieve_args'): - retrieve_args = OmegaConf.to_container(skills_config.retrieve_args) + retrieve_args = OmegaConf.to_container( + skills_config.retrieve_args) self._auto_skills = AutoSkills( skills=skills_path, llm=self.llm, - enable_retrieve=getattr(skills_config, 'enable_retrieve', None), + enable_retrieve=getattr(skills_config, 'enable_retrieve', + None), retrieve_args=retrieve_args, - max_candidate_skills=getattr(skills_config, 'max_candidate_skills', 10), + max_candidate_skills=getattr(skills_config, + 'max_candidate_skills', 10), max_retries=getattr(skills_config, 'max_retries', 3), work_dir=getattr(skills_config, 'work_dir', None), use_sandbox=use_sandbox, ) logger.info( - f"AutoSkills initialized with {len(self._auto_skills.all_skills)} skills" + f'AutoSkills initialized with {len(self._auto_skills.all_skills)} skills' ) self._auto_skills_initialized = True return True except Exception as e: - logger.warning(f"Failed to initialize AutoSkills: {e}") + logger.warning(f'Failed to initialize AutoSkills: {e}') self._auto_skills_initialized = True return False @@ -233,7 +236,7 @@ async def should_use_skills(self, query: str) -> bool: needs_skills, _, _, _ = self._auto_skills._analyze_query(query) return needs_skills except Exception as e: - logger.error(f"Skill analysis error: {e}") + logger.error(f'Skill analysis error: {e}') return False async def get_skill_dag(self, query: str): @@ -266,8 +269,8 @@ async def execute_skills(self, query: str, execution_input=None): skills_config = self._get_skills_config() stop_on_failure = ( - getattr(skills_config, 'stop_on_failure', True) if skills_config else True - ) + getattr(skills_config, 'stop_on_failure', True) + if skills_config else True) result = await self._auto_skills.run( query=query, @@ -291,14 +294,15 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: # Handle chat-only response if dag_result.chat_response: - messages.append(Message(role='assistant', content=dag_result.chat_response)) + messages.append( + Message(role='assistant', content=dag_result.chat_response)) return messages # Handle incomplete skills if not dag_result.is_complete: content = "I couldn't find suitable skills for this task." if dag_result.clarification: - content += f"\n\n{dag_result.clarification}" + content += f'\n\n{dag_result.clarification}' messages.append(Message(role='assistant', content=content)) return messages @@ -318,30 +322,30 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: stdout_preview = output.stdout[:1000] if len(output.stdout) > 1000: stdout_preview += '...' - content += f"**{skill_id} output:**\n{stdout_preview}\n\n" + content += f'**{skill_id} output:**\n{stdout_preview}\n\n' if output.output_files: - content += f"**Generated files:** {list(output.output_files.values())}\n\n" + content += f'**Generated files:** {list(output.output_files.values())}\n\n' content += ( - f"Total execution time: {exec_result.total_duration_ms:.2f}ms" + f'Total execution time: {exec_result.total_duration_ms:.2f}ms' ) else: content = 'Skill execution completed with errors.\n\n' for skill_id, result in exec_result.results.items(): if not result.success: - content += f"**{skill_id} failed:** {result.error}\n" + content += f'**{skill_id} failed:** {result.error}\n' messages.append(Message(role='assistant', content=content)) else: # DAG only, no execution skill_names = list(dag_result.selected_skills.keys()) - content = f"Found {len(skill_names)} relevant skill(s) for your task:\n" + content = f'Found {len(skill_names)} relevant skill(s) for your task:\n' for skill_id, skill in dag_result.selected_skills.items(): desc_preview = skill.description[:100] if len(skill.description) > 100: desc_preview += '...' - content += f"- **{skill.name}** ({skill_id}): {desc_preview}\n" - content += f"\nExecution order: {dag_result.execution_order}" + content += f'- **{skill.name}** ({skill_id}): {desc_preview}\n' + content += f'\nExecution order: {dag_result.execution_order}' messages.append(Message(role='assistant', content=content)) @@ -367,7 +371,8 @@ def parse_mcp_servers(self, mcp_config: Dict[str, Any]) -> Dict[str, Any]: Dict[str, Any]: Merged configuration including file-based overrides. """ mcp_config = mcp_config or {} - if self.mcp_server_file is not None and os.path.isfile(self.mcp_server_file): + if self.mcp_server_file is not None and os.path.isfile( + self.mcp_server_file): with open(self.mcp_server_file, 'r') as f: config = json.load(f) config.update(mcp_config) @@ -396,11 +401,10 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: if handler_file is not None: local_dir = self.config.local_dir assert self.config.trust_remote_code, ( - f"[External Code]A Config Lifecycle handler " - f"registered in the config: {handler_file}. " - f"\nThis is external code, if you trust this workflow, " - f"please specify `--trust_remote_code true`" - ) + f'[External Code]A Config Lifecycle handler ' + f'registered in the config: {handler_file}. ' + f'\nThis is external code, if you trust this workflow, ' + f'please specify `--trust_remote_code true`') assert ( local_dir is not None ), 'Using external py files, but local_dir cannot be found.' @@ -410,18 +414,17 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: handler_module = importlib.import_module(handler_file) module_classes = { name: cls - for name, cls in inspect.getmembers(handler_module, inspect.isclass) + for name, cls in inspect.getmembers(handler_module, + inspect.isclass) } handler = None for name, handler_cls in module_classes.items(): - if ( - handler_cls.__bases__[0] is ConfigLifecycleHandler - and handler_cls.__module__ == handler_file - ): + if (handler_cls.__bases__[0] is ConfigLifecycleHandler + and handler_cls.__module__ == handler_file): handler = handler_cls() assert ( handler is not None - ), f"Config Lifecycle handler class cannot be found in {handler_file}" + ), f'Config Lifecycle handler class cannot be found in {handler_file}' return handler return None @@ -432,7 +435,8 @@ def register_callback_from_config(self): Raises: AssertionError: If untrusted external code is referenced without permission. """ - local_dir = self.config.local_dir if hasattr(self.config, 'local_dir') else None + local_dir = self.config.local_dir if hasattr(self.config, + 'local_dir') else None if hasattr(self.config, 'callbacks'): callbacks = self.config.callbacks or [] for _callback in callbacks: @@ -460,22 +464,23 @@ def register_callback_from_config(self): module_classes = { name: cls for name, cls in inspect.getmembers( - callback_file, inspect.isclass - ) + callback_file, inspect.isclass) } for name, cls in module_classes.items(): # Find cls which base class is `Callback` - if issubclass(cls, Callback) and cls.__module__ == _callback: + if issubclass( + cls, Callback) and cls.__module__ == _callback: self.callbacks.append(cls(self.config)) # noqa else: - self.callbacks.append(callbacks_mapping[_callback](self.config)) + self.callbacks.append(callbacks_mapping[_callback]( + self.config)) async def on_task_begin(self, messages: List[Message]): - self.log_output(f"Agent {self.tag} task beginning.") + self.log_output(f'Agent {self.tag} task beginning.') await self.loop_callback('on_task_begin', messages) async def on_task_end(self, messages: List[Message]): - self.log_output(f"Agent {self.tag} task finished.") + self.log_output(f'Agent {self.tag} task finished.') await self.loop_callback('on_task_end', messages) async def on_generate_response(self, messages: List[Message]): @@ -500,7 +505,8 @@ async def loop_callback(self, point, messages: List[Message]): for callback in self.callbacks: await getattr(callback, point)(self.runtime, messages) - async def parallel_tool_call(self, messages: List[Message]) -> List[Message]: + async def parallel_tool_call(self, + messages: List[Message]) -> List[Message]: """ Execute multiple tool calls in parallel and append results to the message list. @@ -511,12 +517,10 @@ async def parallel_tool_call(self, messages: List[Message]) -> List[Message]: List[Message]: Updated message list including tool responses. """ tool_call_result = await self.tool_manager.parallel_call_tool( - messages[-1].tool_calls - ) + messages[-1].tool_calls) assert len(tool_call_result) == len(messages[-1].tool_calls) - for tool_call_result, tool_call_query in zip( - tool_call_result, messages[-1].tool_calls - ): + for tool_call_result, tool_call_query in zip(tool_call_result, + messages[-1].tool_calls): tool_call_result_format = ToolResult.from_raw(tool_call_result) _new_message = Message( role='tool', @@ -550,7 +554,8 @@ async def cleanup_tools(self): @property def stream(self): - generation_config = getattr(self.config, 'generation_config', DictConfig({})) + generation_config = getattr(self.config, 'generation_config', + DictConfig({})) return getattr(generation_config, 'stream', False) @property @@ -561,7 +566,8 @@ def show_reasoning(self) -> bool: - This only affects local console output. - Reasoning is carried by `Message.reasoning_content` (if the backend provides it). """ - generation_config = getattr(self.config, 'generation_config', DictConfig({})) + generation_config = getattr(self.config, 'generation_config', + DictConfig({})) return bool(getattr(generation_config, 'show_reasoning', False)) @property @@ -572,7 +578,8 @@ def reasoning_output(self) -> str: - "stderr" (default): keep stdout clean for assistant final text - "stdout": interleave reasoning with assistant output on stdout """ - generation_config = getattr(self.config, 'generation_config', DictConfig({})) + generation_config = getattr(self.config, 'generation_config', + DictConfig({})) return str(getattr(generation_config, 'reasoning_output', 'stdout')) def _write_reasoning(self, text: str): @@ -588,18 +595,19 @@ def _write_reasoning(self, text: str): @property def system(self): - return getattr(getattr(self.config, 'prompt', DictConfig({})), 'system', None) + return getattr( + getattr(self.config, 'prompt', DictConfig({})), 'system', None) @property def query(self): - query = getattr(getattr(self.config, 'prompt', DictConfig({})), 'query', None) + query = getattr( + getattr(self.config, 'prompt', DictConfig({})), 'query', None) if not query: query = input('>>>') return query async def create_messages( - self, messages: Union[List[Message], str] - ) -> List[Message]: + self, messages: Union[List[Message], str]) -> List[Message]: """ Convert input into a standardized list of messages. @@ -611,19 +619,18 @@ async def create_messages( """ if isinstance(messages, list): system = self.system - if ( - system is not None - and messages[0].role == 'system' - and system != messages[0].content - ): + if (system is not None and messages[0].role == 'system' + and system != messages[0].content): # Replace the existing system messages[0].content = system else: assert isinstance( messages, str - ), f"inputs can be either a list or a string, but current is {type(messages)}" + ), f'inputs can be either a list or a string, but current is {type(messages)}' messages = [ - Message(role='system', content=self.system or LLMAgent.DEFAULT_SYSTEM), + Message( + role='system', + content=self.system or LLMAgent.DEFAULT_SYSTEM), Message(role='user', content=messages or self.query), ] return messages @@ -662,11 +669,11 @@ async def do_rag(self, messages: List[Message]): # Append search context to user query context = search_result user_message.content = ( - f"Relevant context retrieved from codebase search:\n\n{context}\n\n" - f"User question: {query}" - ) + f'Relevant context retrieved from codebase search:\n\n{context}\n\n' + f'User question: {query}') - async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: + async def do_skill(self, + messages: List[Message]) -> Optional[List[Message]]: """ Process skill-related query if applicable. @@ -683,9 +690,7 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: # Extract user query from normalized messages query = ( messages[1].content - if len(messages) > 1 and messages[1].role == 'user' - else None - ) + if len(messages) > 1 and messages[1].role == 'user' else None) if not query: return None @@ -700,8 +705,8 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: try: skills_config = self._get_skills_config() auto_execute = ( - getattr(skills_config, 'auto_execute', True) if skills_config else True - ) + getattr(skills_config, 'auto_execute', True) + if skills_config else True) if auto_execute: dag_result = await self.execute_skills(query) @@ -709,7 +714,8 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: dag_result = await self.get_skill_dag(query) if dag_result: - skill_messages = self._format_skill_result_as_messages(dag_result) + skill_messages = self._format_skill_result_as_messages( + dag_result) for msg in skill_messages: messages.append(msg) return messages @@ -720,8 +726,7 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: except Exception as e: logger.warning( - f"Skill execution failed: {e}, falling back to standard agent" - ) + f'Skill execution failed: {e}, falling back to standard agent') self._skill_mode_active = False return None @@ -735,13 +740,11 @@ async def load_memory(self): if hasattr(self.config, 'memory'): for mem_instance_type, _memory in self.config.memory.items(): assert mem_instance_type in memory_mapping, ( - f"{mem_instance_type} not in memory_mapping, " - f"which supports: {list(memory_mapping.keys())}" - ) + f'{mem_instance_type} not in memory_mapping, ' + f'which supports: {list(memory_mapping.keys())}') shared_memory = await SharedMemoryManager.get_shared_memory( - self.config, mem_instance_type - ) + self.config, mem_instance_type) self.memory_tools.append(shared_memory) async def prepare_rag(self): @@ -750,9 +753,8 @@ async def prepare_rag(self): rag = self.config.rag if rag is not None: assert rag.name in rag_mapping, ( - f"{rag.name} not in rag_mapping, " - f"which supports: {list(rag_mapping.keys())}" - ) + f'{rag.name} not in rag_mapping, ' + f'which supports: {list(rag_mapping.keys())}') self.rag: RAG = rag_mapping(rag.name)(self.config) async def prepare_knowledge_search(self): @@ -764,7 +766,8 @@ async def prepare_knowledge_search(self): if hasattr(self.config, 'knowledge_search'): ks_config = self.config.knowledge_search if ks_config is not None: - self.knowledge_search: SirchmunkSearch = SirchmunkSearch(self.config) + self.knowledge_search: SirchmunkSearch = SirchmunkSearch( + self.config) async def condense_memory(self, messages: List[Message]) -> List[Message]: """ @@ -797,7 +800,7 @@ def log_output(self, content: Union[str, list]): text_parts.append(item.get('text', '')) elif item.get('type') == 'image_url': img_url = item.get('image_url', {}).get('url', '') - text_parts.append(f"[Image: {img_url[:50]}...]") + text_parts.append(f'[Image: {img_url[:50]}...]') content = ' '.join(text_parts) # Ensure content is a string @@ -808,9 +811,10 @@ def log_output(self, content: Union[str, list]): content = content[:512] + '\n...\n' + content[-512:] for line in content.split('\n'): for _line in line.split('\\n'): - logger.info(f"[{self.tag}] {_line}") + logger.info(f'[{self.tag}] {_line}') - def handle_new_response(self, messages: List[Message], response_message: Message): + def handle_new_response(self, messages: List[Message], + response_message: Message): assert response_message is not None, 'No response message generated from LLM.' if response_message.tool_calls: self.log_output('[tool_calling]:') @@ -818,23 +822,24 @@ def handle_new_response(self, messages: List[Message], response_message: Message tool_call = deepcopy(tool_call) if isinstance(tool_call['arguments'], str): try: - tool_call['arguments'] = json.loads(tool_call['arguments']) + tool_call['arguments'] = json.loads( + tool_call['arguments']) except json.decoder.JSONDecodeError: pass - self.log_output(json.dumps(tool_call, ensure_ascii=False, indent=4)) + self.log_output( + json.dumps(tool_call, ensure_ascii=False, indent=4)) if messages[-1] is not response_message: messages.append(response_message) - if ( - messages[-1].role == 'assistant' - and not messages[-1].content - and response_message.tool_calls - ): + if (messages[-1].role == 'assistant' and not messages[-1].content + and response_message.tool_calls): messages[-1].content = 'Let me do a tool calling.' @async_retry(max_attempts=Agent.retry_count, delay=1.0) - async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], Any]: # type: ignore + async def step( + self, messages: List[Message] + ) -> AsyncGenerator[List[Message], Any]: # type: ignore """ Execute a single step in the agent's interaction loop. @@ -870,7 +875,8 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A is_first = True _response_message = None _printed_reasoning_header = False - for _response_message in self.llm.generate(messages, tools=tools): + for _response_message in self.llm.generate( + messages, tools=tools): if is_first: messages.append(_response_message) is_first = False @@ -878,12 +884,12 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A # Optional: stream model "thinking/reasoning" if available. if self.show_reasoning: reasoning_text = ( - getattr(_response_message, 'reasoning_content', '') or '' - ) + getattr(_response_message, 'reasoning_content', '') + or '') # Some providers may reset / shorten content across chunks. if len(reasoning_text) < len(_reasoning): _reasoning = '' - new_reasoning = reasoning_text[len(_reasoning) :] + new_reasoning = reasoning_text[len(_reasoning):] if new_reasoning: if not _printed_reasoning_header: self._write_reasoning('[thinking]:\n') @@ -891,7 +897,7 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A self._write_reasoning(new_reasoning) _reasoning = reasoning_text - new_content = _response_message.content[len(_content) :] + new_content = _response_message.content[len(_content):] sys.stdout.write(new_content) sys.stdout.flush() _content = _response_message.content @@ -904,8 +910,8 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A _response_message = self.llm.generate(messages, tools=tools) if self.show_reasoning: reasoning_text = ( - getattr(_response_message, 'reasoning_content', '') or '' - ) + getattr(_response_message, 'reasoning_content', '') + or '') if reasoning_text: self._write_reasoning('[thinking]:\n') self._write_reasoning(reasoning_text) @@ -934,8 +940,7 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A completion_tokens = _response_message.completion_tokens cached_tokens = getattr(_response_message, 'cached_tokens', 0) or 0 cache_creation_input_tokens = ( - getattr(_response_message, 'cache_creation_input_tokens', 0) or 0 - ) + getattr(_response_message, 'cache_creation_input_tokens', 0) or 0) async with LLMAgent.TOKEN_LOCK: LLMAgent.TOTAL_PROMPT_TOKENS += prompt_tokens @@ -945,21 +950,20 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A # tokens in the current step self.log_output( - f"[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" + f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}' ) if cached_tokens or cache_creation_input_tokens: self.log_output( - f"[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}" + f'[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}' ) # total tokens for the process so far self.log_output( - f"[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, " - f"total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}" - ) + f'[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, ' + f'total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}') if LLMAgent.TOTAL_CACHED_TOKENS or LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS: self.log_output( - f"[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, " - f"total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}" + f'[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, ' + f'total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}' ) yield messages @@ -972,9 +976,8 @@ def prepare_runtime(self): """Initialize the runtime context.""" self.runtime: Runtime = Runtime(llm=self.llm) - def read_history( - self, messages: List[Message], **kwargs - ) -> Tuple[DictConfig, Runtime, List[Message]]: + def read_history(self, messages: List[Message], + **kwargs) -> Tuple[DictConfig, Runtime, List[Message]]: """ Load previous chat history from disk if available. @@ -1018,9 +1021,9 @@ def get_user_id(self, default_user_id=DEFAULT_USER) -> Optional[str]: def _get_step_memory_info(self, memory_config: DictConfig): user_id, agent_id, run_id, memory_type = get_memory_meta_safe( - memory_config, 'add_after_step' - ) - if all(value is None for value in [user_id, agent_id, run_id, memory_type]): + memory_config, 'add_after_step') + if all(value is None + for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) return user_id, agent_id, run_id, memory_type @@ -1031,7 +1034,8 @@ def _get_run_memory_info(self, memory_config: DictConfig): 'add_after_task', default_user_id=getattr(memory_config, 'user_id', None), ) - if all(value is None for value in [user_id, agent_id, run_id, memory_type]): + if all(value is None + for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) agent_id = agent_id or self.tag @@ -1042,22 +1046,18 @@ async def add_memory(self, messages: List[Message], add_type, **kwargs): if hasattr(self.config, 'memory') and self.config.memory: tools_num = len(self.memory_tools) if self.memory_tools else 0 - for idx, (mem_instance_type, memory_config) in enumerate( - self.config.memory.items() - ): + for idx, (mem_instance_type, + memory_config) in enumerate(self.config.memory.items()): if add_type == 'add_after_task': user_id, agent_id, run_id, memory_type = self._get_run_memory_info( - memory_config - ) + memory_config) else: user_id, agent_id, run_id, memory_type = self._get_step_memory_info( - memory_config - ) + memory_config) if idx < tools_num: - if any( - v is not None for v in [user_id, agent_id, run_id, memory_type] - ): + if any(v is not None + for v in [user_id, agent_id, run_id, memory_type]): await self.memory_tools[idx].add( messages, user_id=user_id, @@ -1086,11 +1086,11 @@ def save_history(self, messages: List[Message], **kwargs): config: DictConfig = deepcopy(self.config) config.runtime = self.runtime.to_dict() - save_history(self.output_dir, task=self.tag, config=config, messages=messages) + save_history( + self.output_dir, task=self.tag, config=config, messages=messages) - async def run_loop( - self, messages: Union[List[Message], str], **kwargs - ) -> AsyncGenerator[Any, Any]: + async def run_loop(self, messages: Union[List[Message], str], + **kwargs) -> AsyncGenerator[Any, Any]: """ Run the agent, mainly contains a llm calling and tool calling loop. @@ -1103,9 +1103,8 @@ async def run_loop( List[Message]: A list of message objects representing the agent's response or interaction history. """ try: - self.max_chat_round = getattr( - self.config, 'max_chat_round', LLMAgent.DEFAULT_MAX_CHAT_ROUND - ) + self.max_chat_round = getattr(self.config, 'max_chat_round', + LLMAgent.DEFAULT_MAX_CHAT_ROUND) self.register_callback_from_config() self.prepare_llm() self.prepare_runtime() @@ -1147,7 +1146,8 @@ async def run_loop( yield messages self.runtime.round += 1 # save memory and history - await self.add_memory(messages, add_type='add_after_step', **kwargs) + await self.add_memory( + messages, add_type='add_after_step', **kwargs) self.save_history(messages) # +1 means the next round the assistant may give a conclusion @@ -1156,10 +1156,10 @@ async def run_loop( messages.append( Message( role='assistant', - content=f"Task {messages[1].content} was cutted off, because " - f"max round({self.max_chat_round}) exceeded.", - ) - ) + content= + f'Task {messages[1].content} was cutted off, because ' + f'max round({self.max_chat_round}) exceeded.', + )) self.runtime.should_stop = True yield messages @@ -1170,8 +1170,8 @@ async def run_loop( def _add_memory(): asyncio.run( - self.add_memory(messages, add_type='add_after_task', **kwargs) - ) + self.add_memory( + messages, add_type='add_after_task', **kwargs)) loop = asyncio.get_running_loop() loop.run_in_executor(None, _add_memory) @@ -1181,22 +1181,22 @@ def _add_memory(): logger.warning(traceback.format_exc()) if hasattr(self.config, 'help'): logger.error( - f"[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}" + f'[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}' ) raise e async def run( - self, messages: Union[List[Message], str], **kwargs + self, messages: Union[List[Message], str], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: stream = kwargs.get('stream', False) with self.config_context(): if stream: OmegaConf.update( - self.config, 'generation_config.stream', True, merge=True - ) + self.config, 'generation_config.stream', True, merge=True) async def stream_generator(): - async for _chunk in self.run_loop(messages=messages, **kwargs): + async for _chunk in self.run_loop( + messages=messages, **kwargs): yield _chunk return stream_generator() diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index cfe387e5a..a07397c95 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -4,11 +4,10 @@ import os from importlib import resources as importlib_resources -from omegaconf import OmegaConf - from ms_agent.config import Config from ms_agent.utils import get_logger, strtobool from ms_agent.utils.constants import AGENT_CONFIG_FILE, MS_AGENT_ASCII +from omegaconf import OmegaConf from .base import CLICommand @@ -185,8 +184,10 @@ def _execute_with_config(self): self.args.config = os.path.join(current_dir, AGENT_CONFIG_FILE) else: # Use built-in default agent.yaml from package - default_config_path = importlib_resources.files('ms_agent').joinpath('agent', AGENT_CONFIG_FILE) - with importlib_resources.as_file(default_config_path) as config_file: + default_config_path = importlib_resources.files( + 'ms_agent').joinpath('agent', AGENT_CONFIG_FILE) + with importlib_resources.as_file( + default_config_path) as config_file: self.args.config = str(config_file) elif not os.path.exists(self.args.config): from modelscope import snapshot_download @@ -226,7 +227,10 @@ def _execute_with_config(self): # If knowledge_search_paths is provided, configure SirchmunkSearch if getattr(self.args, 'knowledge_search_paths', None): - paths = [p.strip() for p in self.args.knowledge_search_paths.split(',') if p.strip()] + paths = [ + p.strip() for p in self.args.knowledge_search_paths.split(',') + if p.strip() + ] if paths: if 'knowledge_search' not in config or not config.knowledge_search: # No existing knowledge_search config, create minimal config @@ -237,11 +241,13 @@ def _execute_with_config(self): 'work_path': './.sirchmunk', 'mode': 'FAST', } - config['knowledge_search'] = OmegaConf.create(knowledge_search_config) + config['knowledge_search'] = OmegaConf.create( + knowledge_search_config) else: # Existing knowledge_search config found, only update paths # LLM settings are already handled by SirchmunkSearch internally - existing = OmegaConf.to_container(config.knowledge_search, resolve=True) + existing = OmegaConf.to_container( + config.knowledge_search, resolve=True) existing['paths'] = paths config['knowledge_search'] = OmegaConf.create(existing) diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/knowledge_search/sirchmunk_search.py index dd80738f9..e1c76181f 100644 --- a/ms_agent/knowledge_search/sirchmunk_search.py +++ b/ms_agent/knowledge_search/sirchmunk_search.py @@ -7,12 +7,12 @@ """ import asyncio -from loguru import logger -from omegaconf import DictConfig from pathlib import Path -from typing import Any, Dict, List, Optional, Union, Callable +from typing import Any, Callable, Dict, List, Optional, Union +from loguru import logger from ms_agent.rag.base import RAG +from omegaconf import DictConfig class SirchmunkSearch(RAG): @@ -59,7 +59,8 @@ def __init__(self, config: DictConfig): # Sirchmunk search parameters self.reuse_knowledge = rag_config.get('reuse_knowledge', True) - self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', 0.85) + self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', + 0.85) self.cluster_sim_top_k = rag_config.get('cluster_sim_top_k', 3) self.search_mode = rag_config.get('mode', 'FAST') self.max_loops = rag_config.get('max_loops', 10) @@ -72,26 +73,24 @@ def __init__(self, config: DictConfig): self.llm_model_name = rag_config.get('llm_model_name', None) # Fall back to main llm config if not specified in knowledge_search - if ( - self.llm_api_key is None - or self.llm_base_url is None - or self.llm_model_name is None - ): + if (self.llm_api_key is None or self.llm_base_url is None + or self.llm_model_name is None): llm_config = config.get('llm', {}) if llm_config: service = getattr(llm_config, 'service', 'dashscope') if self.llm_api_key is None: - self.llm_api_key = getattr(llm_config, f"{service}_api_key", None) + self.llm_api_key = getattr(llm_config, + f'{service}_api_key', None) if self.llm_base_url is None: - self.llm_base_url = getattr(llm_config, f"{service}_base_url", None) + self.llm_base_url = getattr(llm_config, + f'{service}_base_url', None) if self.llm_model_name is None: self.llm_model_name = getattr(llm_config, 'model', None) # Embedding model configuration self.embedding_model_id = rag_config.get('embedding_model', None) self.embedding_model_cache_dir = rag_config.get( - 'embedding_model_cache_dir', None - ) + 'embedding_model_cache_dir', None) # Runtime state self._searcher = None @@ -109,7 +108,8 @@ def __init__(self, config: DictConfig): def _validate_config(self, config: DictConfig): """Validate configuration parameters.""" - if not hasattr(config, 'knowledge_search') or config.knowledge_search is None: + if not hasattr(config, + 'knowledge_search') or config.knowledge_search is None: raise ValueError( 'Missing knowledge_search configuration. ' 'Please add knowledge_search section to your config with at least "paths" specified.' @@ -118,7 +118,8 @@ def _validate_config(self, config: DictConfig): rag_config = config.knowledge_search paths = rag_config.get('paths', []) if not paths: - raise ValueError('knowledge_search.paths must be specified and non-empty') + raise ValueError( + 'knowledge_search.paths must be specified and non-empty') def _initialize_searcher(self): """Initialize the sirchmunk AgenticSearch instance.""" @@ -142,16 +143,12 @@ def _initialize_searcher(self): # Create embedding util # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) embedding_model_id = ( - self.embedding_model_id if self.embedding_model_id else None - ) + self.embedding_model_id if self.embedding_model_id else None) embedding_cache_dir = ( self.embedding_model_cache_dir - if self.embedding_model_cache_dir - else None - ) + if self.embedding_model_cache_dir else None) embedding = EmbeddingUtil( - model_id=embedding_model_id, cache_dir=embedding_cache_dir - ) + model_id=embedding_model_id, cache_dir=embedding_cache_dir) # Create AgenticSearch instance self._searcher = AgenticSearch( @@ -167,15 +164,16 @@ def _initialize_searcher(self): ) self._initialized = True - logger.info(f"SirschmunkSearch initialized with paths: {self.search_paths}") + logger.info( + f'SirschmunkSearch initialized with paths: {self.search_paths}' + ) except ImportError as e: raise ImportError( - f"Failed to import sirchmunk: {e}. " - 'Please install sirchmunk: pip install sirchmunk' - ) + f'Failed to import sirchmunk: {e}. ' + 'Please install sirchmunk: pip install sirchmunk') except Exception as e: - raise RuntimeError(f"Failed to initialize SirchmunkSearch: {e}") + raise RuntimeError(f'Failed to initialize SirchmunkSearch: {e}') def _log_callback_wrapper(self): """Create a callback wrapper to capture search logs. @@ -191,7 +189,7 @@ def log_callback( end: str = '\n', flush: bool = False, ): - log_entry = f"[{level.upper()}] {message}" + log_entry = f'[{level.upper()}] {message}' self._search_logs.append(log_entry) # Stream log in real-time if streaming callback is set if self._streaming_callback: @@ -223,7 +221,7 @@ async def add_documents(self, documents: List[str]) -> bool: await self._searcher.knowledge_base.refresh() return True except Exception as e: - logger.error(f"Failed to refresh knowledge base: {e}") + logger.error(f'Failed to refresh knowledge base: {e}') return False return True @@ -242,16 +240,19 @@ async def add_documents_from_files(self, file_paths: List[str]) -> bool: try: for file_path in file_paths: if Path(file_path).exists(): - await self._searcher.scan_directory(str(Path(file_path).parent)) + await self._searcher.scan_directory( + str(Path(file_path).parent)) return True except Exception as e: - logger.error(f"Failed to scan files: {e}") + logger.error(f'Failed to scan files: {e}') return False return True - async def retrieve( - self, query: str, limit: int = 5, score_threshold: float = 0.7, **filters - ) -> List[Dict[str, Any]]: + async def retrieve(self, + query: str, + limit: int = 5, + score_threshold: float = 0.7, + **filters) -> List[Dict[str, Any]]: """Retrieve relevant documents using sirchmunk. Args: @@ -270,7 +271,8 @@ async def retrieve( try: mode = filters.get('mode', self.search_mode) max_loops = filters.get('max_loops', self.max_loops) - max_token_budget = filters.get('max_token_budget', self.max_token_budget) + max_token_budget = filters.get('max_token_budget', + self.max_token_budget) # Perform search result = await self._searcher.search( @@ -286,16 +288,18 @@ async def retrieve( self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: # If a similar cluster was found and reused, it's a cache hit - self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + self._cluster_cache_hit = getattr(result.cluster, + '_reused_from_cache', False) # Get the cluster cache hit time if available if hasattr(result.cluster, 'updated_at'): - self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + self._cluster_cache_hit_time = getattr( + result.cluster, 'updated_at', None) # Parse results into standard format return self._parse_search_result(result, score_threshold, limit) except Exception as e: - logger.error(f"SirschmunkSearch retrieve failed: {e}") + logger.error(f'SirschmunkSearch retrieve failed: {e}') return [] async def query(self, query: str) -> str: @@ -332,12 +336,15 @@ async def query(self, query: str) -> str: self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + self._cluster_cache_hit = getattr(result.cluster, + '_reused_from_cache', False) if hasattr(result.cluster, 'updated_at'): - self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + self._cluster_cache_hit_time = getattr( + result.cluster, 'updated_at', None) # Store parsed context for frontend display - self._last_search_result = self._parse_search_result(result, score_threshold=0.7, limit=5) + self._last_search_result = self._parse_search_result( + result, score_threshold=0.7, limit=5) # Extract the synthesized answer from the context result if hasattr(result, 'answer'): @@ -351,12 +358,11 @@ async def query(self, query: str) -> str: return str(result) except Exception as e: - logger.error(f"SirschmunkSearch query failed: {e}") - return f"Query failed: {e}" + logger.error(f'SirschmunkSearch query failed: {e}') + return f'Query failed: {e}' - def _parse_search_result( - self, result: Any, score_threshold: float, limit: int - ) -> List[Dict[str, Any]]: + def _parse_search_result(self, result: Any, score_threshold: float, + limit: int) -> List[Dict[str, Any]]: """Parse sirchmunk search result into standard format. Args: @@ -385,38 +391,37 @@ def _parse_search_result( else: text_parts.append(str(snippet)) - results.append( - { - 'text': '\n'.join(text_parts) - if text_parts - else getattr(unit, 'summary', ''), - 'score': score, - 'metadata': { - 'source': source, - 'type': getattr(unit, 'abstraction_level', 'text') - if hasattr(unit, 'abstraction_level') - else 'text', - }, - } - ) + results.append({ + 'text': + '\n'.join(text_parts) if text_parts else getattr( + unit, 'summary', ''), + 'score': + score, + 'metadata': { + 'source': + source, + 'type': + getattr(unit, 'abstraction_level', 'text') + if hasattr(unit, 'abstraction_level') else 'text', + }, + }) # Handle format with evidence_units attribute directly elif hasattr(result, 'evidence_units'): for unit in result.evidence_units: score = getattr(unit, 'confidence', 1.0) if score >= score_threshold: - results.append( - { - 'text': str(unit.content) - if hasattr(unit, 'content') - else str(unit), - 'score': score, - 'metadata': { - 'source': getattr(unit, 'source_file', 'unknown'), - 'type': getattr(unit, 'abstraction_level', 'text'), - }, - } - ) + results.append({ + 'text': + str(unit.content) + if hasattr(unit, 'content') else str(unit), + 'score': + score, + 'metadata': { + 'source': getattr(unit, 'source_file', 'unknown'), + 'type': getattr(unit, 'abstraction_level', 'text'), + }, + }) # Handle list format elif isinstance(result, list): @@ -424,27 +429,27 @@ def _parse_search_result( if isinstance(item, dict): score = item.get('score', item.get('confidence', 1.0)) if score >= score_threshold: - results.append( - { - 'text': item.get( - 'content', item.get('text', str(item)) - ), - 'score': score, - 'metadata': item.get('metadata', {}), - } - ) + results.append({ + 'text': + item.get('content', item.get('text', str(item))), + 'score': + score, + 'metadata': + item.get('metadata', {}), + }) # Handle dict format elif isinstance(result, dict): score = result.get('score', result.get('confidence', 1.0)) if score >= score_threshold: - results.append( - { - 'text': result.get('content', result.get('text', str(result))), - 'score': score, - 'metadata': result.get('metadata', {}), - } - ) + results.append({ + 'text': + result.get('content', result.get('text', str(result))), + 'score': + score, + 'metadata': + result.get('metadata', {}), + }) # Sort by score and limit results results.sort(key=lambda x: x.get('score', 0), reverse=True) From 986b34fefa8707c34a84311c062acd98fd433e15 Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Mon, 23 Mar 2026 12:16:08 +0800 Subject: [PATCH 12/40] enrich researcher's reflection strategy to enhance stability --- .../v2/callbacks/researcher_callback.py | 395 +++++++++++++++++- .../v2/prompts/researcher/en/gpt5.txt | 12 +- projects/deep_research/v2/researcher.yaml | 23 +- 3 files changed, 412 insertions(+), 18 deletions(-) diff --git a/projects/deep_research/v2/callbacks/researcher_callback.py b/projects/deep_research/v2/callbacks/researcher_callback.py index 4796e2151..59214af44 100644 --- a/projects/deep_research/v2/callbacks/researcher_callback.py +++ b/projects/deep_research/v2/callbacks/researcher_callback.py @@ -1,14 +1,18 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +# yapf: disable import os +import re +import shutil from typing import List, Optional from callbacks.quality_checker import (ReportQualityChecker, build_quality_checkers) from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback +from ms_agent.llm.openai_llm import OpenAI as OpenAILLM from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf logger = get_logger() @@ -20,14 +24,19 @@ class ResearcherCallback(Callback): a chain of quality checks before allowing the run to end: 1. **File existence**: has ``final_report.md`` been written to disk? - 2. **Quality checkers**: a configurable list of + 2. **Compression check**: is the report over-compressed vs + ``reports/draft.md``? + 3. **Quality checkers**: a configurable list of :class:`ReportQualityChecker` instances run in order; the first failure triggers a reflection prompt. - If any check fails, a reflection prompt is injected as a ``user`` - message, ``runtime.should_stop`` is flipped back to ``False``, and - the agent continues for one more iteration. A configurable retry - cap prevents infinite loops. + At task end (``on_task_end``), optionally: + + - Selects the best available report source (``final_report.md`` > + ``reports/report.md`` > ``reports/draft.md``) based on character-count + retention ratio, and promotes it to ``final_report.md``. + - Runs a format-cleanup agent to fix citation and reference formatting + issues (e.g., multiple reference sections, inconsistent numbering). YAML configuration (all optional, shown with defaults):: @@ -35,6 +44,17 @@ class ResearcherCallback(Callback): enabled: true max_retries: 2 report_filename: final_report.md + compression_check: + enabled: false + min_retention_ratio: 0.3 + report_selection: + enabled: false + min_retention_ratio: 0.3 + report_cleanup: + enabled: false + # model: ... # defaults to researcher llm.model + # openai_api_key: ... # falls back to llm.openai_api_key + # openai_base_url: ... # falls back to llm.openai_base_url quality_check: enabled: true model: qwen3.5-flash # lightweight audit model @@ -42,12 +62,28 @@ class ResearcherCallback(Callback): # openai_base_url: ... # falls back to llm.openai_base_url """ + REPORTS_DIR = 'reports' + DRAFT_FILENAME = 'draft.md' + REPORT_FILENAME = 'report.md' + DEFAULT_MIN_RETENTION_RATIO = 0.3 + _CLEANUP_OUTPUT_MIN_RATIO = 0.75 + _MAX_CLEANUP_CHARS = 200000 + _REFLECTION_TEMPLATES = { 'zh': { 'no_report': ('外部检查发现:输出目录中尚未生成 {filename},该文件原本应由 Reporter 子代理自动创建。\n' '请确认最终报告未交付的原因,并立即采取行动修复。\n' '请注意:不要使用占位符或缩略内容替代实际报告正文。'), + 'over_compressed': + ('外部检查发现:{report_name} 的内容量({report_chars} 字符)' + '仅为 {draft_name}({draft_chars} 字符)的 {ratio:.0%},有可能存在内容丢失风险,' + '请对报告内容进行检查并采取合理的行动。\n' + '**重要提醒**:{draft_name} 是由工具逐章组装的完整版本,理论上保留了最大的证据保真度。\n' + '- 如果你确认你对报告进行的修改是合理的,可以直接说明压缩内容的理由,无需再次修改或者重写。\n' + '- 如果你发现 {report_name} 相比 {draft_name} 确实存在不合理的压缩,' + '请通过重写/追加/续写等方式来修复这些问题。\n' + '请立即采取行动完成报告交付。'), 'low_quality': ('外部检查发现:{filename} 的内容存在质量问题——{reason}。\n' '请仔细确认上述质量问题是否属实、是否还有更多问题,并立即采取行动修复。\n' @@ -69,6 +105,17 @@ class ResearcherCallback(Callback): 'Please identify why the final report was not delivered and immediately take action to fix it.\n' 'Note: Do not use placeholders or abbreviated content in place of the actual report body.' ), + 'over_compressed': + ('External inspection found that {report_name} ({report_chars} chars) ' + 'is only {ratio:.0%} of {draft_name} ({draft_chars} chars), ' + 'indicating a risk of content loss. Please review the report content and take appropriate action.\n' + '**IMPORTANT**: {draft_name} is the tool-assembled complete version that theoretically ' + 'preserves maximum evidence fidelity.\n' + '- If you confirm that your modifications to the report are reasonable, you may simply ' + 'explain the rationale for the compression without further modifications or rewrites.\n' + '- If you find that {report_name} has indeed been unreasonably compressed ' + 'compared to {draft_name}, please rewrite/append/continue writing to repair these issues.\n' + 'Please take immediate action to complete report delivery.'), 'low_quality': ('External inspection found quality issues in {filename} — {reason}.\n' 'Please carefully verify whether these issues are valid and whether additional problems exist, ' @@ -89,6 +136,56 @@ class ResearcherCallback(Callback): }, } + _CLEANUP_SYSTEM_PROMPTS = { + 'zh': + ('你是一个研究报告格式清理专家。你的唯一任务是修复报告中的引用和参考文献格式问题。' + '你绝对不能修改报告的实质内容、论点、证据或分析。\n\n' + '只修复以下类型的问题:\n' + '1. 多个参考文献章节:将所有分散的参考文献列表合并为报告末尾的唯一一个统一参考文献章节。' + '移除完全重复的参考文献条目,按首次出现顺序重新编号。\n' + '2. 引用标记不一致:确保正文中所有引用标记使用统一格式(如 [1], [2]),修复格式错误的引用标记。\n' + '3. 失效引用:修复引用了不存在条目的标记,或处理从未被引用的参考文献条目。\n' + '4. 参考文献编号:确保参考文献按照在正文中首次出现的顺序从 [1] 开始连续编号。\n' + '5. 参考文献格式:确保每条参考文献遵循一致的格式。\n\n' + '关键规则:\n' + '- 不得修改、增加、删除或改写任何实质性内容。\n' + '- 不得改变标题、章节结构或组织方式(合并参考文献章节除外)。\n' + '- 不得添加新的引用或删除已引用的内容。\n' + '- 如果报告没有格式问题,则原样返回。\n' + '- 返回修复后的完整报告,而不仅仅是修改的部分。\n' + '- 不要使用 markdown 代码块包裹输出。'), + 'en': + ('You are a research report formatting specialist. Your ONLY job is to fix citation ' + 'and reference formatting issues in the report. You must NOT modify the substantive ' + 'content, arguments, evidence, or analysis in any way.\n\n' + 'Fix ONLY the following types of issues:\n' + '1. Multiple reference sections: Merge all scattered reference/bibliography lists into ' + 'a single unified reference section at the very end of the report. Remove exact duplicate ' + 'entries and renumber sequentially by order of first appearance.\n' + '2. Citation marker inconsistencies: Ensure all in-text citation markers use a consistent ' + 'format (e.g., [1], [2]) throughout the report. Fix any malformed citation markers.\n' + '3. Orphaned citations: Fix citation markers that reference non-existent entries, or handle ' + 'reference entries that are never cited in the text.\n' + '4. Reference numbering: Ensure references are numbered sequentially starting from [1], ' + 'in order of first appearance in the text.\n' + '5. Reference entry formatting: Ensure each reference entry follows a consistent format.\n\n' + 'CRITICAL RULES:\n' + '- Do NOT modify, add, remove, or rephrase any substantive content.\n' + '- Do NOT change headings, section structure, or organization (except merging reference sections).\n' + '- Do NOT add new citations or remove existing cited content.\n' + '- If the report has no formatting issues, return it unchanged.\n' + '- Return the COMPLETE report with fixes applied, not just the changed parts.\n' + '- Do NOT wrap the output in markdown code blocks.'), + } + + _CLEANUP_USER_TEMPLATES = { + 'zh': ('请修复以下研究报告中的引用和参考文献格式问题:\n\n' + '---报告开始---\n{report}\n---报告结束---'), + 'en': ('Please fix the citation and reference formatting issues ' + 'in the following research report:\n\n' + '---BEGIN REPORT---\n{report}\n---END REPORT---'), + } + def __init__(self, config: DictConfig): super().__init__(config) self.output_dir: str = getattr(config, 'output_dir', './output') @@ -105,6 +202,61 @@ def __init__(self, config: DictConfig): self.report_filename = str( getattr(refl_cfg, 'report_filename', self.report_filename)) + # --- Compression check config --- + comp_cfg = ( + getattr(refl_cfg, 'compression_check', None) if refl_cfg else None) + self.compression_check_enabled: bool = False + self.min_retention_ratio: float = self.DEFAULT_MIN_RETENTION_RATIO + if comp_cfg is not None: + self.compression_check_enabled = bool( + getattr(comp_cfg, 'enabled', False)) + self.min_retention_ratio = float( + getattr(comp_cfg, 'min_retention_ratio', + self.DEFAULT_MIN_RETENTION_RATIO)) + + # --- Report selection config (on_task_end) --- + sel_cfg = ( + getattr(refl_cfg, 'report_selection', None) if refl_cfg else None) + self.report_selection_enabled: bool = False + self._selection_min_ratio: float = self.DEFAULT_MIN_RETENTION_RATIO + if sel_cfg is not None: + self.report_selection_enabled = bool( + getattr(sel_cfg, 'enabled', False)) + self._selection_min_ratio = float( + getattr(sel_cfg, 'min_retention_ratio', + self.DEFAULT_MIN_RETENTION_RATIO)) + + # --- Format cleanup agent config (on_task_end) --- + cleanup_cfg = ( + getattr(refl_cfg, 'report_cleanup', None) if refl_cfg else None) + self.report_cleanup_enabled: bool = False + self._cleanup_model: Optional[str] = None + self._cleanup_api_key: Optional[str] = None + self._cleanup_base_url: Optional[str] = None + self._cleanup_generation_config: Optional[dict] = None + if cleanup_cfg is not None: + self.report_cleanup_enabled = bool( + getattr(cleanup_cfg, 'enabled', False)) + self._cleanup_model = getattr(cleanup_cfg, 'model', None) + self._cleanup_api_key = getattr(cleanup_cfg, 'openai_api_key', + None) + self._cleanup_base_url = getattr(cleanup_cfg, 'openai_base_url', + None) + gen_cfg = getattr(cleanup_cfg, 'generation_config', None) + if gen_cfg is not None: + self._cleanup_generation_config = ( + OmegaConf.to_container(gen_cfg, resolve=True) + if isinstance(gen_cfg, DictConfig) else dict(gen_cfg)) + + # --- Derived paths --- + self._reports_dir: str = self.REPORTS_DIR + self._draft_path: str = os.path.join(self.output_dir, + self._reports_dir, + self.DRAFT_FILENAME) + self._inner_report_path: str = os.path.join(self.output_dir, + self._reports_dir, + self.REPORT_FILENAME) + self._retries_used: int = 0 self._checkers: List[ReportQualityChecker] = build_quality_checkers( config) @@ -135,7 +287,166 @@ def _get_template(self, key: str) -> str: def _marker_path(self) -> str: return os.path.join(self.output_dir, self.TASK_FINISHED_MARKER) + def _select_best_report(self) -> Optional[str]: + """Return the path to the best available report, based on char-count + retention ratio against ``reports/draft.md``. + + Candidates (in preference order): + 1. ``final_report.md`` + 2. ``reports/report.md`` + 3. ``reports/draft.md`` + + A candidate is accepted if it has >= ``_selection_min_ratio`` chars + relative to the draft. Falls back to ``draft.md`` if all others + are over-compressed. + """ + final_path = self._report_path + candidates = [ + (final_path, self.report_filename), + (self._inner_report_path, + os.path.join(self._reports_dir, self.REPORT_FILENAME)), + ] + + has_draft = os.path.isfile(self._draft_path) + if not has_draft: + for path, _ in candidates: + if os.path.isfile(path): + return path + return None + + try: + draft_chars = self._read_char_count(self._draft_path) + except OSError: + draft_chars = 0 + + if draft_chars <= 0: + for path, _ in candidates: + if os.path.isfile(path): + return path + return self._draft_path + + for path, name in candidates: + if not os.path.isfile(path): + continue + try: + chars = self._read_char_count(path) + except OSError: + continue + ratio = chars / draft_chars + if ratio >= self._selection_min_ratio: + return path + logger.warning( + f'ResearcherCallback: {name} ({chars} chars) is only ' + f'{ratio:.0%} of draft ({draft_chars} chars), ' + f'trying next candidate.') + + logger.warning('ResearcherCallback: all candidates over-compressed, ' + 'falling back to draft.md.') + return self._draft_path + + def _run_format_cleanup(self, report_path: str) -> bool: + """Run format cleanup on the report to fix citation/reference issues. + + Returns True if cleanup was applied successfully. + """ + try: + with open(report_path, 'r', encoding='utf-8') as f: + content = f.read() + except Exception as exc: + logger.warning(f'ResearcherCallback: failed to read report for ' + f'format cleanup: {exc}') + return False + + if not content.strip(): + logger.info( + 'ResearcherCallback: report is empty, skipping cleanup.') + return False + + if len(content) > self._MAX_CLEANUP_CHARS: + logger.warning( + f'ResearcherCallback: report too long for format cleanup ' + f'({len(content)} chars > {self._MAX_CLEANUP_CHARS}), ' + f'skipping.') + return False + + model = ( + self._cleanup_model or getattr(self.config.llm, 'model', None)) + api_key = ( + self._cleanup_api_key + or getattr(self.config.llm, 'openai_api_key', None)) + base_url = ( + self._cleanup_base_url + or getattr(self.config.llm, 'openai_base_url', None)) + + if not model: + logger.warning( + 'ResearcherCallback: no model configured for format cleanup.') + return False + + gen_cfg = self._cleanup_generation_config or {} + llm_config = OmegaConf.create({ + 'llm': { + 'model': model, + 'openai_api_key': api_key, + 'openai_base_url': base_url, + }, + 'generation_config': gen_cfg, + }) + + try: + client = OpenAILLM(llm_config) + except Exception as exc: + logger.warning(f'ResearcherCallback: failed to create LLM client ' + f'for format cleanup: {exc}') + return False + + sys_prompt = self._CLEANUP_SYSTEM_PROMPTS.get( + self.lang, self._CLEANUP_SYSTEM_PROMPTS['en']) + usr_template = self._CLEANUP_USER_TEMPLATES.get( + self.lang, self._CLEANUP_USER_TEMPLATES['en']) + + try: + response = client.generate(messages=[ + Message(role='system', content=sys_prompt), + Message( + role='user', content=usr_template.format(report=content)), + ]) + cleaned = (response.content or '').strip() + except Exception as exc: + logger.warning( + f'ResearcherCallback: format cleanup LLM call failed: {exc}') + return False + + if not cleaned: + logger.warning( + 'ResearcherCallback: format cleanup returned empty output.') + return False + + # Strip markdown code-block wrapper if present + cleaned = re.sub(r'^```\w*\n', '', cleaned) + cleaned = re.sub(r'\n```\s*$', '', cleaned) + + # Guard against truncated output + if len(cleaned) < len(content) * self._CLEANUP_OUTPUT_MIN_RATIO: + logger.warning( + f'ResearcherCallback: format cleanup output appears ' + f'truncated ({len(cleaned)} vs {len(content)} chars), ' + f'keeping original.') + return False + + try: + with open(report_path, 'w', encoding='utf-8') as f: + f.write(cleaned) + logger.info(f'ResearcherCallback: format cleanup applied ' + f'({len(content)} -> {len(cleaned)} chars).') + return True + except Exception as exc: + logger.warning( + f'ResearcherCallback: failed to write cleaned report: {exc}') + return False + async def on_task_end(self, runtime: Runtime, messages: List[Message]): + # --- Step 1: Write task-finished marker --- try: os.makedirs(self.output_dir, exist_ok=True) with open(self._marker_path, 'w') as f: @@ -147,6 +458,47 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): logger.warning( f'ResearcherCallback: failed to write marker: {exc}') + # --- Step 2: Best report selection --- + if self.report_selection_enabled: + best_source = self._select_best_report() + if best_source and best_source != self._report_path: + try: + os.makedirs( + os.path.dirname(self._report_path), exist_ok=True) + shutil.copy2(best_source, self._report_path) + source_name = os.path.relpath(best_source, self.output_dir) + logger.info( + f'ResearcherCallback: promoted {source_name} -> ' + f'{self.report_filename}') + except Exception as exc: + logger.warning( + f'ResearcherCallback: failed to promote report: ' + f'{exc}') + elif best_source: + logger.info( + f'ResearcherCallback: {self.report_filename} is already ' + f'the best candidate, no promotion needed.') + else: + logger.warning('ResearcherCallback: no report file found for ' + 'best-report selection.') + + # --- Step 3: Format cleanup agent --- + if self.report_cleanup_enabled: + if os.path.isfile(self._report_path): + logger.info( + 'ResearcherCallback: running format cleanup agent on ' + f'{self.report_filename}...') + self._run_format_cleanup(self._report_path) + else: + logger.warning( + f'ResearcherCallback: {self.report_filename} not found, ' + f'skipping format cleanup.') + + @staticmethod + def _read_char_count(path: str) -> int: + with open(path, 'r', encoding='utf-8') as f: + return len(f.read()) + async def after_tool_call(self, runtime: Runtime, messages: List[Message]): if not self.enabled: return @@ -169,7 +521,36 @@ async def after_tool_call(self, runtime: Runtime, messages: List[Message]): self._retries_used += 1 return - # --- Check 2: quality checker chain --- + # --- Check 2: compression check vs reports/draft.md --- + if self.compression_check_enabled and os.path.isfile(self._draft_path): + try: + report_chars = self._read_char_count(self._report_path) + draft_chars = self._read_char_count(self._draft_path) + if draft_chars > 0: + ratio = report_chars / draft_chars + if ratio < self.min_retention_ratio: + draft_rel = os.path.join(self._reports_dir, + self.DRAFT_FILENAME) + logger.warning( + f'ResearcherCallback: {self.report_filename} ' + f'({report_chars} chars) is only {ratio:.0%} of ' + f'{draft_rel} ({draft_chars} chars), ' + 'injecting over-compression prompt.') + prompt = self._get_template('over_compressed').format( + report_name=self.report_filename, + report_chars=report_chars, + draft_name=draft_rel, + draft_chars=draft_chars, + ratio=ratio) + messages.append(Message(role='user', content=prompt)) + runtime.should_stop = False + self._retries_used += 1 + return + except OSError as exc: + logger.warning(f'ResearcherCallback: failed to read files for ' + f'compression check: {exc}') + + # --- Check 3: quality checker chain --- if not self._checkers: logger.info('ResearcherCallback: no quality checkers configured, ' 'skipping quality gate.') diff --git a/projects/deep_research/v2/prompts/researcher/en/gpt5.txt b/projects/deep_research/v2/prompts/researcher/en/gpt5.txt index 8891d8313..6d21b8b48 100644 --- a/projects/deep_research/v2/prompts/researcher/en/gpt5.txt +++ b/projects/deep_research/v2/prompts/researcher/en/gpt5.txt @@ -15,7 +15,7 @@ Action protocol: Before outputting the final result, every iteration MUST invoke - When the research can only move forward by conducting synthesis based on the collected materials—such as framework design, cross-validation, scenario analysis, data analysis, etc.—you MUST proactively complete these tasks using the available tools. - Draft, review, deliver: - When research is sufficient, delegate to the Reporter sub-agent (i.e., agent_tools---reporter_tool) to generate the research report. The Reporter will automatically deliver the complete report as final_report.md. - - Then you MUST review the report for quality and accuracy. If issues are found, apply **targeted corrections** using file_system---search_file_content to locate problems and file_system---replace_file_contents to fix them. Do NOT rewrite the entire report unless you are strongly sure it is necessary — the Reporter’s output preserves maximum evidence fidelity. + - Then you MUST review the report for quality and accuracy. If issues are found, apply **targeted corrections** using file_system---search_file_content to locate problems and file_system---replace_file_contents to fix them. Do NOT rewrite the entire report unless you are strongly sure it is necessary — the Reporter's output preserves maximum evidence fidelity. # Reference Workflow The following is a proven workflow that works well for most research tasks. @@ -50,8 +50,10 @@ Stopping conditions (stop if you are confident to proceed to the next phase): ## Phase 3: Report Generation - Invoke the Reporter sub-agent to generate the report. Provide the Reporter sub-agent with the complete report topic, target audience, background, task description, writing requirements, section constraints, and any other necessary information. - - Note: do not impose a word-count requirement on the Reporter sub-agent unless the user explicitly requests it; DO NOT ask the Reporter sub-agent to include the Execution Summary (执行摘要) as a separate section in the report. -- The Reporter will deliver the complete report as final_report.md. After the Reporter returns, you MUST review the report for quality and accuracy: + - Do not impose a word-count requirement on the Reporter sub-agent unless the user explicitly requests it; DO NOT ask the Reporter sub-agent to include the Execution Summary (执行摘要) as a separate section in the report. + - The Reporter sub-agent writes the final report to reports/report.md and returns only the execution summary and artifact file paths. The system automatically copies reports/report.md to final_report.md. Therefore, final_report.md may not be listed in the Reporter's Artifacts field. Do not ask the Reporter to create final_report.md directly. Review and edit final_report.md. +- After the Reporter returns, you MUST review the report for quality and accuracy: + - **Read once, then act.** Read the full report content once to form your assessment. For subsequent checks, use file_system---search_file_content or file_system---read_file with start_line/end_line — do NOT re-read the entire report file repeatedly. Always check your conversation history before re-reading a file that may already be in context. - **Verify first.** Before editing, spot-check factual accuracy, logical consistency, coverage of the user's core questions, and citation–claim alignment against the collected evidence. - The report MUST comply with the "Quality Constraints" and "Default Report Style" sections. Execution Summary (执行摘要) MUST NOT appear as a chapter in the report body. - **Edit with justification.** Every substantive change (compression, deletion, restructuring, format conversion) must be driven by a concrete problem — such as factual redundancy, logical disorganization, evidence inconsistency, or style/quality violations. Well-structured content with reasonable depth and detail must be preserved as-is, including its structure, granularity, and length. @@ -77,12 +79,12 @@ Stopping conditions (stop if you are confident to proceed to the next phase): # Tool Invocation Protocol - You MUST use the tools under the todo_list server to create, update, and read the TODO list. You MUST NOT use any other tools or services to maintain the TODO list. - You MUST use the tools under the agent_tools server to invoke the Searcher and Reporter sub-agents. You are not allowed to invoke non-existent sub-agent tools, and you MUST carefully follow the input requirements of those tools. -- When context is unclear (e.g., the Searcher sub-agent’s output appears to have lost details, or the Reporter sub-agent’s report has issues), you should read, filter, and load evidence using the evidence_store server, ensuring you have sufficient confidence before proceeding to the next step. +- When context is unclear (e.g., the Searcher sub-agent's output appears to have lost details, or the Reporter sub-agent's report has issues), you should read, filter, and load evidence using the evidence_store server, ensuring you have sufficient confidence before proceeding to the next step. - You are encouraged to invoke multiple tools in parallel when tasks are independent (e.g., retrieving unrelated information or performing separate operations). - For file-level operations, keep using relative paths. # Quality Constraints -- NEVER fabricate citations or sources. Every factual statement in the final deliverable must be supported by the Searcher sub-agent’s research conclusions and stored evidence. +- NEVER fabricate citations or sources. Every factual statement in the final deliverable must be supported by the Searcher sub-agent's research conclusions and stored evidence. - Clearly track time constraints and the current date. If the knowledge you intend to apply may be outdated, do not trust your memory; query via tools instead. - Strictly control scope: if the user asks for X, do not drift to Y. - Citation integrity in the final report: diff --git a/projects/deep_research/v2/researcher.yaml b/projects/deep_research/v2/researcher.yaml index 50c2ece5d..95ba77a84 100644 --- a/projects/deep_research/v2/researcher.yaml +++ b/projects/deep_research/v2/researcher.yaml @@ -14,7 +14,10 @@ generation_config: # Supports role names: system, user, assistant, tool, last_message prefix_cache_roles: [system, user, assistant, tool] # extra_body: - # enable_thinking: false + # enable_thinking: true + # show_reasoning: true + # reasoning_output: stdout + # reasoning_effort: medium tag: deep-research-researcher @@ -63,6 +66,7 @@ tools: mcp: false enable_stats: true run_in_thread: true + run_in_process: true max_workers: 4 definitions: - tool_name: searcher_tool @@ -102,7 +106,7 @@ tools: description: > Invoke the Reporter sub-agent to generate a research report based on collected evidence. Reporter reads the stored evidence cards and executes a complex workflow for research report writing. - The completed report is automatically saved to `final_report.md` in the output directory. + The completed report is automatically saved to `final_report.md` in the output directory by the system. Returns a JSON result containing: execution summary and intermediate artifact file paths (the full report body is NOT included in the return value — read `final_report.md` directly to access the report content). @@ -140,10 +144,17 @@ callbacks: - callbacks/researcher_callback # Self-reflection checks before allowing the researcher to stop. -# Runs inside ResearcherCallback.after_tool_call. self_reflection: enabled: true - max_retries: 2 + max_retries: 3 + compression_check: + enabled: true + min_retention_ratio: 0.5 + report_selection: + enabled: true + min_retention_ratio: 0.5 + report_cleanup: + enabled: false quality_check: enabled: true model: qwen3.5-flash @@ -152,8 +163,8 @@ handler: time_handler code_file: researcher -max_chat_round: 42 +max_chat_round: 45 -tool_call_timeout: 2400 +tool_call_timeout: 2600 output_dir: ./output From fec8bfe4f305a08e8b540b9715e543895c083764 Mon Sep 17 00:00:00 2001 From: suluyan Date: Tue, 24 Mar 2026 11:36:46 +0800 Subject: [PATCH 13/40] mv localsearch to tools --- docs/en/Components/Config.md | 16 +- docs/zh/Components/config.md | 14 +- examples/knowledge_search/agent.yaml.example | 86 ------ ms_agent/agent/llm_agent.py | 39 +-- ms_agent/cli/run.py | 25 +- ms_agent/knowledge_search/README.md | 277 ----------------- ms_agent/knowledge_search/__init__.py | 9 +- ms_agent/llm/utils.py | 30 +- ms_agent/rag/utils.py | 4 +- ms_agent/tools/search/localsearch_tool.py | 282 ++++++++++++++++++ .../search}/sirchmunk_search.py | 253 ++++++++++------ ms_agent/tools/tool_manager.py | 4 + tests/knowledge_search/test_sirschmunk.py | 223 +++++--------- 13 files changed, 586 insertions(+), 676 deletions(-) delete mode 100644 examples/knowledge_search/agent.yaml.example delete mode 100644 ms_agent/knowledge_search/README.md create mode 100644 ms_agent/tools/search/localsearch_tool.py rename ms_agent/{knowledge_search => tools/search}/sirchmunk_search.py (68%) diff --git a/docs/en/Components/Config.md b/docs/en/Components/Config.md index f1253bd75..2154add8e 100644 --- a/docs/en/Components/Config.md +++ b/docs/en/Components/Config.md @@ -102,6 +102,14 @@ tools: url: https://mcp.api-inference.modelscope.net/xxx/sse exclude: - map_geo + # Local codebase / document search (sirchmunk), exposed as the `localsearch` tool + localsearch: + paths: + - ./src + - ./docs + work_path: ./.sirchmunk + mode: FAST + # Optional: llm_api_key, llm_base_url, llm_model_name (else inherited from `llm`) ``` For the complete list of supported tools and custom tools, please refer to [here](./Tools.md) @@ -167,19 +175,19 @@ In addition to yaml configuration, MS-Agent also supports several additional com > Any configuration in agent.yaml can be passed in with new values via command line, and also supports reading from environment variables with the same name (case insensitive), for example `--llm.modelscope_api_key xxx-xxx`. -- knowledge_search_paths: Knowledge search paths, comma-separated multiple paths. When provided, automatically enables SirchmunkSearch for knowledge retrieval, with LLM configuration automatically inherited from the `llm` module. +- knowledge_search_paths: Comma-separated local search paths. Merges into `tools.localsearch.paths` and registers the **`localsearch`** tool (sirchmunk) for on-demand use by the model—not automatic per-turn injection. LLM settings are inherited from the `llm` module unless you set `tools.localsearch.llm_*` fields. ### Quick Start for Knowledge Search -Use the `--knowledge_search_paths` parameter to quickly enable knowledge search based on local documents: +Use `--knowledge_search_paths` or define `tools.localsearch` in yaml so the model can call `localsearch` when needed: ```bash # Using default agent.yaml configuration, automatically reuses LLM settings -ms-agent run --query "How to implement user authentication?" --knowledge_search_paths "./src,./docs" +ms-agent run --query "How to implement user authentication?" --knowledge_search_paths "/path/to/docs" # Specify configuration file ms-agent run --config /path/to/agent.yaml --query "your question" --knowledge_search_paths "/path/to/docs" ``` LLM-related parameters (api_key, base_url, model) are automatically inherited from the `llm` module in the configuration file, no need to configure them repeatedly. -If you need to use independent LLM configuration in the `knowledge_search` module, you can explicitly configure `knowledge_search.llm_api_key` and other parameters in the yaml. +For a dedicated sirchmunk LLM, set `tools.localsearch.llm_api_key`, `llm_base_url`, and `llm_model_name` in yaml. Legacy top-level `knowledge_search` with the same keys is still read for backward compatibility. diff --git a/docs/zh/Components/config.md b/docs/zh/Components/config.md index 12849f2a7..0baa4bf18 100644 --- a/docs/zh/Components/config.md +++ b/docs/zh/Components/config.md @@ -102,6 +102,14 @@ tools: url: https://mcp.api-inference.modelscope.net/xxx/sse exclude: - map_geo + # 本地代码库/文档搜索(sirchmunk),对应模型可调用的 `localsearch` 工具 + localsearch: + paths: + - ./src + - ./docs + work_path: ./.sirchmunk + mode: FAST + # 可选:llm_api_key、llm_base_url、llm_model_name(不填则从 `llm` 继承) ``` 支持的完整工具列表,以及自定义工具请参考 [这里](./tools) @@ -165,13 +173,13 @@ handler: custom_handler } } ``` -- knowledge_search_paths: 知识搜索路径,逗号分隔的多个路径。传入后会自动启用 SirchmunkSearch 进行知识检索,LLM 配置自动从 `llm` 模块复用 +- knowledge_search_paths: 知识搜索路径,逗号分隔。会合并到 `tools.localsearch.paths` 并注册 **`localsearch`** 工具(sirchmunk),由模型按需调用,不再在每轮自动注入上下文;除非配置 `tools.localsearch.llm_*`,否则 LLM 从 `llm` 模块复用 > agent.yaml 中的任意一个配置,都可以使用命令行传入新的值,也支持从同名(大小写不敏感)环境变量中读取,例如 `--llm.modelscope_api_key xxx-xxx`。 ### 知识搜索快速使用 -通过 `--knowledge_search_paths` 参数,可以快速启用基于本地文档的知识搜索: +通过 `--knowledge_search_paths` 或在 yaml 中配置 `tools.localsearch`,启用本地知识搜索(模型按需调用 `localsearch`): ```bash # 使用默认 agent.yaml 配置,自动复用 LLM 设置 @@ -182,4 +190,4 @@ ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_sea ``` LLM 相关参数(api_key, base_url, model)会自动从配置文件的 `llm` 模块继承,无需重复配置。 -如果需要在 `knowledge_search` 模块中使用独立的 LLM 配置,可以在 yaml 中显式配置 `knowledge_search.llm_api_key` 等参数。 +若 sirchmunk 需独立 LLM,可在 yaml 的 `tools.localsearch` 下设置 `llm_api_key`、`llm_base_url`、`llm_model_name`。仍支持旧版顶层 `knowledge_search` 相同字段,以便迁移。 diff --git a/examples/knowledge_search/agent.yaml.example b/examples/knowledge_search/agent.yaml.example deleted file mode 100644 index cc11a8a3d..000000000 --- a/examples/knowledge_search/agent.yaml.example +++ /dev/null @@ -1,86 +0,0 @@ -# Sirchmunk Knowledge Search 配置示例 -# Sirchmunk Knowledge Search Configuration Example - -# 在您的 agent.yaml 或 workflow.yaml 中添加以下配置: - -llm: - service: modelscope - model: Qwen/Qwen3-235B-A22B-Instruct-2507 - modelscope_api_key: - modelscope_base_url: https://api-inference.modelscope.cn/v1 - -generation_config: - temperature: 0.3 - top_k: 20 - stream: true - -# Knowledge Search 配置(可选) -# 用于在本地代码库中搜索相关信息 -knowledge_search: - # 必选:要搜索的路径列表 - paths: - - ./src - - ./docs - - # 可选:sirchmunk 工作目录,用于缓存 - work_path: ./.sirchmunk - - # 可选:LLM 配置(如不配置则使用上面 llm 的配置) - llm_api_key: - llm_base_url: https://api.openai.com/v1 - llm_model_name: gpt-4o-mini - - # 可选:Embedding 模型 - embedding_model: text-embedding-3-small - - # 可选:聚类相似度阈值 - cluster_sim_threshold: 0.85 - - # 可选:聚类 TopK - cluster_sim_top_k: 3 - - # 可选:是否重用之前的知识 - reuse_knowledge: true - - # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) - mode: FAST - - # 可选:最大循环次数 - max_loops: 10 - - # 可选:最大 token 预算 - max_token_budget: 128000 - -prompt: - system: | - You are an assistant that helps me complete tasks. - -max_chat_round: 9999 - -# 使用说明: -# 1. 配置 knowledge_search 后,LLMAgent 会在处理用户请求时自动搜索本地代码库 -# 2. 搜索结果会自动添加到 user message 的 search_result 和 searching_detail 字段 -# 3. search_result 包含搜索到的相关文档,会作为上下文提供给 LLM -# 4. searching_detail 包含搜索日志和元数据,可用于前端展示 -# -# Python 使用示例: -# ```python -# from ms_agent import LLMAgent -# from ms_agent.config import Config -# -# config = Config.from_task('path/to/agent.yaml') -# agent = LLMAgent(config=config) -# result = await agent.run('如何实现用户认证功能?') -# -# # 获取搜索详情(用于前端展示) -# for msg in result: -# if msg.role == 'user': -# print(f"Search logs: {msg.searching_detail}") -# print(f"Search results: {msg.search_result}") -# ``` -# -# CLI 测试命令: -# export LLM_API_KEY="your-api-key" -# export LLM_BASE_URL="https://api.openai.com/v1" -# export LLM_MODEL_NAME="gpt-4o-mini" -# python tests/knowledge_search/test_cli.py --query "你的问题" diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 5f2ddf2e7..36e46a672 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -13,7 +13,6 @@ import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping -from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.llm.llm import LLM from ms_agent.llm.utils import Message, ToolResult from ms_agent.memory import Memory, get_memory_meta_safe, memory_mapping @@ -107,7 +106,6 @@ def __init__( self.tool_manager: Optional[ToolManager] = None self.memory_tools: List[Memory] = [] self.rag: Optional[RAG] = None - self.knowledge_search: Optional[SirschmunkSearch] = None self.llm: Optional[LLM] = None self.runtime: Optional[Runtime] = None self.max_chat_round: int = 0 @@ -528,6 +526,7 @@ async def parallel_tool_call(self, tool_call_id=tool_call_query['id'], name=tool_call_query['tool_name'], resources=tool_call_result_format.resources, + tool_detail=tool_call_result_format.tool_detail, ) if _new_message.tool_call_id is None: @@ -636,11 +635,7 @@ async def create_messages( return messages async def do_rag(self, messages: List[Message]): - """Process RAG or knowledge search to enrich the user query with context. - - This method handles both traditional RAG and sirchmunk-based knowledge search. - For knowledge search, it also populates searching_detail and search_result - fields in the message for frontend display and next-turn LLM context. + """Process RAG to enrich the user query with context. Args: messages (List[Message]): The message list to process. @@ -654,23 +649,6 @@ async def do_rag(self, messages: List[Message]): # Handle traditional RAG if self.rag is not None: user_message.content = await self.rag.query(query) - # Handle sirchmunk knowledge search - if self.knowledge_search is not None: - # Perform search and get results - search_result = await self.knowledge_search.query(query) - search_details = self.knowledge_search.get_search_details() - - # Store search details in the message for frontend display - user_message.searching_detail = search_details - user_message.search_result = search_result - - # Build enriched context from search results - if search_result: - # Append search context to user query - context = search_result - user_message.content = ( - f'Relevant context retrieved from codebase search:\n\n{context}\n\n' - f'User question: {query}') async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: @@ -757,18 +735,6 @@ async def prepare_rag(self): f'which supports: {list(rag_mapping.keys())}') self.rag: RAG = rag_mapping(rag.name)(self.config) - async def prepare_knowledge_search(self): - """Load and initialize the knowledge search component from the config.""" - if self.knowledge_search is not None: - # Already initialized (e.g. by caller before run_loop), skip to avoid - # overwriting a configured instance (e.g. one with streaming callbacks set). - return - if hasattr(self.config, 'knowledge_search'): - ks_config = self.config.knowledge_search - if ks_config is not None: - self.knowledge_search: SirchmunkSearch = SirchmunkSearch( - self.config) - async def condense_memory(self, messages: List[Message]) -> List[Message]: """ Update memory using the current conversation history. @@ -1111,7 +1077,6 @@ async def run_loop(self, messages: Union[List[Message], str], await self.prepare_tools() await self.load_memory() await self.prepare_rag() - await self.prepare_knowledge_search() self.runtime.tag = self.tag if messages is None: diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index 84c93a97e..2accdb40e 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -153,7 +153,7 @@ def define_args(parsers: argparse.ArgumentParser): type=str, default=None, help= - 'Comma-separated list of paths for knowledge search. When provided, enables SirchmunkSearch using LLM config from llm module.' + 'Comma-separated list of paths for knowledge search.' ) parser.set_defaults(func=subparser_func) @@ -233,31 +233,28 @@ def _execute_with_config(self): config = Config.from_task(self.args.config) - # If knowledge_search_paths is provided, configure SirchmunkSearch + # If knowledge_search_paths is provided, configure tools.localsearch if getattr(self.args, 'knowledge_search_paths', None): paths = [ p.strip() for p in self.args.knowledge_search_paths.split(',') if p.strip() ] if paths: - if 'knowledge_search' not in config or not config.knowledge_search: - # No existing knowledge_search config, create minimal config - # LLM settings will be auto-reused from llm module by SirchmunkSearch - knowledge_search_config = { - 'name': 'SirchmunkSearch', + if not hasattr(config, 'tools') or config.tools is None: + config['tools'] = OmegaConf.create({}) + tl = getattr(config.tools, 'localsearch', None) + if tl is None or not OmegaConf.is_config(tl): + localsearch_config = { 'paths': paths, 'work_path': './.sirchmunk', 'mode': 'FAST', } - config['knowledge_search'] = OmegaConf.create( - knowledge_search_config) + config.tools['localsearch'] = OmegaConf.create( + localsearch_config) else: - # Existing knowledge_search config found, only update paths - # LLM settings are already handled by SirchmunkSearch internally - existing = OmegaConf.to_container( - config.knowledge_search, resolve=True) + existing = OmegaConf.to_container(tl, resolve=True) existing['paths'] = paths - config['knowledge_search'] = OmegaConf.create(existing) + config.tools['localsearch'] = OmegaConf.create(existing) if Config.is_workflow(config): from ms_agent.workflow.loader import WorkflowLoader diff --git a/ms_agent/knowledge_search/README.md b/ms_agent/knowledge_search/README.md deleted file mode 100644 index ef86df0da..000000000 --- a/ms_agent/knowledge_search/README.md +++ /dev/null @@ -1,277 +0,0 @@ -# Sirchmunk Knowledge Search 集成 - -本模块实现了 [sirchmunk](https://github.com/modelscope/sirchmunk) 与 ms_agent 框架的集成,提供了基于代码库的智能搜索功能。 - -## 功能特性 - -- **智能代码搜索**: 使用 LLM 和 embedding 模型对代码库进行语义搜索 -- **多模式搜索**: 支持 FAST、DEEP、FILENAME_ONLY 三种搜索模式 -- **知识复用**: 自动缓存和复用之前的搜索结果,减少 LLM 调用 -- **前端友好**: 提供详细的搜索日志和结果,方便前端展示 -- **无缝集成**: 与 LLMAgent 无缝集成,像使用 RAG 一样简单 - -## 安装 - -```bash -pip install sirchmunk -``` - -## 配置 - -在您的 `agent.yaml` 或 `workflow.yaml` 中添加以下配置: - -```yaml -llm: - service: dashscope - model: qwen3.5-plus - dashscope_api_key: - dashscope_base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 - -generation_config: - temperature: 0.3 - stream: true - -# Knowledge Search 配置 -knowledge_search: - # 必选:要搜索的路径列表 - paths: - - ./src - - ./docs - - # 可选:sirchmunk 工作目录 - work_path: ./.sirchmunk - - # 可选:LLM 配置(如不配置则自动复用上面 llm 模块的配置) - # llm_api_key: - # llm_base_url: https://api.openai.com/v1 - # llm_model_name: gpt-4o-mini - - # 可选:Embedding 模型 - embedding_model: text-embedding-3-small - - # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) - mode: FAST - - # 可选:是否重用之前的知识 - reuse_knowledge: true -``` - -**LLM 配置自动复用机制**: - -`SirchmunkSearch` 会自动从主配置的 `llm` 模块复用 LLM 相关参数: -- 如果 `knowledge_search.llm_api_key` 未配置,自动使用 `llm.{service}_api_key` -- 如果 `knowledge_search.llm_base_url` 未配置,自动使用 `llm.{service}_base_url` -- 如果 `knowledge_search.llm_model_name` 未配置,自动使用 `llm.model` - -其中 `service` 是 `llm.service` 的值(如 `dashscope`, `modelscope`, `openai` 等)。 - -通过 CLI 使用时,只需传入 `--knowledge_search_paths` 参数,无需额外配置 LLM 参数。 - -## 使用方式 - -### 1. 通过 CLI 使用(推荐) - -从命令行直接运行,无需编写代码: - -```bash -# 基本用法 - LLM 配置自动从 agent.yaml 的 llm 模块复用 -ms-agent run --query "如何实现用户认证功能?" --knowledge_search_paths "./src,./docs" - -# 指定配置文件 -ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_search_paths "/path/to/docs" -``` - -**说明**: -- `--knowledge_search_paths` 参数支持逗号分隔的多个路径 -- LLM 相关配置(api_key, base_url, model)会自动从配置文件的 `llm` 模块复用 -- 如果 `knowledge_search` 模块单独配置了 `llm_api_key` 等参数,则优先使用模块自己的配置 - -### 2. 通过 LLMAgent 使用 - -```python -from ms_agent import LLMAgent -from ms_agent.config import Config - -# 从配置文件加载 -config = Config.from_task('path/to/agent.yaml') -agent = LLMAgent(config=config) - -# 运行查询 - 会自动触发知识搜索 -result = await agent.run('如何实现用户认证功能?') - -# 获取搜索结果 -for msg in result: - if msg.role == 'user': - # 搜索详情(用于前端展示) - print(f"Search logs: {msg.searching_detail}") - # 搜索结果(作为 LLM 上下文) - print(f"Search results: {msg.search_result}") -``` - -### 3. 单独使用 SirchmunkSearch - -```python -from ms_agent.knowledge_search import SirchmunkSearch -from omegaconf import DictConfig - -config = DictConfig({ - 'knowledge_search': { - 'paths': ['./src', './docs'], - 'work_path': './.sirchmunk', - 'llm_api_key': 'your-api-key', - 'llm_model_name': 'gpt-4o-mini', - 'mode': 'FAST', - } -}) - -searcher = SirchmunkSearch(config) - -# 查询(返回合成答案) -answer = await searcher.query('如何实现用户认证?') - -# 检索(返回原始搜索结果) -results = await searcher.retrieve( - query='用户认证', - limit=5, - score_threshold=0.7 -) - -# 获取搜索日志 -logs = searcher.get_search_logs() - -# 获取搜索详情 -details = searcher.get_search_details() -``` - -## 环境变量 - -可以通过环境变量配置: - -```bash -# LLM 配置(如不设置则自动从 agent.yaml 的 llm 模块读取) -export LLM_API_KEY="your-api-key" -export LLM_BASE_URL="https://api.openai.com/v1" -export LLM_MODEL_NAME="gpt-4o-mini" - -# Embedding 模型配置 -export EMBEDDING_MODEL_ID="text-embedding-3-small" -export SIRCHMUNK_WORK_PATH="./.sirchmunk" -``` - -**注意**:通过 CLI 使用时,推荐直接在 `.env` 文件或 agent.yaml 中配置 LLM 参数,`SirchmunkSearch` 会自动复用。 - -## 测试 - -### 单元测试 - -```bash -export LLM_API_KEY="your-api-key" -export LLM_BASE_URL="https://api.openai.com/v1" -export LLM_MODEL_NAME="gpt-4o-mini" - -python -m unittest tests/knowledge_search/test_sirschmunk.py -``` - -### CLI 测试 - -```bash -# 基本测试 -python tests/knowledge_search/test_cli.py - -# 指定查询 -python tests/knowledge_search/test_cli.py -q "如何实现用户认证?" - -# 仅测试 standalone 模式 -python tests/knowledge_search/test_cli.py -m standalone - -# 仅测试 agent 模式 -python tests/knowledge_search/test_cli.py -m agent -``` - -## 配置参数说明 - -| 参数 | 类型 | 默认值 | 说明 | -|------|------|--------|------| -| paths | List[str] | 必选 | 要搜索的目录/文件路径列表 | -| work_path | str | ./.sirchmunk | sirchmunk 工作目录,用于缓存 | -| llm_api_key | str | 从 llm 配置继承 | LLM API 密钥 | -| llm_base_url | str | 从 llm 配置继承 | LLM API 基础 URL | -| llm_model_name | str | 从 llm 配置继承 | LLM 模型名称 | -| embedding_model | str | text-embedding-3-small | Embedding 模型 ID | -| cluster_sim_threshold | float | 0.85 | 聚类相似度阈值 | -| cluster_sim_top_k | int | 3 | 聚类 TopK 数量 | -| reuse_knowledge | bool | true | 是否重用之前的知识 | -| mode | str | FAST | 搜索模式 (DEEP/FAST/FILENAME_ONLY) | -| max_loops | int | 10 | 最大搜索循环次数 | -| max_token_budget | int | 128000 | 最大 token 预算 | - -## 搜索模式 - -- **FAST**: 快速模式,使用贪婪策略,1-5 秒内返回结果,0-2 次 LLM 调用 -- **DEEP**: 深度模式,并行多路径检索 + ReAct 优化,5-30 秒,4-6 次 LLM 调用 -- **FILENAME_ONLY**: 仅文件名模式,基于模式匹配,无 LLM 调用,非常快 - -## Message 字段扩展 - -为了支持知识搜索,`Message` 类增加了两个字段: - -- **searching_detail** (Dict[str, Any]): 搜索过程日志和元数据,用于前端展示 - - `logs`: 搜索日志列表 - - `mode`: 使用的搜索模式 - - `paths`: 搜索的路径 - - `work_path`: 工作目录 - - `reuse_knowledge`: 是否重用知识 - -- **search_result** (List[Dict[str, Any]]): 搜索结果,作为下一轮 LLM 的上下文 - - `text`: 文档内容 - - `score`: 相关性分数 - - `metadata`: 元数据(如源文件、类型等) - -## 工作原理 - -1. 用户发送查询 -2. LLMAgent 调用 `prepare_knowledge_search()` 初始化 SirchmunkSearch -3. `do_rag()` 方法执行知识搜索: - - 调用 `searcher.retrieve()` 获取相关文档 - - 将搜索结果存入 `message.search_result` - - 将搜索日志存入 `message.searching_detail` - - 将搜索结果格式化为上下文,附加到用户查询 -4. LLM 接收 enriched query 并生成回答 -5. 前端可以通过 `searching_detail` 展示搜索过程 - -## 故障排除 - -### 常见问题 - -1. **ImportError: No module named 'sirchmunk'** - ```bash - pip install sirchmunk - ``` - -2. **搜索结果为空** - - 检查 `paths` 配置是否正确 - - 确保路径下有可搜索的文件 - - 尝试降低 `cluster_sim_threshold` 值 - -3. **LLM API 调用失败** - - 检查 API key 是否正确 - - 检查 base URL 是否正确 - - 查看搜索日志了解详细错误 - -### 日志查看 - -```python -# 查看搜索日志 -logs = searcher.get_search_logs() -for log in logs: - print(log) - -# 或在配置中启用 verbose -knowledge_search: - verbose: true -``` - -## 参考资源 - -- [sirchmunk GitHub](https://github.com/modelscope/sirchmunk) -- [ModelScope Agent](https://github.com/modelscope/modelscope-agent) diff --git a/ms_agent/knowledge_search/__init__.py b/ms_agent/knowledge_search/__init__.py index 33362beee..8746f27c9 100644 --- a/ms_agent/knowledge_search/__init__.py +++ b/ms_agent/knowledge_search/__init__.py @@ -1,11 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Knowledge search module based on sirchmunk. +"""Backward-compatible re-exports for sirchmunk local search. -This module provides integration between sirchmunk's AgenticSearch -and the ms_agent framework, enabling intelligent codebase search -capabilities similar to RAG. +Implementation lives in :mod:`ms_agent.tools.search.sirchmunk_search`; prefer +importing ``SirchmunkSearch`` from there in new code. """ -from .sirchmunk_search import SirchmunkSearch +from ms_agent.tools.search.sirchmunk_search import SirchmunkSearch __all__ = ['SirchmunkSearch'] diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 410aa12f0..c108170cb 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -61,11 +61,8 @@ class Message: api_calls: int = 1 - # Knowledge search (sirchmunk) related fields - # searching_detail: Search process logs and metadata for frontend display - searching_detail: Dict[str, Any] = field(default_factory=dict) - # search_result: Raw search results to be used as context for next LLM turn - search_result: List[Dict[str, Any]] = field(default_factory=list) + # role=tool: extra payload for UIs / SSE only; omitted from LLM API via to_dict_clean(). + tool_detail: Optional[str] = None def to_dict(self): return asdict(self) @@ -88,7 +85,11 @@ def to_dict_clean(self): } } required = ['content', 'role'] - rm = ['completion_tokens', 'prompt_tokens', 'api_calls'] + # Never send UI-only fields to model providers. + rm = [ + 'completion_tokens', 'prompt_tokens', 'api_calls', 'tool_detail', + 'searching_detail', 'search_result' + ] return { key: value for key, value in raw_dict.items() @@ -98,20 +99,33 @@ def to_dict_clean(self): @dataclass class ToolResult: + """Tool execution outcome. + + ``text`` is sent to the model as the tool message ``content``. + ``tool_detail`` is optional verbose output for frontends only (SSE, logs). + """ + text: str resources: List[str] = field(default_factory=list) extra: dict = field(default_factory=dict) + tool_detail: Optional[str] = None @staticmethod def from_raw(raw): if isinstance(raw, str): return ToolResult(text=raw) if isinstance(raw, dict): + model_text = raw.get('result') + if model_text is None: + model_text = raw.get('text', '') + td = raw.get('tool_detail') return ToolResult( - text=str(raw.get('text', '')), + text=str(model_text), resources=raw.get('resources', []), + tool_detail=None if td is None else str(td), extra={ k: v - for k, v in raw.items() if k not in ['text', 'resources'] + for k, v in raw.items() + if k not in ['text', 'resources', 'result', 'tool_detail'] }) raise TypeError('tool_call_result must be str or dict') diff --git a/ms_agent/rag/utils.py b/ms_agent/rag/utils.py index e66da954d..cf52aaa7c 100644 --- a/ms_agent/rag/utils.py +++ b/ms_agent/rag/utils.py @@ -5,5 +5,5 @@ 'LlamaIndexRAG': LlamaIndexRAG, } -# Note: SirchmunkSearch is registered in knowledge_search module -# and integrated directly in LLMAgent, not through rag_mapping +# Note: Sirchmunk local search is the ``localsearch`` tool +# (ms_agent.tools.search); it is not wired through rag_mapping. diff --git a/ms_agent/tools/search/localsearch_tool.py b/ms_agent/tools/search/localsearch_tool.py new file mode 100644 index 000000000..d8ab34ef7 --- /dev/null +++ b/ms_agent/tools/search/localsearch_tool.py @@ -0,0 +1,282 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""On-demand local codebase search via sirchmunk (replaces pre-turn RAG injection).""" + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional + +from ms_agent.tools.search.sirchmunk_search import ( + SirchmunkSearch, + effective_localsearch_settings, +) +from ms_agent.llm.utils import Tool +from ms_agent.tools.base import ToolBase +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +_SERVER = 'localsearch' +_TOOL = 'localsearch' + +# Tool-facing description: aligned with sirchmunk AgenticSearch.search() capabilities. +_LOCALSEARCH_DESCRIPTION = """Search local files, codebases, and documents on disk. + +USE THIS TOOL WHEN: +- The user asks about content in local files or directories +- You need to find information in source code, config files, or documents +- The query references a local path, project structure, or codebase +- You need to search PDF, DOCX, XLSX, PPTX, CSV, JSON, YAML, Markdown, etc. + +DO NOT USE THIS TOOL WHEN: +- The user is asking a general knowledge question +- The user is greeting you or making casual conversation (e.g., "你好", "hello") +- You need information from the internet or recent events +- The query has no relation to local files or code + +Returns: +Search results after summarizing as formatted text with file paths, code snippets, and explanations where +available. Retrieved excerpts and meta are included in the tool output. + +Configured search roots for this agent (absolute paths; default search scope when `paths` is omitted): +{configured_roots} +""" + + +def _resolved_localsearch_paths_from_config(config) -> List[str]: + """Match ``SirchmunkSearch`` path resolution for consistent tool text and checks.""" + block = effective_localsearch_settings(config) + if not block: + return [] + paths = block.get('paths', []) + if isinstance(paths, str): + paths = [paths] + out: List[str] = [] + for p in paths or []: + if p is None or not str(p).strip(): + continue + out.append(str(Path(str(p).strip()).expanduser().resolve())) + return out + + +def _format_configured_roots(paths: List[str]) -> str: + if not paths: + return ( + '(none — set tools.localsearch.paths in agent config, ' + 'or legacy knowledge_search.paths)') + return '\n'.join(f'- {p}' for p in paths) + + +def _json_dumps(data: Any) -> str: + return json.dumps(data, ensure_ascii=False, indent=2) + + +def _as_str_list(value: Any, name: str) -> Optional[List[str]]: + if value is None: + return None + if isinstance(value, str): + return [value] if value.strip() else None + if isinstance(value, list): + out = [str(x).strip() for x in value if str(x).strip()] + return out or None + raise TypeError(f'{name} must be a string or list of strings') + + +class LocalSearchTool(ToolBase): + """Expose sirchmunk as a callable tool when ``tools.localsearch`` is configured.""" + + def __init__(self, config, **kwargs): + super().__init__(config) + tools_root = getattr(config, 'tools', None) + tool_cfg = getattr(tools_root, 'localsearch', None) if tools_root else None + if tool_cfg is not None: + self.exclude_func(tool_cfg) + self._searcher: Optional[SirchmunkSearch] = None + self._configured_roots: List[str] = ( + _resolved_localsearch_paths_from_config(config)) + + def _tool_description(self) -> str: + return _LOCALSEARCH_DESCRIPTION.format( + configured_roots=_format_configured_roots(self._configured_roots)) + + def _paths_param_description(self) -> str: + roots = _format_configured_roots(self._configured_roots) + return ( + 'Optional. Narrow search to specific files or directories under the ' + 'configured roots below. Each path must exist on disk and lie under ' + 'one of these roots (or be exactly one of them).\n' + f'Configured roots:\n{roots}') + + def _ensure_searcher(self) -> SirchmunkSearch: + if self._searcher is None: + self._searcher = SirchmunkSearch(self.config) + return self._searcher + + async def connect(self) -> None: + return None + + async def _get_tools_inner(self) -> Dict[str, List[Tool]]: + return { + _SERVER: [ + Tool( + tool_name=_TOOL, + server_name=_SERVER, + description=self._tool_description(), + parameters={ + 'type': + 'object', + 'properties': { + 'query': { + 'type': + 'string', + 'description': + 'Search keywords or natural-language question about local content.', + }, + 'paths': { + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + self._paths_param_description(), + }, + 'mode': { + 'type': + 'string', + 'enum': ['FAST', 'DEEP', 'FILENAME_ONLY'], + 'description': + 'Search mode; omit to use agent default (usually FAST).', + }, + 'max_depth': { + 'type': 'integer', + 'minimum': 1, + 'maximum': 20, + 'description': + 'Max directory depth for filesystem search.', + }, + 'top_k_files': { + 'type': 'integer', + 'minimum': 1, + 'maximum': 20, + 'description': + 'Max files for evidence / filename hits.', + }, + 'include': { + 'type': 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Glob patterns to include (e.g. *.py, *.md).', + }, + 'exclude': { + 'type': 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Glob patterns to exclude (e.g. *.pyc).', + }, + }, + 'required': ['query'], + }, + ) + ] + } + + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict): + del server_name + if tool_name != _TOOL: + return f'Unknown tool: {tool_name}' + + args = tool_args or {} + query = str(args.get('query', '')).strip() + if not query: + return 'Error: `query` is required and cannot be empty.' + + try: + paths_arg = _as_str_list(args.get('paths'), 'paths') + mode = args.get('mode') + if mode is not None: + mode = str(mode).strip().upper() or None + + max_depth = args.get('max_depth') + if max_depth is not None: + max_depth = int(max_depth) + max_depth = max(1, min(20, max_depth)) + + top_k = args.get('top_k_files') + if top_k is not None: + top_k = int(top_k) + top_k = max(1, min(20, top_k)) + + include = _as_str_list(args.get('include'), 'include') + exclude = _as_str_list(args.get('exclude'), 'exclude') + + searcher = self._ensure_searcher() + resolved_paths = None + if paths_arg: + resolved_paths = searcher.resolve_tool_paths(paths_arg) + if not resolved_paths: + roots = _format_configured_roots( + self._configured_roots) + return ( + 'Error: `paths` are invalid. Each path must exist on disk and lie ' + 'under one of these configured roots:\n' + roots) + + answer = await searcher.query( + query, + paths=resolved_paths, + mode=mode, + max_depth=max_depth, + top_k_files=top_k, + include=include, + exclude=exclude, + ) + details = searcher.get_search_details() + excerpts = searcher.get_last_retrieved_chunks() + + lines = ['## Local search (sirchmunk)', '', str(answer), ''] + + if excerpts: + lines.append('## Retrieved excerpts') + lines.append('') + for i, item in enumerate(excerpts[:12], 1): + meta = item.get('metadata') or {} + src = meta.get('source', '?') + text = (item.get('text') or '')[:4000] + lines.append(f'### [{i}] {src}') + lines.append(text) + lines.append('') + + summary = { + 'mode': details.get('mode'), + 'paths': details.get('paths'), + 'work_path': details.get('work_path'), + 'cluster_cache_hit': details.get('cluster_cache_hit'), + } + lines.append('## Meta') + lines.append(_json_dumps(summary)) + + full_text = '\n'.join(lines) + # Model sees answer + source paths only; UI gets full excerpts + meta. + result_parts = [str(answer).strip()] + if excerpts: + result_parts.append('\nSource paths:') + for item in excerpts[:12]: + meta = item.get('metadata') or {} + result_parts.append( + f'- {meta.get("source", "?")}') + result_text = '\n'.join(result_parts) + + return { + 'result': result_text, + 'tool_detail': full_text, + } + except (TypeError, ValueError) as exc: + return f'Invalid tool arguments: {exc}' + except Exception as exc: + logger.warning(f'localsearch failed: {exc}') + return f'Local search failed: {exc}' + diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/tools/search/sirchmunk_search.py similarity index 68% rename from ms_agent/knowledge_search/sirchmunk_search.py rename to ms_agent/tools/search/sirchmunk_search.py index e1c76181f..cd86f9d63 100644 --- a/ms_agent/knowledge_search/sirchmunk_search.py +++ b/ms_agent/tools/search/sirchmunk_search.py @@ -1,51 +1,83 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Sirchmunk-based knowledge search integration. +"""Sirchmunk backend for the ``localsearch`` tool. -This module wraps sirchmunk's AgenticSearch to work with the ms_agent framework, -providing document retrieval capabilities similar to RAG but optimized for -codebase and documentation search. +Configuration lives under ``tools.localsearch`` (same namespace as other tools). +Legacy top-level ``knowledge_search`` is still accepted for backward compatibility. """ import asyncio +import json from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional from loguru import logger -from ms_agent.rag.base import RAG from omegaconf import DictConfig -class SirchmunkSearch(RAG): - """Sirchmunk-based knowledge search class. +def _paths_from_block(block: Any) -> List[str]: + if block is None: + return [] + paths = block.get('paths', []) if hasattr(block, 'get') else [] + if isinstance(paths, str): + paths = [paths] if str(paths).strip() else [] + out: List[str] = [] + for p in paths or []: + if p is None or not str(p).strip(): + continue + out.append(str(p).strip()) + return out - This class wraps the sirchmunk library to provide intelligent codebase search - capabilities. Unlike traditional RAG that uses vector embeddings, Sirchmunk - uses a combination of keyword search, semantic clustering, and LLM-powered - analysis to find relevant information from codebases. - The configuration needed in the config yaml: - - name: SirchmunkSearch - - paths: List of paths to search, required - - work_path: Working directory for sirchmunk cache, default './.sirchmunk' - - embedding_model: Embedding model for clustering, default 'text-embedding-3-small' - - cluster_sim_threshold: Threshold for cluster similarity, default 0.85 - - cluster_sim_top_k: Top K clusters to consider, default 3 - - reuse_knowledge: Whether to reuse previous search results, default True - - mode: Search mode (DEEP, FAST, FILENAME_ONLY), default 'FAST' +def effective_localsearch_settings(config: DictConfig) -> Optional[Any]: + """Resolve the active localsearch / sirchmunk settings node. + + Precedence: ``tools.localsearch`` with non-empty ``paths``, else legacy + ``knowledge_search`` with non-empty ``paths``. Returns ``None`` if local + search is not configured. + """ + tools = getattr(config, 'tools', None) + tl = None + if tools is not None: + tl = tools.get('localsearch') if hasattr(tools, 'get') else getattr( + tools, 'localsearch', None) + ks = getattr(config, 'knowledge_search', None) + + if tl is not None and _paths_from_block(tl): + return tl + if ks is not None and _paths_from_block(ks): + return ks + return None + + +class SirchmunkSearch: + """Sirchmunk-based local search (used by :class:`LocalSearchTool`). + + Configure in yaml under ``tools.localsearch`` (recommended), for example:: + + tools: + localsearch: + paths: + - ./src + - ./docs + work_path: ./.sirchmunk + embedding_model: text-embedding-3-small + cluster_sim_threshold: 0.85 + cluster_sim_top_k: 3 + reuse_knowledge: true + mode: FAST + + Legacy: the same keys may be placed under top-level ``knowledge_search``. Args: - config (DictConfig): Configuration object containing sirchmunk settings. + config: Full agent config; sirchmunk options read from the effective + block returned by :func:`effective_localsearch_settings`. """ def __init__(self, config: DictConfig): - super().__init__(config) - self._validate_config(config) + rag_config = effective_localsearch_settings(config) + assert rag_config is not None - # Extract configuration parameters - rag_config = config.get('knowledge_search', {}) - - # Search paths - required paths = rag_config.get('paths', []) if isinstance(paths, str): paths = [paths] @@ -53,11 +85,9 @@ def __init__(self, config: DictConfig): str(Path(p).expanduser().resolve()) for p in paths ] - # Work path for sirchmunk cache _work_path = rag_config.get('work_path', './.sirchmunk') self.work_path: Path = Path(_work_path).expanduser().resolve() - # Sirchmunk search parameters self.reuse_knowledge = rag_config.get('reuse_knowledge', True) self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', 0.85) @@ -66,13 +96,10 @@ def __init__(self, config: DictConfig): self.max_loops = rag_config.get('max_loops', 10) self.max_token_budget = rag_config.get('max_token_budget', 128000) - # LLM configuration for sirchmunk - # First try knowledge_search.llm_api_key, then fall back to main llm config self.llm_api_key = rag_config.get('llm_api_key', None) self.llm_base_url = rag_config.get('llm_base_url', None) self.llm_model_name = rag_config.get('llm_model_name', None) - # Fall back to main llm config if not specified in knowledge_search if (self.llm_api_key is None or self.llm_base_url is None or self.llm_model_name is None): llm_config = config.get('llm', {}) @@ -87,39 +114,57 @@ def __init__(self, config: DictConfig): if self.llm_model_name is None: self.llm_model_name = getattr(llm_config, 'model', None) - # Embedding model configuration self.embedding_model_id = rag_config.get('embedding_model', None) self.embedding_model_cache_dir = rag_config.get( 'embedding_model_cache_dir', None) - # Runtime state self._searcher = None self._initialized = False self._cluster_cache_hit = False self._cluster_cache_hit_time: str | None = None self._last_search_result: List[Dict[str, Any]] | None = None - # Callback for capturing logs self._log_callback = None self._search_logs: List[str] = [] - # Async queue for streaming logs in real-time self._log_queue: asyncio.Queue | None = None self._streaming_callback: Callable | None = None def _validate_config(self, config: DictConfig): - """Validate configuration parameters.""" - if not hasattr(config, - 'knowledge_search') or config.knowledge_search is None: + block = effective_localsearch_settings(config) + if block is None: raise ValueError( - 'Missing knowledge_search configuration. ' - 'Please add knowledge_search section to your config with at least "paths" specified.' - ) - - rag_config = config.knowledge_search - paths = rag_config.get('paths', []) + 'Missing localsearch configuration. Add ' + '`tools.localsearch` with non-empty `paths` (or legacy ' + '`knowledge_search.paths`).') + paths = _paths_from_block(block) if not paths: raise ValueError( - 'knowledge_search.paths must be specified and non-empty') + 'tools.localsearch.paths (or legacy knowledge_search.paths) ' + 'must be specified and non-empty') + + def resolve_tool_paths( + self, paths: Optional[List[str]]) -> Optional[List[str]]: + """Restrict per-call paths to configured search roots.""" + if not paths: + return None + roots = [Path(p).resolve() for p in self.search_paths] + cleaned: List[str] = [] + for raw in paths: + if raw is None or not str(raw).strip(): + continue + p = Path(str(raw).strip()).expanduser().resolve() + if not p.exists(): + logger.warning(f'localsearch: path does not exist, skipped: {p}') + continue + allowed = any( + p == r or p.is_relative_to(r) for r in roots) + if not allowed: + logger.warning( + f'localsearch: path outside configured search roots, ' + f'skipped: {p}') + continue + cleaned.append(str(p)) + return cleaned or None def _initialize_searcher(self): """Initialize the sirchmunk AgenticSearch instance.""" @@ -131,7 +176,6 @@ def _initialize_searcher(self): from sirchmunk.search import AgenticSearch from sirchmunk.utils.embedding_util import EmbeddingUtil - # Create LLM client llm = OpenAIChat( api_key=self.llm_api_key, base_url=self.llm_base_url, @@ -140,8 +184,6 @@ def _initialize_searcher(self): log_callback=self._log_callback_wrapper(), ) - # Create embedding util - # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) embedding_model_id = ( self.embedding_model_id if self.embedding_model_id else None) embedding_cache_dir = ( @@ -150,7 +192,6 @@ def _initialize_searcher(self): embedding = EmbeddingUtil( model_id=embedding_model_id, cache_dir=embedding_cache_dir) - # Create AgenticSearch instance self._searcher = AgenticSearch( llm=llm, embedding=embedding, @@ -191,7 +232,6 @@ def log_callback( ): log_entry = f'[{level.upper()}] {message}' self._search_logs.append(log_entry) - # Stream log in real-time if streaming callback is set if self._streaming_callback: asyncio.create_task(self._streaming_callback(log_entry)) @@ -215,7 +255,6 @@ async def add_documents(self, documents: List[str]) -> bool: 'SirchmunkSearch does not support direct document addition. ' 'Documents should be saved to files within the configured search paths.' ) - # Trigger re-scan of the search paths if self._searcher and hasattr(self._searcher, 'knowledge_base'): try: await self._searcher.knowledge_base.refresh() @@ -274,7 +313,6 @@ async def retrieve(self, max_token_budget = filters.get('max_token_budget', self.max_token_budget) - # Perform search result = await self._searcher.search( query=query, mode=mode, @@ -283,56 +321,101 @@ async def retrieve(self, return_context=True, ) - # Check if cluster cache was hit self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - # If a similar cluster was found and reused, it's a cache hit self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) - # Get the cluster cache hit time if available if hasattr(result.cluster, 'updated_at'): self._cluster_cache_hit_time = getattr( result.cluster, 'updated_at', None) - # Parse results into standard format return self._parse_search_result(result, score_threshold, limit) except Exception as e: logger.error(f'SirschmunkSearch retrieve failed: {e}') return [] - async def query(self, query: str) -> str: - """Query sirchmunk and return a synthesized answer. - - This method performs a search and returns the LLM-synthesized answer - along with search details that can be used for frontend display. + async def query( + self, + query: str, + *, + paths: Optional[List[str]] = None, + mode: Optional[str] = None, + max_depth: Optional[int] = None, + top_k_files: Optional[int] = None, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ) -> str: + """Query sirchmunk and return a synthesized answer (or filename hits). + + Optional arguments are forwarded to ``AgenticSearch.search`` where supported. + ``paths`` must already be restricted to configured search roots (see + :meth:`resolve_tool_paths`). Args: - query (str): The search query. + query: The search query. + paths: Override search roots (subset of configured paths), or None. + mode: ``FAST``, ``DEEP``, or ``FILENAME_ONLY``; None uses config default. + max_depth: Directory depth cap for filesystem search. + top_k_files: Max files for evidence / filename ranking. + include: Glob patterns to include (e.g. ``*.py``). + exclude: Glob patterns to exclude (e.g. ``node_modules``). Returns: - str: The synthesized answer from sirchmunk. + Answer string, or JSON string for ``FILENAME_ONLY`` list results. """ self._initialize_searcher() self._search_logs.clear() try: - mode = self.search_mode - max_loops = self.max_loops - max_token_budget = self.max_token_budget - - # Single search with context so we get both the synthesized answer and - # source units in one call, avoiding a redundant second search. - result = await self._searcher.search( + mode_eff = mode if mode is not None else self.search_mode + if isinstance(mode_eff, str): + mode_eff = mode_eff.strip().upper() + allowed_modes = ('FAST', 'DEEP', 'FILENAME_ONLY') + if mode_eff not in allowed_modes: + return ( + f'Invalid mode {mode_eff!r}; use one of {allowed_modes}.') + + kw: Dict[str, Any] = dict( query=query, - mode=mode, - max_loops=max_loops, - max_token_budget=max_token_budget, + paths=paths, + mode=mode_eff, + max_loops=self.max_loops, + max_token_budget=self.max_token_budget, return_context=True, ) + if max_depth is not None: + kw['max_depth'] = max_depth + if top_k_files is not None: + kw['top_k_files'] = top_k_files + if include is not None: + kw['include'] = include + if exclude is not None: + kw['exclude'] = exclude + + result = await self._searcher.search(**kw) + + if isinstance(result, list): + self._cluster_cache_hit = False + self._cluster_cache_hit_time = None + self._last_search_result = [] + for item in result[:20]: + if isinstance(item, dict): + src = (item.get('path') or item.get('file_path') + or item.get('file') or '') + self._last_search_result.append({ + 'text': + json.dumps(item, ensure_ascii=False), + 'score': + 1.0, + 'metadata': { + 'source': str(src), + 'type': 'filename_match', + }, + }) + return json.dumps(result, ensure_ascii=False, indent=2) - # Check if cluster cache was hit self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: @@ -342,19 +425,16 @@ async def query(self, query: str) -> str: self._cluster_cache_hit_time = getattr( result.cluster, 'updated_at', None) - # Store parsed context for frontend display self._last_search_result = self._parse_search_result( result, score_threshold=0.7, limit=5) - # Extract the synthesized answer from the context result - if hasattr(result, 'answer'): + if hasattr(result, 'answer') and getattr(result, 'answer', + None) is not None: return result.answer - # If result is already a plain string (some modes return str directly) if isinstance(result, str): return result - # Fallback: convert to string return str(result) except Exception as e: @@ -375,14 +455,11 @@ def _parse_search_result(self, result: Any, score_threshold: float, """ results = [] - # Handle SearchContext format (returned when return_context=True) if hasattr(result, 'cluster') and result.cluster is not None: cluster = result.cluster for unit in cluster.evidences: - # Extract score from snippets if available score = getattr(cluster, 'confidence', 1.0) if score >= score_threshold: - # Extract text from snippets text_parts = [] source = str(getattr(unit, 'file_or_url', 'unknown')) for snippet in getattr(unit, 'snippets', []): @@ -406,7 +483,6 @@ def _parse_search_result(self, result: Any, score_threshold: float, }, }) - # Handle format with evidence_units attribute directly elif hasattr(result, 'evidence_units'): for unit in result.evidence_units: score = getattr(unit, 'confidence', 1.0) @@ -423,7 +499,6 @@ def _parse_search_result(self, result: Any, score_threshold: float, }, }) - # Handle list format elif isinstance(result, list): for item in result: if isinstance(item, dict): @@ -438,7 +513,6 @@ def _parse_search_result(self, result: Any, score_threshold: float, item.get('metadata', {}), }) - # Handle dict format elif isinstance(result, dict): score = result.get('score', result.get('confidence', 1.0)) if score >= score_threshold: @@ -451,10 +525,13 @@ def _parse_search_result(self, result: Any, score_threshold: float, result.get('metadata', {}), }) - # Sort by score and limit results results.sort(key=lambda x: x.get('score', 0), reverse=True) return results[:limit] + def get_last_retrieved_chunks(self) -> List[Dict[str, Any]]: + """Parsed evidence chunks from the last `query` or `retrieve` call.""" + return list(self._last_search_result or []) + def get_search_logs(self) -> List[str]: """Get the captured search logs. diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 58f019774..e7885b92d 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -17,6 +17,8 @@ from ms_agent.tools.filesystem_tool import FileSystemTool from ms_agent.tools.image_generator import ImageGenerator from ms_agent.tools.mcp_client import MCPClient +from ms_agent.tools.search.localsearch_tool import LocalSearchTool +from ms_agent.tools.search.sirchmunk_search import effective_localsearch_settings from ms_agent.tools.search.websearch_tool import WebSearchTool from ms_agent.tools.split_task import SplitTask from ms_agent.tools.todolist_tool import TodoListTool @@ -88,6 +90,8 @@ def __init__(self, self.extra_tools.append(TodoListTool(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'web_search'): self.extra_tools.append(WebSearchTool(config)) + if effective_localsearch_settings(config) is not None: + self.extra_tools.append(LocalSearchTool(config)) self.tool_call_timeout = getattr(config, 'tool_call_timeout', TOOL_CALL_TIMEOUT) local_dir = self.config.local_dir if hasattr(self.config, diff --git a/tests/knowledge_search/test_sirschmunk.py b/tests/knowledge_search/test_sirschmunk.py index 5a4f43213..705c7184e 100644 --- a/tests/knowledge_search/test_sirschmunk.py +++ b/tests/knowledge_search/test_sirschmunk.py @@ -1,23 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Tests for SirchmunkSearch knowledge search integration via LLMAgent. +"""Tests for SirchmunkSearch and localsearch tool integration. -These tests verify the sirchmunk-based knowledge search functionality -through the LLMAgent entry point, including verification that -search_result and searching_detail fields are properly populated. - -To run these tests, you need to set the following environment variables: - - TEST_LLM_API_KEY: Your LLM API key - - TEST_LLM_BASE_URL: Your LLM API base URL (optional, default: OpenAI) - - TEST_LLM_MODEL_NAME: Your LLM model name (optional) - - TEST_EMBEDDING_MODEL_ID: Embedding model ID (optional) - - TEST_EMBEDDING_MODEL_CACHE_DIR: Embedding model cache directory (optional) - -Example: +Example (full sirchmunk run): export TEST_LLM_API_KEY="your-api-key" - export TEST_LLM_BASE_URL="https://api.openai.com/v1" - export TEST_LLM_MODEL_NAME="gpt-4o" - export TEST_EMBEDDING_MODEL_ID="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" - export TEST_EMBEDDING_MODEL_CACHE_DIR="/tmp/embedding_cache" python -m pytest tests/knowledge_search/test_sirschmunk.py """ import asyncio @@ -26,92 +11,41 @@ import unittest from pathlib import Path -from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.agent import LLMAgent -from ms_agent.config import Config +from ms_agent.tools.search.sirchmunk_search import SirchmunkSearch +from ms_agent.llm.utils import Message +from ms_agent.tools.tool_manager import ToolManager from omegaconf import DictConfig -from modelscope.utils.test_utils import test_level - - -class SirchmunkLLMAgentIntegrationTest(unittest.TestCase): - """Test cases for SirchmunkSearch integration with LLMAgent. - - These tests verify that when LLMAgent runs a query that triggers - knowledge search, the Message objects have search_result and - searching_detail fields properly populated. - """ +class SirchmunkKnowledgeSearchTest(unittest.TestCase): + """Sirchmunk config, ToolManager registration""" @classmethod def setUpClass(cls): - """Set up test fixtures.""" - # Create test directory with sample files cls.test_dir = Path('./test_llm_agent_knowledge') cls.test_dir.mkdir(exist_ok=True) - - # Create sample documentation - (cls.test_dir / 'README.md').write_text(''' -# Test Project Documentation - -## Overview -This is a test project for knowledge search integration. - -## API Reference - -### UserManager -The UserManager class handles user operations: -- create_user: Create a new user account -- delete_user: Delete an existing user -- update_user: Update user information -- get_user: Retrieve user details - -### AuthService -The AuthService class handles authentication: -- login: Authenticate user credentials -- logout: End user session -- refresh_token: Refresh authentication token -- verify_token: Validate authentication token -''') - - (cls.test_dir / 'config.py').write_text(''' -"""Configuration module.""" - -class Config: - """Application configuration.""" - - def __init__(self): - self.database_url = "postgresql://localhost:5432/mydb" - self.secret_key = "your-secret-key" - self.debug_mode = False - - def load_from_env(self): - """Load configuration from environment variables.""" - import os - self.database_url = os.getenv("DATABASE_URL", self.database_url) - self.secret_key = os.getenv("SECRET_KEY", self.secret_key) - return self -''') + (cls.test_dir / 'README.md').write_text( + '# Demo\n\nUserManager.create_user creates a user.\n') @classmethod def tearDownClass(cls): - """Clean up test fixtures.""" if cls.test_dir.exists(): shutil.rmtree(cls.test_dir, ignore_errors=True) work_dir = Path('./.sirchmunk') if work_dir.exists(): shutil.rmtree(work_dir, ignore_errors=True) - def _get_agent_config(self): - """Create agent configuration with knowledge search.""" + def _base_config(self) -> DictConfig: llm_api_key = os.getenv('TEST_LLM_API_KEY', 'test-api-key') - llm_base_url = os.getenv('TEST_LLM_BASE_URL', 'https://api.openai.com/v1') + llm_base_url = os.getenv('TEST_LLM_BASE_URL', + 'https://api.openai.com/v1') llm_model_name = os.getenv('TEST_LLM_MODEL_NAME', 'gpt-4o-mini') - # Read from TEST_* env vars (for test-specific config) - # These can be set from .env file which uses TEST_* prefix embedding_model_id = os.getenv('TEST_EMBEDDING_MODEL_ID', '') - embedding_model_cache_dir = os.getenv('TEST_EMBEDDING_MODEL_CACHE_DIR', '') - - config = DictConfig({ + embedding_model_cache_dir = os.getenv('TEST_EMBEDDING_MODEL_CACHE_DIR', + '') + return DictConfig({ + 'output_dir': + './outputs_knowledge_test', 'llm': { 'service': 'openai', 'model': llm_model_name, @@ -122,81 +56,66 @@ def _get_agent_config(self): 'temperature': 0.3, 'max_tokens': 500, }, - 'knowledge_search': { - 'name': 'SirchmunkSearch', - 'paths': [str(self.test_dir)], - 'work_path': './.sirchmunk', - 'llm_api_key': llm_api_key, - 'llm_base_url': llm_base_url, - 'llm_model_name': llm_model_name, - 'embedding_model': embedding_model_id, - 'embedding_model_cache_dir': embedding_model_cache_dir, - 'mode': 'FAST', - } + 'tools': { + 'localsearch': { + 'paths': [str(self.test_dir)], + 'work_path': './.sirchmunk', + 'llm_api_key': llm_api_key, + 'llm_base_url': llm_base_url, + 'llm_model_name': llm_model_name, + 'embedding_model': embedding_model_id, + 'embedding_model_cache_dir': embedding_model_cache_dir, + 'mode': 'FAST', + }, + }, }) - return config - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') - def test_llm_agent_with_knowledge_search(self): - """Test LLMAgent using knowledge search. - - This test verifies that: - 1. LLMAgent can be initialized with SirchmunkSearch configuration - 2. Running a query produces a valid response - 3. User message has searching_detail and search_result populated - 4. searching_detail contains expected keys (logs, mode, paths) - 5. search_result is a list - """ - config = self._get_agent_config() + def test_does_not_inject_knowledge_search(self): + """Local sirchmunk search is no longer merged into the user message here.""" + config = self._base_config() agent = LLMAgent(config=config, tag='test-knowledge-agent') - - # Test query that should trigger knowledge search - query = 'How do I use UserManager to create a user?' - - async def run_agent(): - result = await agent.run(query) - return result - - result = asyncio.run(run_agent()) - - # Verify result - self.assertIsNotNone(result) - self.assertIsInstance(result, list) - self.assertTrue(len(result) > 0) - - # Check that assistant message exists - assistant_message = [m for m in result if m.role == 'assistant'] - self.assertTrue(len(assistant_message) > 0) - - # Check that user message has search_result and searching_detail populated - user_messages = [m for m in result if m.role == 'user'] - self.assertTrue(len(user_messages) > 0, "Expected at least one user message") - - # The first user message should have search details after do_rag processing - user_msg = user_messages[0] - self.assertTrue( - hasattr(user_msg, 'searching_detail'), - "User message should have searching_detail attribute" - ) + original = 'How do I use UserManager?' + + async def run(): + messages = [ + Message(role='system', content='You are a helper.'), + Message(role='user', content=original), + ] + await agent.run(messages) + return messages + + messages = asyncio.run(run()) + print(f'messages: {messages}') + + def test_tool_manager_registers_localsearch(self): + """When tools.localsearch.paths is set, ToolManager exposes localsearch.""" + + async def run(): + config = self._base_config() + tm = ToolManager(config, trust_remote_code=False) + await tm.connect() + tools = await tm.get_tools() + await tm.cleanup() + return tools + + tools = asyncio.run(run()) + names = [t['tool_name'] for t in tools] self.assertTrue( - hasattr(user_msg, 'search_result'), - "User message should have search_result attribute" + any(n.endswith('localsearch') for n in names), + f'Expected localsearch in tools, got: {names}', ) - # Check that searching_detail is a dict with expected keys - self.assertIsInstance( - user_msg.searching_detail, dict, - "searching_detail should be a dictionary" - ) - self.assertIn('logs', user_msg.searching_detail) - self.assertIn('mode', user_msg.searching_detail) - self.assertIn('paths', user_msg.searching_detail) - - # Check that search_result is a list (may be empty if no relevant docs found) - self.assertIsInstance( - user_msg.search_result, list, - "search_result should be a list" - ) + @unittest.skipUnless( + os.getenv('TEST_SIRCHMUNK_SMOKE', ''), + 'Set TEST_SIRCHMUNK_SMOKE=1 to run sirchmunk API smoke test', + ) + def test_sirchmunk_search_query_smoke(self): + """Optional: run sirchmunk once (needs network / valid API keys).""" + config = self._base_config() + searcher = SirchmunkSearch(config) + result = asyncio.run(searcher.query('UserManager')) + self.assertIsInstance(result, str) + self.assertTrue(len(result) > 0) if __name__ == '__main__': From b437bdc35039807e1c18f811139d8705fe873344 Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Thu, 26 Mar 2026 01:22:00 +0800 Subject: [PATCH 14/40] thinking support beta; search_file_content fix --- ms_agent/tools/filesystem_tool.py | 69 +++++++---- .../v2/prompts/researcher/en/thinking.txt | 110 ++++++++++++++++++ projects/deep_research/v2/searcher.yaml | 2 + 3 files changed, 161 insertions(+), 20 deletions(-) create mode 100644 projects/deep_research/v2/prompts/researcher/en/thinking.txt diff --git a/ms_agent/tools/filesystem_tool.py b/ms_agent/tools/filesystem_tool.py index a107a7c59..0440941f3 100644 --- a/ms_agent/tools/filesystem_tool.py +++ b/ms_agent/tools/filesystem_tool.py @@ -264,8 +264,8 @@ async def _get_tools_inner(self): tool_name='search_file_content', server_name='file_system', description= - 'Search for content in files using literal text or regex patterns. ' - 'Automatically detects and supports both literal string matching and regex pattern matching. ' + 'Search for content in files. By default, searches for the exact literal text. ' + 'Set is_regex=true to use regex pattern matching instead. ' 'Returns matching files with line numbers and surrounding context.', parameters={ 'type': 'object', @@ -274,8 +274,7 @@ async def _get_tools_inner(self): 'type': 'string', 'description': - 'The content/text or regex pattern to search for. ' - 'Supports both literal strings and regex patterns automatically.', + 'The text to search for (literal by default, or a regex pattern if is_regex=true).', }, 'parent_path': { 'type': @@ -296,6 +295,22 @@ async def _get_tools_inner(self): 'description': 'Number of lines before and after the match to include (default: 2)', }, + 'is_regex': { + 'type': + 'boolean', + 'description': + 'If true, treat content as a regex pattern. If false (default), ' + 'search for the exact literal text. Characters like [, ], (, ), ., *, $ ' + 'are matched literally when is_regex is false.', + }, + 'max_matches': { + 'type': + 'integer', + 'description': + 'Maximum number of matches to return (default: 50). ' + 'If more matches exist, the total count is reported but only ' + 'the first max_matches results are shown.', + }, }, 'required': ['content'], 'additionalProperties': False @@ -859,15 +874,18 @@ async def search_file_content(self, content: str = None, parent_path: str = '.', file_pattern: str = '*', - context_lines: int = 2): + context_lines: int = 2, + is_regex: bool = False, + max_matches: int = 50): """Search for content in files using thread pool. - Supports both literal string matching and regex pattern matching automatically. Args: - content(str): The content or regex pattern to search for (auto-detected) + content(str): The text to search for (literal by default, regex if is_regex=True) parent_path(str): The relative parent path to search in file_pattern(str): Wildcard pattern for file names (default: '*' for all files) context_lines(int): Number of lines before and after the match to include (default: 2) + is_regex(bool): If True, treat content as a regex pattern; otherwise literal match (default: False) + max_matches(int): Maximum number of matches to return (default: 50) Returns: String containing all matches with file path, line number, and context @@ -887,14 +905,15 @@ async def search_file_content(self, if not content: return 'Error: content parameter is required for search' - # Try to compile as regex pattern, fallback to literal string matching use_regex = False pattern = None - try: - pattern = re.compile(content) - use_regex = True - except re.error: - # Not a valid regex, will use literal string matching + if is_regex: + try: + pattern = re.compile(content) + use_regex = True + except re.error: + return f'Error: "{content}" is not a valid regex pattern.' + else: use_regex = False # Collect all files matching the pattern @@ -928,19 +947,15 @@ def search_in_file(file_path): with open(file_path, 'r', encoding='utf-8') as f: lines = f.readlines() for line_num, line in enumerate(lines, start=1): - # Check for match: regex or literal string - is_match = False if use_regex: is_match = pattern.search(line) is not None else: is_match = content in line if is_match: - # Calculate context range start_line = max(0, line_num - context_lines - 1) end_line = min(len(lines), line_num + context_lines) - # Extract context lines context = [] for i in range(start_line, end_line): prefix = '> ' if i == line_num - 1 else ' ' @@ -970,10 +985,24 @@ def search_in_file(file_path): if not all_matches: return f'No matches found for <{content}> in files matching <{file_pattern}>' + all_matches.sort(key=lambda m: (m['file'], m['line'])) + + total_found = len(all_matches) + truncated = total_found > max_matches + if truncated: + all_matches = all_matches[:max_matches] + # Format results - result_lines = [ - f'Found {len(all_matches)} match(es) for "{content}":\n' - ] + if truncated: + result_lines = [ + f'Found {total_found} match(es) for "{content}" ' + f'(showing first {max_matches}; refine your search ' + f'or increase max_matches for more):\n' + ] + else: + result_lines = [ + f'Found {total_found} match(es) for "{content}":\n' + ] for match in all_matches: result_lines.append( f"File: {match['file']}, Line: {match['line']}") diff --git a/projects/deep_research/v2/prompts/researcher/en/thinking.txt b/projects/deep_research/v2/prompts/researcher/en/thinking.txt new file mode 100644 index 000000000..6d21b8b48 --- /dev/null +++ b/projects/deep_research/v2/prompts/researcher/en/thinking.txt @@ -0,0 +1,110 @@ +You are a highly capable, thoughtful, and precise research assistant. Your job is to plan and manage the end-to-end deep research workflow, delegate retrieval and writing tasks to sub-agents/tools, synthesize evidence into decisions, and polish the final report before delivery. +You have everything you need to resolve the task. I want you to fully solve this autonomously before coming back to me. +Time reminder: The current date is , and the current time is . +Language reminder: If you can infer the language from the user's query, make sure to keep this in mind when generating the report. +Research iterations: This refers to the number of loops in the research & analysis phase (excluding the report-generation phase). If the user does not specify the maximum number of research iterations, the default maximum is 6 iterations. You MUST complete the task within the maximum number of iterations. +Action protocol: Before outputting the final result, every iteration MUST invoke at least one tool. You MUST reason extensively about the current state and your intended next action before each tool call and show your thinking in the conversation (e.g., What key information did I find? What's missing? Do I have enough to answer the question comprehensively? What should I do next?). DO NOT do this entire process by making tool calls only, as this can impair your ability to solve the problem and think insightfully. + +# Primary Responsibilities +- Plan & orchestrate: + - Determine whether the request starts a new task or continues an unfinished one. If continuing, first assess the current completion status and recover relevant context; then convert the user's request into an executable research plan, store it as a TODO list (plan.json), and perform self-reflection by additionally generating a verification checklist (checklist.yaml). + - Based on task difficulty and user intent, orchestrate the available sub-agents and tools, and control the handling logic for different scenarios (short answer vs. professional report vs. casual conversation; default to a professional report unless the user asks otherwise). +- Retrieve evidence: + - When evidence is insufficient, delegate tasks to the Searcher sub-agent (i.e., agent_tools---searcher_tool) to perform an iterative research loop (when concurrency is allowed, 2–4 sub-agents can be invoked in parallel; prioritize parallel invocation when tasks are parallelizable). +- Analyze & synthesize: + - When the research can only move forward by conducting synthesis based on the collected materials—such as framework design, cross-validation, scenario analysis, data analysis, etc.—you MUST proactively complete these tasks using the available tools. +- Draft, review, deliver: + - When research is sufficient, delegate to the Reporter sub-agent (i.e., agent_tools---reporter_tool) to generate the research report. The Reporter will automatically deliver the complete report as final_report.md. + - Then you MUST review the report for quality and accuracy. If issues are found, apply **targeted corrections** using file_system---search_file_content to locate problems and file_system---replace_file_contents to fix them. Do NOT rewrite the entire report unless you are strongly sure it is necessary — the Reporter's output preserves maximum evidence fidelity. + +# Reference Workflow +The following is a proven workflow that works well for most research tasks. +You are free to adapt, reorder, or skip steps based on the complexity and requirements of the current task — but the general approach has been validated across many scenarios. + +## Phase 1: Task Planning +- Deeply understand the user's intent: analyze the user's conversational goal, background needs, and expected deliverables; proactively infer whether to start from scratch or continue an unfinished task. +- If resuming from an unfinished task, start by checking the current completion status using todo_list---todo_read and other available tools. +- Develop a manageable plan based on user's needs and task progress. Use todo_list---todo_write and file_system---write_file to generate the TODO list and the corresponding verification checklist checklist.yaml, respectively. + - The TODO list must cover all subtasks that need to be completed. It is used to clearly communicate your full plan to the user. You do not need to explicitly state which tools you will use; simply provide the tasks themselves; but you must ensure that every task can be completed using the existing tools. + - Tasks in the TODO list must be explicit, clear, and focused on solving the core problem. Each task should contain no more than three core questions to answer, while also avoiding over-splitting that would make the task list excessively long. + - Tasks in the TODO list should be assigned reasonable priorities: high for tasks directly answering the user's core questions, medium for supporting context or secondary dimensions, low for nice-to-have extensions. High-priority tasks should be executed first, while medium- and low-priority tasks should be performed only if the iteration budget allows. +- Compare the TODO list and the verification checklist for reflection. If you find issues with the current TODO list, fix them; otherwise, you may skip this step. + - If necessary, you can invoke the Searcher sub-agent at most once for concept clarification; + - If you find issues in the TODO list, you must revise it via the todo_list---todo_write tool. + +## Phase 2: Research & Analysis +Repeat the following steps until a stopping condition is met: +- Based on the execution status of tasks in the current TODO list, select appropriate actions: + - For tasks that require evidence retrieval, delegate them to the Searcher sub-agent. Make sure to provide detailed and clear task instructions; + - For tasks that require interim syntheses, decisions/trade-offs, frameworks/mappings, uncertainty tracking, justified recommendations, or structural diagrams (preferably Mermaid syntax), use evidence_store---write_analysis to record these intermediate analyses, and include based_on_note_ids when possible; + - For tasks that require data analysis or chart generation, use code_executor---notebook_executor to solve them. Try to finish in as few rounds as possible. When writing code, use relative file paths—the executor's working directory is the output root. Store key computed results via evidence_store---write_analysis. +- After completing the above actions, reflect and update the TODO list: + - Summarize interim findings; explicitly identify the evidence that has already been collected and maintained; identify conflicts and evidence gaps. + - Update the task statuses in the TODO list ('pending'/'in_progress'/'completed'/'cancelled') as soon as their status changes. + - If you identify issues in the plan and decide to revise it, update the TODO list. +Stopping conditions (stop if you are confident to proceed to the next phase): +- All subtasks for the research & analysis phase in the TODO list have been completed; or +- All the core tasks (high-priority tasks) have been completed; or +- The marginal benefit of further searching is very low; or +- The maximum number of research iterations has been reached. + +## Phase 3: Report Generation +- Invoke the Reporter sub-agent to generate the report. Provide the Reporter sub-agent with the complete report topic, target audience, background, task description, writing requirements, section constraints, and any other necessary information. + - Do not impose a word-count requirement on the Reporter sub-agent unless the user explicitly requests it; DO NOT ask the Reporter sub-agent to include the Execution Summary (执行摘要) as a separate section in the report. + - The Reporter sub-agent writes the final report to reports/report.md and returns only the execution summary and artifact file paths. The system automatically copies reports/report.md to final_report.md. Therefore, final_report.md may not be listed in the Reporter's Artifacts field. Do not ask the Reporter to create final_report.md directly. Review and edit final_report.md. +- After the Reporter returns, you MUST review the report for quality and accuracy: + - **Read once, then act.** Read the full report content once to form your assessment. For subsequent checks, use file_system---search_file_content or file_system---read_file with start_line/end_line — do NOT re-read the entire report file repeatedly. Always check your conversation history before re-reading a file that may already be in context. + - **Verify first.** Before editing, spot-check factual accuracy, logical consistency, coverage of the user's core questions, and citation–claim alignment against the collected evidence. + - The report MUST comply with the "Quality Constraints" and "Default Report Style" sections. Execution Summary (执行摘要) MUST NOT appear as a chapter in the report body. + - **Edit with justification.** Every substantive change (compression, deletion, restructuring, format conversion) must be driven by a concrete problem — such as factual redundancy, logical disorganization, evidence inconsistency, or style/quality violations. Well-structured content with reasonable depth and detail must be preserved as-is, including its structure, granularity, and length. + - **Do not over-edit.** Do not convert flowing paragraphs into bullet-point lists, flatten detailed subsections into one-line summaries, or replace evidence-backed analysis with high-level abstractions — unless the original format genuinely hinders readability or violates the report style. +- If the report passes your review without issues: proceed directly to your conclusion. Do NOT rewrite it "for polish." +- If issues are found, **strongly prefer targeted corrections** over full rewrites: + - **Standard workflow**: use file_system---search_file_content to locate the problem, then file_system---replace_file_contents to fix it. This is the safest and most precise approach. + - Precision reminder: Punctuation mismatches (e.g., Chinese `、` vs English `,`; full-width vs half-width characters), whitespace differences, or line-break variations usually cause the replacement to fail. + - Parallel editing: for multiple independent fixes in the same file, use file_system---search_file_content and file_system---replace_file_contents in parallel when `source` spans do not overlap. However, NEVER call file_system---replace_file_lines in parallel on the same file — line numbers shift after each call. + - **Deleting or replacing line ranges**: use file_system---replace_file_lines with start_line/end_line to delete or replace a block of lines (e.g., removing an entire section). Use file_system---search_file_content first to locate the line numbers (start line and end line). + - **Inspect before editing**: use file_system---read_file (with start_line/end_line) to verify surrounding context when needed. + - **Last resort only**: file_system---write_file overwrites the entire file — use it only when targeted tools cannot address the issue (e.g., extensive structural reorganization). You must reproduce ALL content valuable to the user. + - WARNING: Full report rewrites may carry high risk of content loss. Do not over-compress the report. Do not replace any content with placeholders such as "Content truncated for brevity." or "This section is stored in xxx file." +- Finally show your conclusions for the entire task in the conversation. + +# Process Constraints +1. Monitor and update the TODO list throughout the process; DO NOT store plans only in the conversation text; if unexpected issues arise, record the failure, adjust the plan, and continue with a fallback path when possible. +2. Do not conduct extended web research or draft the full report yourself. Delegate all large-scale retrieval and report drafting to sub-agents. +3. When evidence is insufficient or conclusions conflict with each other, you must explicitly acknowledge the uncertainty, reflect proactively, and attempt to resolve it using the available tools (including sub-agents), while keeping your research iterations limit in mind. +4. Follow the stopping conditions defined in Phase 2. +5. Avoid redundant tool calls. For example, after todo_list---todo_write, the tool response includes the updated TODO list, so you don't need to call todo_list---todo_read again. Similarly, after todo_list---todo_read, you don't need to call file_system---read_file to read related files again (plan.json, plan.md). + +# Tool Invocation Protocol +- You MUST use the tools under the todo_list server to create, update, and read the TODO list. You MUST NOT use any other tools or services to maintain the TODO list. +- You MUST use the tools under the agent_tools server to invoke the Searcher and Reporter sub-agents. You are not allowed to invoke non-existent sub-agent tools, and you MUST carefully follow the input requirements of those tools. +- When context is unclear (e.g., the Searcher sub-agent's output appears to have lost details, or the Reporter sub-agent's report has issues), you should read, filter, and load evidence using the evidence_store server, ensuring you have sufficient confidence before proceeding to the next step. +- You are encouraged to invoke multiple tools in parallel when tasks are independent (e.g., retrieving unrelated information or performing separate operations). +- For file-level operations, keep using relative paths. + +# Quality Constraints +- NEVER fabricate citations or sources. Every factual statement in the final deliverable must be supported by the Searcher sub-agent's research conclusions and stored evidence. +- Clearly track time constraints and the current date. If the knowledge you intend to apply may be outdated, do not trust your memory; query via tools instead. +- Strictly control scope: if the user asks for X, do not drift to Y. +- Citation integrity in the final report: + - **Fix if broken**: + - Invalid citation forms (e.g., [Note ID]-style placeholders) — replace with proper `[1]`, `[2]`, ... numbered markers. + - Multiple ## References / ## 参考文献 sections (e.g., per-chapter reference lists) are not allowed — this includes any variant headings such as "## 参考文献(合并版)", "## References (Merged)", "## 参考资料", or similar. The report body and individual chapters must NOT contain any reference/bibliography list; remove such sections entirely. Keep only one unified reference section at the very end of the report. Re-number in-text citations if needed. + - **Supplement if missing**: Add numbered citations `[1]`, `[2]`, ... in the body (multiple may appear together like `[1][3]`). The report must end with exactly one `## References` (English) or `## 参考文献` (Chinese) section with consistent numbering. Do not use long-title links (e.g., `[Title](URL)`) in the body text. + - **Preserve by default**: Do not alter correct citations delivered by the Reporter sub-agent. Your edits must not cause citation loss. +- For the final report, you MUST use the language specified by the user; if none is specified, you must keep it consistent with the language the user is using. +- The final report in final_report.md MUST follow the "Default Report Style" section. + +# Default Report Style +- Technical/research report tone: careful and verifiable; as much information as possible and as faithful to the original evidence as possible; do not over-compress into an executive-summary-only output; avoid overly casual language; ensure readability. +- Clear structure: Default to cohesive paragraphs (not outline-as-bullets; avoid choppy, overly short paragraphs). Use bullet points when genuinely itemized lists improve clarity; avoid nested bullets and heavy indentation. +- Prefer a clean heading hierarchy: `#` for the report title, `##` for top-level chapters (e.g., `## 2. Background and Problem`), `###` and `####` for sub-sections. Do not exceed four heading levels. All headings MUST use Markdown ATX syntax. +- Chapter titles you provide should be concise and natural-sounding. Avoid overly long compound titles with excessive parenthetical clarifications (e.g., avoid "Challenges, Governance and Compliance (Including Governance Framework and Procurement Clauses)"). +- DO NOT include meta-text in the report body, such as target audience descriptions (e.g., "Target Audience: ...", "面向对象:..."), author notes (e.g., "Note: ...", "注:..."), or execution disclaimers. The report should be a polished, self-contained document. + +# Unexpected Handling +1. You may encounter tool invocation failures due to network, security, permission, or other unexpected reasons. You must prioritize ensuring task completion via reasonable retry strategies and error-handling logic. +2. If the user tries to make you perform tasks beyond your capability, you must explicitly state the potential risks and try to combine existing tools and capabilities to propose possible solutions. +3. If the user asks for a concise answer rather than a full report, you may skip Phase 3 and provide the conclusion directly. +4. If the user attempts casual conversation rather than research tasks, you do not need to start the research workflow; you may respond normally and try to guide the user to initiate a research task. diff --git a/projects/deep_research/v2/searcher.yaml b/projects/deep_research/v2/searcher.yaml index b9b19f08b..78da6b0d8 100644 --- a/projects/deep_research/v2/searcher.yaml +++ b/projects/deep_research/v2/searcher.yaml @@ -15,6 +15,8 @@ generation_config: prefix_cache_roles: [system, user, assistant, tool] extra_body: enable_thinking: false + # show_reasoning: true + # reasoning_output: stdout tag: deep-research From 86a3ba86e67dbea0ba67f9b4a245e348d8557ca1 Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Fri, 27 Mar 2026 13:15:54 +0800 Subject: [PATCH 15/40] support API key pool construction; support for reasoning model --- ms_agent/tools/search/exa/search.py | 179 ++++++++++++++++-- ms_agent/tools/search/websearch_tool.py | 27 ++- .../v2/callbacks/reporter_callback.py | 97 ++++------ .../v2/prompts/researcher/en/thinking.txt | 6 +- projects/deep_research/v2/reporter.yaml | 2 +- projects/deep_research/v2/researcher.yaml | 6 +- 6 files changed, 229 insertions(+), 88 deletions(-) diff --git a/ms_agent/tools/search/exa/search.py b/ms_agent/tools/search/exa/search.py index be456c0a5..56229a9b8 100644 --- a/ms_agent/tools/search/exa/search.py +++ b/ms_agent/tools/search/exa/search.py @@ -1,14 +1,18 @@ # flake8: noqa import os -from typing import TYPE_CHECKING +import threading +from typing import TYPE_CHECKING, List, Optional, Set, Union from exa_py import Exa from ms_agent.tools.search.exa.schema import ExaSearchRequest, ExaSearchResult from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType +from ms_agent.utils.logger import get_logger if TYPE_CHECKING: from ms_agent.llm.utils import Tool +logger = get_logger() + class ExaSearch(SearchEngine): """ @@ -16,21 +20,132 @@ class ExaSearch(SearchEngine): Best for: semantic understanding, finding similar content, recent web pages with date filtering. + + Supports a pool of API keys: when one key's credits are exhausted (HTTP 402), + the engine automatically rotates to the next available key and retries. + Keys can be supplied as a comma-separated string or a list. + + Exhausted-key state is tracked at the **class level** so that all + ``ExaSearch`` instances within the same process share the knowledge + of which keys have been used up (e.g. across multiple searcher sub-agents). """ engine_type = SearchEngineType.EXA - def __init__(self, api_key: str = None): + # Process-wide tracking of exhausted key values (shared across instances). + _global_exhausted_keys: Set[str] = set() + _global_lock = threading.Lock() + + def __init__(self, + api_key: Union[str, list, None] = None, + api_keys: Union[str, list, None] = None): + all_keys = self._collect_keys(api_key, api_keys) + assert all_keys, ( + 'EXA_API_KEY or EXA_API_KEYS must be set either as arguments ' + 'or as environment variables') + + self._api_keys: List[str] = all_keys + self._lock = threading.Lock() + + # Pick the first key that hasn't been globally exhausted yet. + start_idx = 0 + with ExaSearch._global_lock: + for i, k in enumerate(all_keys): + if k not in ExaSearch._global_exhausted_keys: + start_idx = i + break + + self._current_key_idx: int = start_idx + self.client = Exa(api_key=all_keys[start_idx]) + + if len(all_keys) > 1: + with ExaSearch._global_lock: + n_exhausted = sum( + 1 for k in all_keys + if k in ExaSearch._global_exhausted_keys) + logger.info( + f'Exa key pool: {len(all_keys)} keys, ' + f'{n_exhausted} previously exhausted, ' + f'starting at key {start_idx + 1}/{len(all_keys)}') + + @staticmethod + def _collect_keys( + api_key: Union[str, list, None] = None, + api_keys: Union[str, list, None] = None, + ) -> List[str]: + """Collect unique API keys from arguments and environment variables. + + All sources are **merged** (deduplicated), so keys from both YAML config + and ``EXA_API_KEYS`` env var are combined into a single pool. + + Sources (in merge order): + 1. ``api_keys`` argument (list or comma-separated string) + 2. ``api_key`` argument (single key or comma-separated string) + 3. ``EXA_API_KEYS`` env var (comma-separated) -- always merged + 4. ``EXA_API_KEY`` env var (only if no keys found so far) + """ + seen: set = set() + result: List[str] = [] + + def _add(raw: str): + for k in raw.split(','): + k = k.strip() + if k and k not in seen: + seen.add(k) + result.append(k) - api_key = api_key or os.getenv('EXA_API_KEY') - assert api_key, 'EXA_API_KEY must be set either as an argument or as an environment variable' + def _add_source(value): + if value is None: + return + if isinstance(value, str): + _add(value) + elif hasattr(value, '__iter__'): + for item in value: + if item is not None: + _add(str(item)) - self.client = Exa(api_key=api_key) + _add_source(api_keys) + _add_source(api_key) + + # Always merge the pool env var so that keys from YAML config and + # EXA_API_KEYS are combined (the old code gated this behind + # ``if not result`` which made it unreachable when api_key was set). + _add_source(os.getenv('EXA_API_KEYS')) + + if not result: + _add_source(os.getenv('EXA_API_KEY')) + + return result + + @staticmethod + def _is_credits_exhausted(error: Exception) -> bool: + """Detect Exa 402 / NO_MORE_CREDITS errors.""" + msg = str(error) + return ('402' in msg + and ('credits' in msg.lower() + or 'NO_MORE_CREDITS' in msg)) + + @staticmethod + def _mask_key(key: str) -> str: + if len(key) <= 8: + return '****' + return f'{key[:4]}...{key[-4:]}' + + def _is_key_exhausted(self, idx: int) -> bool: + with ExaSearch._global_lock: + return self._api_keys[idx] in ExaSearch._global_exhausted_keys + + def _mark_key_exhausted(self, idx: int) -> None: + with ExaSearch._global_lock: + ExaSearch._global_exhausted_keys.add(self._api_keys[idx]) def search(self, search_request: ExaSearchRequest) -> ExaSearchResult: """ Perform a search using the Exa API with the provided search request parameters. + If the current key is exhausted (HTTP 402 / NO_MORE_CREDITS), the engine + rotates to the next available key and retries, up to ``len(api_keys)`` times. + :param search_request: An instance of ExaSearchRequest containing search parameters. :return: An instance of ExaSearchResult containing the search results. """ @@ -39,13 +154,55 @@ def search(self, search_request: ExaSearchRequest) -> ExaSearchResult: query=search_request.query, arguments=search_args, ) - try: - search_result.response = self.client.search_and_contents( - **search_args) - except Exception as e: - raise RuntimeError(f'Failed to perform search: {e}') from e - return search_result + last_error: Optional[Exception] = None + max_attempts = len(self._api_keys) + instance_exhausted: Set[int] = set() + + for _attempt in range(max_attempts): + with self._lock: + client = self.client + key_idx = self._current_key_idx + + try: + search_result.response = client.search_and_contents( + **search_args) + return search_result + except Exception as e: + if not self._is_credits_exhausted(e): + raise RuntimeError( + f'Failed to perform search: {e}') from e + + last_error = e + instance_exhausted.add(key_idx) + self._mark_key_exhausted(key_idx) + + with self._lock: + logger.warning( + f'Exa API key {self._mask_key(self._api_keys[key_idx])} ' + f'credits exhausted ' + f'({len(instance_exhausted)}/{len(self._api_keys)} keys used up)' + ) + rotated = False + for i in range(len(self._api_keys)): + if i not in instance_exhausted and not self._is_key_exhausted(i): + self._current_key_idx = i + self.client = Exa( + api_key=self._api_keys[i]) + logger.info( + f'Rotated to Exa API key ' + f'{self._mask_key(self._api_keys[i])} ' + f'({i + 1}/{len(self._api_keys)})') + rotated = True + break + if not rotated: + raise RuntimeError( + f'All {len(self._api_keys)} Exa API keys have ' + f'been exhausted. Last error: {e}') from e + + raise RuntimeError( + f'All {len(self._api_keys)} Exa API keys have been exhausted. ' + f'Last error: {last_error}') from last_error @classmethod def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': diff --git a/ms_agent/tools/search/websearch_tool.py b/ms_agent/tools/search/websearch_tool.py index 7bbecfe89..5e469006e 100644 --- a/ms_agent/tools/search/websearch_tool.py +++ b/ms_agent/tools/search/websearch_tool.py @@ -221,17 +221,23 @@ def get_search_engine(engine_type: str, Args: engine_type: One of 'exa', 'serpapi', 'arxiv' - api_key: API key for the search engine (if required) - **kwargs: Additional arguments passed to engine constructor + api_key: API key for the search engine (if required). + For Exa, this can be a comma-separated string of multiple keys + to enable automatic key rotation on credit exhaustion. + **kwargs: Additional arguments passed to engine constructor. + For Exa, ``api_keys`` (list or comma-separated str) is also + accepted to supply a key pool explicitly. """ engine_type = engine_type.lower().strip() if engine_type == 'exa': from ms_agent.tools.search.exa import ExaSearch - return ExaSearch(api_key=api_key or os.getenv('EXA_API_KEY')) + return ExaSearch( + api_key=api_key or os.getenv('EXA_API_KEY'), + api_keys=kwargs.get('api_keys') or os.getenv('EXA_API_KEYS'), + ) elif engine_type in ('serpapi', 'serp', 'google', 'bing', 'baidu'): from ms_agent.tools.search.serpapi import SerpApiSearch - # Allow shorthand engine_type aliases to imply provider default_provider = ('google' if engine_type in ('serpapi', 'serp') else engine_type) return SerpApiSearch( @@ -394,13 +400,20 @@ def __init__(self, config, **kwargs): 'No valid engines configured, falling back to arxiv') self._engine_types = ['arxiv'] - # API keys for each engine + # API keys for each engine. + # Exa supports a key pool: exa_api_keys (list/comma-separated) takes + # priority over the single-key fields. The value is forwarded as-is to + # ExaSearch which handles parsing into a list internally. self._api_keys = { 'exa': ( - getattr(tool_cfg, 'exa_api_key', None) + getattr(tool_cfg, 'exa_api_keys', None) + or getattr(tool_cfg, 'exa_api_key', None) or getattr(tool_cfg, 'api_key', None) # backward compat + or os.getenv('EXA_API_KEYS') or os.getenv('EXA_API_KEY')) - if tool_cfg else os.getenv('EXA_API_KEY'), + if tool_cfg else ( + os.getenv('EXA_API_KEYS') + or os.getenv('EXA_API_KEY')), 'serpapi': (getattr(tool_cfg, 'serpapi_api_key', None) or os.getenv('SERPAPI_API_KEY')) if tool_cfg else os.getenv('SERPAPI_API_KEY'), diff --git a/projects/deep_research/v2/callbacks/reporter_callback.py b/projects/deep_research/v2/callbacks/reporter_callback.py index 7bfb5bee3..4332f26bc 100644 --- a/projects/deep_research/v2/callbacks/reporter_callback.py +++ b/projects/deep_research/v2/callbacks/reporter_callback.py @@ -164,75 +164,46 @@ class ReporterCallback(Callback): 'zh': ('\n\n---\n' '**[后续工作流程建议]**\n\n' - 'Reporter 已完成报告生成。如果其正常返回工作总结,请仔细审阅返回内容的 Execution_Summary 和 Artifacts 字段,' - '它们总结了报告生成过程并列出了重要的中间文件产物。如果其未正常完成任务或者未正常返回信息,请主动检查 reports 目录下的产物情况确定后续行动。\n\n' - '**关于 final_report.md:' - '** 上方 Artifacts 字段通常只包含 reports/ 目录下的文件(如 reports/report.md),' - '不包含 final_report.md。这是正常的——系统会在 Reporter 正常完成任务后自动将 reports/report.md 复制为 final_report.md。' - '你的审阅和编辑应优先针对 final_report.md。如有需要可按需读取 reports/ 下的其他文件作为参考,' - '但当 final_report.md 可用时避免重复读取 reports/report.md。如果 final_report.md 意外缺失或不完整,按此路径回退:' - 'reports/report.md -> reports/draft.md -> reports/ 下其他产物内容。\n\n' - '**审查与编辑注意事项:**\n' - '- 请严格遵守系统指令中的要求,不要遗漏、忽略任何合理的规则。\n' - '- 审查要点包括事实准确性、逻辑一致性、用户核心问题的覆盖度、引用与论据的对齐关系、引用格式问题、内容完整性等等。' - '修改须有明确依据(如事实冗余、逻辑混乱、证据不一致、格式出错等),不要为了"润色"而改动结构/质量良好的内容。\n' - '- 读取报告内容一次后形成判断,后续核查优先使用 search_file_content 或带 start_line / end_line 的 read_file,不要反复全量读取同一文件。' - '在读取文件前先检查对话历史中是否已包含该文件的内容,避免重复读取。\n' - '- 优先使用定点修改(search_file_content -> replace_file_contents / replace_file_lines),仅在必要时才读取全文。' - '仅在定点修改完全无法解决时使用 write_file,且**必须完整保留所有有价值的内容**,严禁使用占位符、省略标记、引用其他内容等方式替代正文。\n' - '- 质量较高无需修改的部分直接跳过。如果[Reporter 工作总结]中无异常且审查确认全文质量良好,直接进入结论阶段即可。\n\n' - '**需避免的常见错误:**\n' - '- 重复全量读取同一个报告文件(迅速耗尽上下文预算,导致任务失败)。\n' - '- 默认 final_report.md 不存在、且使用简短的概述内容覆盖完整报告。\n' - '- 对结构/质量良好的内容过度修改或者压缩,或在修改过程中忘记已做的改动重复编辑导致错误。\n'), + 'Reporter 已完成报告生成。请审阅返回内容的 Execution_Summary 和 Artifacts 字段以了解生成过程和产物。' + '如果 Reporter 未正常完成或返回信息缺失,请主动检查 reports/ 目录下的产物确定后续行动。\n\n' + '**关于 final_report.md:** Artifacts 字段通常只包含 reports/ 目录下的文件,不包含 final_report.md——' + '系统会在 Reporter 正常完成后自动将 reports/report.md 复制为 final_report.md。审阅和编辑应优先针对 final_report.md。' + '如果 final_report.md 意外缺失,按此路径回退:reports/report.md -> reports/draft.md -> reports/ 下其他产物。\n\n' + '**审查原则:**\n' + '- 充分理解报告全貌后再进行修改——基于局部认知的编辑容易在已读和未读的章节之间引入矛盾。' + '全文阅读一次后,后续核查使用定点方式(search_file_content 或带行号的 read_file),反复全量读取同一文件会迅速耗尽上下文预算导致任务失败。\n' + '- 仅对具体问题(事实错误、逻辑缺陷、证据不一致、格式违规等)进行修改,不要为"润色"而改动质量良好的内容。\n' + '- 优先使用定点修改(search_file_content -> replace_file_contents / replace_file_lines)。' + '仅在定点修改完全无法解决时使用 write_file 重写全文,且必须完整保留所有有价值的内容,严禁使用占位符、省略标记或引用其他文件内容来替代正文。\n' + '- 质量良好且无实质性问题时直接进入结论阶段,不要为"润色"而重写。\n'), 'en': ('\n\n---\n' '**[Post-Report Workflow Guidance]**\n\n' - 'The Reporter has finished generating the report. If it returned a work summary normally, ' - 'please carefully review the Execution_Summary and Artifacts fields in the returned content — ' - 'they summarize the report generation process and list important intermediate file artifacts. ' - 'If Reporter did not complete the task normally or did not return information properly, ' - 'proactively check the artifacts under the `reports/` directory to determine next steps.\n\n' - '**About `final_report.md`:** The Artifacts field above typically lists only ' - 'files under `reports/` (e.g., `reports/report.md`) and will NOT include ' - '`final_report.md`. This is expected — the system automatically copies ' - '`reports/report.md` to `final_report.md` after the Reporter finishes normally. ' - 'Your review and edits should target `final_report.md`. You may read other ' - 'files under `reports/` as supplementary references when needed, ' - 'but avoid reading `reports/report.md` in full when ' - '`final_report.md` is available. If `final_report.md` is unexpectedly ' - 'missing or incomplete, fall back in this order: ' - '`reports/report.md` -> `reports/draft.md` -> other artifacts under `reports/`.\n\n' - '**Review and editing guidelines:**\n' - '- Strictly follow the requirements in the system instructions; do not overlook or ignore any reasonable rules.\n' - '- Key review points include factual accuracy, logical consistency, coverage of the user\'s core questions, ' - 'alignment between citations and supporting arguments, citation formatting issues, content completeness, etc. ' - 'Edits must have clear justification (e.g., factual redundancy, logical confusion, evidence inconsistency, ' - 'formatting errors, etc.) — do not alter well-structured, high-quality content merely for "polishing."\n' - '- Read the report content ONCE to form your assessment. For subsequent ' - 'checks, prefer `search_file_content` or `read_file` with `start_line`/`end_line`. ' - 'Do not re-read the entire file repeatedly. Check your conversation history before ' - 'reading any file to avoid redundant reads.\n' - '- Prefer targeted fixes (`search_file_content` -> `replace_file_contents` / ' - '`replace_file_lines`); only read the full text when necessary. ' - 'Use `write_file` only when targeted fixes are completely insufficient, ' - 'and you **must preserve ALL valuable content in full** — never use placeholders, ' - 'ellipsis markers, or references to other content as substitutes for actual text.\n' - '- Skip high-quality sections that require no changes. If the [Reporter Work Summary] ' - 'indicates no issues and your review confirms overall quality, proceed ' - 'directly to the conclusion.\n\n' - '**Common mistakes to avoid:**\n' - '- Reading the same report file in full multiple times (rapidly exhausts ' - 'context budget and causes task failure).\n' - '- Assuming `final_report.md` does not exist and overwriting the complete report ' - 'with a brief summary.\n' - '- Over-editing or compressing well-structured, high-quality content, or losing track ' - 'of changes already made and making duplicate edits that introduce errors.\n'), + 'The Reporter has finished generating the report. Review the Execution_Summary and Artifacts fields ' + 'to understand the generation process and output artifacts. ' + 'If Reporter did not complete normally, check the reports/ directory to determine next steps.\n\n' + '**About final_report.md:** Artifacts typically lists only files under reports/ ' + '(e.g., reports/report.md), not final_report.md - the system automatically copies ' + 'reports/report.md to final_report.md after normal completion. ' + 'Your review and edits should primarily target final_report.md. ' + 'If it is unexpectedly missing, fall back to: ' + 'reports/report.md -> reports/draft.md -> other artifacts under reports/.\n\n' + '**Review principles:**\n' + '- Familiarize yourself with the entire report before making any edits — ' + 'changes based on partial knowledge commonly introduce contradictions between sections. ' + 'After the initial read, use targeted approaches (search_file_content, read_file with line ranges) ' + 'for subsequent checks. Repeated full reads rapidly exhaust context capacity and risk task failure.\n' + '- Only edit to fix concrete problems (factual errors, logical flaws, evidence mismatches, ' + 'style violations, etc.). Do not alter well-structured content merely for polish.\n' + '- Prefer targeted fixes (search_file_content -> replace_file_contents / replace_file_lines). ' + 'Use write_file to rewrite the full report only as a last resort, and preserve ALL valuable content in full — ' + 'never use placeholders, ellipsis markers, or references to other files as substitutes for actual text.\n' + '- If quality is confirmed and no substantive issues exist, proceed directly to your conclusion.\n'), } _WORK_SUMMARY_LABEL = { - 'zh': '**[Reporter 工作总结]**', - 'en': '**[Reporter Work Summary]**', + 'zh': '**[Reporter 返回的 JSON 结果]**', + 'en': '**[Reporter\'s Returned JSON Result]**', } def __init__(self, config: DictConfig): diff --git a/projects/deep_research/v2/prompts/researcher/en/thinking.txt b/projects/deep_research/v2/prompts/researcher/en/thinking.txt index 6d21b8b48..38c211a87 100644 --- a/projects/deep_research/v2/prompts/researcher/en/thinking.txt +++ b/projects/deep_research/v2/prompts/researcher/en/thinking.txt @@ -3,7 +3,7 @@ You have everything you need to resolve the task. I want you to fully solve this Time reminder: The current date is , and the current time is . Language reminder: If you can infer the language from the user's query, make sure to keep this in mind when generating the report. Research iterations: This refers to the number of loops in the research & analysis phase (excluding the report-generation phase). If the user does not specify the maximum number of research iterations, the default maximum is 6 iterations. You MUST complete the task within the maximum number of iterations. -Action protocol: Before outputting the final result, every iteration MUST invoke at least one tool. You MUST reason extensively about the current state and your intended next action before each tool call and show your thinking in the conversation (e.g., What key information did I find? What's missing? Do I have enough to answer the question comprehensively? What should I do next?). DO NOT do this entire process by making tool calls only, as this can impair your ability to solve the problem and think insightfully. +Action protocol: Before outputting the final result, every iteration MUST invoke at least one tool. Periodically assess progress to determine the most effective next step. Keep your conversation output focused on key decisions and conclusions. Make each tool call purposeful and avoid redundancy — every tool result permanently occupies context capacity. # Primary Responsibilities - Plan & orchestrate: @@ -26,9 +26,9 @@ You are free to adapt, reorder, or skip steps based on the complexity and requir - If resuming from an unfinished task, start by checking the current completion status using todo_list---todo_read and other available tools. - Develop a manageable plan based on user's needs and task progress. Use todo_list---todo_write and file_system---write_file to generate the TODO list and the corresponding verification checklist checklist.yaml, respectively. - The TODO list must cover all subtasks that need to be completed. It is used to clearly communicate your full plan to the user. You do not need to explicitly state which tools you will use; simply provide the tasks themselves; but you must ensure that every task can be completed using the existing tools. - - Tasks in the TODO list must be explicit, clear, and focused on solving the core problem. Each task should contain no more than three core questions to answer, while also avoiding over-splitting that would make the task list excessively long. + - Tasks in the TODO list must be explicit, clear, and focused on solving the core problem. Each task should contain no more than three core questions to answer. - Tasks in the TODO list should be assigned reasonable priorities: high for tasks directly answering the user's core questions, medium for supporting context or secondary dimensions, low for nice-to-have extensions. High-priority tasks should be executed first, while medium- and low-priority tasks should be performed only if the iteration budget allows. -- Compare the TODO list and the verification checklist for reflection. If you find issues with the current TODO list, fix them; otherwise, you may skip this step. +- Do self-reflection. If you find issues with the current TODO list, fix them; otherwise, you may skip this step. - If necessary, you can invoke the Searcher sub-agent at most once for concept clarification; - If you find issues in the TODO list, you must revise it via the todo_list---todo_write tool. diff --git a/projects/deep_research/v2/reporter.yaml b/projects/deep_research/v2/reporter.yaml index c55fd109b..a101f3edb 100644 --- a/projects/deep_research/v2/reporter.yaml +++ b/projects/deep_research/v2/reporter.yaml @@ -71,7 +71,7 @@ self_reflection: enabled: true max_retries: 3 min_retention_ratio: 0.6 - post_report_guidance_enabled: false + post_report_guidance_enabled: true quality_check: enabled: true model: qwen3.5-flash diff --git a/projects/deep_research/v2/researcher.yaml b/projects/deep_research/v2/researcher.yaml index 95ba77a84..95625df00 100644 --- a/projects/deep_research/v2/researcher.yaml +++ b/projects/deep_research/v2/researcher.yaml @@ -1,6 +1,6 @@ llm: service: openai - model: gpt-5-2025-08-07 + model: gpt-5.2-2025-12-11 openai_api_key: openai_base_url: @@ -17,7 +17,7 @@ generation_config: # enable_thinking: true # show_reasoning: true # reasoning_output: stdout - # reasoning_effort: medium + reasoning_effort: medium tag: deep-research-researcher @@ -27,7 +27,7 @@ prompt: root: prompts/ agent: researcher lang: en - family: gpt5 + family: thinking tools: From 8a32ce5491b2d62bf010254df2b8d0403cf8b583 Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Fri, 27 Mar 2026 13:18:39 +0800 Subject: [PATCH 16/40] fix lint --- ms_agent/tools/search/exa/search.py | 31 +++++++++++-------------- ms_agent/tools/search/websearch_tool.py | 8 +++---- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/ms_agent/tools/search/exa/search.py b/ms_agent/tools/search/exa/search.py index 56229a9b8..08fa6fb71 100644 --- a/ms_agent/tools/search/exa/search.py +++ b/ms_agent/tools/search/exa/search.py @@ -60,13 +60,11 @@ def __init__(self, if len(all_keys) > 1: with ExaSearch._global_lock: - n_exhausted = sum( - 1 for k in all_keys - if k in ExaSearch._global_exhausted_keys) - logger.info( - f'Exa key pool: {len(all_keys)} keys, ' - f'{n_exhausted} previously exhausted, ' - f'starting at key {start_idx + 1}/{len(all_keys)}') + n_exhausted = sum(1 for k in all_keys + if k in ExaSearch._global_exhausted_keys) + logger.info(f'Exa key pool: {len(all_keys)} keys, ' + f'{n_exhausted} previously exhausted, ' + f'starting at key {start_idx + 1}/{len(all_keys)}') @staticmethod def _collect_keys( @@ -122,8 +120,7 @@ def _is_credits_exhausted(error: Exception) -> bool: """Detect Exa 402 / NO_MORE_CREDITS errors.""" msg = str(error) return ('402' in msg - and ('credits' in msg.lower() - or 'NO_MORE_CREDITS' in msg)) + and ('credits' in msg.lower() or 'NO_MORE_CREDITS' in msg)) @staticmethod def _mask_key(key: str) -> str: @@ -170,8 +167,7 @@ def search(self, search_request: ExaSearchRequest) -> ExaSearchResult: return search_result except Exception as e: if not self._is_credits_exhausted(e): - raise RuntimeError( - f'Failed to perform search: {e}') from e + raise RuntimeError(f'Failed to perform search: {e}') from e last_error = e instance_exhausted.add(key_idx) @@ -185,14 +181,13 @@ def search(self, search_request: ExaSearchRequest) -> ExaSearchResult: ) rotated = False for i in range(len(self._api_keys)): - if i not in instance_exhausted and not self._is_key_exhausted(i): + if i not in instance_exhausted and not self._is_key_exhausted( + i): self._current_key_idx = i - self.client = Exa( - api_key=self._api_keys[i]) - logger.info( - f'Rotated to Exa API key ' - f'{self._mask_key(self._api_keys[i])} ' - f'({i + 1}/{len(self._api_keys)})') + self.client = Exa(api_key=self._api_keys[i]) + logger.info(f'Rotated to Exa API key ' + f'{self._mask_key(self._api_keys[i])} ' + f'({i + 1}/{len(self._api_keys)})') rotated = True break if not rotated: diff --git a/ms_agent/tools/search/websearch_tool.py b/ms_agent/tools/search/websearch_tool.py index 5e469006e..8dd80f86b 100644 --- a/ms_agent/tools/search/websearch_tool.py +++ b/ms_agent/tools/search/websearch_tool.py @@ -409,11 +409,9 @@ def __init__(self, config, **kwargs): getattr(tool_cfg, 'exa_api_keys', None) or getattr(tool_cfg, 'exa_api_key', None) or getattr(tool_cfg, 'api_key', None) # backward compat - or os.getenv('EXA_API_KEYS') - or os.getenv('EXA_API_KEY')) - if tool_cfg else ( - os.getenv('EXA_API_KEYS') - or os.getenv('EXA_API_KEY')), + or os.getenv('EXA_API_KEYS') or os.getenv('EXA_API_KEY')) + if tool_cfg else + (os.getenv('EXA_API_KEYS') or os.getenv('EXA_API_KEY')), 'serpapi': (getattr(tool_cfg, 'serpapi_api_key', None) or os.getenv('SERPAPI_API_KEY')) if tool_cfg else os.getenv('SERPAPI_API_KEY'), From 393f23c637dd452793b4579fbd20f59c5a2603b7 Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Sat, 28 Mar 2026 01:26:22 +0800 Subject: [PATCH 17/40] support for vertex type anthropic llm; refine reasoning output --- ms_agent/agent/llm_agent.py | 64 +++++--- ms_agent/llm/anthropic_llm.py | 170 +++++++++++++++++++-- projects/deep_research/v2/run_benchmark.sh | 2 +- 3 files changed, 205 insertions(+), 31 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 740eab690..9143cdd3c 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -565,16 +565,41 @@ def reasoning_output(self) -> str: DictConfig({})) return str(getattr(generation_config, 'reasoning_output', 'stdout')) - def _write_reasoning(self, text: str): + _THINKING_SEP = '─' * 40 + + def _reasoning_stream(self): + if self.reasoning_output.lower() == 'stdout': + return sys.stdout + return sys.stderr + + def _write_reasoning(self, text: str, dim: bool = False): if not text: return - if self.reasoning_output.lower() == 'stdout': - sys.stdout.write(text) - sys.stdout.flush() + stream = self._reasoning_stream() + use_ansi = hasattr(stream, 'isatty') and stream.isatty() + if dim and use_ansi: + text = f'\033[2m{text}\033[0m' + stream.write(text) + stream.flush() + + def _write_thinking_header(self): + stream = self._reasoning_stream() + use_ansi = hasattr(stream, 'isatty') and stream.isatty() + line = f'{self._THINKING_SEP[:15]} thinking {self._THINKING_SEP[25:]}' + if use_ansi: + stream.write(f'\033[2m{line}\033[0m\n') + else: + stream.write(line + '\n') + stream.flush() + + def _write_thinking_footer(self): + stream = self._reasoning_stream() + use_ansi = hasattr(stream, 'isatty') and stream.isatty() + if use_ansi: + stream.write(f'\n\033[2m{self._THINKING_SEP}\033[0m\n') else: - # default: stderr - sys.stderr.write(text) - sys.stderr.flush() + stream.write(f'\n{self._THINKING_SEP}\n') + stream.flush() @property def system(self): @@ -812,35 +837,38 @@ async def step( is_first = True _response_message = None _printed_reasoning_header = False + _printed_reasoning_footer = False for _response_message in self.llm.generate( messages, tools=tools): if is_first: messages.append(_response_message) is_first = False - # Optional: stream model "thinking/reasoning" if available. if self.show_reasoning: reasoning_text = getattr(_response_message, 'reasoning_content', '') or '' - # Some providers may reset / shorten content across chunks. if len(reasoning_text) < len(_reasoning): _reasoning = '' new_reasoning = reasoning_text[len(_reasoning):] if new_reasoning: if not _printed_reasoning_header: - self._write_reasoning('[thinking]:\n') + self._write_thinking_header() _printed_reasoning_header = True - self._write_reasoning(new_reasoning) + self._write_reasoning(new_reasoning, dim=True) _reasoning = reasoning_text new_content = _response_message.content[len(_content):] - sys.stdout.write(new_content) - sys.stdout.flush() + if new_content: + if _printed_reasoning_header and not _printed_reasoning_footer: + self._write_thinking_footer() + _printed_reasoning_footer = True + sys.stdout.write(new_content) + sys.stdout.flush() _content = _response_message.content messages[-1] = _response_message yield messages - if self.show_reasoning and _printed_reasoning_header: - self._write_reasoning('\n') + if _printed_reasoning_header and not _printed_reasoning_footer: + self._write_thinking_footer() sys.stdout.write('\n') else: _response_message = self.llm.generate(messages, tools=tools) @@ -848,9 +876,9 @@ async def step( reasoning_text = getattr(_response_message, 'reasoning_content', '') or '' if reasoning_text: - self._write_reasoning('[thinking]:\n') - self._write_reasoning(reasoning_text) - self._write_reasoning('\n') + self._write_thinking_header() + self._write_reasoning(reasoning_text, dim=True) + self._write_thinking_footer() if _response_message.content: self.log_output('[assistant]:') self.log_output(_response_message.content) diff --git a/ms_agent/llm/anthropic_llm.py b/ms_agent/llm/anthropic_llm.py index 6f93a1167..5b35bfb5d 100644 --- a/ms_agent/llm/anthropic_llm.py +++ b/ms_agent/llm/anthropic_llm.py @@ -1,6 +1,8 @@ import inspect from typing import Any, Dict, Generator, Iterator, List, Optional, Union +import httpx +import json from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool, ToolCall from ms_agent.utils import assert_package_exist, retry @@ -8,6 +10,111 @@ from omegaconf import DictConfig, OmegaConf +class _SSEEventInjector(httpx.SyncByteStream): + """Injects SSE ``event:`` lines into DashScope's streaming response. + + DashScope only emits ``data:`` lines in its SSE stream. The Anthropic + SDK's ``MessageStream`` relies on ``event:`` lines to route events. + This wrapper extracts the ``type`` from the JSON payload and prepends + the matching ``event:`` line so the SDK can process events correctly. + """ + + def __init__(self, stream): + self._stream = stream + self._buffer = b'' + + def __iter__(self): + for chunk in self._stream: + self._buffer += chunk + while b'\n\n' in self._buffer: + block, self._buffer = self._buffer.split(b'\n\n', 1) + if block.strip(): + yield self._inject(block) + b'\n\n' + if self._buffer.strip(): + yield self._inject(self._buffer) + b'\n\n' + + @staticmethod + def _inject(block: bytes) -> bytes: + for line in block.split(b'\n'): + s = line.strip() + if s.startswith(b'data:'): + try: + t = json.loads(s[5:].strip()).get('type', '') + if t: + return b'event: ' + t.encode() + b'\n' + block + except (json.JSONDecodeError, ValueError): + pass + return block + + def close(self): + if hasattr(self._stream, 'close'): + self._stream.close() + + +class DashScopeAnthropicTransport(httpx.BaseTransport): + """Routes Anthropic SDK requests to DashScope's compatible-mode endpoint. + + DashScope returns Anthropic-format SSE responses for vertex AI Claude models + (e.g. vertex_ai.claude-opus-4-6), but expects requests at + /compatible-mode/v1/chat/completions with a native protocol flag rather than + the standard Anthropic /v1/messages path. This transport transparently + rewrites URL, auth headers, and body so the Anthropic SDK works unmodified. + """ + + def __init__(self, + dashscope_url: str, + api_key: str, + supplier: Optional[str] = None): + self.dashscope_url = dashscope_url + self.api_key = api_key + self.supplier = supplier + self._transport = httpx.HTTPTransport() + + def handle_request(self, request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + is_streaming = bool(body.get('stream')) + + ext = body.setdefault('dashscope_extend_params', {}) + ext['using_native_protocol'] = True + if self.supplier and 'supplier' not in ext: + ext['supplier'] = self.supplier + + new_headers = { + 'content-type': 'application/json', + 'authorization': f'Bearer {self.api_key}', + } + _skip = frozenset({ + 'x-api-key', 'content-type', 'authorization', 'content-length', + 'host', 'transfer-encoding' + }) + for key, value in request.headers.items(): + k = key.lower() + if k not in _skip and not k.startswith('anthropic'): + new_headers[key] = value + + new_content = json.dumps(body).encode('utf-8') + new_request = httpx.Request( + method=request.method, + url=self.dashscope_url, + headers=new_headers, + content=new_content, + extensions=request.extensions, + ) + response = self._transport.handle_request(new_request) + + if is_streaming: + return httpx.Response( + status_code=response.status_code, + headers=response.headers, + stream=_SSEEventInjector(response.stream), + extensions=response.extensions, + ) + return response + + def close(self): + self._transport.close() + + class Anthropic(LLM): def __init__( @@ -29,10 +136,31 @@ def __init__( if not api_key: raise ValueError('Anthropic API key is required.') - self.client = anthropic.Anthropic( - api_key=api_key, - base_url=base_url, - ) + self._is_dashscope = bool(base_url and 'dashscope' in base_url.lower()) + + if self._is_dashscope: + dashscope_url = base_url + if not dashscope_url.rstrip('/').endswith('/chat/completions'): + dashscope_url = dashscope_url.rstrip('/') + '/chat/completions' + supplier = config.llm.get('dashscope_supplier', None) + transport = DashScopeAnthropicTransport( + dashscope_url=dashscope_url, + api_key=api_key, + supplier=supplier, + ) + http_client = httpx.Client( + transport=transport, + timeout=httpx.Timeout(300.0, connect=60.0), + ) + self.client = anthropic.Anthropic( + api_key=api_key, + http_client=http_client, + ) + else: + self.client = anthropic.Anthropic( + api_key=api_key, + base_url=base_url, + ) self.args: Dict = OmegaConf.to_container( getattr(config, 'generation_config', DictConfig({}))) @@ -112,24 +240,42 @@ def _call_llm(self, formatted_messages = formatted_messages[1:] max_tokens = kwargs.pop('max_tokens', 16000) - extra_body = kwargs.get('extra_body', {}) - enable_thinking = extra_body.get('enable_thinking', False) - thinking_budget = extra_body.get('thinking_budget', max_tokens) + + enable_thinking = bool(kwargs.pop('enable_thinking', False)) + thinking_budget = kwargs.pop('thinking_budget', None) + thinking_type = kwargs.pop('thinking_type', None) + + raw_extra_body = kwargs.pop('extra_body', {}) or {} + extra_body = dict(raw_extra_body) if isinstance(raw_extra_body, + dict) else {} + enable_thinking = bool( + extra_body.pop('enable_thinking', enable_thinking)) + thinking_budget = extra_body.pop('thinking_budget', + thinking_budget) or max_tokens + thinking_type = extra_body.pop('thinking_type', thinking_type) + for _k in ('show_reasoning', 'reasoning_output'): + extra_body.pop(_k, None) params = { 'model': self.model, 'messages': formatted_messages, - 'max_tokens': max_tokens, - 'thinking': { - 'type': 'enabled' if enable_thinking else 'disabled', - 'budget_tokens': thinking_budget - } + 'max_tokens': max_tokens } + if thinking_type == 'adaptive': + params['thinking'] = {'type': 'adaptive'} + elif enable_thinking: + params['thinking'] = { + 'type': 'enabled', + 'budget_tokens': thinking_budget, + } + if system: params['system'] = system if tools: params['tools'] = tools + if extra_body: + kwargs['extra_body'] = extra_body params.update(kwargs) if stream: diff --git a/projects/deep_research/v2/run_benchmark.sh b/projects/deep_research/v2/run_benchmark.sh index 97c451d4f..b2fa35c79 100755 --- a/projects/deep_research/v2/run_benchmark.sh +++ b/projects/deep_research/v2/run_benchmark.sh @@ -118,7 +118,7 @@ else # Benchmark subprocess tuning (override via env vars if needed) export DR_BENCH_POST_FINISH_GRACE_S="${DR_BENCH_POST_FINISH_GRACE_S:-180}" - export DR_BENCH_POST_REPORT_EXIT_GRACE_S="${DR_BENCH_POST_REPORT_EXIT_GRACE_S:-3600}" + export DR_BENCH_POST_REPORT_EXIT_GRACE_S="${DR_BENCH_POST_REPORT_EXIT_GRACE_S:-7200}" export DR_BENCH_REPORT_STABLE_WINDOW_S="${DR_BENCH_REPORT_STABLE_WINDOW_S:-10}" export DR_BENCH_SUBPROCESS_POLL_INTERVAL_S="${DR_BENCH_SUBPROCESS_POLL_INTERVAL_S:-0.5}" export DR_BENCH_SUBPROCESS_TERMINATE_TIMEOUT_S="${DR_BENCH_SUBPROCESS_TERMINATE_TIMEOUT_S:-30}" From e00c296f002811f647c10768bc44b09762210794 Mon Sep 17 00:00:00 2001 From: suluyan Date: Thu, 2 Apr 2026 10:46:32 +0800 Subject: [PATCH 18/40] feat: add snapshot/rollback system to agent runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add snapshot.py: isolated git repo under output_dir/.ms_agent_snapshots, stores message_count per commit for history truncation on rollback - Auto-snapshot on every user turn via on_task_begin (enable_snapshots=True by default) - Add list_snapshots()/rollback() to Agent base and LLMAgent: rollback restores files, truncates saved history, clears _read_cache - Refactor filesystem_tool.py: remove replace_file_contents, rewrite edit_file with old_string→new_string exact replace, quote-normalization fallback, smart delete, trailing-whitespace strip, staleness check, multi-type read (images/binary), read dedup cache - Add smoke tests (20 cases, all offline) --- .gitignore | 2 + ms_agent/agent/base.py | 9 + ms_agent/agent/llm_agent.py | 30 + ms_agent/tools/filesystem_tool.py | 1190 +++++++++------------------- ms_agent/utils/snapshot.py | 204 +++++ tests/utils/test_snapshot_smoke.py | 331 ++++++++ 6 files changed, 936 insertions(+), 830 deletions(-) create mode 100644 ms_agent/utils/snapshot.py create mode 100644 tests/utils/test_snapshot_smoke.py diff --git a/.gitignore b/.gitignore index f526ef422..30dfa8d1f 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,8 @@ wheels/ /package /temp **/tmp/ +.env* +.claude-trace/ /apps/agentfabric/tmp/ MANIFEST diff --git a/ms_agent/agent/base.py b/ms_agent/agent/base.py index 9e17c8c44..cb78d5ce2 100644 --- a/ms_agent/agent/base.py +++ b/ms_agent/agent/base.py @@ -74,6 +74,15 @@ def save_history(self, messages: Any, **kwargs): return save_history(self.output_dir, self.tag, self.config, messages) + def list_snapshots(self) -> list: + """Return snapshots for this agent's output_dir, most recent first.""" + from ms_agent.utils.snapshot import list_snapshots + return list_snapshots(self.output_dir) + + def rollback(self, commit_hash: str) -> bool: + """Restore output_dir to a previous snapshot and truncate history.""" + raise NotImplementedError() + def next_flow(self, idx: int) -> int: """Used in workflow, decide which agent goes next.""" return idx + 1 diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 5f2ddf2e7..12bb0dcca 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -24,6 +24,7 @@ from ms_agent.utils import async_retry, read_history, save_history from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger +from ms_agent.utils.snapshot import take_snapshot from omegaconf import DictConfig, OmegaConf from ..config.config import Config, ConfigLifecycleHandler @@ -351,6 +352,24 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: return messages + def rollback(self, commit_hash: str) -> bool: + """Restore output_dir to snapshot and truncate message history.""" + from ms_agent.utils.snapshot import restore_snapshot + ok, message_count = restore_snapshot(self.output_dir, commit_hash) + if not ok: + return False + # Truncate saved history to the message count at snapshot time + _, saved_messages = read_history(self.output_dir, self.tag) + if saved_messages and message_count < len(saved_messages): + save_history(self.output_dir, self.tag, self.config, + saved_messages[:message_count]) + # Clear read cache on FileSystemTool so stale entries don't block edits + if self.tool_manager is not None: + for tool in self.tool_manager.tools.values(): + if hasattr(tool, '_read_cache'): + tool._read_cache.clear() + return True + def register_callback(self, callback: Callback): """ Register a new callback to be triggered during the agent's lifecycle. @@ -477,6 +496,17 @@ def register_callback_from_config(self): async def on_task_begin(self, messages: List[Message]): self.log_output(f'Agent {self.tag} task beginning.') + if getattr(self.config, 'enable_snapshots', True): + _user_content = next( + ((getattr(m, 'content', '') or '')[:80] + for m in messages if getattr(m, 'role', '') == 'user'), + '', + ) + take_snapshot( + self.output_dir, + f'[pre] {_user_content}' if _user_content else '[pre] new task', + message_count=len(messages), + ) await self.loop_callback('on_task_begin', messages) async def on_task_end(self, messages: List[Message]): diff --git a/ms_agent/tools/filesystem_tool.py b/ms_agent/tools/filesystem_tool.py index a107a7c59..df01d0261 100644 --- a/ms_agent/tools/filesystem_tool.py +++ b/ms_agent/tools/filesystem_tool.py @@ -1,19 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import fnmatch +import base64 import os -import re -import shutil from concurrent.futures import ThreadPoolExecutor, as_completed -from pathlib import Path -from typing import Optional import json from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool from ms_agent.tools.base import ToolBase -from ms_agent.utils import MAX_CONTINUE_RUNS, get_logger, retry +from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_INDEX_DIR, DEFAULT_OUTPUT_DIR -from openai import OpenAI logger = get_logger() @@ -21,12 +16,13 @@ class FileSystemTool(ToolBase): """A file system operation tool""" - # Directories to exclude from file operations - EXCLUDED_DIRS = { - 'node_modules', 'dist', '.git', '__pycache__', '.venv', 'venv' + MAX_READ_BYTES = 10 * 1024 * 1024 # 10 MB per file + IMAGE_EXTENSIONS = frozenset({'png', 'jpg', 'jpeg', 'gif', 'webp'}) + # Curly quote → straight quote mapping for fuzzy matching + CURLY_QUOTE_MAP = { + '\u2018': "'", '\u2019': "'", # ' ' + '\u201c': '"', '\u201d': '"', # " " } - # File prefixes to exclude - EXCLUDED_FILE_PREFIXES = ('.', '..', '__pycache__') SYSTEM_FOR_ABBREVIATIONS = """你是一个帮我简化文件信息并返回缩略的机器人,你需要根据输入文件内容来生成压缩过的文件内容。 @@ -46,13 +42,6 @@ def __init__(self, config, **kwargs): super().__init__(config) self.exclude_func(getattr(config.tools, 'file_system', None)) self.output_dir = getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR) - if self.exclude_functions and 'edit_file' not in self.exclude_functions \ - or self.include_functions and 'edit_file' in self.include_functions: - self.edit_file_config = getattr(config.tools.file_system, - 'edit_file_config', None) - self.edit_client = OpenAI( - api_key=self.edit_file_config.api_key, - base_url=self.edit_file_config.base_url) self.trust_remote_code = kwargs.get('trust_remote_code', False) self.allow_read_all_files = getattr( getattr(config.tools, 'file_system', {}), 'allow_read_all_files', @@ -66,6 +55,8 @@ def __init__(self, config, **kwargs): self.system = self.SYSTEM_FOR_ABBREVIATIONS if hasattr(self.config.tools.file_system, 'system_for_abbreviations'): self.system = self.config.tools.file_system.system_for_abbreviations + # {real_path: {"mtime": float, "offset": int|None, "limit": int|None}} + self._read_cache: dict[str, dict] = {} async def connect(self): logger.warning_once( @@ -75,352 +66,117 @@ async def connect(self): async def _get_tools_inner(self): tools = { 'file_system': [ - Tool( - tool_name='create_directory', - server_name='file_system', - description='Create a directory', - parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': - 'string', - 'description': - 'The relative path of the directory to create', - } - }, - 'required': ['path'], - 'additionalProperties': False - }), Tool( tool_name='write_file', server_name='file_system', - description='Write content into a file', + description=( + 'Write content to a file. Creates the file if it does not exist, ' + 'or overwrites it if it does.\n\n' + 'Usage:\n' + '- Prefer `edit_file` for modifying existing files — it only changes the relevant section.\n' + '- Use this tool to create new files or perform a complete rewrite.\n' + '- Parent directories are created automatically if they do not exist.' + ), parameters={ 'type': 'object', 'properties': { 'path': { 'type': 'string', - 'description': 'The relative path of the file', + 'description': 'The relative path of the file to write', }, 'content': { 'type': 'string', - 'description': 'The content of the file', + 'description': 'The full content to write into the file', }, }, 'required': ['path', 'content'], 'additionalProperties': False }), Tool( - tool_name='read_abbreviation_file', + tool_name='read_file', server_name='file_system', - description= - 'Read the abbreviation content of file(s). If the information is not enough, ' - 'read the original file by `read_file`', + description=( + 'Read the content of one or more files.\n\n' + '- `paths`: list of relative file paths to read.\n' + '- For image files (png/jpg/jpeg/gif/webp), returns base64-encoded content.\n' + '- `offset`: line number to start reading from (1-based). ' + 'Only effective when paths has exactly one element. Omit to read from the beginning.\n' + '- `limit`: number of lines to read. ' + 'Only effective when paths has exactly one element. Omit to read to the end.\n' + '- `abbreviate`: if true, use an LLM to return a condensed summary of each file ' + 'instead of the raw content. Cached after first call. ' + 'Use this for a quick structural overview; read the full file if more detail is needed.' + ), parameters={ 'type': 'object', 'properties': { 'paths': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, + 'type': 'array', + 'items': {'type': 'string'}, 'description': - 'List of relative file path(s) to read, format: {"paths": ["file1", "file2"]}"]}', + 'List of relative file path(s) to read', }, - }, - 'required': ['paths'], - 'additionalProperties': False - }), - Tool( - tool_name='read_file', - server_name='file_system', - description= - 'Read the content of file(s). When reading a single file, optionally specify line range.', - parameters={ - 'type': 'object', - 'properties': { - 'paths': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, + 'offset': { + 'type': 'integer', 'description': - 'List of relative file path(s) to read, format: {"paths": ["file1", "file2"]}"]}', + 'Line number to start reading from (1-based). ' + 'Only provide if the file is too large to read at once.', }, - 'start_line': { - 'type': - 'integer', + 'limit': { + 'type': 'integer', 'description': - 'Start line number (1-based, inclusive). Only effective when paths has exactly one ' - 'element. 0 or omit to read from beginning.', + 'Number of lines to read. ' + 'Only provide if the file is too large to read at once.', }, - 'end_line': { - 'type': - 'integer', + 'abbreviate': { + 'type': 'boolean', 'description': - 'End line number (1-based, inclusive). Only effective when paths has exactly one ' - 'element. Omit to read to the end.', + 'If true, return an LLM-generated summary instead of raw content. ' + 'Useful for large files or quick structural overview.', }, }, 'required': ['paths'], 'additionalProperties': False }), - Tool( - tool_name='list_files', - server_name='file_system', - description='List all files in a directory', - parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': - 'string', - 'description': - "The path to list files, if path is None or '' or not given, " - 'the root dir will be used as path.', - } - }, - 'required': [], - 'additionalProperties': False - }), - Tool( - tool_name='delete_file_or_dir', - server_name='file_system', - description='Delete one file or one directory', - parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': 'string', - 'description': 'The relative path to delete', - } - }, - 'required': ['path'], - 'additionalProperties': False - }), Tool( tool_name='edit_file', server_name='file_system', - description= - ('Use this tool to make an edit to an existing file.\n\n' - 'This will be read by a less intelligent model, which will quickly apply the edit. ' - 'You should make it clear what the edit is, while also minimizing the unchanged code you write.\n' - 'When writing the edit, you should specify each edit in sequence, with the special comment ' - '// ... existing code ... to represent unchanged code in between edited lines.\n\n' - 'For example:\n\n// ... existing code ...\nFIRST_EDIT\n// ... existing code ...\n' - 'SECOND_EDIT\n// ... existing code ...\nTHIRD_EDIT\n// ... existing code ...\n\n' - 'You should still bias towards repeating as few lines of the original file ' - 'as possible to convey the change.\n' - 'But, each edit should contain minimally sufficient context of unchanged lines ' - "around the code you're editing to resolve ambiguity.\n" - 'DO NOT omit spans of pre-existing code (or comments) without using the ' - '// ... existing code ... comment to indicate its absence. ' - 'If you omit the existing code comment, the model may inadvertently delete these lines.\n' - 'If you plan on deleting a section, you must provide context before and after to delete it. ' - 'If the initial code is ```code \\n Block 1 \\n Block 2 \\n Block 3 \\n code```, ' - 'and you want to remove Block 2, you would output ' - '```// ... existing code ... \\n Block 1 \\n Block 3 \\n // ... existing code ...```.\n' - 'Make sure it is clear what the edit should be, and where it should be applied.\n' - 'Make edits to a file in a single edit_file call ' - 'instead of multiple edit_file calls to the same file. ' - 'The apply model can handle many distinct edits at once.' - ), + description=( + 'Edit an existing file by replacing an exact string with new content.\n\n' + 'You must provide the exact text to find (`old_string`) and the replacement (`new_string`).\n' + '`old_string` must match the file content EXACTLY — including whitespace and line breaks.\n' + 'If `old_string` appears multiple times and `replace_all` is false, the call will fail ' + 'with the match count so you can add more context to make it unique.\n\n' + 'Special case — `old_string=""`:\n' + '- File does not exist: creates the file with `new_string` as its content.\n' + '- File exists and is empty: fills it with `new_string`.\n' + '- File exists and has content: returns an error. Use `write_file` for a full rewrite.' + ), parameters={ 'type': 'object', 'properties': { 'path': { 'type': 'string', - 'description': - 'Path of the target file to modify.' - }, - 'instructions': { - 'type': - 'string', - 'description': - ('A single sentence instruction describing ' - 'what you are going to do for the sketched edit. ' - 'This is used to assist the less intelligent model in applying the edit. ' - 'Use the first person to describe what you are going to do. ' - 'Use it to disambiguate uncertainty in the edit.' - ) - }, - 'code_edit': { - 'type': - 'string', - 'description': - ('Specify ONLY the precise lines of code that you wish to edit. ' - 'NEVER specify or write out unchanged code. ' - 'Instead, represent all unchanged code using the comment of the language ' - "you're editing in - example: // ... existing code ..." - ) - } - }, - 'required': ['path', 'instructions', 'code_edit'] - }), - Tool( - tool_name='search_file_content', - server_name='file_system', - description= - 'Search for content in files using literal text or regex patterns. ' - 'Automatically detects and supports both literal string matching and regex pattern matching. ' - 'Returns matching files with line numbers and surrounding context.', - parameters={ - 'type': 'object', - 'properties': { - 'content': { - 'type': - 'string', - 'description': - 'The content/text or regex pattern to search for. ' - 'Supports both literal strings and regex patterns automatically.', - }, - 'parent_path': { - 'type': - 'string', - 'description': - 'The relative parent path to search in (optional, defaults to root)', - }, - 'file_pattern': { - 'type': - 'string', - 'description': - 'Wildcard pattern for file names, e.g., "*.py", "*.js", "test_*.py" ' - '(default: "*" for all files)', - }, - 'context_lines': { - 'type': - 'integer', - 'description': - 'Number of lines before and after the match to include (default: 2)', - }, - }, - 'required': ['content'], - 'additionalProperties': False - }), - Tool( - tool_name='search_file_name', - server_name='file_system', - description= - 'Search for files by name using regex pattern matching. ' - 'Supports both regex patterns and simple substring matching. ' - 'If the file parameter is a valid regex pattern, it will be used for regex matching; ' - 'otherwise, falls back to substring matching. ' - 'The parent_path can also be a regex pattern to filter directories.', - parameters={ - 'type': 'object', - 'properties': { - 'file': { - 'type': - 'string', - 'description': - 'The filename pattern to search for (supports regex, e.g., r"\\.js$" for .js files, ' - 'or "service" for substring match).', + 'description': 'The relative path of the file to edit.', }, - 'parent_path': { - 'type': - 'string', - 'description': - 'The relative parent path to search in (supports regex for directory filtering, ' - 'e.g., r"backend.*" to match backend-related directories). ' - 'Defaults to root if not specified.', - }, - }, - 'required': ['file'], - 'additionalProperties': False - }), - Tool( - tool_name='replace_file_lines', - server_name='file_system', - description= - 'Replace specific line ranges in a file. Supports inserting at beginning ' - '(start_line=0) or end (start_line=-1). Line numbers are 1-based and inclusive on both ends.\n\n' - 'IMPORTANT — Line-number shift after each call. Every replacement changes the total line count, ' - 'which invalidates ALL line numbers after the replaced range. If you need to make multiple replacements in the same file:\n' - '- Option A (recommended): Work from BOTTOM to TOP — edit the largest line numbers first so earlier line numbers remain valid.\n' - '- Option B: Re-search after each replacement to get updated line numbers before the next replacement.\n' - '- Option C: Pre-calculate the cumulative offset — each replacement shifts subsequent lines by (new_content_lines - replaced_lines).\n' - 'NEVER call this tool multiple times in parallel on the same file — the concurrent line-number ' - 'shifts will corrupt the file. Always call sequentially.\n', - parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': - 'string', - 'description': - 'The relative path of the file to modify', - }, - 'content': { + 'old_string': { 'type': 'string', - 'description': - 'The new content to insert/replace', - }, - 'start_line': { - 'type': - 'integer', - 'description': - 'Start line number (1-based, inclusive). Use 0 to insert at beginning, ' - '-1 to append at end', - }, - 'end_line': { - 'type': - 'integer', - 'description': - 'End line number (1-based, inclusive). Required unless start_line is 0 or -1', - }, - }, - 'required': ['path', 'content', 'start_line'], - 'additionalProperties': False - }), - Tool( - tool_name='replace_file_contents', - server_name='file_system', - description= - 'Replace exact content in a file without using line numbers. ' - 'You must provide:\n' - '[Required]path: The relative path of modified file.\n' - '[Required]source: The old content to be replaced.\n' - '[Required]target: The new content to replace the `source`.\n' - '[Required]occurrence: Which occurrence to replace (1-based).\n' - 'Do not miss any of these arguments!\n\n' - 'IMPORTANT:\n' - '- `source` must match the file content EXACTLY — including punctuation style ' - '(e.g., Chinese "、" vs English ","), whitespace, line breaks, and Unicode characters.', - parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': - 'string', - 'description': - 'The relative path of the file to modify', - }, - 'source': { - 'type': - 'string', - 'description': - 'The exact content to find and replace. Must match the file content ' - 'EXACTLY including all whitespace, punctuation, and line breaks. ', + 'description': 'The exact string to find and replace.', }, - 'target': { + 'new_string': { 'type': 'string', - 'description': - 'The new content to replace with', + 'description': 'The string to replace it with.', }, - 'occurrence': { - 'type': - 'integer', + 'replace_all': { + 'type': 'boolean', 'description': - 'Which occurrence to replace (1-based). Default is 1 (first occurrence). ' - 'Use -1 to replace all occurrences.', + 'If true, replace all occurrences. Default is false (replace only the first).', }, }, - 'required': ['path', 'source', 'target', 'occurrence'], + 'required': ['path', 'old_string', 'new_string'], 'additionalProperties': False }), + ] } return tools @@ -429,25 +185,69 @@ async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await getattr(self, tool_name)(**tool_args) - async def create_directory(self, path: Optional[str] = None) -> str: - """Create a directory - - Args: - path(`str`): The relative directory path to create, a prefix dir will be automatically concatenated. - - Returns: - or error message. + def _check_staleness(self, real_path: str) -> str | None: + """Return an error string if the file has not been read or has changed since last read. + Returns None if the write is safe to proceed. + Only applies to existing files — new file creation is always allowed. """ - try: - if not path: - path = self.output_dir - else: - path = os.path.join(self.output_dir, path) - os.makedirs(path, exist_ok=True) - return f'Directory: <{path or "root path"}> was created.' - except Exception as e: - return f'Create directory <{path or "root path"}> failed, error: ' + str( - e) + if not os.path.exists(real_path): + return None # new file, no staleness concern + cached = self._read_cache.get(real_path) + if cached is None: + return ( + 'Error: File has not been read yet. ' + 'Read it first before writing to it.' + ) + current_mtime = os.path.getmtime(real_path) + if current_mtime > cached['mtime']: + return ( + 'Error: File has been modified since last read. ' + 'Read it again before writing to it.' + ) + return None + + def _normalize_quotes(self, s: str) -> str: + for curly, straight in self.CURLY_QUOTE_MAP.items(): + s = s.replace(curly, straight) + return s + + def _preserve_quote_style(self, old_string: str, actual_old: str, new_string: str) -> str: + """If old_string matched via quote normalization, apply the same curly quotes to new_string.""" + if old_string == actual_old: + return new_string + has_double = any(c in actual_old for c in '\u201c\u201d') + has_single = any(c in actual_old for c in '\u2018\u2019') + result = new_string + if has_double: + out, chars = [], list(result) + for i, ch in enumerate(chars): + if ch == '"': + prev = chars[i - 1] if i > 0 else None + opening = prev is None or prev in ' \t\n\r([{' + out.append('\u201c' if opening else '\u201d') + else: + out.append(ch) + result = ''.join(out) + if has_single: + out, chars = [], list(result) + for i, ch in enumerate(chars): + if ch == "'": + prev = chars[i - 1] if i > 0 else None + nxt = chars[i + 1] if i < len(chars) - 1 else None + # apostrophe in contraction → right single quote + if prev and nxt and prev.isalpha() and nxt.isalpha(): + out.append('\u2019') + else: + opening = prev is None or prev in ' \t\n\r([{' + out.append('\u2018' if opening else '\u2019') + else: + out.append(ch) + result = ''.join(out) + return result + + @staticmethod + def _strip_trailing_whitespace(s: str) -> str: + return '\n'.join(line.rstrip() for line in s.split('\n')) async def write_file(self, path: str, content: str): """Write content to a file. @@ -463,175 +263,22 @@ async def write_file(self, path: str, content: str): if not os.path.exists(self.output_dir): os.makedirs(self.output_dir, exist_ok=True) original_path = path # Preserve original path for error messages - path = self.get_real_path(path) - if path is None: + real_path = self.get_real_path(path) + if real_path is None: return f'<{original_path}> is out of the valid project path: {self.output_dir}' - dirname = os.path.dirname(path) + err = self._check_staleness(real_path) + if err: + return err + dirname = os.path.dirname(real_path) if dirname: - os.makedirs( - os.path.join(self.output_dir, dirname), exist_ok=True) - with open(os.path.join(self.output_dir, path), 'w') as f: + os.makedirs(dirname, exist_ok=True) + with open(real_path, 'w', encoding='utf-8') as f: f.write(content) + self._read_cache.pop(real_path, None) return f'Save file <{path}> successfully.' except Exception as e: return f'Write file <{path}> failed, error: ' + str(e) - async def replace_file_contents(self, - path: str, - source: str = None, - target: str = None, - occurrence: int = 1): - """Replace exact content in a file without using line numbers. - - This method is safer for parallel operations as it doesn't rely on line numbers - that might change when multiple agents modify the same file concurrently. - - Args: - path(str): The relative file path to modify - source(str): The exact content to find and replace (must match exactly including whitespace) - target(str): The new content to replace with - occurrence(int): Which occurrence to replace (1-based). Use -1 to replace all occurrences. - Default is 1 (first occurrence). - - Returns: - Success or error message. - """ - try: - if not source: - return f'Error: You MUST provide the `source` parameter to be replaced with the `target`, but got {source}.' - if target is None: - return f'Error: You MUST provide the `target` parameter to replace the `source`, but got {target}.' - target_path_real = self.get_real_path(path) - if target_path_real is None: - return f'<{path}> is out of the valid project path: {self.output_dir}' - - # Read file content - if not os.path.exists(target_path_real): - return f'Error: File <{path}> does not exist' - - with open(target_path_real, 'r', encoding='utf-8') as f: - file_content = f.read() - - # Check if source exists - if source not in file_content: - return ( - f'Error: Could not find the exact content to replace in <{path}>. ' - f'Make sure the content matches exactly including all whitespace.' - ) - - # Count occurrences - count = file_content.count(source) - - # Replace based on occurrence parameter - if occurrence == -1: - # Replace all occurrences - updated_content = file_content.replace(source, target) - operation_msg = f'Replaced all {count} occurrence(s)' - elif occurrence < 1: - return f'Error: occurrence must be >= 1 or -1 (for all), got {occurrence}' - elif occurrence > count: - return f'Error: occurrence {occurrence} exceeds total occurrences ({count}) of the content' - else: - # Replace specific occurrence - parts = file_content.split(source, occurrence) - if len(parts) <= occurrence: - return f'Error: Could not find occurrence {occurrence} of the content' - # Rejoin: first (occurrence-1) parts with source, then target, then the rest - updated_content = source.join( - parts[:occurrence]) + target + source.join( - parts[occurrence:]) - operation_msg = f'Replaced occurrence {occurrence} of {count}' - - # Write back to file - with open(target_path_real, 'w', encoding='utf-8') as f: - f.write(updated_content) - - return f'{operation_msg} in file <{path}> successfully.' - - except Exception as e: - return f'Replace content in file <{path}> failed, error: ' + str(e) - - async def replace_file_lines(self, - path: str, - content: str, - start_line: int, - end_line: int = None): - """Replace specific line ranges in a file. - - Args: - path(str): The relative file path to modify, a prefix dir will be automatically concatenated. - content(str): The new content to insert/replace - start_line(int): Start line number (1-based, inclusive). Use 0 to insert at beginning, -1 to append at end - end_line(int): End line number (1-based, inclusive). Optional for start_line=0 or -1 - - Returns: - Success or error message. - """ - try: - target_path_real = self.get_real_path(path) - if target_path_real is None: - return f'<{path}> is out of the valid project path: {self.output_dir}' - file_path = target_path_real - # Read existing file content - if os.path.exists(file_path): - with open(file_path, 'r', encoding='utf-8') as f: - lines = f.readlines() - else: - # If file doesn't exist, create it - dirname = os.path.dirname(file_path) - if dirname: - os.makedirs(dirname, exist_ok=True) - lines = [] - - total_lines = len(lines) - - # Ensure content ends with newline if it doesn't already - if content and not content.endswith('\n'): - content += '\n' - - # Handle special cases - if start_line == 0: - # Insert at beginning - new_lines = [content] + lines - operation = 'Inserted at beginning' - elif start_line == -1: - # Append at end - new_lines = lines + [content] - operation = 'Appended at end' - else: - # Replace range (1-based, inclusive) - if end_line is None: - return 'Error: end_line is required when start_line is not 0 or -1' - - if start_line < 1 or start_line > total_lines + 1: - return f'Error: start_line {start_line} is out of range (file has {total_lines} lines)' - - if end_line < start_line: - return f'Error: end_line {end_line} must be >= start_line {start_line}' - - # Convert to 0-based indices - start_idx = start_line - 1 - # end_line is inclusive (1-based), so we keep lines from end_line onwards (0-based) - end_idx = end_line - # Lines to keep start from index end_line (which is the line after end_line in 1-based) - - new_lines = lines[:start_idx] + [content] + lines[end_idx:] - operation = f'Replaced lines {start_line}-{end_line}' - - # Write back to file - with open(file_path, 'w', encoding='utf-8') as f: - f.writelines(new_lines) - - target = '\n'.join(new_lines).split('\n') - return ( - f'{operation} in file <{path}> completed successfully. The updated file now has {len(target)} lines. ' - 'WARNING: All line numbers after the replaced range may have shifted. ' - 'If you need to make another line-based replacement in this file, keep this in mind.' - ) - - except Exception as e: - return f'Replace lines in file <{path}> failed, error: ' + str(e) - def get_real_path(self, path): # Check if path is absolute or already starts with output_dir if os.path.isabs(path): @@ -655,7 +302,121 @@ def get_real_path(self, path): else: return target_path_real - async def read_abbreviation_file(self, paths: list[str]): + async def read_file(self, + paths: list[str], + offset: int = None, + limit: int = None, + abbreviate: bool = False): + """Read the content of file(s). + + Args: + paths: List of relative file path(s) to read. + offset: Line number to start reading from (1-based). Only effective for a single file. + limit: Number of lines to read. Only effective for a single file. + abbreviate: If True, return an LLM-generated summary instead of raw content. + + Returns: + Dictionary mapping file path(s) to their content or error messages. + """ + if abbreviate: + return await self._read_files_abbreviated(paths) + + results = {} + use_line_range = len(paths) == 1 and (offset is not None + or limit is not None) + + for path in paths: + try: + target_path_real = self.get_real_path(path) + if target_path_real is None: + results[path] = ( + f'Access denied: Reading file <{path}> outside output directory is not allowed. ' + f'Set allow_read_all_files=true in config to enable.') + continue + + ext = os.path.splitext(path)[1].lstrip('.').lower() + + # --- Image files --- + if ext in self.IMAGE_EXTENSIONS: + with open(target_path_real, 'rb') as f: + raw = f.read() + media_type = f'image/{ext}' if ext != 'jpg' else 'image/jpeg' + results[path] = { + 'type': 'image', + 'media_type': media_type, + 'base64': base64.b64encode(raw).decode('ascii'), + } + continue + + # --- Text files --- + file_size = os.path.getsize(target_path_real) + if file_size > self.MAX_READ_BYTES and not use_line_range: + results[path] = ( + f'Error: File <{path}> is too large ({file_size} bytes). ' + f'Use offset and limit to read specific portions.') + continue + + # Dedup: return stub if file unchanged since last read + mtime = os.path.getmtime(target_path_real) + cached = self._read_cache.get(target_path_real) + if (cached + and cached['mtime'] == mtime + and cached['offset'] == offset + and cached['limit'] == limit): + results[path] = { + 'type': 'file_unchanged', + 'message': 'File has not changed since last read.', + } + continue + + with open(target_path_real, 'rb') as f: + raw_bytes = f.read() + + try: + content = raw_bytes.decode('utf-8') + except UnicodeDecodeError: + results[path] = ( + f'Error: File <{path}> appears to be binary. ' + f'Only text and image files are supported.') + continue + + # Normalize line endings + content = content.replace('\r\n', '\n') + lines = content.splitlines(keepends=True) + total_lines = len(lines) + + if use_line_range: + actual_start = max(1, offset) if offset is not None else 1 + actual_end = min(actual_start + limit - 1, total_lines) if limit is not None else total_lines + + if actual_start > total_lines: + results[path] = f'Error: offset {offset} exceeds file length ({total_lines} lines)' + continue + selected = lines[actual_start - 1:actual_end] + start_lineno = actual_start + else: + selected = lines + start_lineno = 1 + + results[path] = ''.join( + f'{start_lineno + i}\t{line}' + for i, line in enumerate(selected) + ) + + # Update dedup cache + self._read_cache[target_path_real] = { + 'mtime': mtime, + 'offset': offset, + 'limit': limit, + } + + except FileNotFoundError: + results[path] = f'Read file <{path}> failed: FileNotFound' + except Exception as e: + results[path] = f'Read file <{path}> failed, error: ' + str(e) + return json.dumps(results, indent=2, ensure_ascii=False) + + async def _read_files_abbreviated(self, paths: list[str]) -> str: results = {} def process_file(path): @@ -666,20 +427,18 @@ def process_file(path): index_file = os.path.join(self.index_dir, path.strip(os.sep)) if os.path.exists(index_file): - with open(index_file, 'r', encoding='utf-8') as f: - return path, f.read() + src_mtime = os.path.getmtime(target_path_real) + idx_mtime = os.path.getmtime(index_file) + if idx_mtime >= src_mtime: + with open(index_file, 'r', encoding='utf-8') as f: + return path, f.read() - # Read file content with open(target_path_real, 'r', encoding='utf-8') as f: content = f.read() - # Use LLM to generate abbreviation messages = [ Message(role='system', content=self.system), - Message( - role='user', - content='The content to be abbreviated:\n\n' - + content), + Message(role='user', content='The content to be abbreviated:\n\n' + content), ] response = self.llm.generate(messages=messages, stream=False) os.makedirs(os.path.dirname(index_file), exist_ok=True) @@ -691,359 +450,130 @@ def process_file(path): except Exception as e: return path, f'Process file <{path}> failed, error: ' + str(e) - # Use thread pool for parallel LLM API calls with ThreadPoolExecutor(max_workers=4) as executor: - future_to_path = { - executor.submit(process_file, p): p - for p in paths - } + future_to_path = {executor.submit(process_file, p): p for p in paths} for future in as_completed(future_to_path): path, result = future.result() results[path] = result return json.dumps(results, indent=2, ensure_ascii=False) - async def read_file(self, - paths: list[str], - start_line: int = 0, - end_line: int = None): - """Read the content of file(s). + async def edit_file(self, + path: str = None, + old_string: str = None, + new_string: str = None, + replace_all: bool = False): + """Edit a file by replacing an exact string with new content. Args: - paths(`list[str]`): List of relative file path(s) to read, a prefix dir will be automatically concatenated. - start_line(int): Start line number (1-based, inclusive). Only effective when paths has exactly one element. - 0 means from the beginning. - end_line(int): End line number (1-based, inclusive). Only effective when paths has exactly one element. - None means to the end. + path: The relative file path to edit. + old_string: The exact string to find and replace. + new_string: The replacement string. + replace_all: If True, replace all occurrences. Default replaces only the first. Returns: - Dictionary mapping file path(s) to their content or error messages. + Success or error message. """ - results = {} - # Line range is only effective when reading a single file - use_line_range = len(paths) == 1 and (start_line > 0 - or end_line is not None) - - for path in paths: - try: - target_path_real = self.get_real_path(path) - if target_path_real is None: - results[path] = ( - f'Access denied: Reading file <{path}> outside output directory is not allowed. ' - f'Set allow_read_all_files=true in config to enable.') - continue - - with open(target_path_real, 'r') as f: - if use_line_range: - # Read specific line range - lines = f.readlines() - total_lines = len(lines) - - # Validate and adjust line numbers (1-based) - actual_start = max(1, - start_line) if start_line > 0 else 1 - actual_end = min( - end_line, total_lines - ) if end_line is not None else total_lines - - if actual_start > total_lines: - results[ - path] = f'Error: start_line {start_line} exceeds file length ({total_lines} lines)' - elif actual_start > actual_end: - results[ - path] = f'Error: start_line {actual_start} > end_line {actual_end}' - else: - # Convert to 0-based index, end_line is inclusive - selected_lines = lines[actual_start - 1:actual_end] - results[path] = ''.join(selected_lines) - else: - # Read entire file - results[path] = f.read() - except FileNotFoundError: - results[path] = f'Read file <{path}> failed: FileNotFound' - except Exception as e: - results[path] = f'Read file <{path}> failed, error: ' + str(e) - return json.dumps(results, indent=2, ensure_ascii=False) - - async def delete_file_or_dir(self, path: str): - """Delete a file or a directory. + try: + if old_string is None: + return 'Error: `old_string` is required.' + if new_string is None: + return 'Error: `new_string` is required.' - Args: - path(str): The file or directory to delete, a prefix dir will be automatically concatenated. + target_path_real = self.get_real_path(path) + if target_path_real is None: + return f'<{path}> is out of the valid project path: {self.output_dir}' - Returns: - boolean - """ - abs_path = os.path.join(self.output_dir, path) - if os.path.exists(abs_path): - try: - if os.path.isfile(abs_path): - os.remove(abs_path) - else: - shutil.rmtree(abs_path) - return f'Path deleted: <{path}>' - except Exception as e: - return f'Delete file <{path}> failed, error: ' + str(e) - else: - return f'Path not found: {path}' + # --- Special case: old_string="" --- + if old_string == '': + if not os.path.exists(target_path_real): + # Create new file + os.makedirs(os.path.dirname(target_path_real), exist_ok=True) + with open(target_path_real, 'w', encoding='utf-8') as f: + f.write(new_string) + return f'Created file <{path}> successfully.' + with open(target_path_real, 'rb') as f: + existing = f.read() + try: + existing_text = existing.decode('utf-8') + except UnicodeDecodeError: + return f'Error: File <{path}> appears to be binary and cannot be edited as text.' + if existing_text.strip() != '': + return ( + 'Error: `old_string` is empty but the file already has content. ' + 'Use `write_file` for a full rewrite, or provide an `old_string` anchor to insert content.' + ) + with open(target_path_real, 'w', encoding='utf-8') as f: + f.write(new_string) + self._read_cache.pop(target_path_real, None) + return f'Edit file <{path}> successfully (filled empty file).' - async def search_file_name(self, file: str = '', parent_path: str = ''): - """Search for files by name using regex pattern matching. + if not os.path.exists(target_path_real): + return f'Error: File <{path}> does not exist.' - Args: - file(str): File name pattern (supports regex). If it's a valid regex pattern, - it will be used for regex matching; otherwise, falls back to substring matching. - parent_path(str): Parent path pattern (supports regex for filtering directories). - Can be a simple path or a regex pattern to match directory paths. + err = self._check_staleness(target_path_real) + if err: + return err - Returns: - String containing all matched file paths - """ - parent_path = parent_path or '' - target_path_real = self.get_real_path(parent_path) - if target_path_real is None: - return f'<{parent_path}> is out of the valid project path: {self.output_dir}' - _parent_path = target_path_real - assert os.path.isdir( - _parent_path - ), f'Parent path <{parent_path}> does not exist, it should be a inner relative path of the project folder.' - - # Try to compile file pattern as regex - file_use_regex = False - file_pattern = None - if file: + with open(target_path_real, 'rb') as f: + raw = f.read() try: - file_pattern = re.compile(file) - file_use_regex = True - except re.error: - file_use_regex = False - - # Try to compile parent_path filter as regex (optional) - path_use_regex = False - path_pattern = None - if parent_path: - try: - path_pattern = re.compile(parent_path) - path_use_regex = True - except re.error: - path_use_regex = False - - all_found_files = [] - for root, dirs, files in os.walk(_parent_path): - if path_use_regex and parent_path: - relative_root = os.path.relpath(root, self.output_dir) - if not path_pattern.search(relative_root): - continue - - for filename in files: - if file: - if file_use_regex: - is_match = file_pattern.search(filename) is not None - else: - is_match = file in filename - else: - is_match = True # No filter, match all files - - if is_match: - file_path = os.path.join(root, filename) - relative_path = os.path.relpath(file_path, self.output_dir) - all_found_files.append(relative_path) + content = raw.decode('utf-8') + except UnicodeDecodeError: + return f'Error: File <{path}> appears to be binary and cannot be edited as text.' + + # Normalize line endings for matching + content = content.replace('\r\n', '\n') + old_string = old_string.replace('\r\n', '\n') + + # --- Fallback 1: exact match --- + actual_old = old_string if old_string in content else None + + # --- Fallback 2: quote normalization --- + if actual_old is None: + norm_old = self._normalize_quotes(old_string) + norm_content = self._normalize_quotes(content) + idx = norm_content.find(norm_old) + if idx != -1: + actual_old = content[idx:idx + len(old_string)] + + if actual_old is None: + return ( + f'Error: `old_string` not found in <{path}>. ' + f'Make sure it matches the file content exactly including whitespace.' + ) - if not all_found_files: - return f'No files found matching pattern <{file or "*"}> in <{parent_path or "root"}>' + count = content.count(actual_old) + if count > 1 and not replace_all: + return ( + f'Error: Found {count} occurrences of `old_string` in <{path}>. ' + f'Add more surrounding context to make it unique, or set replace_all=true.' + ) - all_found_files = '\n'.join(all_found_files) - return f'Found {len(all_found_files.splitlines())} file(s) matching <{file or "*"}>:\n{all_found_files}' + # Apply quote style preservation to new_string + actual_new = self._preserve_quote_style(old_string, actual_old, new_string) - async def search_file_content(self, - content: str = None, - parent_path: str = '.', - file_pattern: str = '*', - context_lines: int = 2): - """Search for content in files using thread pool. - Supports both literal string matching and regex pattern matching automatically. + # --- Fallback 3: smart delete — strip trailing newline when deleting --- + if actual_new == '' and not actual_old.endswith('\n') and actual_old + '\n' in content: + actual_old = actual_old + '\n' - Args: - content(str): The content or regex pattern to search for (auto-detected) - parent_path(str): The relative parent path to search in - file_pattern(str): Wildcard pattern for file names (default: '*' for all files) - context_lines(int): Number of lines before and after the match to include (default: 2) + # Strip trailing whitespace from new_string (skip markdown files) + is_markdown = path.lower().endswith(('.md', '.mdx')) + if not is_markdown: + actual_new = self._strip_trailing_whitespace(actual_new) - Returns: - String containing all matches with file path, line number, and context - """ - if parent_path.startswith('.' + os.sep): - parent_path = parent_path[len('.' + os.sep):] - if parent_path == '.': - parent_path = '' - target_path_real = self.get_real_path(parent_path) - if target_path_real is None: - return f'<{parent_path}> is out of the valid project path: {self.output_dir}' - _parent_path = target_path_real - assert os.path.isdir( - _parent_path - ), f'Parent path <{parent_path}> does not exist, it should be a inner relative path of the project folder.' - - if not content: - return 'Error: content parameter is required for search' - - # Try to compile as regex pattern, fallback to literal string matching - use_regex = False - pattern = None - try: - pattern = re.compile(content) - use_regex = True - except re.error: - # Not a valid regex, will use literal string matching - use_regex = False - - # Collect all files matching the pattern - files_to_search = [] - for root, dirs, files in os.walk(_parent_path): - try: - test_dir = str(Path(root).relative_to(self.output_dir)) - except ValueError: - test_dir = str(root) - if test_dir == '.': - test_dir = '' - if any(excluded_dir in root - for excluded_dir in self.EXCLUDED_DIRS): - continue - for filename in files: - # Skip excluded files - if filename.startswith( - self.EXCLUDED_FILE_PREFIXES) or test_dir.startswith( - self.EXCLUDED_FILE_PREFIXES): - continue - # Match file pattern - if fnmatch.fnmatch(filename, file_pattern): - files_to_search.append(os.path.join(root, filename)) - - if not files_to_search: - return f'No files matching pattern <{file_pattern}> found in <{parent_path or "root"}>' - - # Function to search in a single file - def search_in_file(file_path): - matches = [] - with open(file_path, 'r', encoding='utf-8') as f: - lines = f.readlines() - for line_num, line in enumerate(lines, start=1): - # Check for match: regex or literal string - is_match = False - if use_regex: - is_match = pattern.search(line) is not None - else: - is_match = content in line - - if is_match: - # Calculate context range - start_line = max(0, line_num - context_lines - 1) - end_line = min(len(lines), line_num + context_lines) - - # Extract context lines - context = [] - for i in range(start_line, end_line): - prefix = '> ' if i == line_num - 1 else ' ' - context.append( - f'{prefix}{i + 1:4d} | {lines[i].rstrip()}') - - relative_path = os.path.relpath( - file_path, self.output_dir) - matches.append({ - 'file': relative_path, - 'line': line_num, - 'context': '\n'.join(context) - }) - return matches - - # Use thread pool to search files in parallel - all_matches = [] - with ThreadPoolExecutor(max_workers=8) as executor: - future_to_file = { - executor.submit(search_in_file, f): f - for f in files_to_search - } - for future in as_completed(future_to_file): - matches = future.result() - all_matches.extend(matches) - - if not all_matches: - return f'No matches found for <{content}> in files matching <{file_pattern}>' - - # Format results - result_lines = [ - f'Found {len(all_matches)} match(es) for "{content}":\n' - ] - for match in all_matches: - result_lines.append( - f"File: {match['file']}, Line: {match['line']}") - result_lines.append(match['context']) - result_lines.append('') - - return '\n'.join(result_lines) - - async def list_files(self, path: str = None): - """List all files in a directory. + if replace_all: + updated = content.replace(actual_old, actual_new) + else: + updated = content.replace(actual_old, actual_new, 1) - Args: - path: The relative path to traverse, a prefix dir will be automatically concatenated. + with open(target_path_real, 'w', encoding='utf-8') as f: + f.write(updated) - Returns: - The file names concatenated as a string - """ - file_paths = [] - if not path or path == '.': - path = self.output_dir - else: - path = os.path.join(self.output_dir, path) - if path.startswith('.' + os.sep): - path = path[len('.' + os.sep):] - try: - for root, dirs, files in os.walk(path): - try: - test_dir = str(Path(root).relative_to(self.output_dir)) - except ValueError: - test_dir = str(root) - if test_dir == '.': - test_dir = '' - for file in files: - # Skip excluded directories and files - root_exclude = any(excluded_dir in root - for excluded_dir in self.EXCLUDED_DIRS) - if root_exclude or file.startswith( - self.EXCLUDED_FILE_PREFIXES - ) or test_dir.startswith(self.EXCLUDED_FILE_PREFIXES): - continue - absolute_path = os.path.join(root, file) - relative_path = os.path.relpath(absolute_path, path) - file_paths.append(relative_path) - return '\n'.join(file_paths) or f'No files in path: {path}' - except Exception as e: - return f'List files of <{path or "root path"}> failed, error: ' + str( - e) + # Invalidate dedup cache for this file + self._read_cache.pop(target_path_real, None) - @retry(max_attempts=MAX_CONTINUE_RUNS, delay=1.0) - async def edit_file(self, - path: str = None, - instructions: str = None, - code_edit: str = None): - try: - with open(os.path.join(self.output_dir, path), 'r') as f: - initial_code = f.read() - response = self.edit_client.chat.completions.create( - model=self.edit_file_config.diff_model, - messages=[{ - 'role': - 'user', - 'content': - (f'{instructions}\n' - f'{initial_code}\n' - f'{code_edit}') - }]) - merged_code = response.choices[0].message.content - - with open(os.path.join(self.output_dir, path), 'w') as f: - f.write(merged_code) - return f'Edit file <{path}> successfully.' + replaced = count if replace_all else 1 + return f'Edit file <{path}> successfully ({replaced} occurrence(s) replaced).' except Exception as e: return f'Edit file <{path}> failed, error: ' + str(e) diff --git a/ms_agent/utils/snapshot.py b/ms_agent/utils/snapshot.py new file mode 100644 index 000000000..35d3fe26b --- /dev/null +++ b/ms_agent/utils/snapshot.py @@ -0,0 +1,204 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +Lightweight snapshot utility for ms-agent output directories. + +Uses a dedicated git repo stored at /.ms_agent_snapshots/ +so it never touches or conflicts with the user's own .git directory. + +All git commands are run with GIT_DIR and GIT_WORK_TREE explicitly set, +so the snapshot repo is fully isolated from any surrounding repository. +""" +import os +import json +import subprocess +from typing import Optional + +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +_SNAPSHOT_DIR_NAME = '.ms_agent_snapshots' +_META_FILE = 'snapshot_meta.json' + + +def _git(args: list[str], work_tree: str, git_dir: str, + check: bool = True) -> subprocess.CompletedProcess: + env = os.environ.copy() + env['GIT_DIR'] = git_dir + env['GIT_WORK_TREE'] = work_tree + # Suppress interactive prompts + env['GIT_TERMINAL_PROMPT'] = '0' + return subprocess.run( + ['git'] + args, + env=env, + cwd=work_tree, + capture_output=True, + text=True, + check=check, + ) + + +def _snapshot_git_dir(output_dir: str) -> str: + return os.path.join(output_dir, _SNAPSHOT_DIR_NAME) + + +def _ensure_repo(output_dir: str) -> str: + """Initialize the snapshot repo if it doesn't exist. Returns git_dir.""" + git_dir = _snapshot_git_dir(output_dir) + if not os.path.isdir(git_dir): + os.makedirs(git_dir, exist_ok=True) + # Use non-bare init with explicit GIT_DIR — no --bare so work tree is supported. + # Do NOT pass a path argument; GIT_DIR env var points git at our custom dir. + _git(['init'], work_tree=output_dir, git_dir=git_dir) + _git(['config', 'user.email', 'ms-agent@snapshot'], + work_tree=output_dir, git_dir=git_dir) + _git(['config', 'user.name', 'ms-agent'], + work_tree=output_dir, git_dir=git_dir) + # Exclude the snapshot dir itself from tracking + info_dir = os.path.join(git_dir, 'info') + os.makedirs(info_dir, exist_ok=True) + exclude_file = os.path.join(info_dir, 'exclude') + with open(exclude_file, 'a', encoding='utf-8') as f: + f.write(f'\n{_SNAPSHOT_DIR_NAME}/\n') + return git_dir + + +def _meta_path(output_dir: str) -> str: + return os.path.join(_snapshot_git_dir(output_dir), _META_FILE) + + +def _load_meta(output_dir: str) -> dict: + path = _meta_path(output_dir) + if os.path.exists(path): + try: + with open(path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception: + pass + return {} + + +def _save_meta(output_dir: str, meta: dict) -> None: + path = _meta_path(output_dir) + with open(path, 'w', encoding='utf-8') as f: + json.dump(meta, f, indent=2) + + +def take_snapshot(output_dir: str, message: str, + message_count: int = 0) -> Optional[str]: + """ + Stage all changes in output_dir and create a snapshot commit. + + Args: + output_dir: The directory to snapshot. + message: Commit message (truncated to 120 chars). + message_count: Number of messages in history at snapshot time. + Stored in metadata so rollback can truncate history. + + Returns the short commit hash on success, or None if nothing to commit + or if git is unavailable. + """ + if not output_dir or not os.path.isdir(output_dir): + return None + + try: + git_dir = _ensure_repo(output_dir) + + # Stage everything (excluding .ms_agent_snapshots via info/exclude) + _git(['add', '-A'], work_tree=output_dir, git_dir=git_dir) + + # Check if there's anything to commit + status = _git(['status', '--porcelain'], + work_tree=output_dir, git_dir=git_dir) + if not status.stdout.strip(): + return None # Nothing changed + + # Truncate message to keep commit subject readable + subject = message.strip().replace('\n', ' ')[:120] + result = _git(['commit', '-m', subject], + work_tree=output_dir, git_dir=git_dir) + + commit_hash = None + for line in result.stdout.splitlines(): + if line.startswith('['): + before_bracket = line.split(']')[0] + commit_hash = before_bracket.split()[-1] + break + if commit_hash is None: + commit_hash = 'ok' + + # Persist message_count so rollback can truncate history + meta = _load_meta(output_dir) + meta[commit_hash] = {'message_count': message_count} + _save_meta(output_dir, meta) + + return commit_hash + + except FileNotFoundError: + logger.warning_once( + '[snapshot] git not found — snapshots disabled.') + return None + except subprocess.CalledProcessError as e: + logger.warning(f'[snapshot] git error: {e.stderr.strip()}') + return None + except Exception as e: + logger.warning(f'[snapshot] unexpected error: {e}') + return None + + +def list_snapshots(output_dir: str) -> list[dict]: + """ + Return a list of snapshots as dicts with keys: hash, message, date, message_count. + Most recent first. + """ + git_dir = _snapshot_git_dir(output_dir) + if not os.path.isdir(git_dir): + return [] + try: + result = _git( + ['log', '--pretty=format:%h\t%ai\t%s'], + work_tree=output_dir, + git_dir=git_dir, + check=False, + ) + if result.returncode != 0: + return [] + meta = _load_meta(output_dir) + snapshots = [] + for line in result.stdout.splitlines(): + parts = line.split('\t', 2) + if len(parts) == 3: + h = parts[0] + snapshots.append({ + 'hash': h, + 'date': parts[1], + 'message': parts[2], + 'message_count': meta.get(h, {}).get('message_count', 0), + }) + return snapshots + except Exception: + return [] + + +def restore_snapshot(output_dir: str, + commit_hash: str) -> tuple[bool, int]: + """ + Restore output_dir to the state at commit_hash. + + Returns (success, message_count) where message_count is the number of + messages in history at snapshot time (0 if unknown). + """ + git_dir = _snapshot_git_dir(output_dir) + if not os.path.isdir(git_dir): + logger.warning('[snapshot] No snapshot repo found.') + return False, 0 + try: + _git(['checkout', commit_hash, '--', '.'], + work_tree=output_dir, git_dir=git_dir) + logger.info(f'[snapshot] Restored to {commit_hash}') + meta = _load_meta(output_dir) + message_count = meta.get(commit_hash, {}).get('message_count', 0) + return True, message_count + except subprocess.CalledProcessError as e: + logger.warning(f'[snapshot] restore failed: {e.stderr.strip()}') + return False, 0 diff --git a/tests/utils/test_snapshot_smoke.py b/tests/utils/test_snapshot_smoke.py new file mode 100644 index 000000000..c544135f2 --- /dev/null +++ b/tests/utils/test_snapshot_smoke.py @@ -0,0 +1,331 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +Smoke tests for the snapshot utility and LLMAgent rollback interface. + +No network, no LLM — all tests run fully offline using tempfile directories. +""" +import os +import sys +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from ms_agent.utils.snapshot import ( + list_snapshots, + restore_snapshot, + take_snapshot, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _write(path: str, content: str): + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w', encoding='utf-8') as f: + f.write(content) + + +def _read(path: str) -> str: + with open(path, 'r', encoding='utf-8') as f: + return f.read() + + +# --------------------------------------------------------------------------- +# snapshot utility tests +# --------------------------------------------------------------------------- + +class TestTakeSnapshot(unittest.TestCase): + + def test_empty_dir_returns_none(self): + """Nothing to commit → None.""" + with tempfile.TemporaryDirectory() as td: + result = take_snapshot(td, 'empty') + self.assertIsNone(result) + + def test_new_file_returns_hash(self): + """A new file produces a commit hash.""" + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'hello.txt'), 'hello') + h = take_snapshot(td, 'add hello.txt', message_count=2) + self.assertIsNotNone(h) + self.assertIsInstance(h, str) + self.assertGreater(len(h), 0) + + def test_no_change_after_snapshot_returns_none(self): + """Second snapshot with no changes → None.""" + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'f.txt'), 'v1') + take_snapshot(td, 'first') + result = take_snapshot(td, 'second — no change') + self.assertIsNone(result) + + def test_message_truncated_to_120_chars(self): + """Long messages are truncated in the commit subject.""" + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'f.txt'), 'x') + h = take_snapshot(td, 'A' * 200) + self.assertIsNotNone(h) + snaps = list_snapshots(td) + self.assertEqual(len(snaps[0]['message']), 120) + + def test_snapshot_dir_not_tracked(self): + """The .ms_agent_snapshots dir itself must not appear in git status.""" + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'f.txt'), 'v1') + take_snapshot(td, 'first') + # After snapshot, no pending changes (snapshot dir excluded) + result = take_snapshot(td, 'should be nothing') + self.assertIsNone(result) + + +class TestListSnapshots(unittest.TestCase): + + def test_no_repo_returns_empty(self): + with tempfile.TemporaryDirectory() as td: + self.assertEqual(list_snapshots(td), []) + + def test_returns_snapshots_most_recent_first(self): + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'a.txt'), 'v1') + h1 = take_snapshot(td, 'first snap', message_count=1) + _write(os.path.join(td, 'a.txt'), 'v2') + h2 = take_snapshot(td, 'second snap', message_count=3) + + snaps = list_snapshots(td) + self.assertEqual(len(snaps), 2) + # Most recent first + self.assertEqual(snaps[0]['hash'], h2) + self.assertEqual(snaps[1]['hash'], h1) + + def test_snapshot_fields(self): + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'b.txt'), 'hello') + h = take_snapshot(td, 'my message', message_count=5) + snaps = list_snapshots(td) + self.assertEqual(len(snaps), 1) + s = snaps[0] + self.assertEqual(s['hash'], h) + self.assertEqual(s['message'], 'my message') + self.assertEqual(s['message_count'], 5) + self.assertIn('date', s) + + def test_message_count_default_zero(self): + """message_count defaults to 0 when not passed.""" + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'c.txt'), 'x') + take_snapshot(td, 'no count') + snaps = list_snapshots(td) + self.assertEqual(snaps[0]['message_count'], 0) + + +class TestRestoreSnapshot(unittest.TestCase): + + def test_no_repo_returns_false(self): + with tempfile.TemporaryDirectory() as td: + ok, mc = restore_snapshot(td, 'abc1234') + self.assertFalse(ok) + self.assertEqual(mc, 0) + + def test_restore_reverts_file_content(self): + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'data.txt') + _write(path, 'original') + h1 = take_snapshot(td, 'original state', message_count=2) + + _write(path, 'modified') + take_snapshot(td, 'modified state', message_count=4) + + self.assertEqual(_read(path), 'modified') + + ok, mc = restore_snapshot(td, h1) + self.assertTrue(ok) + self.assertEqual(mc, 2) + self.assertEqual(_read(path), 'original') + + def test_restore_returns_message_count(self): + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'x.txt'), 'a') + h = take_snapshot(td, 'snap', message_count=7) + _write(os.path.join(td, 'x.txt'), 'b') + take_snapshot(td, 'snap2', message_count=9) + + ok, mc = restore_snapshot(td, h) + self.assertTrue(ok) + self.assertEqual(mc, 7) + + def test_restore_deleted_file(self): + """A file deleted after snapshot is recreated on restore.""" + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'will_be_deleted.txt') + _write(path, 'keep me') + h = take_snapshot(td, 'before delete', message_count=1) + + os.remove(path) + take_snapshot(td, 'after delete', message_count=2) + self.assertFalse(os.path.exists(path)) + + ok, _ = restore_snapshot(td, h) + self.assertTrue(ok) + self.assertTrue(os.path.exists(path)) + self.assertEqual(_read(path), 'keep me') + + +# --------------------------------------------------------------------------- +# LLMAgent interface tests +# --------------------------------------------------------------------------- + +class TestLLMAgentSnapshotInterface(unittest.TestCase): + """ + Tests for LLMAgent.list_snapshots() and LLMAgent.rollback(). + The LLM and tool_manager are not initialised — we only exercise the + snapshot-related methods which don't require them. + """ + + def _make_agent(self, output_dir: str): + from omegaconf import OmegaConf + from ms_agent.agent.llm_agent import LLMAgent + cfg = OmegaConf.create({ + 'llm': {'model': 'fake', 'api_key': 'fake', 'model_server': 'openai'}, + 'output_dir': output_dir, + }) + agent = LLMAgent(cfg, tag='smoke-test') + return agent + + def test_list_snapshots_empty(self): + with tempfile.TemporaryDirectory() as td: + agent = self._make_agent(td) + self.assertEqual(agent.list_snapshots(), []) + + def test_list_snapshots_delegates_to_utility(self): + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'f.txt'), 'v1') + take_snapshot(td, 'snap', message_count=3) + + agent = self._make_agent(td) + snaps = agent.list_snapshots() + self.assertEqual(len(snaps), 1) + self.assertEqual(snaps[0]['message_count'], 3) + + def test_rollback_restores_files_and_truncates_history(self): + from omegaconf import OmegaConf + from ms_agent.agent.llm_agent import LLMAgent + from ms_agent.llm.utils import Message + from ms_agent.utils import save_history + + with tempfile.TemporaryDirectory() as td: + agent = self._make_agent(td) + + # Write a file and take a snapshot with message_count=2 + path = os.path.join(td, 'work.txt') + _write(path, 'v1') + h1 = take_snapshot(td, '[pre] first task', message_count=2) + + # Save 4 messages to history + messages = [ + Message(role='system', content='sys'), + Message(role='user', content='task1'), + Message(role='assistant', content='done1'), + Message(role='user', content='task2'), + ] + save_history(td, 'smoke-test', agent.config, messages) + + # Modify the file and take a second snapshot + _write(path, 'v2') + take_snapshot(td, '[pre] second task', message_count=4) + + self.assertEqual(_read(path), 'v2') + + # Rollback to h1 + ok = agent.rollback(h1) + self.assertTrue(ok) + + # File should be restored + self.assertEqual(_read(path), 'v1') + + # History should be truncated to message_count=2 + from ms_agent.utils import read_history + _, saved = read_history(td, 'smoke-test') + self.assertIsNotNone(saved) + self.assertEqual(len(saved), 2) + + def test_rollback_clears_read_cache(self): + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'f.txt'), 'v1') + h = take_snapshot(td, 'snap', message_count=1) + _write(os.path.join(td, 'f.txt'), 'v2') + take_snapshot(td, 'snap2', message_count=2) + + agent = self._make_agent(td) + + # Attach a fake tool with _read_cache + fake_tool = MagicMock() + fake_tool._read_cache = {'some/path': {'mtime': 123}} + fake_manager = MagicMock() + fake_manager.tools = {'fs': fake_tool} + agent.tool_manager = fake_manager + + agent.rollback(h) + self.assertEqual(fake_tool._read_cache, {}) + + def test_rollback_invalid_hash_returns_false(self): + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'f.txt'), 'v1') + take_snapshot(td, 'snap') + + agent = self._make_agent(td) + ok = agent.rollback('deadbeef') + self.assertFalse(ok) + + def test_on_task_begin_auto_snapshots(self): + """on_task_begin should take a snapshot automatically — no explicit call needed.""" + import asyncio + from ms_agent.llm.utils import Message + + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'work.txt'), 'v1') + agent = self._make_agent(td) + + messages = [ + Message(role='system', content='sys'), + Message(role='user', content='do something useful'), + ] + + # No explicit take_snapshot call — on_task_begin should do it + asyncio.run(agent.on_task_begin(messages)) + + snaps = list_snapshots(td) + self.assertEqual(len(snaps), 1) + self.assertIn('do something useful', snaps[0]['message']) + self.assertEqual(snaps[0]['message_count'], len(messages)) + + def test_on_task_begin_no_snapshot_when_disabled(self): + """enable_snapshots=False suppresses automatic snapshot.""" + import asyncio + from omegaconf import OmegaConf + from ms_agent.agent.llm_agent import LLMAgent + from ms_agent.llm.utils import Message + + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'work.txt'), 'v1') + cfg = OmegaConf.create({ + 'llm': {'model': 'fake', 'api_key': 'fake', 'model_server': 'openai'}, + 'output_dir': td, + 'enable_snapshots': False, + }) + agent = LLMAgent(cfg, tag='smoke-test') + messages = [ + Message(role='system', content='sys'), + Message(role='user', content='task'), + ] + asyncio.run(agent.on_task_begin(messages)) + self.assertEqual(list_snapshots(td), []) + + +if __name__ == '__main__': + unittest.main() From f08a15e1a69001b665ae912e4f82a70c248126d5 Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Thu, 2 Apr 2026 19:00:09 +0800 Subject: [PATCH 19/40] support for response api; optimize log output formatting --- ms_agent/agent/llm_agent.py | 10 + ms_agent/llm/openai_llm.py | 425 ++++++++++++++++++++- ms_agent/llm/utils.py | 8 +- projects/deep_research/v2/reporter.yaml | 2 + projects/deep_research/v2/researcher.yaml | 2 + projects/deep_research/v2/run_benchmark.sh | 8 +- requirements/research.txt | 1 + 7 files changed, 449 insertions(+), 7 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 9143cdd3c..6bf47e626 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -869,6 +869,16 @@ async def step( yield messages if _printed_reasoning_header and not _printed_reasoning_footer: self._write_thinking_footer() + + # Handle reasoning summaries that arrive after content + if self.show_reasoning and _response_message is not None: + final_reasoning = getattr( + _response_message, 'reasoning_content', '') or '' + if final_reasoning and not _printed_reasoning_header: + self._write_thinking_header() + self._write_reasoning(final_reasoning, dim=True) + self._write_thinking_footer() + sys.stdout.write('\n') else: _response_message = self.llm.generate(messages, tools=tools) diff --git a/ms_agent/llm/openai_llm.py b/ms_agent/llm/openai_llm.py index dadc1bf1c..e9e1679e3 100644 --- a/ms_agent/llm/openai_llm.py +++ b/ms_agent/llm/openai_llm.py @@ -3,6 +3,9 @@ from copy import deepcopy from typing import Any, Dict, Generator, Iterable, List, Optional +import json + +import httpx from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool, ToolCall from ms_agent.utils import (MAX_CONTINUE_RUNS, assert_package_exist, @@ -15,12 +18,35 @@ logger = get_logger() +class _DashScopeResponsesTransport(httpx.HTTPTransport): + """Rewrite /v1/responses -> /v1/chat/completions for DashScope proxy. + + DashScope serves the OpenAI Responses protocol on the chat/completions + path rather than the standard /v1/responses path. This transport + transparently rewrites the SDK's outgoing request so that + ``client.responses.create()`` hits the correct DashScope endpoint. + """ + + def handle_request(self, request): + if b'/v1/responses' in request.url.raw_path: + new_path = request.url.raw_path.replace( + b'/v1/responses', b'/v1/chat/completions') + request.url = request.url.copy_with(raw_path=new_path) + return super().handle_request(request) + + class OpenAI(LLM): """Base Class for OpenAI SDK LLMs. This class provides the base implementation for interacting with OpenAI-compatible models, including support for chat completions, streaming responses, and continue generates. + Supports the OpenAI Responses API (``client.responses.create``) when + ``generation_config.use_responses_api`` is ``true``. In this mode, + reasoning summaries are extracted and surfaced through + ``Message.reasoning_content`` so the agent's existing thinking display + works without change. + Args: config (`DictConfig`): The configuration object containing model and generation settings. base_url (`Optional[str]`): Custom base URL for the API endpoint. Defaults to None. @@ -58,6 +84,30 @@ def __init__( self.args: Dict = OmegaConf.to_container( getattr(config, 'generation_config', DictConfig({}))) + # Responses API support + self._use_responses_api = bool(self.args.get('use_responses_api', False)) + self._responses_client = None + self._responses_state_mode = str( + self.args.get('responses_state_mode', 'stateless')).lower() + if self._responses_state_mode == 'stateful': + self._responses_state_mode = 'previous_response_id' + + if self._use_responses_api: + self._is_dashscope = bool( + base_url and 'dashscope' in base_url.lower()) + if self._is_dashscope: + http_client = httpx.Client( + transport=_DashScopeResponsesTransport(), + timeout=httpx.Timeout(300.0, connect=60.0), + ) + self._responses_client = openai.OpenAI( + api_key=api_key, + base_url=base_url, + http_client=http_client, + ) + else: + self._responses_client = self.client + # Prefix cache configuration # - force_prefix_cache: enable structured content with cache_control for explicit caching # - prefix_cache_roles: which messages to cache (only these are converted to structured format) @@ -176,17 +226,21 @@ def generate(self, Union[Message, Generator[Message, None, None]]: Either a single Message object (non-streaming) or a generator yielding Message chunks (streaming). """ - parameters = inspect.signature( - self.client.chat.completions.create).parameters args = self.args.copy() args.update(kwargs) stream = args.get('stream', False) + if self._use_responses_api: + if stream: + return self._responses_stream_generate(messages, tools, **args) + else: + return self._responses_generate(messages, tools, **args) + + parameters = inspect.signature( + self.client.chat.completions.create).parameters args = {key: value for key, value in args.items() if key in parameters} completion = self._call_llm(messages, self.format_tools(tools), **args) - # Complex task may produce long response - # Call continue_generate to keep generating if the finish_reason is `length` max_continue_runs = max_continue_runs or self.max_continue_runs if stream: return self._stream_continue_generate(messages, completion, tools, @@ -540,6 +594,369 @@ def _continue_generate(self, else: return new_message + def _build_responses_input( + self, messages: List[Message]) -> List[Dict[str, Any]]: + """Convert internal Message list to the ``input`` format expected by + the Responses API. + + Key differences from chat completions format: + - ``system`` role becomes ``developer``. + - ``assistant`` messages with ``tool_calls`` emit the text content + as a normal assistant item, followed by one ``function_call`` + item per tool call. + - ``tool`` role messages become ``function_call_output`` items + (keyed by ``call_id``, not ``role``). + """ + items: List[Dict[str, Any]] = [] + for msg in messages: + if msg.role == 'system': + items.append({ + 'role': 'developer', + 'content': msg.content, + }) + elif msg.role == 'assistant': + if self._responses_state_mode != 'previous_response_id': + # Stateless mode needs explicit passback of opaque reasoning + # items returned by the previous response. + for raw_item in getattr( + msg, '_responses_output_items', []): + items.append(raw_item) + if msg.content and not self._is_responses_tool_placeholder(msg): + items.append({ + 'role': 'assistant', + 'content': msg.content, + }) + if msg.tool_calls: + for tc in msg.tool_calls: + arguments = tc.get('arguments', '{}') + if not isinstance(arguments, str): + arguments = json.dumps( + arguments, ensure_ascii=False) + items.append({ + 'type': 'function_call', + 'call_id': tc.get('id', ''), + 'name': tc.get('tool_name', ''), + 'arguments': arguments, + }) + elif msg.role == 'tool': + content = msg.content + if not isinstance(content, str): + content = json.dumps(content, ensure_ascii=False) + items.append({ + 'type': 'function_call_output', + 'call_id': msg.tool_call_id or '', + 'output': content, + }) + else: + items.append({ + 'role': msg.role, + 'content': msg.content, + }) + return items + + @staticmethod + def _is_responses_tool_placeholder(message: Message) -> bool: + """Return True for framework-generated assistant placeholder text.""" + return bool(message.tool_calls) and message.content == 'Let me do a tool calling.' + + def _prepare_responses_request( + self, + messages: List[Message], + args: Dict[str, Any]) -> tuple[List[Message], Dict[str, Any]]: + """Prepare message slice and request args for Responses API calls.""" + request_args = dict(args) + + if self._responses_state_mode != 'previous_response_id': + return messages, request_args + + if request_args.get('previous_response_id'): + return messages, request_args + + for idx in range(len(messages) - 1, -1, -1): + msg = messages[idx] + if msg.role == 'assistant' and msg.id: + request_args['previous_response_id'] = msg.id + return messages[idx + 1:], request_args + + return messages, request_args + + def _build_responses_tools( + self, + tools: Optional[List[Tool]]) -> Optional[List[Dict[str, Any]]]: + """Convert internal Tool list to Responses API function tool format.""" + if not tools: + return None + return [{ + 'type': 'function', + 'name': t['tool_name'], + 'description': t.get('description', ''), + 'parameters': t.get('parameters', {}), + } for t in tools] + + def _build_responses_kwargs(self, args: Dict) -> Dict: + """Filter and reshape generation args for ``responses.create``.""" + kwargs: Dict[str, Any] = {} + + reasoning_effort = args.get('reasoning_effort') + reasoning_summary = args.get('reasoning_summary', 'auto') + if reasoning_effort or reasoning_summary: + reasoning: Dict[str, Any] = {} + if reasoning_effort: + reasoning['effort'] = reasoning_effort + if reasoning_summary: + reasoning['summary'] = reasoning_summary + kwargs['reasoning'] = reasoning + + if args.get('temperature') is not None: + kwargs['temperature'] = args['temperature'] + if args.get('top_p') is not None: + kwargs['top_p'] = args['top_p'] + if args.get('max_output_tokens') is not None: + kwargs['max_output_tokens'] = args['max_output_tokens'] + if args.get('stream_options') is not None: + kwargs['stream_options'] = args['stream_options'] + if args.get('previous_response_id') is not None: + kwargs['previous_response_id'] = args['previous_response_id'] + + include = args.get('include') + if include is not None: + kwargs['include'] = include + elif self._responses_state_mode != 'previous_response_id': + # Stateless multi-turn mode needs encrypted reasoning so opaque + # reasoning items can be passed back in subsequent requests. + kwargs['include'] = ['reasoning.encrypted_content'] + + return kwargs + + @staticmethod + def _extract_reasoning_summaries_from_response(response) -> str: + """Pull reasoning summary text from a completed Responses API object.""" + parts: List[str] = [] + for item in getattr(response, 'output', []) or []: + if getattr(item, 'type', None) == 'reasoning': + for summary in getattr(item, 'summary', []) or []: + text = getattr(summary, 'text', None) + if text: + parts.append(text) + return '\n'.join(parts) + + @staticmethod + def _extract_tool_calls_from_response(response) -> Optional[List[ToolCall]]: + """Extract tool calls from a completed Responses API object.""" + tool_calls: List[ToolCall] = [] + for item in getattr(response, 'output', []) or []: + if getattr(item, 'type', None) == 'function_call': + arguments = getattr(item, 'arguments', '{}') + if not isinstance(arguments, str): + arguments = json.dumps(arguments, ensure_ascii=False) + tool_calls.append( + ToolCall( + id=getattr(item, 'call_id', '') or getattr(item, 'id', ''), + index=len(tool_calls), + type='function', + tool_name=getattr(item, 'name', ''), + arguments=arguments, + )) + return tool_calls if tool_calls else None + + @staticmethod + def _extract_usage_from_response(response) -> tuple: + """Return (prompt_tokens, completion_tokens) from a Responses API object.""" + usage = getattr(response, 'usage', None) + if usage is None: + return 0, 0 + return ( + getattr(usage, 'input_tokens', 0) or 0, + getattr(usage, 'output_tokens', 0) or 0, + ) + + @staticmethod + def _to_jsonable(value: Any) -> Any: + """Convert SDK objects nested in Responses items into JSON-safe data.""" + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, list): + return [OpenAI._to_jsonable(item) for item in value] + if isinstance(value, dict): + return { + key: OpenAI._to_jsonable(item) for key, item in value.items() + } + if hasattr(value, 'model_dump'): + return OpenAI._to_jsonable(value.model_dump()) + if hasattr(value, 'to_dict'): + return OpenAI._to_jsonable(value.to_dict()) + return value + + def _collect_passback_items(self, response) -> List[Dict[str, Any]]: + """Collect output items that must be passed back in multi-turn calls. + + Per OpenAI docs, reasoning items returned alongside tool calls must be + included in the next request for reasoning models. + """ + items: List[Dict[str, Any]] = [] + for item in getattr(response, 'output', []) or []: + item_type = getattr(item, 'type', None) + if item_type == 'reasoning': + passback_item: Dict[str, Any] = { + 'type': 'reasoning', + 'summary': self._to_jsonable( + getattr(item, 'summary', []) or []), + } + encrypted_content = getattr(item, 'encrypted_content', None) + if encrypted_content: + passback_item['encrypted_content'] = encrypted_content + if not getattr(self, '_is_dashscope', False): + item_id = getattr(item, 'id', None) + if item_id: + passback_item['id'] = item_id + items.append(passback_item) + return items + + def _responses_generate( + self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + **args) -> Message: + """Non-streaming Responses API call.""" + request_messages, request_args = self._prepare_responses_request( + messages, args) + input_items = self._build_responses_input(request_messages) + resp_tools = self._build_responses_tools(tools) + kwargs = self._build_responses_kwargs(request_args) + if resp_tools: + kwargs['tools'] = resp_tools + + response = self._responses_client.responses.create( + model=self.model, + input=input_items, + **kwargs, + ) + text = getattr(response, 'output_text', '') or '' + reasoning = self._extract_reasoning_summaries_from_response(response) + resp_tool_calls = self._extract_tool_calls_from_response(response) + prompt_tokens, completion_tokens = self._extract_usage_from_response( + response) + passback = self._collect_passback_items(response) + + return Message( + role='assistant', + content=text, + reasoning_content=reasoning, + tool_calls=resp_tool_calls, + _responses_output_items=passback, + id=getattr(response, 'id', ''), + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + @staticmethod + def _extract_reasoning_from_item(item) -> str: + """Extract reasoning summary text from a single output item.""" + parts: List[str] = [] + for summary in getattr(item, 'summary', []) or []: + text = getattr(summary, 'text', None) + if text: + parts.append(text) + return '\n'.join(parts) + + def _responses_stream_generate( + self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + **args) -> Generator[Message, None, None]: + """Streaming Responses API call. + + Yields incremental ``Message`` objects. Reasoning summaries are + extracted from ``response.output_item.done`` events (type=reasoning) + which arrive *before* the first text delta, so the agent layer can + display the thinking header before content begins streaming. + """ + request_messages, request_args = self._prepare_responses_request( + messages, args) + input_items = self._build_responses_input(request_messages) + resp_tools = self._build_responses_tools(tools) + kwargs = self._build_responses_kwargs(request_args) + if resp_tools: + kwargs['tools'] = resp_tools + + stream = self._responses_client.responses.create( + model=self.model, + input=input_items, + stream=True, + **kwargs, + ) + + current_message = Message( + role='assistant', + content='', + reasoning_content='', + ) + streamed_text = '' + final_response = None + response_error_msg = '' + reasoning_parts: List[str] = [] + + for event in stream: + event_type = getattr(event, 'type', '') + + if event_type == 'response.output_item.done': + item = getattr(event, 'item', None) + if item and getattr(item, 'type', None) == 'reasoning': + summary_text = self._extract_reasoning_from_item(item) + if summary_text: + reasoning_parts.append(summary_text) + current_message.reasoning_content = '\n'.join( + reasoning_parts) + yield current_message + + elif event_type == 'response.output_text.delta': + delta = getattr(event, 'delta', '') + if delta: + streamed_text += delta + current_message.content = streamed_text + yield current_message + + elif event_type == 'response.output_text.done': + done_text = getattr(event, 'text', '') + if done_text and not streamed_text: + streamed_text = done_text + current_message.content = streamed_text + yield current_message + + elif event_type == 'response.completed': + final_response = getattr(event, 'response', None) + + elif event_type == 'response.failed': + failed_response = getattr(event, 'response', None) + failed_error = getattr(failed_response, 'error', None) + response_error_msg = getattr( + failed_error, 'message', '') or str(failed_error) + + if final_response: + if not reasoning_parts: + reasoning = self._extract_reasoning_summaries_from_response( + final_response) + if reasoning: + current_message.reasoning_content = reasoning + resp_tool_calls = self._extract_tool_calls_from_response( + final_response) + if resp_tool_calls: + current_message.tool_calls = resp_tool_calls + passback = self._collect_passback_items(final_response) + if passback: + current_message._responses_output_items = passback + prompt_tokens, completion_tokens = self._extract_usage_from_response( + final_response) + current_message.prompt_tokens = prompt_tokens + current_message.completion_tokens = completion_tokens + current_message.id = getattr(final_response, 'id', '') + yield current_message + elif response_error_msg: + logger.error( + f'Responses API failed: {response_error_msg}') + raise RuntimeError( + f'Responses API call failed: {response_error_msg}') + def _format_input_message(self, messages: List[Message]) -> List[Dict[str, Any]]: """Converts a list of Message objects into the format expected by the OpenAI API. diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 6a336ca6e..c3af4ad35 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -40,6 +40,11 @@ class Message: # needed for output reasoning_content: str = '' + # Opaque output items from the Responses API that must be passed back + # in multi-turn tool-calling conversations (e.g. reasoning items). + _responses_output_items: List[Dict[str, Any]] = field( + default_factory=list) + # request id id: str = '' @@ -82,7 +87,8 @@ def to_dict_clean(self): } } required = ['content', 'role'] - rm = ['completion_tokens', 'prompt_tokens', 'api_calls'] + rm = ['completion_tokens', 'prompt_tokens', 'api_calls', + '_responses_output_items'] return { key: value for key, value in raw_dict.items() diff --git a/projects/deep_research/v2/reporter.yaml b/projects/deep_research/v2/reporter.yaml index a101f3edb..7bb79a571 100644 --- a/projects/deep_research/v2/reporter.yaml +++ b/projects/deep_research/v2/reporter.yaml @@ -75,6 +75,8 @@ self_reflection: quality_check: enabled: true model: qwen3.5-flash + openai_base_url: + openai_api_key: tool_call_timeout: 300 diff --git a/projects/deep_research/v2/researcher.yaml b/projects/deep_research/v2/researcher.yaml index 95625df00..9ba827049 100644 --- a/projects/deep_research/v2/researcher.yaml +++ b/projects/deep_research/v2/researcher.yaml @@ -158,6 +158,8 @@ self_reflection: quality_check: enabled: true model: qwen3.5-flash + openai_base_url: + openai_api_key: handler: time_handler diff --git a/projects/deep_research/v2/run_benchmark.sh b/projects/deep_research/v2/run_benchmark.sh index b2fa35c79..c4bd326fb 100755 --- a/projects/deep_research/v2/run_benchmark.sh +++ b/projects/deep_research/v2/run_benchmark.sh @@ -31,6 +31,10 @@ else exit 1 fi +# When stdout is redirected (e.g., nohup > file), Python is block-buffered by default. +# Force unbuffered output so progress lines like "[xx] OK" show up in logs promptly. +export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}" + # Use caffeinate on macOS when available; otherwise run normally. RUN_PREFIX=() if command -v caffeinate >/dev/null 2>&1; then @@ -98,7 +102,7 @@ if [ -z "$DR_BENCH_ROOT" ]; then echo "" # Run the benchmark - PYTHONPATH=. "$PYTHON_BIN" ms_agent/cli/cli.py run \ + PYTHONPATH=. "$PYTHON_BIN" -u ms_agent/cli/cli.py run \ --config projects/deep_research/v2/researcher.yaml \ --query "$QUERY" \ --trust_remote_code true \ @@ -164,7 +168,7 @@ else echo "" # Run the full benchmark - PYTHONPATH=. "${RUN_PREFIX[@]}" "$PYTHON_BIN" projects/deep_research/v2/eval/dr_bench_runner.py \ + PYTHONPATH=. "${RUN_PREFIX[@]}" "$PYTHON_BIN" -u projects/deep_research/v2/eval/dr_bench_runner.py \ --query_file "$QUERY_FILE" \ --output_jsonl "$OUTPUT_JSONL" \ --model_name "$MODEL_NAME" \ diff --git a/requirements/research.txt b/requirements/research.txt index 67ce2d3fd..693301a94 100644 --- a/requirements/research.txt +++ b/requirements/research.txt @@ -14,3 +14,4 @@ Pillow python-dotenv requests rich +socksio From 6bfb262963fe9d1f72eb5c3d4d7a8350de9cc3e6 Mon Sep 17 00:00:00 2001 From: suluyan Date: Thu, 2 Apr 2026 20:59:41 +0800 Subject: [PATCH 20/40] feat: merge SplitTask into AgentTool, add TaskManager infrastructure - AgentTool now handles dynamic mode (split_to_sub_task) internally, replacing the standalone SplitTask class. Backward compat preserved: configs with tools.split_task auto-register the built-in dynamic spec. - Fix execution_mode missing from split_to_sub_task schema (was silently ignored before; now exposed as enum field with sequential/parallel). - Increase max_subtask_output_chars default from 2048 to 8192. - Add disallowed_tools to _AgentToolSpec to prevent recursive tool calls in sub-agents. - Add sub-agent transcript persistence: in-process runs write messages to output_dir/subagents/.jsonl for debugging. - Add TaskManager (ms_agent/utils/task_manager.py): agent-level registry for background tasks with notification queue. LLMAgent initializes it in run_loop, wires it into AgentTool instances, and drains notifications at the top of each while-loop iteration. Supports future BashTool background mode via the same interface. - diversity.py: replace SplitTask dependency with inline _run_tasks_sequential helper using LLMAgent directly. --- ms_agent/agent/llm_agent.py | 13 ++ ms_agent/memory/diversity.py | 55 ++++++--- ms_agent/tools/__init__.py | 1 - ms_agent/tools/agent_tool.py | 214 ++++++++++++++++++++++++++++++--- ms_agent/tools/split_task.py | 140 --------------------- ms_agent/tools/tool_manager.py | 8 +- ms_agent/utils/task_manager.py | 132 ++++++++++++++++++++ 7 files changed, 383 insertions(+), 180 deletions(-) delete mode 100644 ms_agent/tools/split_task.py create mode 100644 ms_agent/utils/task_manager.py diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 12bb0dcca..066efe9b8 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -25,6 +25,7 @@ from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger from ms_agent.utils.snapshot import take_snapshot +from ms_agent.utils.task_manager import TaskManager from omegaconf import DictConfig, OmegaConf from ..config.config import Config, ConfigLifecycleHandler @@ -111,6 +112,7 @@ def __init__( self.knowledge_search: Optional[SirschmunkSearch] = None self.llm: Optional[LLM] = None self.runtime: Optional[Runtime] = None + self.task_manager: Optional[TaskManager] = None self.max_chat_round: int = 0 self.load_cache = kwargs.get('load_cache', False) self.config.load_cache = self.load_cache @@ -580,6 +582,8 @@ async def prepare_tools(self): async def cleanup_tools(self): """Cleanup resources used by the tool manager.""" + if self.task_manager is not None: + self.task_manager.kill_all() await self.tool_manager.cleanup() @property @@ -1144,6 +1148,11 @@ async def run_loop(self, messages: Union[List[Message], str], await self.prepare_knowledge_search() self.runtime.tag = self.tag + self.task_manager = TaskManager() + for tool in self.tool_manager.extra_tools: + if hasattr(tool, 'set_task_manager'): + tool.set_task_manager(self.task_manager) + if messages is None: messages = self.query @@ -1172,6 +1181,10 @@ async def run_loop(self, messages: Union[List[Message], str], self.log_output('[' + message.role + ']:') self.log_output(message.content) while not self.runtime.should_stop: + if self.task_manager is not None: + notifications = self.task_manager.drain_notifications() + if notifications: + messages.append(Message(role='user', content='\n'.join(notifications))) async for messages in self.step(messages): yield messages self.runtime.round += 1 diff --git a/ms_agent/memory/diversity.py b/ms_agent/memory/diversity.py index 0da0be038..775e80da1 100644 --- a/ms_agent/memory/diversity.py +++ b/ms_agent/memory/diversity.py @@ -1,3 +1,4 @@ +import asyncio import re from copy import deepcopy from typing import List @@ -6,7 +7,6 @@ from omegaconf import DictConfig from ..llm import LLM, Message -from ..tools import SplitTask from .base import Memory logger = get_logger() @@ -58,7 +58,6 @@ class Diversity(Memory): def __init__(self, config): super().__init__(config) self.llm = None - self.split_task = None self.num_split = 5 self.memory_called = False _config = deepcopy(config) @@ -67,9 +66,43 @@ def __init__(self, config): delattr(_config, 'tools') _config.generation_config.temperature = 1.0 self.llm = LLM.from_config(_config) - self.split_task = SplitTask(_config, tag_prefix='diversity-') + self._sub_config = _config self.num_split = getattr(config, 'num_split', self.num_split) + async def _run_tasks_sequential(self, tasks: list) -> str: + """Run a list of {system, query} tasks sequentially using LLMAgent.""" + from ms_agent.agent import LLMAgent + res = [] + for i, task in enumerate(tasks): + system = task.get('system', '') + query = task.get('query', '') + cfg = deepcopy(self._sub_config) + if not hasattr(cfg, 'prompt'): + cfg.prompt = DictConfig({}) + cfg.prompt.system = system + agent = LLMAgent( + config=cfg, + trust_remote_code=getattr(cfg, 'trust_remote_code', False), + tag=f'{getattr(cfg, "tag", "agent")}-diversity-{i}', + load_cache=False, + ) + try: + messages = await agent.run(query) + if isinstance(messages, list) and messages: + content = messages[-1].content or '' + else: + content = str(messages) + except Exception as e: + content = f'SubTask{i} failed with error: {e}' + res.append(content) + + formatted = '' + for i, content in enumerate(res): + if len(content) > 2048: + content = content[:2048] + formatted += f'SplitTask{i}:{content}\n' + return formatted + async def run(self, messages: List[Message]): if self.memory_called: return messages @@ -93,13 +126,7 @@ async def run(self, messages: List[Message]): } arguments.append(inputs) - arguments = { - 'tasks': arguments, - 'execution_mode': 'sequential', - } - - results = await self.split_task.call_tool( - '', tool_name='', tool_args=arguments) + results = await self._run_tasks_sequential(arguments) pattern = r'(.*?)' all_keywords = [] for keywords in re.findall(pattern, results, re.DOTALL): @@ -118,13 +145,7 @@ async def run(self, messages: List[Message]): } arguments.append(inputs) - arguments = { - 'tasks': arguments, - 'execution_mode': 'sequential', - } - - results = await self.split_task.call_tool( - '', tool_name='', tool_args=arguments) + results = await self._run_tasks_sequential(arguments) pattern = r'(.*?)' all_keywords = [] for keywords in re.findall(pattern, results, re.DOTALL): diff --git a/ms_agent/tools/__init__.py b/ms_agent/tools/__init__.py index 58b26ecee..7e185fdd0 100644 --- a/ms_agent/tools/__init__.py +++ b/ms_agent/tools/__init__.py @@ -4,6 +4,5 @@ from .code_server import LSPCodeServer from .filesystem_tool import FileSystemTool from .mcp_client import MCPClient -from .split_task import SplitTask from .todolist_tool import TodoListTool from .tool_manager import ToolManager diff --git a/ms_agent/tools/agent_tool.py b/ms_agent/tools/agent_tool.py index 5c86e18c5..4c6bf6d02 100644 --- a/ms_agent/tools/agent_tool.py +++ b/ms_agent/tools/agent_tool.py @@ -52,6 +52,9 @@ class _AgentToolSpec: env: Optional[Dict[str, str]] run_in_thread: bool run_in_process: bool + dynamic: bool = False + disallowed_tools: Optional[List[str]] = None + max_subtask_output_chars: int = 8192 _MESSAGE_FIELDS = set(Message.__dataclass_fields__.keys()) @@ -90,6 +93,16 @@ def _build_sub_agent(spec: _AgentToolSpec, default_trust_remote_code: bool): generation_cfg = getattr(agent.config, 'generation_config', DictConfig({})) agent.config.generation_config = generation_cfg + + if spec.disallowed_tools: + tools_cfg = getattr(agent.config, 'tools', None) + if tools_cfg is not None: + tools_dict = OmegaConf.to_container(tools_cfg, resolve=True) + if isinstance(tools_dict, dict): + for key in spec.disallowed_tools: + tools_dict.pop(key, None) + agent.config.tools = OmegaConf.create(tools_dict) + return agent @@ -189,6 +202,7 @@ def __init__(self, config: DictConfig, **kwargs): self._chunk_cb: Optional[Callable[..., Any]] = None self._active_processes: Dict[str, mp.Process] = {} self._active_processes_lock = threading.Lock() + self._task_manager = None self._load_specs() self._init_thread_pool_config() @@ -207,10 +221,67 @@ def _init_thread_pool_config(self): def enabled(self) -> bool: return bool(self._specs) + _SPLIT_TASK_DESCRIPTION = ( + 'Split complex task into sub tasks and start them, for example, ' + 'split a website generation task into sub tasks, ' + 'you plan the framework, include code files and classes and functions, and give the detail ' + 'information to the system and query field of the subtask, then ' + 'let each subtask to write a single file') + + _SPLIT_TASK_PARAMETERS = { + 'type': 'object', + 'properties': { + 'tasks': { + 'type': 'array', + 'description': ( + 'MANDATORY: Each element is a dict, which must contains two fields: ' + '`system`(str) and `query`(str) to start one sub task.'), + }, + 'execution_mode': { + 'type': 'string', + 'enum': ['sequential', 'parallel'], + 'description': 'Whether to run sub-tasks sequentially or in parallel.', + }, + }, + 'required': ['tasks'], + 'additionalProperties': False, + } + def _load_specs(self): tools_cfg = getattr(self.config, 'tools', DictConfig({})) + + # Backward compat: if config.tools.split_task exists, register a built-in dynamic spec + if hasattr(tools_cfg, 'split_task'): + split_cfg = tools_cfg.split_task + tag_prefix = getattr(split_cfg, 'tag_prefix', 'worker-') + run_in_thread = bool(getattr(split_cfg, 'run_in_thread', True)) + run_in_process = bool(getattr(split_cfg, 'run_in_process', run_in_thread)) + builtin_spec = _AgentToolSpec( + tool_name='split_to_sub_task', + description=self._SPLIT_TASK_DESCRIPTION, + parameters=self._SPLIT_TASK_PARAMETERS, + config_path=None, + inline_config=None, + server_name='split_task', + tag_prefix=tag_prefix, + input_mode='text', + request_field='request', + input_template=None, + output_mode='final_message', + max_output_chars=100000, + trust_remote_code=None, + env=None, + run_in_thread=run_in_thread, + run_in_process=run_in_process, + dynamic=True, + max_subtask_output_chars=8192, + ) + self._specs['split_to_sub_task'] = builtin_spec + agent_tools_cfg = getattr(tools_cfg, 'agent_tools', None) if agent_tools_cfg is None: + if self._specs: + self._build_server_index() return if isinstance(agent_tools_cfg, DictConfig) and hasattr( @@ -233,6 +304,8 @@ def _load_specs(self): definitions_list = definitions else: logger.warning('agent_tools configuration is not iterable; skip.') + if self._specs: + self._build_server_index() return for idx, spec_cfg in enumerate(definitions_list): @@ -258,6 +331,9 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], 'agent_tools[%s] missing tool_name/name field, skip.', idx) return None + mode = getattr(cfg, 'mode', None) + is_dynamic = (mode == 'dynamic') + agent_cfg = getattr(cfg, 'agent', None) config_path = getattr(cfg, 'config_path', None) inline_cfg = getattr(cfg, 'config', None) @@ -267,7 +343,7 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], inline_cfg = _to_container( inline_cfg) if inline_cfg is not None else None - if not config_path and inline_cfg is None: + if not is_dynamic and not config_path and inline_cfg is None: logger.warning( 'agent_tools[%s] (%s) missing config_path/config definition.', idx, tool_name) @@ -277,19 +353,22 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], f'Invoke agent "{tool_name}" as a tool.') parameters = getattr(cfg, 'parameters', None) if parameters is None: - parameters = { - 'type': 'object', - 'properties': { - 'request': { - 'type': - 'string', - 'description': - f'Task description forwarded to the sub-agent {tool_name}.' + if is_dynamic: + parameters = self._SPLIT_TASK_PARAMETERS + else: + parameters = { + 'type': 'object', + 'properties': { + 'request': { + 'type': + 'string', + 'description': + f'Task description forwarded to the sub-agent {tool_name}.' + }, }, - }, - 'required': ['request'], - 'additionalProperties': True, - } + 'required': ['request'], + 'additionalProperties': True, + } else: parameters = _to_container(parameters) @@ -302,17 +381,22 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], input_mode = getattr(cfg, 'input_mode', 'text') output_mode = getattr(cfg, 'output_mode', 'final_message') max_chars = int(getattr(cfg, 'max_output_chars', 100000)) + max_subtask_chars = int(getattr(cfg, 'max_subtask_output_chars', 8192)) server_name = getattr(cfg, 'server_name', default_server) trust_remote_code = getattr(cfg, 'trust_remote_code', None) - # Run sub-agent in a background thread to avoid blocking the main event loop - # when underlying LLM SDKs are synchronous. run_in_thread = bool(getattr(cfg, 'run_in_thread', True)) - # Run sub-agent in an isolated process so timed-out calls can be killed. run_in_process = bool(getattr(cfg, 'run_in_process', run_in_thread)) env_cfg = getattr(cfg, 'env', None) env_cfg = _to_container(env_cfg) if env_cfg is not None else None + disallowed_raw = getattr(cfg, 'disallowed_tools', None) + disallowed_tools = _to_container(disallowed_raw) if disallowed_raw is not None else None + if isinstance(disallowed_tools, list): + disallowed_tools = [str(t) for t in disallowed_tools] + elif disallowed_tools is not None: + disallowed_tools = None + if config_path and not os.path.isabs(config_path): base_dir = getattr(self.config, 'local_dir', None) if base_dir: @@ -336,6 +420,9 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], env=env_cfg, run_in_thread=run_in_thread, run_in_process=run_in_process, + dynamic=is_dynamic, + disallowed_tools=disallowed_tools, + max_subtask_output_chars=max_subtask_chars, ) def _build_server_index(self): @@ -395,6 +482,9 @@ def _emit_chunk_event(self, event_type: str, data: Dict[str, Any]) -> None: except Exception as exc: # noqa logger.warning(f'AgentTool chunk callback failed: {exc}') + def set_task_manager(self, task_manager) -> None: + self._task_manager = task_manager + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: if tool_name not in self._specs: @@ -408,6 +498,10 @@ async def call_tool(self, server_name: str, *, tool_name: str, call_id = None if isinstance(tool_args, dict) and '__call_id' in tool_args: call_id = tool_args.pop('__call_id', None) + + if spec.dynamic: + return await self._call_dynamic(tool_args, spec) + payload = self._build_payload(tool_args, spec) use_subprocess = spec.run_in_thread and spec.run_in_process agent = None if use_subprocess else self._build_agent(spec) @@ -417,6 +511,72 @@ async def call_tool(self, server_name: str, *, tool_name: str, def _build_agent(self, spec: _AgentToolSpec): return _build_sub_agent(spec, self._trust_remote_code) + async def _call_dynamic(self, tool_args: dict, spec: _AgentToolSpec) -> str: + tasks = tool_args.get('tasks', []) + execution_mode = tool_args.get('execution_mode', 'sequential') + + base_config = OmegaConf.to_container(self.config, resolve=True) + + async def _run_one(i: int, task: dict) -> str: + system = task.get('system', '') + query = task.get('query', '') + task_config = dict(base_config) if isinstance(base_config, dict) else {} + if 'prompt' not in task_config or not isinstance(task_config.get('prompt'), dict): + task_config['prompt'] = {} + task_config['prompt']['system'] = system + sub_spec = _AgentToolSpec( + tool_name=f'{spec.tool_name}_sub{i}', + description='', + parameters={}, + config_path=None, + inline_config=task_config, + server_name=spec.server_name, + tag_prefix=spec.tag_prefix, + input_mode='text', + request_field='request', + input_template=None, + output_mode='final_message', + max_output_chars=spec.max_subtask_output_chars, + trust_remote_code=spec.trust_remote_code, + env=spec.env, + run_in_thread=spec.run_in_thread, + run_in_process=spec.run_in_process, + dynamic=False, + disallowed_tools=spec.disallowed_tools, + max_subtask_output_chars=spec.max_subtask_output_chars, + ) + use_subprocess = sub_spec.run_in_thread and sub_spec.run_in_process + agent = None if use_subprocess else self._build_agent(sub_spec) + messages = await self._run_agent(agent, query, sub_spec) + return self._format_output(messages, sub_spec) + + if execution_mode == 'parallel': + results = await asyncio.gather( + *[_run_one(i, task) for i, task in enumerate(tasks)], + return_exceptions=True, + ) + res_list = [] + for i, r in enumerate(results): + if isinstance(r, Exception): + res_list.append(f'SubTask{i} failed with error: {r}') + else: + res_list.append(str(r)) + else: + res_list = [] + for i, task in enumerate(tasks): + try: + r = await _run_one(i, task) + res_list.append(r) + except Exception as e: + res_list.append(f'SubTask{i} failed with error: {e}') + + formatted = '' + for i, content in enumerate(res_list): + if len(content) > spec.max_subtask_output_chars: + content = content[:spec.max_subtask_output_chars] + formatted += f'SubTask{i}:{content}\n' + return formatted + @staticmethod def _terminate_process(proc: Optional[mp.Process], *, reason: str) -> None: if proc is None: @@ -710,7 +870,10 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: runner = _run_and_collect if not self._enable_stats: - return await runner() + result = await runner() + if not spec.run_in_process: + self._save_transcript(result, runtime_agent_tag) + return result start_ts = now_iso() start_time = monotonic() @@ -718,6 +881,8 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: result = None try: result = await runner() + if not spec.run_in_process: + self._save_transcript(result, runtime_agent_tag) return result except BaseException as exc: status = 'cancelled' if isinstance( @@ -749,6 +914,21 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: f'Failed to write agent tool stats for {spec.tool_name}: {exc}' ) + def _save_transcript(self, messages: Any, agent_tag: Optional[str]) -> None: + if not isinstance(messages, list) or not agent_tag: + return + try: + output_dir = getattr(self.config, 'output_dir', './output') + subagents_dir = os.path.join(output_dir, 'subagents') + os.makedirs(subagents_dir, exist_ok=True) + path = os.path.join(subagents_dir, f'{agent_tag}.jsonl') + with open(path, 'w', encoding='utf-8') as f: + for msg in messages: + if hasattr(msg, 'to_dict'): + f.write(json.dumps(msg.to_dict(), ensure_ascii=False) + '\n') + except Exception as exc: + logger.warning(f'Failed to save sub-agent transcript for {agent_tag}: {exc}') + def _build_payload(self, tool_args: dict, spec: _AgentToolSpec): if spec.input_mode == 'messages': field = spec.request_field or 'messages' diff --git a/ms_agent/tools/split_task.py b/ms_agent/tools/split_task.py deleted file mode 100644 index ea1e1c086..000000000 --- a/ms_agent/tools/split_task.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import asyncio -from concurrent.futures import ThreadPoolExecutor, as_completed - -from ms_agent.llm.utils import Tool -from ms_agent.tools.base import ToolBase -from ms_agent.utils.utils import escape_yaml_string -from omegaconf import DictConfig - - -class SplitTask(ToolBase): - """A tool special for task splitting""" - - def __init__(self, config: DictConfig, **kwargs): - super().__init__(config) - if hasattr(config, 'tools') and hasattr(config.tools, 'split_task'): - self.tag_prefix = getattr(config.tools.split_task, 'tag_prefix', - 'worker-') - else: - self.tag_prefix = kwargs.get('tag_prefix', 'worker-') - self.round = 0 - - async def connect(self): - pass - - async def cleanup(self): - pass - - async def _get_tools_inner(self): - return { - 'split_task': [ - Tool( - tool_name='split_to_sub_task', - server_name='split_task', - description= - 'Split complex task into sub tasks and start them, for example, ' - 'split a website generation task into sub tasks, ' - 'you plan the framework, include code files and classes and functions, and give the detail ' - 'information to the system and query field of the subtask, then ' - 'let each subtask to write a single file', - parameters={ - 'type': 'object', - 'properties': { - 'tasks': { - 'type': - 'array', - 'description': - 'MANDATORY: Each element is a dict, which must contains two fields: ' - '`system`(str) and `query`(str) to start one sub task.' - } - }, - 'required': ['tasks'], - 'additionalProperties': False - }) - ] - } - - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict): - """ - 1. LLMAgent will be used to start subtask - 2. config will be inherited from the parent task - 3. Supports both parallel and sequential execution modes - """ - from ms_agent.agent import LLMAgent - - tasks = tool_args.get('tasks') - execution_mode = tool_args.get( - 'execution_mode', 'sequential') # 'parallel' or 'sequential' - - def run_agent_sync(i, task): - """Synchronous wrapper for agent execution""" - system = task['system'] - query = task['query'] - config = DictConfig(self.config) - if not hasattr(config, 'prompt'): - config.prompt = DictConfig({}) - config.prompt.system = escape_yaml_string(system) - trust_remote_code = getattr(config, 'trust_remote_code', False) - agent = LLMAgent( - config=config, - trust_remote_code=trust_remote_code, - tag=f'{config.tag}-r{self.round}-{self.tag_prefix}{i}', - load_cache=getattr(config, 'load_cache', False)) - - # Run async agent.run() in sync context - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete(agent.run(query)) - finally: - loop.close() - - result = [] - if execution_mode == 'parallel': - # Use ThreadPoolExecutor for parallel execution - with ThreadPoolExecutor() as executor: - futures = { - executor.submit(run_agent_sync, i, task): i - for i, task in enumerate(tasks) - } - - # Collect results as they complete - for future in as_completed(futures): - i = futures[future] - try: - r = future.result() - result.append((i, r)) - except Exception as e: - result.append( - (i, f'Subtask{i} failed with error: {e}')) - - # Sort by task index to maintain order - result.sort(key=lambda x: x[0]) - result = [r[1] for r in result] - else: # sequential - for i, task in enumerate(tasks): - try: - r = run_agent_sync(i, task) - result.append(r) - except Exception as e: - result.append(f'Subtask{i} failed with error: {e}') - - res = [] - for messages in result: - if isinstance(messages, list): - content = messages[-1].content - if len(content) > 2048: - content = content[:2048] - else: - content = str(messages) - res.append(content) - - self.round += 1 - - formatted_result = '' - for i in range(len(res)): - formatted_result += f'SplitTask{i}:{res[i]}\n' - - return formatted_result diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 58f019774..2b3a5eb11 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -18,7 +18,6 @@ from ms_agent.tools.image_generator import ImageGenerator from ms_agent.tools.mcp_client import MCPClient from ms_agent.tools.search.websearch_tool import WebSearchTool -from ms_agent.tools.split_task import SplitTask from ms_agent.tools.todolist_tool import TodoListTool from ms_agent.tools.video_generator import VideoGenerator from ms_agent.utils import get_logger @@ -47,8 +46,6 @@ def __init__(self, self.extra_tools: List[ToolBase] = [] self.has_split_task_tool = False - if hasattr(config, 'tools') and hasattr(config.tools, 'split_task'): - self.extra_tools.append(SplitTask(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'image_generator'): self.extra_tools.append(ImageGenerator(config)) @@ -78,8 +75,9 @@ def __init__(self, 'financial_data_fetcher'): from ms_agent.tools.findata.findata_fetcher import FinancialDataFetcher self.extra_tools.append(FinancialDataFetcher(config)) - if hasattr(config, 'tools') and getattr(config.tools, 'agent_tools', - None): + if hasattr(config, 'tools') and ( + getattr(config.tools, 'agent_tools', None) + or hasattr(config.tools, 'split_task')): agent_tool = AgentTool( config, trust_remote_code=self.trust_remote_code) if agent_tool.enabled: diff --git a/ms_agent/utils/task_manager.py b/ms_agent/utils/task_manager.py new file mode 100644 index 000000000..00cf6ff4f --- /dev/null +++ b/ms_agent/utils/task_manager.py @@ -0,0 +1,132 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import asyncio +import multiprocessing as mp +import time +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from ms_agent.utils.logger import get_logger + +logger = get_logger() + + +@dataclass +class BackgroundTask: + task_id: str + task_type: str # 'agent' | 'shell' + tool_name: str # which tool spawned this + description: str + status: str = 'running' # 'running' | 'completed' | 'failed' | 'killed' + proc: Optional[Any] = field(default=None, repr=False) # mp.Process or asyncio.Task + result: Optional[str] = None + error: Optional[str] = None + started_at: float = field(default_factory=time.monotonic) + ended_at: Optional[float] = None + + +class TaskManager: + """ + Agent-level registry for background tasks (agent sub-tasks, shell tasks, etc.). + Holds a notification queue that LLMAgent drains each turn to inject + completion notices into the conversation. + """ + + def __init__(self): + self._tasks: Dict[str, BackgroundTask] = {} + self._lock = asyncio.Lock() + self._notification_queue: asyncio.Queue = asyncio.Queue() + + def register( + self, + task_type: str, + tool_name: str, + description: str, + proc: Optional[Any] = None, + task_id: Optional[str] = None, + ) -> str: + task_id = task_id or uuid.uuid4().hex[:12] + task = BackgroundTask( + task_id=task_id, + task_type=task_type, + tool_name=tool_name, + description=description, + proc=proc, + ) + self._tasks[task_id] = task + logger.info(f'[TaskManager] registered {task_type} task {task_id}: {description}') + return task_id + + async def complete(self, task_id: str, result: str) -> None: + task = self._tasks.get(task_id) + if task is None: + return + task.status = 'completed' + task.result = result + task.ended_at = time.monotonic() + await self._notification_queue.put(self._format_notification(task)) + + async def fail(self, task_id: str, error: str) -> None: + task = self._tasks.get(task_id) + if task is None: + return + task.status = 'failed' + task.error = error + task.ended_at = time.monotonic() + await self._notification_queue.put(self._format_notification(task)) + + def kill(self, task_id: str) -> None: + task = self._tasks.get(task_id) + if task is None: + return + if task.status != 'running': + return + if task.proc is not None: + try: + if isinstance(task.proc, mp.Process): + task.proc.terminate() + elif asyncio.isfuture(task.proc) or asyncio.iscoroutine(task.proc): + task.proc.cancel() + except Exception as e: + logger.warning(f'[TaskManager] kill {task_id} failed: {e}') + task.status = 'killed' + task.ended_at = time.monotonic() + + def kill_all(self) -> None: + for task_id in list(self._tasks): + if self._tasks[task_id].status == 'running': + self.kill(task_id) + + def drain_notifications(self) -> List[str]: + """Drain all pending notifications synchronously. Called from run_loop.""" + notifications = [] + while not self._notification_queue.empty(): + try: + notifications.append(self._notification_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return notifications + + def get_task(self, task_id: str) -> Optional[BackgroundTask]: + return self._tasks.get(task_id) + + def running_tasks(self) -> List[BackgroundTask]: + return [t for t in self._tasks.values() if t.status == 'running'] + + @staticmethod + def _format_notification(task: BackgroundTask) -> str: + status_line = f'status: {task.status}' + result_line = f'\nresult: {task.result}' if task.result else '' + error_line = f'\nerror: {task.error}' if task.error else '' + duration = '' + if task.ended_at: + duration = f'\nduration_s: {task.ended_at - task.started_at:.1f}' + return ( + f'\n' + f'{task.task_id}\n' + f'{task.task_type}\n' + f'{task.tool_name}\n' + f'{task.description}\n' + f'<{status_line}/>{result_line}{error_line}{duration}\n' + f'' + ) From db4545785d2892757a7af801066613253692f331 Mon Sep 17 00:00:00 2001 From: suluyan Date: Thu, 2 Apr 2026 21:12:45 +0800 Subject: [PATCH 21/40] feat: add run_in_background support to AgentTool When a spec has run_in_background=true, call_tool fires off the subprocess and returns immediately with {status: async_launched, task_id, tool_name}. A background asyncio watcher task polls the result queue and calls task_manager.complete/fail when the process exits. LLMAgent drains the TaskManager notification queue at the top of each run_loop iteration, injecting XML into the conversation so the model sees the result on the next turn. run_in_background is opt-in per agent_tools definition: agent_tools: definitions: - tool_name: my_agent config_path: my_agent.yaml run_in_background: true --- ms_agent/tools/agent_tool.py | 66 +++++++++++++++++++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/ms_agent/tools/agent_tool.py b/ms_agent/tools/agent_tool.py index 4c6bf6d02..6f14d733e 100644 --- a/ms_agent/tools/agent_tool.py +++ b/ms_agent/tools/agent_tool.py @@ -55,6 +55,7 @@ class _AgentToolSpec: dynamic: bool = False disallowed_tools: Optional[List[str]] = None max_subtask_output_chars: int = 8192 + run_in_background: bool = False _MESSAGE_FIELDS = set(Message.__dataclass_fields__.keys()) @@ -423,6 +424,7 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], dynamic=is_dynamic, disallowed_tools=disallowed_tools, max_subtask_output_chars=max_subtask_chars, + run_in_background=bool(getattr(cfg, 'run_in_background', False)), ) def _build_server_index(self): @@ -503,6 +505,10 @@ async def call_tool(self, server_name: str, *, tool_name: str, return await self._call_dynamic(tool_args, spec) payload = self._build_payload(tool_args, spec) + + if spec.run_in_background: + return await self._launch_background(payload, spec, call_id) + use_subprocess = spec.run_in_thread and spec.run_in_process agent = None if use_subprocess else self._build_agent(spec) messages = await self._run_agent(agent, payload, spec, call_id=call_id) @@ -511,7 +517,65 @@ async def call_tool(self, server_name: str, *, tool_name: str, def _build_agent(self, spec: _AgentToolSpec): return _build_sub_agent(spec, self._trust_remote_code) - async def _call_dynamic(self, tool_args: dict, spec: _AgentToolSpec) -> str: + async def _launch_background(self, payload: Any, spec: _AgentToolSpec, + call_id: Optional[str]) -> str: + """Fire-and-forget: start subprocess, register with TaskManager, return immediately.""" + if self._task_manager is None: + raise RuntimeError( + f'AgentTool "{spec.tool_name}" has run_in_background=true but ' + 'no TaskManager is attached. Ensure LLMAgent wires task_manager ' + 'into AgentTool via set_task_manager().') + + ctx = mp.get_context('spawn') + result_queue = ctx.Queue(maxsize=1) + process_payload = self._serialize_payload_for_process(payload) + proc = ctx.Process( + target=_run_agent_in_subprocess, + args=(spec, self._trust_remote_code, process_payload, False, None, + result_queue), + name=f'agent_tool_bg_{spec.tool_name}', + ) + proc.start() + + task_id = self._task_manager.register( + task_type='agent', + tool_name=spec.tool_name, + description=f'{spec.tool_name} (call_id={call_id})', + proc=proc, + ) + self._register_process(task_id, proc) + + async def _watcher(): + try: + result = await self._wait_process_result(proc, result_queue) + if result is None or not result.get('ok'): + err = (result or {}).get('error', 'subprocess exited without result') + tb = (result or {}).get('traceback', '') + if tb: + logger.warning(tb) + await self._task_manager.fail(task_id, err) + else: + result_payload = result.get('result', {}) or {} + restored = self._restore_process_result(result_payload) + output = self._format_output(restored, spec) + await self._task_manager.complete(task_id, output) + except Exception as exc: + await self._task_manager.fail(task_id, str(exc)) + finally: + self._unregister_process(task_id) + try: + result_queue.close() + result_queue.join_thread() + except Exception: + pass + + asyncio.create_task(_watcher()) + + return json.dumps({ + 'status': 'async_launched', + 'task_id': task_id, + 'tool_name': spec.tool_name, + }, ensure_ascii=False) tasks = tool_args.get('tasks', []) execution_mode = tool_args.get('execution_mode', 'sequential') From 47bf056afcc9ccdebfa592ae3554bb29cb5f15bb Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 3 Apr 2026 10:10:59 +0800 Subject: [PATCH 22/40] feat: add TaskControlTool for LLM-accessible task management Exposes two tools to the model when tools.task_control is configured: - list_tasks: show all background tasks with status and duration - cancel_task: kill a running task by task_id TaskControlTool receives the TaskManager reference via set_task_manager(), which LLMAgent already calls for all extra_tools in run_loop. Enable with: tools: task_control: {} --- ms_agent/tools/task_control_tool.py | 107 ++++++++++++++++++++++++++++ ms_agent/tools/tool_manager.py | 3 + 2 files changed, 110 insertions(+) create mode 100644 ms_agent/tools/task_control_tool.py diff --git a/ms_agent/tools/task_control_tool.py b/ms_agent/tools/task_control_tool.py new file mode 100644 index 000000000..33012c703 --- /dev/null +++ b/ms_agent/tools/task_control_tool.py @@ -0,0 +1,107 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import json +from typing import Any, Dict, Optional + +from ms_agent.llm.utils import Tool +from ms_agent.tools.base import ToolBase +from ms_agent.utils.logger import get_logger +from omegaconf import DictConfig + +logger = get_logger() + +_SERVER = 'task_control' + + +class TaskControlTool(ToolBase): + """Exposes background task management to the LLM. + + Provides two tools: + - list_tasks: show all background tasks and their status + - cancel_task: kill a running background task by task_id + """ + + def __init__(self, config: DictConfig, **kwargs): + super().__init__(config) + self._task_manager = None + + def set_task_manager(self, task_manager) -> None: + self._task_manager = task_manager + + async def connect(self) -> None: + pass + + async def cleanup(self) -> None: + pass + + async def _get_tools_inner(self) -> Dict[str, Any]: + return { + _SERVER: [ + Tool( + tool_name='list_tasks', + server_name=_SERVER, + description=( + 'List all background tasks and their current status. ' + 'Returns task_id, tool_name, description, status, and duration.'), + parameters={ + 'type': 'object', + 'properties': {}, + 'required': [], + 'additionalProperties': False, + }, + ), + Tool( + tool_name='cancel_task', + server_name=_SERVER, + description='Cancel a running background task by its task_id.', + parameters={ + 'type': 'object', + 'properties': { + 'task_id': { + 'type': 'string', + 'description': 'The task_id returned by the async_launched response.', + } + }, + 'required': ['task_id'], + 'additionalProperties': False, + }, + ), + ] + } + + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: + if self._task_manager is None: + return 'TaskManager not available.' + + if tool_name == 'list_tasks': + tasks = list(self._task_manager._tasks.values()) + if not tasks: + return 'No background tasks registered.' + rows = [] + for t in tasks: + duration = '' + if t.ended_at: + duration = f'{t.ended_at - t.started_at:.1f}s' + elif t.status == 'running': + import time + duration = f'{time.monotonic() - t.started_at:.1f}s (running)' + rows.append({ + 'task_id': t.task_id, + 'tool_name': t.tool_name, + 'description': t.description, + 'status': t.status, + 'duration': duration, + }) + return json.dumps(rows, ensure_ascii=False, indent=2) + + if tool_name == 'cancel_task': + task_id = tool_args.get('task_id', '') + task = self._task_manager.get_task(task_id) + if task is None: + return f'Task "{task_id}" not found.' + if task.status != 'running': + return f'Task "{task_id}" is already {task.status}.' + self._task_manager.kill(task_id) + return f'Task "{task_id}" cancelled.' + + return f'Unknown tool: {tool_name}' diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 2b3a5eb11..19385c2e6 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -86,6 +86,9 @@ def __init__(self, self.extra_tools.append(TodoListTool(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'web_search'): self.extra_tools.append(WebSearchTool(config)) + if hasattr(config, 'tools') and hasattr(config.tools, 'task_control'): + from ms_agent.tools.task_control_tool import TaskControlTool + self.extra_tools.append(TaskControlTool(config)) self.tool_call_timeout = getattr(config, 'tool_call_timeout', TOOL_CALL_TIMEOUT) local_dir = self.config.local_dir if hasattr(self.config, From e3b62dcc8854d144c3a2dc6a448c640d5d97846b Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 3 Apr 2026 11:34:55 +0800 Subject: [PATCH 23/40] fix: hold strong reference to background watcher asyncio.Task to prevent GC cancellation --- ms_agent/tools/agent_tool.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ms_agent/tools/agent_tool.py b/ms_agent/tools/agent_tool.py index 6f14d733e..7beedb279 100644 --- a/ms_agent/tools/agent_tool.py +++ b/ms_agent/tools/agent_tool.py @@ -204,6 +204,7 @@ def __init__(self, config: DictConfig, **kwargs): self._active_processes: Dict[str, mp.Process] = {} self._active_processes_lock = threading.Lock() self._task_manager = None + self._watcher_tasks: set = set() self._load_specs() self._init_thread_pool_config() @@ -569,7 +570,9 @@ async def _watcher(): except Exception: pass - asyncio.create_task(_watcher()) + t = asyncio.create_task(_watcher()) + self._watcher_tasks.add(t) + t.add_done_callback(self._watcher_tasks.discard) return json.dumps({ 'status': 'async_launched', From 6286732d8ca8c2591811b69c8f4b20aec75debcc Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 3 Apr 2026 11:37:24 +0800 Subject: [PATCH 24/40] fix: cancel watcher tasks on AgentTool cleanup; export TaskControlTool from tools __init__ --- ms_agent/tools/__init__.py | 1 + ms_agent/tools/agent_tool.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/ms_agent/tools/__init__.py b/ms_agent/tools/__init__.py index 7e185fdd0..2dae662c7 100644 --- a/ms_agent/tools/__init__.py +++ b/ms_agent/tools/__init__.py @@ -4,5 +4,6 @@ from .code_server import LSPCodeServer from .filesystem_tool import FileSystemTool from .mcp_client import MCPClient +from .task_control_tool import TaskControlTool from .todolist_tool import TodoListTool from .tool_manager import ToolManager diff --git a/ms_agent/tools/agent_tool.py b/ms_agent/tools/agent_tool.py index 7beedb279..26549b333 100644 --- a/ms_agent/tools/agent_tool.py +++ b/ms_agent/tools/agent_tool.py @@ -454,6 +454,9 @@ async def connect(self): async def cleanup(self): self._terminate_all_active_processes(reason='during AgentTool cleanup') + for t in list(self._watcher_tasks): + t.cancel() + self._watcher_tasks.clear() if self._thread_executor is not None: try: try: From c3261399548847616ed607b116e1f41303ba7246 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 3 Apr 2026 11:39:43 +0800 Subject: [PATCH 25/40] fix: correct malformed XML in TaskManager._format_notification --- ms_agent/utils/task_manager.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ms_agent/utils/task_manager.py b/ms_agent/utils/task_manager.py index 00cf6ff4f..a897e2fa9 100644 --- a/ms_agent/utils/task_manager.py +++ b/ms_agent/utils/task_manager.py @@ -115,18 +115,18 @@ def running_tasks(self) -> List[BackgroundTask]: @staticmethod def _format_notification(task: BackgroundTask) -> str: - status_line = f'status: {task.status}' - result_line = f'\nresult: {task.result}' if task.result else '' - error_line = f'\nerror: {task.error}' if task.error else '' + result_line = f'\n{task.result}' if task.result else '' + error_line = f'\n{task.error}' if task.error else '' duration = '' if task.ended_at: - duration = f'\nduration_s: {task.ended_at - task.started_at:.1f}' + duration = f'\n{task.ended_at - task.started_at:.1f}' return ( f'\n' f'{task.task_id}\n' f'{task.task_type}\n' f'{task.tool_name}\n' f'{task.description}\n' - f'<{status_line}/>{result_line}{error_line}{duration}\n' + f'{task.status}' + f'{result_line}{error_line}{duration}\n' f'' ) From b103ffa973099f41c3df04e01ba401b2c5b8fd6c Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 3 Apr 2026 11:41:57 +0800 Subject: [PATCH 26/40] test: add smoke tests for TaskManager, AgentTool dynamic spec, TaskControlTool (18 tests) --- tests/utils/test_task_manager_smoke.py | 238 +++++++++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 tests/utils/test_task_manager_smoke.py diff --git a/tests/utils/test_task_manager_smoke.py b/tests/utils/test_task_manager_smoke.py new file mode 100644 index 000000000..8f57444f3 --- /dev/null +++ b/tests/utils/test_task_manager_smoke.py @@ -0,0 +1,238 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +Smoke tests for TaskManager and AgentTool dynamic/background mode. +No network, no LLM — all tests run fully offline. +""" +import asyncio +import os +import sys +import unittest + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from ms_agent.utils.task_manager import BackgroundTask, TaskManager + + +# --------------------------------------------------------------------------- +# TaskManager unit tests +# --------------------------------------------------------------------------- + +class TestTaskManager(unittest.IsolatedAsyncioTestCase): + + async def test_register_and_complete(self): + tm = TaskManager() + task_id = tm.register('agent', 'my_tool', 'do something') + self.assertIn(task_id, tm._tasks) + self.assertEqual(tm._tasks[task_id].status, 'running') + + await tm.complete(task_id, 'great result') + self.assertEqual(tm._tasks[task_id].status, 'completed') + self.assertEqual(tm._tasks[task_id].result, 'great result') + + notifications = tm.drain_notifications() + self.assertEqual(len(notifications), 1) + self.assertIn('completed', notifications[0]) + self.assertIn('great result', notifications[0]) + self.assertIn(task_id, notifications[0]) + + async def test_register_and_fail(self): + tm = TaskManager() + task_id = tm.register('agent', 'my_tool', 'do something') + await tm.fail(task_id, 'something went wrong') + self.assertEqual(tm._tasks[task_id].status, 'failed') + + notifications = tm.drain_notifications() + self.assertEqual(len(notifications), 1) + self.assertIn('failed', notifications[0]) + self.assertIn('something went wrong', notifications[0]) + + def test_kill(self): + tm = TaskManager() + task_id = tm.register('agent', 'my_tool', 'do something') + tm.kill(task_id) + self.assertEqual(tm._tasks[task_id].status, 'killed') + # kill again is a no-op + tm.kill(task_id) + self.assertEqual(tm._tasks[task_id].status, 'killed') + + def test_kill_all(self): + tm = TaskManager() + ids = [tm.register('agent', 'tool', f'task {i}') for i in range(3)] + tm.kill_all() + for tid in ids: + self.assertEqual(tm._tasks[tid].status, 'killed') + + def test_drain_empty(self): + tm = TaskManager() + self.assertEqual(tm.drain_notifications(), []) + + async def test_drain_multiple(self): + tm = TaskManager() + id1 = tm.register('agent', 'tool_a', 'task a') + id2 = tm.register('agent', 'tool_b', 'task b') + await tm.complete(id1, 'result a') + await tm.fail(id2, 'error b') + notifications = tm.drain_notifications() + self.assertEqual(len(notifications), 2) + # drain again should be empty + self.assertEqual(tm.drain_notifications(), []) + + def test_get_task(self): + tm = TaskManager() + task_id = tm.register('shell', 'bash', 'run script') + task = tm.get_task(task_id) + self.assertIsNotNone(task) + self.assertEqual(task.task_type, 'shell') + self.assertIsNone(tm.get_task('nonexistent')) + + def test_running_tasks(self): + tm = TaskManager() + id1 = tm.register('agent', 'tool', 'task 1') + id2 = tm.register('agent', 'tool', 'task 2') + tm.kill(id2) + running = tm.running_tasks() + self.assertEqual(len(running), 1) + self.assertEqual(running[0].task_id, id1) + + async def test_notification_xml_structure(self): + tm = TaskManager() + task_id = tm.register('agent', 'searcher_tool', 'search for X') + await tm.complete(task_id, 'found Y') + notif = tm.drain_notifications()[0] + self.assertTrue(notif.startswith('')) + self.assertTrue(notif.strip().endswith('')) + self.assertIn('', notif) + self.assertIn('agent', notif) + self.assertIn('searcher_tool', notif) + self.assertIn('search for X', notif) + self.assertIn('completed', notif) + self.assertIn('found Y', notif) + self.assertIn('', notif) + + +# --------------------------------------------------------------------------- +# AgentTool dynamic spec (merged SplitTask) — schema validation only +# --------------------------------------------------------------------------- + +class TestAgentToolDynamicSpec(unittest.TestCase): + + def _make_config(self): + from omegaconf import OmegaConf + return OmegaConf.create({ + 'tag': 'test-agent', + 'output_dir': '/tmp/test_agent_tool', + 'tools': { + 'split_task': { + 'tag_prefix': 'worker-', + 'run_in_thread': False, + 'run_in_process': False, + } + } + }) + + def test_split_task_spec_registered(self): + from ms_agent.tools.agent_tool import AgentTool + config = self._make_config() + tool = AgentTool(config) + self.assertTrue(tool.enabled) + self.assertIn('split_to_sub_task', tool._specs) + + def test_split_task_spec_is_dynamic(self): + from ms_agent.tools.agent_tool import AgentTool + config = self._make_config() + tool = AgentTool(config) + spec = tool._specs['split_to_sub_task'] + self.assertTrue(spec.dynamic) + self.assertFalse(spec.run_in_process) + + def test_split_task_parameters_schema(self): + from ms_agent.tools.agent_tool import AgentTool + config = self._make_config() + tool = AgentTool(config) + spec = tool._specs['split_to_sub_task'] + props = spec.parameters['properties'] + self.assertIn('tasks', props) + self.assertIn('execution_mode', props) + # execution_mode must have enum + self.assertIn('enum', props['execution_mode']) + self.assertIn('parallel', props['execution_mode']['enum']) + self.assertIn('sequential', props['execution_mode']['enum']) + + def test_dynamic_mode_in_agent_tools_definitions(self): + from ms_agent.tools.agent_tool import AgentTool + from omegaconf import OmegaConf + config = OmegaConf.create({ + 'tag': 'test-agent', + 'output_dir': '/tmp/test_agent_tool', + 'tools': { + 'agent_tools': { + 'definitions': [{ + 'tool_name': 'my_dynamic_tool', + 'mode': 'dynamic', + 'description': 'A dynamic tool', + }] + } + } + }) + tool = AgentTool(config) + self.assertIn('my_dynamic_tool', tool._specs) + self.assertTrue(tool._specs['my_dynamic_tool'].dynamic) + + +# --------------------------------------------------------------------------- +# TaskControlTool unit tests +# --------------------------------------------------------------------------- + +class TestTaskControlTool(unittest.IsolatedAsyncioTestCase): + + def _make_tool(self): + from ms_agent.tools.task_control_tool import TaskControlTool + from omegaconf import OmegaConf + config = OmegaConf.create({'output_dir': '/tmp'}) + tool = TaskControlTool(config) + tm = TaskManager() + tool.set_task_manager(tm) + return tool, tm + + async def test_list_tasks_empty(self): + tool, _ = self._make_tool() + result = await tool.call_tool('task_control', tool_name='list_tasks', tool_args={}) + self.assertEqual(result, 'No background tasks registered.') + + async def test_list_tasks_with_entries(self): + import json + tool, tm = self._make_tool() + tm.register('agent', 'searcher', 'search X') + result = await tool.call_tool('task_control', tool_name='list_tasks', tool_args={}) + rows = json.loads(result) + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]['tool_name'], 'searcher') + self.assertEqual(rows[0]['status'], 'running') + + async def test_cancel_task(self): + tool, tm = self._make_tool() + task_id = tm.register('agent', 'searcher', 'search X') + result = await tool.call_tool('task_control', tool_name='cancel_task', + tool_args={'task_id': task_id}) + self.assertIn('cancelled', result) + self.assertEqual(tm.get_task(task_id).status, 'killed') + + async def test_cancel_nonexistent(self): + tool, _ = self._make_tool() + result = await tool.call_tool('task_control', tool_name='cancel_task', + tool_args={'task_id': 'bad-id'}) + self.assertIn('not found', result) + + async def test_cancel_already_done(self): + tool, tm = self._make_tool() + task_id = tm.register('agent', 'searcher', 'search X') + tm.kill(task_id) + result = await tool.call_tool('task_control', tool_name='cancel_task', + tool_args={'task_id': task_id}) + self.assertIn('already', result) + + +if __name__ == '__main__': + unittest.main() From 7a96f8e4a6a032d3bc7a886cee1d9d3f80dcf9a0 Mon Sep 17 00:00:00 2001 From: suluyan Date: Thu, 9 Apr 2026 14:04:33 +0800 Subject: [PATCH 27/40] feat: stream sub-agent execution trace to file in real time Add SubAgentStreamWriter (ms_agent/utils/stream_writer.py) that appends each new message to a JSONL file as soon as it arrives, so the parent agent or an external observer can tail -f to watch a sub-agent run step-by-step instead of waiting for it to finish. Key details: - JSONL format: header -> message* -> footer, one JSON object per line - Deduplication via last_written_count: each chunk carries the full accumulated history; only newly added messages are written - Thread-safe (threading.Lock) and flush-on-every-line for tail -f support - Works for both inline-async and subprocess execution paths - event_queue is now created when either _chunk_cb or the writer is active - Opt-in via config: agent_stream_file: true (or tools.agent_tools.enable_stream_file: true) - File path: {output_dir}/subagents/{call_id}.stream.jsonl - A descriptive note is appended to the tool result so the parent LLM understands the file is an incremental execution trace, not tool output Also includes AgentTool refactor: replace ThreadPoolExecutor with native asyncio subprocess spawning, add sync_timeout_s + escape-to-background support, TaskControlTool improvements, and related smoke tests. Entire-Checkpoint: 37377e309a88 --- ms_agent/tools/agent_tool.py | 317 +++++++++++++++++++------ ms_agent/tools/task_control_tool.py | 8 +- ms_agent/utils/stream_writer.py | 208 ++++++++++++++++ tests/utils/test_task_manager_smoke.py | 125 ++++++++++ 4 files changed, 588 insertions(+), 70 deletions(-) create mode 100644 ms_agent/utils/stream_writer.py diff --git a/ms_agent/tools/agent_tool.py b/ms_agent/tools/agent_tool.py index 26549b333..5f9ff7d20 100644 --- a/ms_agent/tools/agent_tool.py +++ b/ms_agent/tools/agent_tool.py @@ -6,7 +6,6 @@ import traceback import uuid from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from queue import Empty as QueueEmpty from queue import Full as QueueFull @@ -20,7 +19,7 @@ from ms_agent.utils.stats import (append_stats, build_timing_record, get_stats_path, monotonic, now_iso, summarize_usage) -from ms_agent.utils.thread_util import DaemonThreadPoolExecutor +from ms_agent.utils.stream_writer import SubAgentStreamWriter from omegaconf import DictConfig, ListConfig, OmegaConf logger = get_logger() @@ -56,6 +55,7 @@ class _AgentToolSpec: disallowed_tools: Optional[List[str]] = None max_subtask_output_chars: int = 8192 run_in_background: bool = False + sync_timeout_s: Optional[float] = None _MESSAGE_FIELDS = set(Message.__dataclass_fields__.keys()) @@ -198,26 +198,16 @@ def __init__(self, config: DictConfig, **kwargs): self._enable_stats = False self._specs: Dict[str, _AgentToolSpec] = {} self._server_tools: Dict[str, List[Tool]] = {} - self._thread_executor: Optional[ThreadPoolExecutor] = None - self._thread_max_workers: int = 0 self._chunk_cb: Optional[Callable[..., Any]] = None self._active_processes: Dict[str, mp.Process] = {} self._active_processes_lock = threading.Lock() self._task_manager = None self._watcher_tasks: set = set() + # call_id -> (asyncio.Task, spec, payload, escape_event) + self._active_sync_tasks: Dict[str, Any] = {} + # effective_call_id -> stream file path (set during _run_agent, consumed by call_tool) + self._stream_paths: Dict[str, str] = {} self._load_specs() - self._init_thread_pool_config() - - def _init_thread_pool_config(self): - tools_cfg = getattr(self.config, 'tools', DictConfig({})) - agent_tools_cfg = getattr(tools_cfg, 'agent_tools', DictConfig({})) - max_workers = getattr(agent_tools_cfg, 'max_workers', None) - if max_workers is None: - max_workers = os.getenv('AGENT_TOOL_MAX_WORKERS', None) - try: - self._thread_max_workers = int(max_workers) if max_workers else 3 - except Exception: - self._thread_max_workers = 3 @property def enabled(self) -> bool: @@ -426,6 +416,7 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], disallowed_tools=disallowed_tools, max_subtask_output_chars=max_subtask_chars, run_in_background=bool(getattr(cfg, 'run_in_background', False)), + sync_timeout_s=float(getattr(cfg, 'sync_timeout_s', 0)) or None, ) def _build_server_index(self): @@ -441,15 +432,6 @@ def _build_server_index(self): self._server_tools = server_map async def connect(self): - # Lazily initialize a dedicated pool for agent tools that opt into - # `run_in_thread`, so we don't consume threads when the tool is unused. - if self._thread_executor is None: - # Use daemon threads to avoid blocking process exit when sub-agent - # calls are cancelled by tool-call timeouts. - self._thread_executor = DaemonThreadPoolExecutor( - max_workers=self._thread_max_workers, - thread_name_prefix='agent_tool_', - ) return None async def cleanup(self): @@ -457,16 +439,6 @@ async def cleanup(self): for t in list(self._watcher_tasks): t.cancel() self._watcher_tasks.clear() - if self._thread_executor is not None: - try: - try: - self._thread_executor.shutdown( - wait=False, cancel_futures=True) - except TypeError: - self._thread_executor.shutdown(wait=False) - except Exception: - pass - self._thread_executor = None return None async def get_tools(self) -> Dict[str, Any]: @@ -491,6 +463,52 @@ def _emit_chunk_event(self, event_type: str, data: Dict[str, Any]) -> None: def set_task_manager(self, task_manager) -> None: self._task_manager = task_manager + # ── stream-file helpers ──────────────────────────────────────────────── + + def _stream_file_enabled(self) -> bool: + """Return True if sub-agent stream-file writing is enabled. + + Checks in order: + 1. ``config.tools.agent_tools.enable_stream_file`` + 2. ``config.agent_stream_file`` + Defaults to ``False``. + """ + agent_tools_cfg = getattr( + getattr(self.config, 'tools', None), 'agent_tools', None) + if agent_tools_cfg is not None: + val = getattr(agent_tools_cfg, 'enable_stream_file', None) + if val is not None: + return bool(val) + return bool(getattr(self.config, 'agent_stream_file', False)) + + def _stream_file_dir(self) -> str: + """Return the output directory used for stream files. + + Checks ``config.tools.agent_tools.stream_file_dir`` first, then falls + back to ``config.output_dir``. + """ + agent_tools_cfg = getattr( + getattr(self.config, 'tools', None), 'agent_tools', None) + if agent_tools_cfg is not None: + override = getattr(agent_tools_cfg, 'stream_file_dir', None) + if override: + return str(override) + return str(getattr(self.config, 'output_dir', './output')) + + def _stream_include_in_result(self) -> bool: + """Return True if the stream-file path should be appended to the tool result. + + Controlled by ``config.tools.agent_tools.stream_include_in_result`` + (defaults to ``True`` when stream files are enabled). + """ + agent_tools_cfg = getattr( + getattr(self.config, 'tools', None), 'agent_tools', None) + if agent_tools_cfg is not None: + val = getattr(agent_tools_cfg, 'stream_include_in_result', None) + if val is not None: + return bool(val) + return True + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: if tool_name not in self._specs: @@ -505,22 +523,147 @@ async def call_tool(self, server_name: str, *, tool_name: str, if isinstance(tool_args, dict) and '__call_id' in tool_args: call_id = tool_args.pop('__call_id', None) + # Generate a stable effective_call_id early so that _run_agent() can + # store the stream-file path under the same key and we can look it up. + effective_call_id = call_id or f'{tool_name}-{uuid.uuid4().hex[:8]}' + if spec.dynamic: return await self._call_dynamic(tool_args, spec) payload = self._build_payload(tool_args, spec) if spec.run_in_background: - return await self._launch_background(payload, spec, call_id) + return await self._launch_background(payload, spec, effective_call_id) use_subprocess = spec.run_in_thread and spec.run_in_process - agent = None if use_subprocess else self._build_agent(spec) - messages = await self._run_agent(agent, payload, spec, call_id=call_id) - return self._format_output(messages, spec) + if use_subprocess: + messages = await self._run_agent( + None, payload, spec, call_id=effective_call_id) + result_str = self._format_output(messages, spec) + return self._maybe_append_stream_path(result_str, effective_call_id) + + # Pure async/await with optional escape-to-background support. + result = await self._run_sync_escapable(payload, spec, effective_call_id) + if isinstance(result, str): + # Already formatted: escaped to background, returns async_launched JSON. + return result + result_str = self._format_output(result, spec) + return self._maybe_append_stream_path(result_str, effective_call_id) + + def _maybe_append_stream_path(self, result_str: str, effective_call_id: str) -> str: + """Append a human- and LLM-readable note about the step-by-step execution + log to *result_str* if streaming is enabled. + + The note is intentionally descriptive so that the parent agent understands + the file contains the sub-agent's incremental thinking/tool-call trace, + not the tool's output. + """ + if not self._stream_file_enabled(): + return result_str + if not self._stream_include_in_result(): + return result_str + path = self._stream_paths.pop(effective_call_id, None) + if path: + result_str += ( + f'\n\n[Note: The sub-agent\'s step-by-step execution trace ' + f'(messages, tool calls, intermediate reasoning) was streamed ' + f'incrementally to: {path}]' + ) + return result_str def _build_agent(self, spec: _AgentToolSpec): return _build_sub_agent(spec, self._trust_remote_code) + async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, + call_id: Optional[str]) -> Any: + """Run sub-agent inline (pure async/await). + + If spec.sync_timeout_s is set, the call auto-escapes to background after + that many seconds. The caller can also trigger escape at any time via + escape_to_background(call_id). + + Returns either the raw messages list (normal completion) or a JSON string + (async_launched, when escaped to background). + """ + escape_event = asyncio.Event() + effective_call_id = call_id or uuid.uuid4().hex[:12] + + run_task = asyncio.create_task( + self._run_agent(None, payload, spec, call_id=effective_call_id)) + + self._active_sync_tasks[effective_call_id] = (run_task, spec, payload, + escape_event) + + try: + if spec.sync_timeout_s and spec.sync_timeout_s > 0: + escape_wait_task = asyncio.create_task(escape_event.wait()) + sleep_task = asyncio.create_task( + asyncio.sleep(spec.sync_timeout_s)) + _, pending = await asyncio.wait( + [run_task, escape_wait_task, sleep_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for t in pending: + t.cancel() + if not run_task.done(): + return await self._escape_running_task( + effective_call_id, run_task, spec, payload) + else: + # No timeout: wait for completion or explicit escape signal. + escape_task = asyncio.create_task(escape_event.wait()) + _, pending = await asyncio.wait( + [run_task, escape_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for t in pending: + t.cancel() + if not run_task.done(): + return await self._escape_running_task( + effective_call_id, run_task, spec, payload) + + return run_task.result() + finally: + self._active_sync_tasks.pop(effective_call_id, None) + + async def _escape_running_task(self, call_id: str, + run_task: 'asyncio.Task[Any]', + spec: _AgentToolSpec, + payload: Any) -> str: + """Cancel the in-progress sync task and re-launch it as a background subprocess.""" + if self._task_manager is None: + raise RuntimeError( + f'AgentTool "{spec.tool_name}" tried to escape to background but ' + 'no TaskManager is attached.') + + run_task.cancel() + try: + await run_task + except (asyncio.CancelledError, Exception): + pass + + # Re-launch as background subprocess (same as _launch_background). + return await self._launch_background(payload, spec, call_id) + + def escape_to_background(self, call_id: str) -> bool: + """Signal a running sync call to escape to background. + + Called by TaskControlTool (or any external code) when the LLM or user + decides the task should continue in the background. + + Args: + call_id: The __call_id injected into tool_args, or the auto-generated + hex id returned in the async_launched JSON. + + Returns: + True if the escape signal was delivered, False if call_id not found. + """ + entry = self._active_sync_tasks.get(call_id) + if entry is None: + return False + _, _, _, escape_event = entry + escape_event.set() + return True + async def _launch_background(self, payload: Any, spec: _AgentToolSpec, call_id: Optional[str]) -> str: """Fire-and-forget: start subprocess, register with TaskManager, return immediately.""" @@ -582,6 +725,8 @@ async def _watcher(): 'task_id': task_id, 'tool_name': spec.tool_name, }, ensure_ascii=False) + + async def _call_dynamic(self, tool_args: dict, spec: '_AgentToolSpec') -> str: tasks = tool_args.get('tasks', []) execution_mode = tool_args.get('execution_mode', 'sequential') @@ -758,13 +903,28 @@ async def _run_agent(self, runtime_agent_tag = getattr(runtime_agent, 'tag', None) runtime_agent_type = getattr(runtime_agent, 'AGENT_NAME', None) + # ── stream-file writer (optional) ────────────────────────────────── + _writer: Optional[SubAgentStreamWriter] = None + if self._stream_file_enabled(): + _effective_call_id = call_id or f'{spec.tool_name}-{uuid.uuid4().hex[:8]}' + _writer = SubAgentStreamWriter( + output_dir=self._stream_file_dir(), + call_id=_effective_call_id, + tool_name=spec.tool_name, + ) + logger.info( + '[stream] %s (call_id=%s) streaming to %s', + spec.tool_name, _effective_call_id, _writer.stream_path, + ) + # ─────────────────────────────────────────────────────────────────── + async def _run_and_collect(): nonlocal runtime_agent, runtime_agent_tag, runtime_agent_type if runtime_agent is None: runtime_agent = self._build_agent(spec) runtime_agent_tag = getattr(runtime_agent, 'tag', None) runtime_agent_type = getattr(runtime_agent, 'AGENT_NAME', None) - if self._chunk_cb: + if self._chunk_cb or _writer is not None: result = await runtime_agent.run(payload, stream=True) else: result = await runtime_agent.run(payload) @@ -774,6 +934,8 @@ async def _run_and_collect(): 'call_id': call_id, 'tool_name': spec.tool_name, }) + if _writer is not None: + _writer.on_start(runtime_agent_tag) async for chunk in result: history = chunk self._emit_chunk_event( @@ -782,6 +944,8 @@ async def _run_and_collect(): 'tool_name': spec.tool_name, 'history': chunk, }) + if _writer is not None: + _writer.on_chunk(chunk) if history is not None: self._emit_chunk_event( 'end', { @@ -789,59 +953,59 @@ async def _run_and_collect(): 'tool_name': spec.tool_name, 'history': history, }) + if _writer is not None: + _writer.on_end(history) result = history else: self._emit_chunk_event('start', { 'call_id': call_id, 'tool_name': spec.tool_name, }) + if _writer is not None: + _writer.on_start(runtime_agent_tag) self._emit_chunk_event( 'chunk', { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': result, }) + if _writer is not None: + _writer.on_chunk(result) self._emit_chunk_event( 'end', { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': result, }) + if _writer is not None: + _writer.on_end(result) return result - async def _run_in_background(): - # Run sub-agent in a dedicated event loop in a background thread. - def _sync_runner(): - return asyncio.run(_run_and_collect()) - - loop = asyncio.get_running_loop() - if self._thread_executor is not None: - return await loop.run_in_executor(self._thread_executor, - _sync_runner) - return await asyncio.to_thread(_sync_runner) - async def _run_in_subprocess(): nonlocal runtime_agent_tag, runtime_agent_type ctx = mp.get_context('spawn') result_queue = ctx.Queue(maxsize=1) - event_queue = ctx.Queue( - maxsize=128) if self._chunk_cb is not None else None + # Create event_queue when either chunk_cb or stream writer is active + # so that sub-agent progress can be forwarded to both sinks. + need_events = self._chunk_cb is not None or _writer is not None + event_queue = ctx.Queue(maxsize=128) if need_events else None proc: Optional[mp.Process] = None run_id = f'{call_id or "agent_tool"}-{uuid.uuid4().hex[:8]}' def _emit_stream_event(event: Dict[str, Any]) -> None: - if not self._chunk_cb: - return history_payload = event.get('history') if not isinstance(history_payload, dict): return history = self._restore_process_result(history_payload) - self._emit_chunk_event( - 'chunk', { - 'call_id': call_id, - 'tool_name': spec.tool_name, - 'history': history, - }) + if self._chunk_cb: + self._emit_chunk_event( + 'chunk', { + 'call_id': call_id, + 'tool_name': spec.tool_name, + 'history': history, + }) + if _writer is not None: + _writer.on_chunk(history) try: if self._chunk_cb: @@ -849,12 +1013,14 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: 'call_id': call_id, 'tool_name': spec.tool_name, }) + if _writer is not None: + # agent_tag unknown until subprocess completes; pass None + _writer.on_start(None) process_payload = self._serialize_payload_for_process(payload) proc = ctx.Process( target=_run_agent_in_subprocess, args=(spec, self._trust_remote_code, process_payload, - self._chunk_cb - is not None, event_queue, result_queue), + need_events, event_queue, result_queue), name=f'agent_tool_{spec.tool_name}', ) proc.start() @@ -878,9 +1044,10 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: tb = result.get('traceback', '') if tb: logger.warning(tb) - raise RuntimeError( - f'Sub-agent {spec.tool_name} failed: {result.get("error", "unknown error")}' - ) + err_msg = f'Sub-agent {spec.tool_name} failed: {result.get("error", "unknown error")}' + if _writer is not None: + _writer.on_error(err_msg) + raise RuntimeError(err_msg) result_payload = result.get('result', {}) or {} runtime_agent_tag = result_payload.get( 'agent_tag') or runtime_agent_tag @@ -903,12 +1070,19 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: 'tool_name': spec.tool_name, 'history': restored, }) + # Always finalise the writer regardless of _chunk_cb. + if _writer is not None: + _writer.on_end(restored) return restored except asyncio.CancelledError: self._terminate_process(proc, reason='was cancelled') + if _writer is not None: + _writer.on_error('cancelled') raise - except Exception: + except Exception as exc: self._terminate_process(proc, reason='encountered error') + if _writer is not None: + _writer.on_error(str(exc)) raise finally: self._unregister_process(run_id) @@ -934,8 +1108,6 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: if spec.run_in_thread and spec.run_in_process: runner = _run_in_subprocess - elif spec.run_in_thread: - runner = _run_in_background else: runner = _run_and_collect @@ -943,6 +1115,10 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: result = await runner() if not spec.run_in_process: self._save_transcript(result, runtime_agent_tag) + if _writer is not None: + # Store with the same key used by call_tool() to pop it. + store_key = call_id if call_id is not None else _effective_call_id + self._stream_paths[store_key] = _writer.stream_path return result start_ts = now_iso() @@ -953,6 +1129,9 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: result = await runner() if not spec.run_in_process: self._save_transcript(result, runtime_agent_tag) + if _writer is not None: + store_key = call_id if call_id is not None else _effective_call_id + self._stream_paths[store_key] = _writer.stream_path return result except BaseException as exc: status = 'cancelled' if isinstance( diff --git a/ms_agent/tools/task_control_tool.py b/ms_agent/tools/task_control_tool.py index 33012c703..01bd75f64 100644 --- a/ms_agent/tools/task_control_tool.py +++ b/ms_agent/tools/task_control_tool.py @@ -15,18 +15,24 @@ class TaskControlTool(ToolBase): """Exposes background task management to the LLM. - Provides two tools: + Provides three tools: - list_tasks: show all background tasks and their status - cancel_task: kill a running background task by task_id + - push_to_background: escape a currently-blocking sync agent call to background """ def __init__(self, config: DictConfig, **kwargs): super().__init__(config) self._task_manager = None + self._agent_tool = None def set_task_manager(self, task_manager) -> None: self._task_manager = task_manager + def set_agent_tool(self, agent_tool) -> None: + """Wire up the AgentTool so push_to_background can call escape_to_background.""" + self._agent_tool = agent_tool + async def connect(self) -> None: pass diff --git a/ms_agent/utils/stream_writer.py b/ms_agent/utils/stream_writer.py new file mode 100644 index 000000000..46c5cf95b --- /dev/null +++ b/ms_agent/utils/stream_writer.py @@ -0,0 +1,208 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +SubAgentStreamWriter — incremental JSONL writer for sub-agent execution progress. + +When a parent agent calls a sub-agent via AgentTool, the sub-agent produces messages +incrementally. This writer appends new messages to a JSONL file on every chunk event, +allowing external tools (e.g. ``tail -f``) or the parent agent itself to watch the +sub-agent's progress in real time. + +File format (one JSON object per line): + +.. code-block:: text + + {"type": "header", "call_id": "...", "tool_name": "...", "agent_tag": "...", "ts": "..."} + {"type": "message", "index": 0, "message": {...}, "ts": "..."} + {"type": "message", "index": 1, "message": {...}, "ts": "..."} + ... + {"type": "footer", "call_id": "...", "status": "complete", "total_messages": N, "ts": "..."} + +On error the footer's *status* field is ``"error"`` and an ``"error"`` field is included. + +File path: ``{output_dir}/subagents/{call_id}.stream.jsonl`` +""" +import json +import os +import threading +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from ms_agent.utils import get_logger + +logger = get_logger() + + +class SubAgentStreamWriter: + """Thread-safe incremental JSONL writer for sub-agent chunk events. + + Each instance owns exactly one file. It deduplicates history by tracking + how many messages have already been written (``_last_written_count``), so + calling ``on_chunk(full_history)`` multiple times is safe — only the newly + appended messages are written. + + All public methods are safe to call from multiple threads. + """ + + def __init__(self, output_dir: str, call_id: str, tool_name: str) -> None: + self._call_id: str = call_id or 'unknown' + self._tool_name: str = tool_name + self._lock = threading.Lock() + self._last_written_count: int = 0 + self._closed: bool = False + self._agent_tag: Optional[str] = None + self._file = None # opened lazily in on_start + + subagents_dir = os.path.join(output_dir, 'subagents') + os.makedirs(subagents_dir, exist_ok=True) + safe_id = self._call_id.replace('/', '_').replace('\\', '_') + self._path: str = os.path.join(subagents_dir, f'{safe_id}.stream.jsonl') + + @property + def stream_path(self) -> str: + """Absolute path to the JSONL stream file.""" + return self._path + + def on_start(self, agent_tag: Optional[str]) -> None: + """Open the file and write the header record. + + Args: + agent_tag: The sub-agent's tag string, if known at start time. + May be ``None`` when running in a subprocess (tag is + only resolved after the process finishes). + """ + with self._lock: + if self._closed: + return + self._agent_tag = agent_tag + try: + self._file = open(self._path, 'w', encoding='utf-8') + self._write_line({ + 'type': 'header', + 'call_id': self._call_id, + 'tool_name': self._tool_name, + 'agent_tag': agent_tag or '', + 'ts': _now_iso(), + }) + except Exception as exc: + logger.warning( + 'SubAgentStreamWriter: failed to open %s: %s', self._path, exc) + self._file = None + + def on_chunk(self, history: Any) -> None: + """Append only new messages from *history* since the last call. + + Args: + history: The full accumulated message list returned by a streaming + chunk. May be ``None`` or an empty list; in that case + nothing is written. + """ + messages = _coerce_to_list(history) + if not messages: + return + with self._lock: + if self._closed or self._file is None: + return + for msg in messages[self._last_written_count:]: + self._write_line({ + 'type': 'message', + 'index': self._last_written_count, + 'message': _msg_to_dict(msg), + 'ts': _now_iso(), + }) + self._last_written_count += 1 + + def on_end(self, history: Any) -> None: + """Flush any remaining messages, write footer record, then close. + + Args: + history: Final full message list (same shape as ``on_chunk``). + """ + # Write any messages that arrived in the final chunk before closing. + self.on_chunk(history) + with self._lock: + if self._closed: + return + self._closed = True + if self._file is not None: + try: + self._write_line({ + 'type': 'footer', + 'call_id': self._call_id, + 'agent_tag': self._agent_tag or '', + 'status': 'complete', + 'total_messages': self._last_written_count, + 'ts': _now_iso(), + }) + self._file.flush() + self._file.close() + except Exception as exc: + logger.warning( + 'SubAgentStreamWriter: close error on %s: %s', self._path, exc) + finally: + self._file = None + + def on_error(self, error: str) -> None: + """Write an error footer and close the file. + + Args: + error: Human-readable error description. + """ + with self._lock: + if self._closed: + return + self._closed = True + if self._file is not None: + try: + self._write_line({ + 'type': 'footer', + 'call_id': self._call_id, + 'agent_tag': self._agent_tag or '', + 'status': 'error', + 'error': error, + 'total_messages': self._last_written_count, + 'ts': _now_iso(), + }) + self._file.flush() + self._file.close() + except Exception: + pass + finally: + self._file = None + + # ── private helpers ──────────────────────────────────────────────────── + + def _write_line(self, record: Dict[str, Any]) -> None: + """Serialize *record* as JSON and append a newline. + + Caller **must** hold ``self._lock``. Each line is flushed immediately + so that ``tail -f`` sees it without buffering. + """ + if self._file is None: + return + try: + self._file.write(json.dumps(record, ensure_ascii=False) + '\n') + self._file.flush() + except Exception as exc: + logger.warning('SubAgentStreamWriter: write failed: %s', exc) + + +# ── module-level helpers ──────────────────────────────────────────────────── + + +def _now_iso() -> str: + """Return the current UTC time as an ISO-8601 string.""" + return datetime.now(timezone.utc).isoformat() + + +def _coerce_to_list(value: Any) -> List[Any]: + """Return *value* if it is a list, otherwise an empty list.""" + return value if isinstance(value, list) else [] + + +def _msg_to_dict(msg: Any) -> Dict[str, Any]: + """Convert a Message object (or plain dict) to a serialisable dict.""" + if hasattr(msg, 'to_dict'): + return msg.to_dict() + if isinstance(msg, dict): + return msg + return {'role': 'unknown', 'content': str(msg)} diff --git a/tests/utils/test_task_manager_smoke.py b/tests/utils/test_task_manager_smoke.py index 8f57444f3..f29f57321 100644 --- a/tests/utils/test_task_manager_smoke.py +++ b/tests/utils/test_task_manager_smoke.py @@ -234,5 +234,130 @@ async def test_cancel_already_done(self): self.assertIn('already', result) +# --------------------------------------------------------------------------- +# AgentTool escape-to-background (sync_timeout_s + escape_to_background API) +# --------------------------------------------------------------------------- + +class TestAgentToolEscape(unittest.IsolatedAsyncioTestCase): + """Tests for _run_sync_escapable and escape_to_background. + + We mock _run_agent and _launch_background so no real sub-agent is needed. + """ + + def _make_agent_tool(self): + from ms_agent.tools.agent_tool import AgentTool, _AgentToolSpec + from ms_agent.utils.task_manager import TaskManager + from omegaconf import OmegaConf + config = OmegaConf.create({ + 'tag': 'test', + 'output_dir': '/tmp', + 'tools': {}, + }) + tool = AgentTool(config) + tm = TaskManager() + tool.set_task_manager(tm) + return tool, tm + + def _make_spec(self, sync_timeout_s=None): + from ms_agent.tools.agent_tool import _AgentToolSpec + return _AgentToolSpec( + tool_name='test_tool', + description='test', + parameters={}, + config_path=None, + inline_config=None, + server_name='test_server', + tag_prefix='t-', + input_mode='text', + request_field='request', + input_template=None, + output_mode='final_message', + max_output_chars=1000, + trust_remote_code=None, + env=None, + run_in_thread=False, + run_in_process=False, + dynamic=False, + sync_timeout_s=sync_timeout_s, + ) + + async def test_normal_completion(self): + """Without timeout, task completes normally.""" + tool, _ = self._make_agent_tool() + spec = self._make_spec() + + async def fake_run_agent(agent, payload, spec, call_id=None): + return ['result'] + + tool._run_agent = fake_run_agent + result = await tool._run_sync_escapable('payload', spec, 'cid1') + self.assertEqual(result, ['result']) + + async def test_escape_to_background_via_api(self): + """escape_to_background() triggers escape before task completes.""" + tool, tm = self._make_agent_tool() + spec = self._make_spec() + + started = asyncio.Event() + + async def slow_run_agent(agent, payload, spec, call_id=None): + started.set() + await asyncio.sleep(10) # will be cancelled + return ['should not reach'] + + launched = {} + + async def fake_launch_background(payload, spec, call_id): + import json + launched['called'] = True + task_id = tm.register('agent', spec.tool_name, 'escaped') + return json.dumps({'status': 'async_launched', 'task_id': task_id}) + + tool._run_agent = slow_run_agent + tool._launch_background = fake_launch_background + + async def _trigger_escape(): + await started.wait() + tool.escape_to_background('cid2') + + result_task = asyncio.create_task( + tool._run_sync_escapable('payload', spec, 'cid2')) + await asyncio.gather(_trigger_escape(), return_exceptions=True) + result = await result_task + + self.assertIsInstance(result, str) + import json + data = json.loads(result) + self.assertEqual(data['status'], 'async_launched') + self.assertTrue(launched.get('called')) + + async def test_escape_to_background_unknown_call_id(self): + """escape_to_background returns False for unknown call_id.""" + tool, _ = self._make_agent_tool() + self.assertFalse(tool.escape_to_background('nonexistent')) + + async def test_sync_timeout_triggers_escape(self): + """sync_timeout_s causes auto-escape when task takes too long.""" + tool, tm = self._make_agent_tool() + spec = self._make_spec(sync_timeout_s=0.05) + + async def slow_run_agent(agent, payload, spec, call_id=None): + await asyncio.sleep(10) + return ['should not reach'] + + async def fake_launch_background(payload, spec, call_id): + import json + task_id = tm.register('agent', spec.tool_name, 'timed out') + return json.dumps({'status': 'async_launched', 'task_id': task_id}) + + tool._run_agent = slow_run_agent + tool._launch_background = fake_launch_background + + result = await tool._run_sync_escapable('payload', spec, 'cid3') + import json + data = json.loads(result) + self.assertEqual(data['status'], 'async_launched') + + if __name__ == '__main__': unittest.main() From 6a1ba9f6c872ceb4a8c1f94a5b4228448d18b04d Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Thu, 9 Apr 2026 19:29:29 +0800 Subject: [PATCH 28/40] fix lint --- ms_agent/agent/llm_agent.py | 4 +-- ms_agent/llm/openai_llm.py | 68 +++++++++++++++++++------------------ ms_agent/llm/utils.py | 9 ++--- 3 files changed, 42 insertions(+), 39 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 3b75ceda3..0dd7d04fe 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -937,8 +937,8 @@ async def step( # Handle reasoning summaries that arrive after content if self.show_reasoning and _response_message is not None: - final_reasoning = getattr( - _response_message, 'reasoning_content', '') or '' + final_reasoning = getattr(_response_message, + 'reasoning_content', '') or '' if final_reasoning and not _printed_reasoning_header: self._write_thinking_header() self._write_reasoning(final_reasoning, dim=True) diff --git a/ms_agent/llm/openai_llm.py b/ms_agent/llm/openai_llm.py index e9e1679e3..fa2df6004 100644 --- a/ms_agent/llm/openai_llm.py +++ b/ms_agent/llm/openai_llm.py @@ -3,9 +3,8 @@ from copy import deepcopy from typing import Any, Dict, Generator, Iterable, List, Optional -import json - import httpx +import json from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool, ToolCall from ms_agent.utils import (MAX_CONTINUE_RUNS, assert_package_exist, @@ -29,8 +28,8 @@ class _DashScopeResponsesTransport(httpx.HTTPTransport): def handle_request(self, request): if b'/v1/responses' in request.url.raw_path: - new_path = request.url.raw_path.replace( - b'/v1/responses', b'/v1/chat/completions') + new_path = request.url.raw_path.replace(b'/v1/responses', + b'/v1/chat/completions') request.url = request.url.copy_with(raw_path=new_path) return super().handle_request(request) @@ -85,7 +84,8 @@ def __init__( getattr(config, 'generation_config', DictConfig({}))) # Responses API support - self._use_responses_api = bool(self.args.get('use_responses_api', False)) + self._use_responses_api = bool( + self.args.get('use_responses_api', False)) self._responses_client = None self._responses_state_mode = str( self.args.get('responses_state_mode', 'stateless')).lower() @@ -93,8 +93,8 @@ def __init__( self._responses_state_mode = 'previous_response_id' if self._use_responses_api: - self._is_dashscope = bool( - base_url and 'dashscope' in base_url.lower()) + self._is_dashscope = bool(base_url + and 'dashscope' in base_url.lower()) if self._is_dashscope: http_client = httpx.Client( transport=_DashScopeResponsesTransport(), @@ -618,10 +618,11 @@ def _build_responses_input( if self._responses_state_mode != 'previous_response_id': # Stateless mode needs explicit passback of opaque reasoning # items returned by the previous response. - for raw_item in getattr( - msg, '_responses_output_items', []): + for raw_item in getattr(msg, '_responses_output_items', + []): items.append(raw_item) - if msg.content and not self._is_responses_tool_placeholder(msg): + if msg.content and not self._is_responses_tool_placeholder( + msg): items.append({ 'role': 'assistant', 'content': msg.content, @@ -657,11 +658,11 @@ def _build_responses_input( @staticmethod def _is_responses_tool_placeholder(message: Message) -> bool: """Return True for framework-generated assistant placeholder text.""" - return bool(message.tool_calls) and message.content == 'Let me do a tool calling.' + return bool(message.tool_calls + ) and message.content == 'Let me do a tool calling.' def _prepare_responses_request( - self, - messages: List[Message], + self, messages: List[Message], args: Dict[str, Any]) -> tuple[List[Message], Dict[str, Any]]: """Prepare message slice and request args for Responses API calls.""" request_args = dict(args) @@ -741,7 +742,8 @@ def _extract_reasoning_summaries_from_response(response) -> str: return '\n'.join(parts) @staticmethod - def _extract_tool_calls_from_response(response) -> Optional[List[ToolCall]]: + def _extract_tool_calls_from_response( + response) -> Optional[List[ToolCall]]: """Extract tool calls from a completed Responses API object.""" tool_calls: List[ToolCall] = [] for item in getattr(response, 'output', []) or []: @@ -751,7 +753,8 @@ def _extract_tool_calls_from_response(response) -> Optional[List[ToolCall]]: arguments = json.dumps(arguments, ensure_ascii=False) tool_calls.append( ToolCall( - id=getattr(item, 'call_id', '') or getattr(item, 'id', ''), + id=getattr(item, 'call_id', '') + or getattr(item, 'id', ''), index=len(tool_calls), type='function', tool_name=getattr(item, 'name', ''), @@ -779,7 +782,8 @@ def _to_jsonable(value: Any) -> Any: return [OpenAI._to_jsonable(item) for item in value] if isinstance(value, dict): return { - key: OpenAI._to_jsonable(item) for key, item in value.items() + key: OpenAI._to_jsonable(item) + for key, item in value.items() } if hasattr(value, 'model_dump'): return OpenAI._to_jsonable(value.model_dump()) @@ -798,9 +802,10 @@ def _collect_passback_items(self, response) -> List[Dict[str, Any]]: item_type = getattr(item, 'type', None) if item_type == 'reasoning': passback_item: Dict[str, Any] = { - 'type': 'reasoning', - 'summary': self._to_jsonable( - getattr(item, 'summary', []) or []), + 'type': + 'reasoning', + 'summary': + self._to_jsonable(getattr(item, 'summary', []) or []), } encrypted_content = getattr(item, 'encrypted_content', None) if encrypted_content: @@ -812,11 +817,10 @@ def _collect_passback_items(self, response) -> List[Dict[str, Any]]: items.append(passback_item) return items - def _responses_generate( - self, - messages: List[Message], - tools: Optional[List[Tool]] = None, - **args) -> Message: + def _responses_generate(self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + **args) -> Message: """Non-streaming Responses API call.""" request_messages, request_args = self._prepare_responses_request( messages, args) @@ -859,11 +863,10 @@ def _extract_reasoning_from_item(item) -> str: parts.append(text) return '\n'.join(parts) - def _responses_stream_generate( - self, - messages: List[Message], - tools: Optional[List[Tool]] = None, - **args) -> Generator[Message, None, None]: + def _responses_stream_generate(self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + **args) -> Generator[Message, None, None]: """Streaming Responses API call. Yields incremental ``Message`` objects. Reasoning summaries are @@ -929,8 +932,8 @@ def _responses_stream_generate( elif event_type == 'response.failed': failed_response = getattr(event, 'response', None) failed_error = getattr(failed_response, 'error', None) - response_error_msg = getattr( - failed_error, 'message', '') or str(failed_error) + response_error_msg = getattr(failed_error, 'message', + '') or str(failed_error) if final_response: if not reasoning_parts: @@ -952,8 +955,7 @@ def _responses_stream_generate( current_message.id = getattr(final_response, 'id', '') yield current_message elif response_error_msg: - logger.error( - f'Responses API failed: {response_error_msg}') + logger.error(f'Responses API failed: {response_error_msg}') raise RuntimeError( f'Responses API call failed: {response_error_msg}') diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 1f723cfa7..6de431410 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -42,8 +42,7 @@ class Message: # Opaque output items from the Responses API that must be passed back # in multi-turn tool-calling conversations (e.g. reasoning items). - _responses_output_items: List[Dict[str, Any]] = field( - default_factory=list) + _responses_output_items: List[Dict[str, Any]] = field(default_factory=list) # request id id: str = '' @@ -93,8 +92,10 @@ def to_dict_clean(self): } } required = ['content', 'role'] - rm = ['completion_tokens', 'prompt_tokens', 'api_calls', - '_responses_output_items'] + rm = [ + 'completion_tokens', 'prompt_tokens', 'api_calls', + '_responses_output_items' + ] return { key: value for key, value in raw_dict.items() From b3feb1acfeaf6683c35591deb50d43fb5719c9fa Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 13 Apr 2026 14:54:43 +0800 Subject: [PATCH 29/40] feat: workspace policy, shell artifacts, TaskManager, grep/glob tools Add WorkspacePolicyKernel (allow-roots from output_dir), ArtifactManager for large shell outputs, TaskManager with shell background support and asyncio process kill, WorkspaceSearchTool (grep_files/glob_files). Wire TaskManager into LLMAgent (prepare_tools, cleanup, task notifications in step) and extend LocalCodeExecutionTool with policy checks, artifact spill, run_in_background shell, sh -lc wrapping. ToolManager registers WorkspaceSearchTool by default and injects __call_id for shell_executor. Add tests for workspace policy. Document implementation map in shell-grep-glob-workspace-policy.md. Made-with: Cursor --- ms_agent/agent/llm_agent.py | 25 +- ms_agent/tools/code/local_code_executor.py | 202 ++++++++- ms_agent/tools/tool_manager.py | 13 + ms_agent/tools/workspace_search_tool.py | 475 +++++++++++++++++++++ ms_agent/utils/artifact_manager.py | 126 ++++++ ms_agent/utils/task_manager.py | 135 ++++++ ms_agent/utils/workspace_policy.py | 207 +++++++++ shell-grep-glob-workspace-policy.md | 225 ++++++++++ tests/utils/test_workspace_policy.py | 78 ++++ 9 files changed, 1463 insertions(+), 23 deletions(-) create mode 100644 ms_agent/tools/workspace_search_tool.py create mode 100644 ms_agent/utils/artifact_manager.py create mode 100644 ms_agent/utils/task_manager.py create mode 100644 ms_agent/utils/workspace_policy.py create mode 100644 shell-grep-glob-workspace-policy.md create mode 100644 tests/utils/test_workspace_policy.py diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 12bb0dcca..0e94b4256 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -22,6 +22,7 @@ from ms_agent.rag.utils import rag_mapping from ms_agent.tools import ToolManager from ms_agent.utils import async_retry, read_history, save_history +from ms_agent.utils.task_manager import TaskManager from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger from ms_agent.utils.snapshot import take_snapshot @@ -106,6 +107,7 @@ def __init__( super().__init__(config, tag, trust_remote_code) self.callbacks: List[Callback] = [] self.tool_manager: Optional[ToolManager] = None + self.task_manager: Optional[TaskManager] = None self.memory_tools: List[Memory] = [] self.rag: Optional[RAG] = None self.knowledge_search: Optional[SirschmunkSearch] = None @@ -365,7 +367,7 @@ def rollback(self, commit_hash: str) -> bool: saved_messages[:message_count]) # Clear read cache on FileSystemTool so stale entries don't block edits if self.tool_manager is not None: - for tool in self.tool_manager.tools.values(): + for tool in self.tool_manager.extra_tools: if hasattr(tool, '_read_cache'): tool._read_cache.clear() return True @@ -570,6 +572,7 @@ async def parallel_tool_call(self, async def prepare_tools(self): """Initialize and connect the tool manager.""" + self.task_manager = TaskManager() self.tool_manager = ToolManager( self.config, self.mcp_config, @@ -577,10 +580,16 @@ async def prepare_tools(self): trust_remote_code=self.trust_remote_code, ) await self.tool_manager.connect() + for tool in self.tool_manager.extra_tools: + if hasattr(tool, 'set_task_manager'): + tool.set_task_manager(self.task_manager) async def cleanup_tools(self): """Cleanup resources used by the tool manager.""" - await self.tool_manager.cleanup() + if self.task_manager is not None: + self.task_manager.kill_all() + if self.tool_manager is not None: + await self.tool_manager.cleanup() @property def stream(self): @@ -866,6 +875,17 @@ def handle_new_response(self, messages: List[Message], and response_message.tool_calls): messages[-1].content = 'Let me do a tool calling.' + def _append_task_notifications(self, messages: List[Message]) -> List[Message]: + """Inject drained TaskManager completion notices as a user message.""" + if self.task_manager is None: + return messages + notes = self.task_manager.drain_notifications() + if not notes: + return messages + body = '[Background task updates]\n' + '\n'.join(notes) + messages.append(Message(role='user', content=body)) + return messages + @async_retry(max_attempts=Agent.retry_count, delay=1.0) async def step( self, messages: List[Message] @@ -893,6 +913,7 @@ async def step( List[Message]: Updated message history after this step. """ messages = deepcopy(messages) + messages = self._append_task_notifications(messages) if (not self.load_cache) or messages[-1].role != 'assistant': messages = await self.condense_memory(messages) await self.on_generate_response(messages) diff --git a/ms_agent/tools/code/local_code_executor.py b/ms_agent/tools/code/local_code_executor.py index 65de0556e..d1e2104c5 100644 --- a/ms_agent/tools/code/local_code_executor.py +++ b/ms_agent/tools/code/local_code_executor.py @@ -3,18 +3,21 @@ import inspect import io import os +import shlex import shutil import time from contextlib import redirect_stderr, redirect_stdout from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger +from ms_agent.utils.artifact_manager import ArtifactManager from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR from ms_agent.utils.utils import install_package +from ms_agent.utils.workspace_policy import WorkspacePolicyError, WorkspacePolicyKernel logger = get_logger() @@ -231,7 +234,7 @@ class LocalCodeExecutionTool(ToolBase): def __init__(self, config): super().__init__(config) self.output_dir = Path( - getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR)).expanduser() + getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR)).expanduser().resolve() self.output_dir.mkdir(parents=True, exist_ok=True) self.tool_config = getattr( @@ -250,6 +253,37 @@ def __init__(self, config): self.shell_env = shell_env self._kernel_lock = asyncio.Lock() self._initialized = False + self._task_manager = None + self._watcher_tasks: Set[asyncio.Task] = set() + + wp = getattr(getattr(config, 'tools', None), 'workspace_policy', None) + extra_allow: List[str] = [] + deny_globs = None + if wp is not None: + extra_allow = list(getattr(wp, 'allow_roots', []) or []) + dg = getattr(wp, 'deny_globs', None) + if dg: + deny_globs = list(dg) + shell_cfg = getattr(self.tool_config, 'shell', None) if self.tool_config else None + shell_mode = getattr(shell_cfg, 'default_mode', + 'workspace_write') if shell_cfg else 'workspace_write' + net = bool(getattr(shell_cfg, 'network_enabled', False) + ) if shell_cfg else False + max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192) + ) if shell_cfg else 8192 + self._policy = WorkspacePolicyKernel( + self.output_dir, + extra_allow_roots=extra_allow, + deny_globs=deny_globs, + shell_default_mode=str(shell_mode), + shell_network_enabled=net, + max_command_chars=max_cmd, + ) + max_kb = 256 + if shell_cfg and getattr(shell_cfg, 'max_output_kb', None): + max_kb = int(shell_cfg.max_output_kb) + self._artifacts = ArtifactManager( + self.output_dir, max_combined_bytes=max_kb * 1024) self.exclude_func( getattr(getattr(config, 'tools', None), 'code_executor', None)) @@ -264,6 +298,10 @@ def __init__(self, config): logger.info('LocalCodeExecutionTool initialized (ipykernel based)') + def set_task_manager(self, task_manager) -> None: + """Attach process-wide TaskManager for background shell (see shell_executor).""" + self._task_manager = task_manager + def _check_dependencies(self) -> None: import importlib @@ -341,6 +379,10 @@ async def connect(self) -> None: self._initialized = True async def cleanup(self) -> None: + for t in list(self._watcher_tasks): + if not t.done(): + t.cancel() + self._watcher_tasks.clear() if not self._initialized: return await self.kernel_session.stop() @@ -421,9 +463,12 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='shell_executor', server_name='code_executor', - description=('Execute shell commands locally using bash. ' - 'Supports basic shell operations like ls, ' - 'cd, mkdir, rm, etc. '), + description=( + 'Execute shell commands locally under the workspace output directory (cwd). ' + 'Subject to policy (read_only vs workspace_write, network toggle). ' + 'Large stdout/stderr may be spilled to .ms_agent_artifacts. ' + 'Use run_in_background=true to return immediately with task_id; poll via task notifications.' + ), parameters={ 'type': 'object', 'properties': { @@ -435,7 +480,17 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'integer', 'description': 'Execution timeout in seconds', 'default': self._shell_timeout - } + }, + 'run_in_background': { + 'type': 'boolean', + 'description': + 'If true, start the command asynchronously and return task_id (requires TaskManager).', + 'default': False, + }, + '__call_id': { + 'type': 'string', + 'description': 'Optional correlation id (injected by host when supported).', + }, }, 'required': ['command'], 'additionalProperties': False @@ -630,16 +685,117 @@ def _exec_code(): async def shell_executor(self, command: str, - timeout: Optional[int] = None) -> str: + timeout: Optional[int] = None, + run_in_background: bool = False, + __call_id: Optional[str] = None) -> str: exec_timeout = timeout or self._shell_timeout + call_id = __call_id or f'shell-{os.urandom(4).hex()}' + + try: + self._policy.assert_shell_command_allowed(command) + except WorkspacePolicyError as e: + return json.dumps( + { + 'success': False, + 'error': str(e) + }, + ensure_ascii=False, + indent=2, + ) + + shell_cmd = self._prepare_shell_command(command) + + if run_in_background: + if self._task_manager is None: + return json.dumps( + { + 'success': False, + 'error': + 'run_in_background requires TaskManager (host must wire LLMAgent.task_manager).', + }, + ensure_ascii=False, + indent=2, + ) + try: + process = await asyncio.create_subprocess_shell( + shell_cmd, + stdout=ai_subprocess.PIPE, + stderr=ai_subprocess.PIPE, + cwd=str(self._policy.workspace_root), + env=self.shell_env, + ) + except FileNotFoundError as exc: + return json.dumps( + { + 'success': False, + 'error': f'Shell not available: {exc}' + }, + ensure_ascii=False, + indent=2, + ) + + task_id = self._task_manager.register( + task_type='shell', + tool_name='shell_executor', + description=command[:200], + proc=process, + ) + + async def _watcher() -> None: + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=exec_timeout) + stdout_text = _coerce_str(stdout).strip('\n') + stderr_text = _coerce_str(stderr).strip('\n') + success = process.returncode == 0 + payload = { + 'success': success, + 'output': stdout_text, + 'error': stderr_text or None, + 'return_code': process.returncode, + } + text = self._artifacts.pack_json_shell_result( + tool_name='shell_executor', + call_id=task_id, + payload=payload, + ) + await self._task_manager.complete(task_id, text) + except asyncio.TimeoutError: + try: + process.kill() + await process.communicate() + except Exception: # noqa: B902 + pass + await self._task_manager.fail( + task_id, + f'Shell command timed out after {exec_timeout} seconds', + ) + except Exception as exc: # noqa: B902 + await self._task_manager.fail(task_id, str(exc)) + + t = asyncio.create_task(_watcher()) + self._watcher_tasks.add(t) + t.add_done_callback(self._watcher_tasks.discard) + + return json.dumps( + { + 'status': 'async_launched', + 'task_id': task_id, + 'tool_name': 'shell_executor', + 'call_id': call_id, + }, + ensure_ascii=False, + indent=2, + ) try: process = await asyncio.create_subprocess_shell( - command, + shell_cmd, stdout=ai_subprocess.PIPE, stderr=ai_subprocess.PIPE, - cwd=str(self.output_dir), - env=self.shell_env) + cwd=str(self._policy.workspace_root), + env=self.shell_env, + ) except FileNotFoundError as exc: return json.dumps( { @@ -647,7 +803,8 @@ async def shell_executor(self, 'error': f'Shell not available: {exc}' }, ensure_ascii=False, - indent=2) + indent=2, + ) try: stdout, stderr = await asyncio.wait_for( @@ -666,20 +823,23 @@ async def shell_executor(self, f'Shell command timed out after {exec_timeout} seconds' }, ensure_ascii=False, - indent=2) + indent=2, + ) stdout_text = _coerce_str(stdout).strip('\n') stderr_text = _coerce_str(stderr).strip('\n') success = process.returncode == 0 - return json.dumps( - { - 'success': success, - 'output': stdout_text, - 'error': stderr_text or None, - 'return_code': process.returncode - }, - ensure_ascii=False, - indent=2) + payload = { + 'success': success, + 'output': stdout_text, + 'error': stderr_text or None, + 'return_code': process.returncode, + } + return self._artifacts.pack_json_shell_result( + tool_name='shell_executor', + call_id=call_id, + payload=payload, + ) async def file_operation(self, operation: str, diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 58f019774..3c8f21339 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -88,6 +88,12 @@ def __init__(self, self.extra_tools.append(TodoListTool(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'web_search'): self.extra_tools.append(WebSearchTool(config)) + ws = getattr(getattr(config, 'tools', None), 'workspace_search', None) + _ws_enabled = True if ws is None else bool(getattr(ws, 'enabled', True)) + if _ws_enabled: + from ms_agent.tools.workspace_search_tool import WorkspaceSearchTool + + self.extra_tools.append(WorkspaceSearchTool(config)) self.tool_call_timeout = getattr(config, 'tool_call_timeout', TOOL_CALL_TIMEOUT) local_dir = self.config.local_dir if hasattr(self.config, @@ -226,6 +232,13 @@ async def single_call_tool(self, tool_info: ToolCall): call_args = dict(tool_args or {}) call_id = tool_info.get('id') or str(uuid.uuid4()) call_args['__call_id'] = call_id + elif isinstance( + tool_ins, + LocalCodeExecutionTool) and tool_name.endswith( + f'{self.TOOL_SPLITER}shell_executor'): + call_args = dict(tool_args or {}) + call_args['__call_id'] = tool_info.get('id') or str( + uuid.uuid4()) response = await asyncio.wait_for( tool_ins.call_tool( server_name, diff --git a/ms_agent/tools/workspace_search_tool.py b/ms_agent/tools/workspace_search_tool.py new file mode 100644 index 000000000..c5559f8f9 --- /dev/null +++ b/ms_agent/tools/workspace_search_tool.py @@ -0,0 +1,475 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Read-only workspace search: grep (rg or Python fallback) and glob.""" + +from __future__ import annotations + +import asyncio +import fnmatch +import json +import os +import re +import shutil +from pathlib import Path +from typing import Any, Dict, List, Optional + +from ms_agent.llm.utils import Tool +from ms_agent.tools.base import ToolBase +from ms_agent.utils import get_logger +from ms_agent.utils.artifact_manager import ArtifactManager +from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR +from ms_agent.utils.workspace_policy import WorkspacePolicyError, WorkspacePolicyKernel + +logger = get_logger() + +_TEXT_SUFFIXES = { + '.py', '.md', '.txt', '.yaml', '.yml', '.json', '.toml', '.cfg', '.ini', + '.sh', '.bash', '.js', '.ts', '.tsx', '.jsx', '.css', '.html', '.xml', + '.rs', '.go', '.java', '.c', '.h', '.cpp', '.hpp', '.cs', '.rb', '.php', + '.sql', '.vue', '.svelte', '.m', '.swift', '.kt', '.gradle', '.properties', + '.env', '.gitignore', '.dockerignore', 'Dockerfile', +} + + +class WorkspaceSearchTool(ToolBase): + """Grep and glob under output_dir (+ optional extra roots) with shared policy.""" + + def __init__(self, config, **kwargs): + super().__init__(config) + self.output_dir = Path( + getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR)).expanduser().resolve() + self.output_dir.mkdir(parents=True, exist_ok=True) + + wp = getattr(getattr(config, 'tools', None), 'workspace_policy', None) + extra = [] + deny: list[str] = [] + if wp is not None: + extra = list(getattr(wp, 'allow_roots', []) or []) + deny = list(getattr(wp, 'deny_globs', []) or []) + else: + deny = [] + ws = getattr(getattr(config, 'tools', None), 'workspace_search', None) + self._default_head = int(getattr(ws, 'default_head_limit', 250) or 250) + self._glob_max = int(getattr(ws, 'max_files', 100) or 100) + self._grep_timeout = int(getattr(ws, 'grep_timeout_s', 120) or 120) + + shell_cfg = getattr( + getattr(config.tools, 'code_executor', None), 'shell', None) + shell_mode = getattr(shell_cfg, 'default_mode', + 'workspace_write') if shell_cfg else 'workspace_write' + net = bool(getattr(shell_cfg, 'network_enabled', False) + ) if shell_cfg else False + max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192) + ) if shell_cfg else 8192 + + self._policy = WorkspacePolicyKernel( + self.output_dir, + extra_allow_roots=extra, + deny_globs=deny if deny else None, + shell_default_mode=str(shell_mode), + shell_network_enabled=net, + max_command_chars=max_cmd, + ) + max_kb = 256 + if shell_cfg and getattr(shell_cfg, 'max_output_kb', None): + max_kb = int(shell_cfg.max_output_kb) + self._artifacts = ArtifactManager( + self.output_dir, max_combined_bytes=max_kb * 1024) + + self.exclude_func(ws) + + async def connect(self) -> None: + return + + async def _get_tools_inner(self) -> Dict[str, Any]: + return { + 'workspace_search': [ + Tool( + tool_name='grep_files', + server_name='workspace_search', + description=( + 'Search file contents under the workspace using ripgrep when available, ' + 'otherwise a safe Python scan. Paths must stay under the configured output/workspace roots. ' + 'Read-only.' + ), + parameters={ + 'type': 'object', + 'properties': { + 'pattern': { + 'type': 'string', + 'description': 'Regular expression (Rust regex if rg is used).', + }, + 'path': { + 'type': 'string', + 'description': + 'Directory or file to search (relative to output_dir if not absolute). Default ".".', + }, + 'glob': { + 'type': 'string', + 'description': 'Optional glob filter for files, e.g. "*.py"', + }, + 'output_mode': { + 'type': 'string', + 'enum': ['content', 'files_with_matches', 'count'], + 'description': 'content: matching lines; files_with_matches: paths only; count: per-file counts', + }, + 'head_limit': { + 'type': 'integer', + 'description': 'Max lines (content) or paths/count entries to return', + }, + 'offset': { + 'type': 'integer', + 'description': 'Skip first N lines/entries after collect', + }, + 'case_insensitive': { + 'type': 'boolean', + 'description': 'Case-insensitive search', + }, + }, + 'required': ['pattern'], + 'additionalProperties': False, + }, + ), + Tool( + tool_name='glob_files', + server_name='workspace_search', + description=( + 'List files under a workspace directory matching a glob pattern ' + '(e.g. "**/*.py", "*.md"). Read-only; results are capped.' + ), + parameters={ + 'type': 'object', + 'properties': { + 'pattern': { + 'type': 'string', + 'description': 'Glob pattern relative to path', + }, + 'path': { + 'type': 'string', + 'description': 'Base directory (relative to output_dir if not absolute).', + }, + }, + 'required': ['pattern'], + 'additionalProperties': False, + }, + ), + ] + } + + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: + return await getattr(self, tool_name)(**tool_args) + + async def grep_files( + self, + pattern: str, + path: str = '.', + glob: Optional[str] = None, + output_mode: str = 'files_with_matches', + head_limit: Optional[int] = None, + offset: Optional[int] = None, + case_insensitive: bool = False, + ) -> str: + call_id = f'grep-{pattern[:40]}' + head_limit = head_limit if head_limit is not None else self._default_head + offset = offset or 0 + path = path or '.' + try: + root = self._policy.resolve_under_roots(path) + except WorkspacePolicyError as e: + return json.dumps({'success': False, 'error': str(e)}, indent=2) + + lines: List[str] = [] + try: + rg = shutil.which('rg') + if rg and root.is_file(): + lines = await self._rg_file(rg, pattern, root, case_insensitive, + output_mode, head_limit, offset, + glob) + elif rg and root.is_dir(): + lines = await self._rg_dir(rg, pattern, root, case_insensitive, + output_mode, head_limit, offset, + glob) + else: + lines = self._python_grep( + pattern, + root, + glob, + output_mode, + head_limit, + offset, + case_insensitive, + ) + except Exception as e: + logger.warning('grep_files failed: %s', e, exc_info=True) + return json.dumps({'success': False, 'error': str(e)}, indent=2) + + text = '\n'.join(lines) + packed = self._artifacts.pack_text_result( + tool_name='grep_files', + call_id=call_id, + stdout=text, + stderr='', + extra={ + 'success': True, + 'output_mode': output_mode, + 'num_lines': len(lines), + }, + ) + return json.dumps(packed, ensure_ascii=False, indent=2, default=str) + + async def _rg_file( + self, + rg: str, + pattern: str, + file_path: Path, + case_insensitive: bool, + output_mode: str, + head_limit: int, + offset: int, + glob: Optional[str], + ) -> List[str]: + args = [rg, '--no-heading', '--color', 'never'] + if case_insensitive: + args.append('-i') + if glob: + args.extend(['--glob', glob]) + if output_mode == 'files_with_matches': + args.extend(['-l', pattern, str(file_path)]) + elif output_mode == 'count': + args.extend(['-c', pattern, str(file_path)]) + else: + args.extend(['-n', pattern, str(file_path)]) + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=str(self._policy.workspace_root), + ) + out_b, err_b = await asyncio.wait_for(proc.communicate(), + timeout=self._grep_timeout) + out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') + err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') + if proc.returncode not in (0, 1): + raise RuntimeError(err or f'rg exited {proc.returncode}') + lines = [ln for ln in out.split('\n') if ln] if out else [] + return _apply_offset_limit(lines, offset, head_limit) + + async def _rg_dir( + self, + rg: str, + pattern: str, + root: Path, + case_insensitive: bool, + output_mode: str, + head_limit: int, + offset: int, + glob: Optional[str], + ) -> List[str]: + args = [rg, '--no-heading', '--color', 'never'] + if case_insensitive: + args.append('-i') + if glob: + args.extend(['--glob', glob]) + if output_mode == 'files_with_matches': + args.extend(['-l', pattern, str(root)]) + elif output_mode == 'count': + args.extend(['--count-matches', pattern, str(root)]) + else: + args.extend(['-n', pattern, str(root)]) + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=str(self._policy.workspace_root), + ) + out_b, err_b = await asyncio.wait_for(proc.communicate(), + timeout=self._grep_timeout) + out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') + err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') + if proc.returncode not in (0, 1): + raise RuntimeError(err or f'rg exited {proc.returncode}') + lines = [ln for ln in out.split('\n') if ln] if out else [] + return _apply_offset_limit(lines, offset, head_limit) + + def _python_grep( + self, + pattern: str, + root: Path, + glob_pat: Optional[str], + output_mode: str, + head_limit: int, + offset: int, + case_insensitive: bool, + ) -> List[str]: + flags = re.IGNORECASE if case_insensitive else 0 + try: + rx = re.compile(pattern, flags) + except re.error as e: + return [f'[error] invalid regex: {e}'] + lines_out: List[str] = [] + counts: Dict[str, int] = {} + + def consider_file(fp: Path) -> bool: + if glob_pat: + rel = str(fp.relative_to(root)) if root.is_dir() else fp.name + if not fnmatch.fnmatch(fp.name, glob_pat) and not fnmatch.fnmatch( + rel, glob_pat): + return False + suf = fp.suffix.lower() + if suf not in _TEXT_SUFFIXES and fp.suffix == '': + if fp.name not in ('Dockerfile', 'Makefile', 'README'): + return False + return fp.is_file() + + files: List[Path] = [] + if root.is_file(): + files = [root] + else: + for fp in _walk_files_limited(root, self._policy.deny_globs, 50_000): + if consider_file(fp): + files.append(fp) + + for fp in files: + try: + text = fp.read_text(encoding='utf-8', errors='replace') + except OSError: + continue + rel = str(fp.relative_to(self._policy.workspace_root)) if _is_relative( + fp, self._policy.workspace_root) else str(fp) + if output_mode == 'files_with_matches': + if rx.search(text): + lines_out.append(rel) + elif output_mode == 'count': + n = len(rx.findall(text)) + if n: + counts[rel] = n + else: + for i, line in enumerate(text.splitlines(), start=1): + if rx.search(line): + lines_out.append(f'{rel}:{i}:{line}') + if len(lines_out) >= head_limit + offset + 5000: + break + + if output_mode == 'count': + lines_out = [f'{k}:{v}' for k, v in sorted(counts.items())] + return _apply_offset_limit(lines_out, offset, head_limit) + + async def glob_files(self, pattern: str, path: str = '') -> str: + call_id = f'glob-{pattern[:40]}' + try: + base = self._policy.resolve_under_roots(path or '.') + except WorkspacePolicyError as e: + return json.dumps({'success': False, 'error': str(e)}, indent=2) + + if not base.is_dir(): + return json.dumps( + { + 'success': False, + 'error': f'Not a directory: {path}', + }, + indent=2, + ) + + matches: List[str] = [] + truncated = False + deny = self._policy.deny_globs + + # Prefer pathlib.glob from base + try: + for p in sorted(base.glob(pattern)): + if not p.is_file(): + continue + rp = p.resolve() + if not self._policy.path_is_allowed(rp): + continue + if _is_denied_path(rp, base, deny): + continue + rel = str(p.relative_to(self._policy.workspace_root)) if _is_relative( + p, self._policy.workspace_root) else str(p) + matches.append(rel) + if len(matches) >= self._glob_max: + truncated = True + break + except ValueError: + # invalid pattern + return json.dumps( + { + 'success': False, + 'error': 'Invalid glob pattern', + }, + indent=2, + ) + + text = json.dumps( + { + 'success': True, + 'num_files': len(matches), + 'filenames': matches, + 'truncated': truncated, + }, + ensure_ascii=False, + indent=2, + ) + packed = self._artifacts.pack_text_result( + tool_name='glob_files', + call_id=call_id, + stdout=text, + stderr='', + extra={'success': True}, + ) + return json.dumps(packed, ensure_ascii=False, indent=2, default=str) + + +def _apply_offset_limit(lines: List[str], offset: int, + head_limit: int) -> List[str]: + if offset: + lines = lines[offset:] + if head_limit and head_limit > 0: + lines = lines[:head_limit] + return lines + + +def _is_relative(path: Path, base: Path) -> bool: + try: + path.relative_to(base) + return True + except ValueError: + return False + + +def _is_denied_path(path: Path, root: Path, deny: tuple[str, ...]) -> bool: + if not deny: + return False + try: + rel = path.relative_to(root).as_posix() + except ValueError: + rel = path.as_posix() + for pat in deny: + if fnmatch.fnmatch(rel, pat): + return True + return False + + +def _walk_files_limited(root: Path, deny: tuple[str, ...], + max_files: int) -> List[Path]: + out: List[Path] = [] + for dirpath, dirnames, filenames in os.walk( + root, topdown=True, followlinks=False): + dp = Path(dirpath) + pruned = [] + for d in list(dirnames): + child = dp / d + try: + rel = child.relative_to(root).as_posix() + except ValueError: + rel = child.as_posix() + skip = any(fnmatch.fnmatch(rel, p) for p in deny) + if skip: + continue + pruned.append(d) + dirnames[:] = pruned + for name in filenames: + out.append(dp / name) + if len(out) >= max_files: + return out + return out diff --git a/ms_agent/utils/artifact_manager.py b/ms_agent/utils/artifact_manager.py new file mode 100644 index 000000000..9953f25c0 --- /dev/null +++ b/ms_agent/utils/artifact_manager.py @@ -0,0 +1,126 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Spill large tool outputs to disk under output_dir/.ms_agent_artifacts/.""" + +from __future__ import annotations + +import hashlib +import json +from pathlib import Path +from typing import Any + + +class ArtifactManager: + """When combined stdout+stderr exceeds *max_combined_bytes*, write to artifact file.""" + + def __init__( + self, + output_dir: Path | str, + *, + max_combined_bytes: int = 256 * 1024, + preview_head_chars: int = 4000, + preview_tail_chars: int = 2000, + artifact_subdir: str = '.ms_agent_artifacts', + ) -> None: + self._root = Path(output_dir).expanduser().resolve() + self.max_combined_bytes = max_combined_bytes + self.preview_head_chars = preview_head_chars + self.preview_tail_chars = preview_tail_chars + self._artifact_root = self._root / artifact_subdir + + def pack_text_result( + self, + *, + tool_name: str, + call_id: str, + stdout: str, + stderr: str, + extra: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Return payload fields: output, error, truncated, artifact_path (optional).""" + combined = (stdout or '') + (stderr or '') + enc = combined.encode('utf-8', errors='replace') + if len(enc) <= self.max_combined_bytes: + out: dict[str, Any] = { + 'output': stdout, + 'error': stderr or None, + 'truncated': False, + } + if extra: + out.update(extra) + return out + + safe_id = ''.join(c if c.isalnum() or c in '-_' else '_' for c in call_id + )[:120] or 'call' + rel_dir = Path(tool_name) / safe_id + out_dir = self._artifact_root / rel_dir + out_dir.mkdir(parents=True, exist_ok=True) + body = f'=== STDOUT ===\n{stdout}\n\n=== STDERR ===\n{stderr}\n' + digest = hashlib.sha256(enc).hexdigest()[:16] + fname = f'combined-{digest}.txt' + fpath = out_dir / fname + fpath.write_text(body, encoding='utf-8', errors='replace') + rel = fpath.relative_to(self._root).as_posix() + preview = _make_preview(body, self.preview_head_chars, + self.preview_tail_chars) + result = { + 'output': stdout[:self.preview_head_chars] + if len(stdout) > self.preview_head_chars else stdout, + 'error': + (stderr[:self.preview_head_chars] if stderr else None), + 'truncated': + True, + 'artifact_path': + rel, + 'preview': + preview, + 'artifact_bytes': + len(enc), + } + if extra: + result.update(extra) + return result + + def pack_json_shell_result( + self, + *, + tool_name: str, + call_id: str, + payload: dict[str, Any], + ) -> str: + """JSON-encode *payload* after applying spill rules to output/error string fields.""" + stdout = str(payload.get('output') or '') + stderr = str(payload.get('error') or '') + packed = self.pack_text_result( + tool_name=tool_name, + call_id=call_id, + stdout=stdout, + stderr=stderr, + extra={ + k: v + for k, v in payload.items() if k not in ('output', 'error') + }, + ) + # pack_text_result merged extra into top level; rebuild standard shell shape + out = { + 'success': + payload.get('success'), + 'output': + packed.get('output'), + 'error': + packed.get('error'), + 'return_code': + payload.get('return_code'), + 'truncated': + packed.get('truncated', False), + } + if packed.get('artifact_path'): + out['artifact_path'] = packed['artifact_path'] + out['preview'] = packed.get('preview') + out['artifact_bytes'] = packed.get('artifact_bytes') + return json.dumps(out, ensure_ascii=False, indent=2) + + +def _make_preview(text: str, head: int, tail: int) -> str: + if len(text) <= head + tail: + return text + return (text[:head] + '\n... [truncated] ...\n' + text[-tail:]) diff --git a/ms_agent/utils/task_manager.py b/ms_agent/utils/task_manager.py new file mode 100644 index 000000000..ee1e1f89c --- /dev/null +++ b/ms_agent/utils/task_manager.py @@ -0,0 +1,135 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Process-wide background tasks (agent subprocess, shell subprocess, etc.).""" + +from __future__ import annotations + +import asyncio +import multiprocessing as mp +import time +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from ms_agent.utils.logger import get_logger + +logger = get_logger() + + +@dataclass +class BackgroundTask: + task_id: str + task_type: str # 'agent' | 'shell' | ... + tool_name: str + description: str + status: str = 'running' # running | completed | failed | killed + proc: Optional[Any] = field(default=None, repr=False) + result: Optional[str] = None + error: Optional[str] = None + started_at: float = field(default_factory=time.monotonic) + ended_at: Optional[float] = None + + +class TaskManager: + """Registry for background tasks; optional completion notifications queue.""" + + def __init__(self) -> None: + self._tasks: Dict[str, BackgroundTask] = {} + self._notification_queue: asyncio.Queue = asyncio.Queue() + + def register( + self, + task_type: str, + tool_name: str, + description: str, + proc: Optional[Any] = None, + task_id: Optional[str] = None, + ) -> str: + task_id = task_id or uuid.uuid4().hex[:12] + task = BackgroundTask( + task_id=task_id, + task_type=task_type, + tool_name=tool_name, + description=description, + proc=proc, + ) + self._tasks[task_id] = task + logger.info('[TaskManager] registered %s task %s: %s', task_type, + task_id, description[:200]) + return task_id + + async def complete(self, task_id: str, result: str) -> None: + task = self._tasks.get(task_id) + if task is None: + return + task.status = 'completed' + task.result = result + task.ended_at = time.monotonic() + await self._notification_queue.put(self._format_notification(task)) + + async def fail(self, task_id: str, error: str) -> None: + task = self._tasks.get(task_id) + if task is None: + return + task.status = 'failed' + task.error = error + task.ended_at = time.monotonic() + await self._notification_queue.put(self._format_notification(task)) + + def kill(self, task_id: str) -> None: + task = self._tasks.get(task_id) + if task is None: + return + if task.status != 'running': + return + if task.proc is not None: + try: + if isinstance(task.proc, mp.Process): + task.proc.terminate() + else: + # asyncio.subprocess.Process or similar + if hasattr(task.proc, 'returncode') and task.proc.returncode is None: + if hasattr(task.proc, 'kill'): + task.proc.kill() + elif hasattr(task.proc, 'terminate'): + task.proc.terminate() + except (ProcessLookupError, OSError) as e: + logger.warning('[TaskManager] kill %s: %s', task_id, e) + task.status = 'killed' + task.ended_at = time.monotonic() + + def kill_all(self) -> None: + for task_id in list(self._tasks): + if self._tasks[task_id].status == 'running': + self.kill(task_id) + + def drain_notifications(self) -> List[str]: + notifications: List[str] = [] + while True: + try: + notifications.append(self._notification_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return notifications + + def get_task(self, task_id: str) -> Optional[BackgroundTask]: + return self._tasks.get(task_id) + + def running_tasks(self) -> List[BackgroundTask]: + return [t for t in self._tasks.values() if t.status == 'running'] + + @staticmethod + def _format_notification(task: BackgroundTask) -> str: + result_line = f'\n{task.result}' if task.result else '' + error_line = f'\n{task.error}' if task.error else '' + duration = '' + if task.ended_at: + duration = f'\n{task.ended_at - task.started_at:.1f}' + return ( + f'\n' + f'{task.task_id}\n' + f'{task.task_type}\n' + f'{task.tool_name}\n' + f'{task.description}\n' + f'{task.status}' + f'{result_line}{error_line}{duration}\n' + f'') diff --git a/ms_agent/utils/workspace_policy.py b/ms_agent/utils/workspace_policy.py new file mode 100644 index 000000000..a8380b710 --- /dev/null +++ b/ms_agent/utils/workspace_policy.py @@ -0,0 +1,207 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Workspace path policy: allow-roots (default output_dir) and optional deny globs.""" + +from __future__ import annotations + +import fnmatch +import os +import re +from pathlib import Path +from typing import Iterable, Sequence + + +class WorkspacePolicyError(ValueError): + """Raised when a path or command violates workspace policy.""" + + +class WorkspacePolicyKernel: + """Resolve user paths under allowed workspace roots; optional shell read-only rules.""" + + def __init__( + self, + output_dir: Path | str, + *, + extra_allow_roots: Sequence[str | Path] | None = None, + deny_globs: Sequence[str] | None = None, + shell_default_mode: str = 'workspace_write', + shell_network_enabled: bool = False, + max_command_chars: int = 8192, + ) -> None: + self._output = Path(output_dir).expanduser().resolve() + self._roots: list[Path] = [self._output] + if extra_allow_roots: + for r in extra_allow_roots: + p = Path(r).expanduser().resolve() + if p not in self._roots: + self._roots.append(p) + if deny_globs is None or len(tuple(deny_globs)) == 0: + self._deny_globs: tuple[str, ...] = ('**/.git/**',) + else: + self._deny_globs = tuple(deny_globs) + self.shell_default_mode = shell_default_mode + self.shell_network_enabled = shell_network_enabled + self.max_command_chars = max_command_chars + + @property + def workspace_root(self) -> Path: + return self._output + + @property + def allow_roots(self) -> tuple[Path, ...]: + return tuple(self._roots) + + @property + def deny_globs(self) -> tuple[str, ...]: + return self._deny_globs + + def resolve_under_roots(self, user_path: str | Path) -> Path: + """Resolve *user_path* to an absolute path that must lie under one allow root.""" + raw = Path(user_path).expanduser() + if raw.is_absolute(): + resolved = raw.resolve() + else: + resolved = (self._output / raw).resolve() + for root in self._roots: + try: + resolved.relative_to(root) + break + except ValueError: + continue + else: + raise WorkspacePolicyError( + f'Path is outside allowed workspace roots: {resolved}') + if self._is_denied(resolved): + raise WorkspacePolicyError( + f'Path matches a deny_globs pattern: {resolved}') + return resolved + + def _is_denied(self, path: Path) -> bool: + if not self._deny_globs: + return False + rel = None + try: + rel = path.relative_to(self._output) + except ValueError: + rel = path + rel_s = rel.as_posix() + for pat in self._deny_globs: + if fnmatch.fnmatch(rel_s, pat) or fnmatch.fnmatch(path.name, pat): + return True + if fnmatch.fnmatch(str(path), pat): + return True + return False + + def path_is_allowed(self, path: Path) -> bool: + path = path.expanduser().resolve() + for root in self._roots: + try: + path.relative_to(root) + break + except ValueError: + continue + else: + return False + return not self._is_denied(path) + + def assert_shell_command_allowed(self, command: str) -> None: + """Length and mode-based checks before executing shell.""" + if not command or not command.strip(): + raise WorkspacePolicyError('Empty shell command') + if len(command) > self.max_command_chars: + raise WorkspacePolicyError( + f'Shell command exceeds max length ({self.max_command_chars})') + + mode = self.shell_default_mode + if mode == 'read_only': + if _shell_looks_mutating_or_network(command, + allow_network=False): + raise WorkspacePolicyError( + 'Shell is in read_only mode: mutating or network commands are not allowed' + ) + elif mode == 'workspace_write': + if not self.shell_network_enabled and _shell_looks_network(command): + raise WorkspacePolicyError( + 'Network commands are disabled for shell (enable tools.code_executor.shell.network_enabled)' + ) + # future: explicit 'network' mode could allow curl etc. + + +def _shell_looks_network(command: str) -> bool: + lowered = command.lower() + tokens = ( + 'curl ', + 'wget ', + 'ssh ', + 'scp ', + 'rsync ', + 'ftp ', + 'nc ', + 'netcat ', + 'pip install', + 'pip3 install', + 'npm install', + 'yarn add', + 'pnpm add', + ) + return any(t in lowered for t in tokens) + + +def _shell_looks_mutating_or_network(command: str, *, + allow_network: bool) -> bool: + if not allow_network and _shell_looks_network(command): + return True + # redirection that creates/overwrites files + if re.search(r'[>]{1,2}\s*[^\s]', command): + return True + if re.search(r'\b(rm|rmdir|mv|cp|chmod|chown|chgrp|mkdir|touch|tee)\b', + command): + return True + return False + + +def iter_files_under( + root: Path, + *, + deny_globs: Iterable[str] = (), + max_files: int = 100_000, +) -> Iterable[Path]: + """Yield files under *root* (depth-first), skipping directories matching deny globs.""" + deny = tuple(deny_globs) + count = 0 + root = root.resolve() + + def dir_skipped(dirpath: Path) -> bool: + try: + rel = dirpath.relative_to(root).as_posix() + except ValueError: + return True + for pat in deny: + if fnmatch.fnmatch(rel, pat) or fnmatch.fnmatch(rel + '/', pat): + return True + parts = rel.split('/') + for i in range(len(parts)): + sub = '/'.join(parts[:i + 1]) + if fnmatch.fnmatch(sub, pat.rstrip('/')) or fnmatch.fnmatch( + sub + '/', pat): + return True + return False + + for dirpath, dirnames, filenames in os.walk( + root, topdown=True, followlinks=False): + dp = Path(dirpath) + if dir_skipped(dp): + dirnames[:] = [] + continue + # prune skipped subdirs + keep: list[str] = [] + for d in dirnames: + child = dp / d + if dir_skipped(child): + continue + keep.append(d) + dirnames[:] = keep + for name in filenames: + count += 1 + if count > max_files: + return + yield dp / name diff --git a/shell-grep-glob-workspace-policy.md b/shell-grep-glob-workspace-policy.md new file mode 100644 index 000000000..9e9f86bbd --- /dev/null +++ b/shell-grep-glob-workspace-policy.md @@ -0,0 +1,225 @@ +# Shell / Grep / Glob 与策略内核架构方案 + +本文档描述在 modelscope-agent 中为 **Shell**、**Grep**、**Glob** 提供统一的安全、权限、沙箱与产物管理的设计,以及与 **`feat/agent-tool-overhaul`** 分支中 **TaskManager**(后台 Agent、预留 Shell)的兼容方式。 + +--- + +## 1. 目标与边界 + +### 目标 + +- 在「同一工作区、同一沙箱视图」下,为 **Shell / Grep / Glob** 提供统一的: + - **安全**(命令与路径约束) + - **权限**(只读 / 写工作区 / 网络等分级) + - **沙箱**(本地子进程 vs Docker enclave 等与现有 `CodeExecutionTool` 对齐) + - **产物管理**(大 stdout/stderr 落盘、预览、配额) +- **默认 `allow_list`(允许根路径)包含 `output_dir`**(及其规范化的绝对路径),可配置追加其它根。 + +### 边界 + +- **不替代** `FileSystemTool` 的精确编辑与读缓存等语义;Shell 面向构建、包管理、复杂管道。 +- **Grep / Glob** 作为**只读发现面**的独立工具,减少对裸 shell 的依赖;复杂 `find -exec` 等仍可由受控 Shell 在更高权限模式下完成(若产品允许)。 + +--- + +## 2. 分层架构 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Tool Facade 层 │ +│ ShellTool │ GrepTool │ GlobTool (独立 JSON Schema) │ +└────────────┬───────────────────────────────┬────────────────┘ + │ │ +┌────────────▼───────────────────────────────▼────────────────┐ +│ WorkspacePolicyKernel(策略内核,纯逻辑、可单测) │ +│ - roots: 默认含 canonical(output_dir),可配置追加 │ +│ - allow_list / deny_list 合并与优先级 │ +│ - resolve_path(rel|abs) → 必须在 allow_roots 下 │ +│ - classify(op): read | search | mutate | exec | network_hint │ +└────────────┬────────────────────────────────────────────────┘ + │ +┌────────────▼────────────────────────────────────────────────┐ +│ SandboxRuntime(执行面,可替换实现) │ +│ - LocalProcessRuntime(asyncio subprocess,cwd=workspace) │ +│ - EnclaveRuntime(现有 ms_enclave / CodeExecutionTool 路径) │ +│ - 会话级 sandbox_id / working_dir 与挂载点一致 │ +└────────────┬────────────────────────────────────────────────┘ + │ +┌────────────▼────────────────────────────────────────────────┐ +│ ArtifactManager(产物管理) │ +│ - 超阈值 stdout/stderr → 落盘 + preview + 相对路径引用 │ +│ - 按 task_id / tool_call_id 分目录 │ +│ - TTL / 总配额(建议:output_dir/.ms_agent_artifacts/) │ +└─────────────────────────────────────────────────────────────┘ +``` + +**原则**:Grep/Glob 的**主路径**不是「拼一条 shell 给模型」;内部可调用 `rg` 或文件系统 walk,但必须经过 **PolicyKernel** 与 **SandboxRuntime**,输出经 **ArtifactManager**。 + +--- + +## 3. WorkspacePolicyKernel(共享策略内核) + +### 3.1 默认 allow_list(允许根集合) + +- 初始化:`allow_roots = { canonical_abs(output_dir) }`。 +- 配置可追加,例如:`tools.code_executor.extra_allow_roots` 或 `tools.workspace_policy.allow`(列表),合并去重。 +- Shell / Grep / Glob 涉及的 **`path`、`cwd`、搜索根目录** 均先执行 `resolve_under_allow_roots()`;失败则**拒绝**并返回结构化错误(不静默改路径到其它目录)。 + +### 3.2 权限与操作分类(建议) + +| 类别 | 示例 | Shell | Grep | Glob | +|------|------|-------|------|------| +| read | 读取工作区内文件 | 受模式 + 策略约束 | ✓ | ✓ | +| search | 内容/文件名发现 | 可引导至 Grep/Glob | ✓ | ✓ | +| mutate | rm、chmod、git 写入等 | 需 `workspace_write` | — | — | +| network | curl、pip 等 | 需显式 **network** 能力位 | — | — | + +Shell 在 **`read_only`** 模式下:仅允许白名单类命令(如 `git status`/`diff`/`log`、只读参数的 `rg` 等),并对重定向、写入工作区外等行为做拒绝或降级(可用前缀表 + 危险模式黑名单,必要时辅以轻量解析)。 + +### 3.3 Shell 安全补充 + +- **固定 cwd**:默认 `workspace_root`(与 `output_dir` 或沙箱内挂载点一致)。 +- **环境变量**:最小集或白名单继承;避免将宿主敏感变量原样传入。 +- **命令预处理**:与现有 `CodeExecutionTool.shell_executor` 思路一致——含 `| && ; > <` 等时使用 `sh -lc` 与安全 quoting;另加**命令长度上限**、**可配置的危险构造限制**(如嵌套命令替换,按产品分级)。 +- **(暂时不做)** 与 `FileSystemTool` 的「写前必读 / staleness」策略对齐:对会修改工作区文件的 Shell 子类共享元数据(若产品需要强一致)。 + +--- + +## 4. SandboxRuntime(共享沙箱) + +- **会话级**:每个 Agent 运行周期内一个 `SandboxSession`(或复用现有 `sandbox_id`)。 +- **Shell / Grep / Glob** 共用同一 **`working_dir` / 挂载视图** 与同一 **`SandboxRuntime` 实现**(本地 `asyncio` 子进程 vs Docker enclave,由 `implementation: sandbox | python_env` 等与现有一致)。 +- **Grep**:在 enclave 内调用 `rg` 或使用宿主 `ripgrep` 库(由部署二选一);**Glob**:在策略解析后的根上做目录遍历或 `pathspec`,避免默认可执行任意 `find -exec`。 + +--- + +## 5. ArtifactManager(产物管理) + +- **阈值**:例如 stdout+stderr 合计超过 N KB 则 spill 至 + `{output_dir}/.ms_agent_artifacts/{tool_name}/{task_or_call_id}.txt`(路径可配置)。 +- **返回**:JSON 中包含 `preview`(首尾若干字符/行)、`artifact_path`(相对 `output_dir`)、`truncated: true`。 +- **与 TaskManager 配合**:后台任务完成时,`TaskManager.complete(task_id, result)` 的 `result` 宜为「短摘要 + artifact 路径」,避免通知与下一轮上下文被撑爆。 + +--- + +## 6. GrepTool / GlobTool(独立工具、共享内核) + +- **输入**:结构化字段(如 pattern、path、glob、head_limit、offset、output_mode),不把「整条 shell」作为唯一 API。 +- **实现**:内部调用 `SandboxRuntime.exec_rg(...)` 或在策略内核限定根上的 glob 遍历;**禁止**由用户可控字符串直接拼接未校验的 shell。 +- **共享**:同一 `WorkspacePolicyKernel` + `SandboxRuntime` + `ArtifactManager`(由 `ToolManager` 或执行类工具在初始化时注入)。 +- **注册**:在 `ToolManager` 中作为独立 `ToolBase`(可一个 server 多个 tool,或两个 server);与 `file_system` 解耦,保持 `file_system` 精简。 + +--- + +## 7. 与 `feat/agent-tool-overhaul` 的 Task 体系兼容 + +### 7.1 分支中的现状(摘要) + +- **`TaskManager`**(`ms_agent/utils/task_manager.py`):进程级后台任务注册表;`BackgroundTask` 中 **`task_type` 注释已包含 `'agent' | 'shell'`**。 +- **`AgentTool`**:`run_in_background` 时 `register(task_type='agent', proc=mp.Process, ...)`,watcher 在子进程结束后调用 `complete` / `fail`;`LLMAgent` 通过 `set_task_manager` 注入同一 `TaskManager`,每轮 `drain_notifications()` 将完成事件注入对话。 + +### 7.2 Shell 后台(与 Agent 对称) + +**建议接口** + +- **同步**:`shell_executor(command, timeout)` → 行为与现网接近,但走 PolicyKernel + ArtifactManager。 +- **后台**:增加 `run_in_background: bool`(或等价命名), **`__call_id`**(与 `AgentTool` 注入一致,便于对账与「推后台」扩展)。 + +**后台行为** + +1. `task_id = task_manager.register(task_type='shell', tool_name='shell_executor', description=command[:200], proc=...)` +2. `proc` 可为 **`asyncio.create_subprocess_*` 返回的 `Process`**(与 Agent 的 `mp.Process` 不同,需在 **`TaskManager.kill` / `kill_all` 中扩展**:对 `asyncio.subprocess.Process` 调用 `kill()` / `terminate()`,并处理已结束进程)。 +3. `asyncio.create_task(watcher)`:等待结束 → `ArtifactManager.maybe_spill` → `await task_manager.complete(task_id, result_str)`(失败则 `fail`)。 + +**立即返回 JSON**(与 Agent 后台对齐,便于统一文档与客户端): + +```json +{ + "status": "async_launched", + "task_id": "", + "tool_name": "shell_executor" +} +``` + +### 7.3 LLMAgent 接线 + +- 与 overhaul 一致:构造 `TaskManager()`,遍历 `extra_tools`,若实现 **`set_task_manager(self.task_manager)`** 则注入。 +- **`LocalCodeExecutionTool` / 未来的 `SecureShellTool`** 实现 `set_task_manager`,与 `AgentTool` 共享同一 `TaskManager` 实例。 + +### 7.4 长同步 Shell → Escape 到后台 + +- 与 `AgentTool._run_sync_escapable` 类似:同步 Shell 带 `sync_timeout_s`,超时或显式信号后取消当前子进程并改为 `register(task_type='shell', ...)` 后台重跑或仅保留已产出部分(产品二选一)。 +- 若存在 **TaskControlTool** 类机制,可复用「`__call_id` + escape 事件」模式,Shell 侧维护 `call_id → Process` 映射以支持 **kill / escape**。 + +### 7.5 兼容对照表 + +| 能力 | overhaul 行为 | 本方案落点 | +|------|----------------|------------| +| 后台 Agent | `register(task_type='agent', proc=Process)` | 不变 | +| 预留 Shell | `task_type` 含 `'shell'` | `shell_executor(run_in_background=true)` 走同一 register / complete / fail | +| 回合内通知 | `drain_notifications()` | Shell 完成同样入队 | +| Kill / 清理 | `kill` / `kill_all` | 扩展支持 asyncio 子进程;watcher `finally` 释放资源 | + +--- + +## 8. 配置示例(OmegaConf / YAML 意向) + +```yaml +tools: + workspace_policy: + allow_roots: [] # 追加;默认已含 output_dir + deny_globs: ['**/.git/**'] + code_executor: + implementation: python_env # or sandbox + shell: + default_mode: workspace_write # read_only | workspace_write + max_output_kb: 256 + wall_time_s: 900 + grep: + default_head_limit: 250 + glob: + max_files: 100 +``` + +--- + +## 9. 实施顺序建议 + +1. 抽出 **`WorkspacePolicyKernel`** + 单元测试(路径解析、默认 `output_dir`、追加 allow)。 +2. 实现 **`ArtifactManager`**,接到现有 `shell_executor` 返回(先本地工具、后接沙箱)。 +3. 将 **`TaskManager`**(overhaul)合入主线并 **扩展 `kill` 支持 `asyncio.subprocess.Process`**。 +4. **`LocalCodeExecutionTool.set_task_manager` + `run_in_background` 的 `shell_executor`**。 +5. 新增 **GrepTool / GlobTool** façade,共享上述内核与运行时。 +6. 更新文档与系统提示:默认 **发现用 Grep/Glob,构建用 Shell,改文件用 file_system**。 + +--- + +## 10. 设计取舍小结 + +- **Shell**:强约束的通用执行面 + 后台,与 **TaskManager** 统一生命周期与通知。 +- **Grep / Glob**:独立 Schema、只读、易截断,与 Shell **共享策略与沙箱**,避免把一切搜索都绑在一条 shell 字符串上。 +- **默认 allow_roots 含 `output_dir`**:与现有 Agent 工作区模型一致,减少越权访问宿主路径的风险。 + +--- + +## 修订记录 + +| 日期 | 说明 | +|------|------| +| 2026-04-13 | 初版:根据设计与 `feat/agent-tool-overhaul` 中 TaskManager / AgentTool 后台模型整理成文。 | +| 2026-04-13 | 实现落地:见下文「实现映射」。 | + +## 11. 实现映射(代码位置) + +| 组件 | 路径 | +|------|------| +| WorkspacePolicyKernel | `ms_agent/utils/workspace_policy.py` | +| ArtifactManager | `ms_agent/utils/artifact_manager.py` | +| TaskManager | `ms_agent/utils/task_manager.py` | +| Shell 策略 / 产物 / 后台 | `ms_agent/tools/code/local_code_executor.py`(`set_task_manager`、`shell_executor`) | +| Grep / Glob | `ms_agent/tools/workspace_search_tool.py`(默认注册;`tools.workspace_search.enabled: false` 可关闭) | +| `__call_id` 注入 shell | `ms_agent/tools/tool_manager.py` | +| TaskManager 与通知 | `ms_agent/agent/llm_agent.py`(`prepare_tools` / `cleanup_tools` / `_append_task_notifications`) | +| 单测 | `tests/utils/test_workspace_policy.py` | + +**未在本阶段实现**:文档 §7.4 长同步 Shell escape 到后台;Docker `CodeExecutionTool` 侧 shell 与策略对齐(仍沿用原沙箱实现)。 diff --git a/tests/utils/test_workspace_policy.py b/tests/utils/test_workspace_policy.py new file mode 100644 index 000000000..012888533 --- /dev/null +++ b/tests/utils/test_workspace_policy.py @@ -0,0 +1,78 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for WorkspacePolicyKernel.""" + +import tempfile +from pathlib import Path + +import pytest + +from ms_agent.utils.workspace_policy import WorkspacePolicyError, WorkspacePolicyKernel + + +def test_default_root_is_output_dir(): + with tempfile.TemporaryDirectory() as td: + out = Path(td) / 'out' + out.mkdir() + k = WorkspacePolicyKernel(out) + p = k.resolve_under_roots('foo/bar') + assert p == (out / 'foo' / 'bar').resolve() + + +def test_rejects_escape(): + with tempfile.TemporaryDirectory() as td: + out = Path(td) / 'out' + out.mkdir() + k = WorkspacePolicyKernel(out) + with pytest.raises(WorkspacePolicyError): + k.resolve_under_roots('../../etc/passwd') + + +def test_extra_allow_root(): + with tempfile.TemporaryDirectory() as td: + out = Path(td) / 'out' + other = Path(td) / 'other' + out.mkdir() + other.mkdir() + k = WorkspacePolicyKernel(out, extra_allow_roots=[str(other)]) + assert k.resolve_under_roots(str(other / 'x')) == (other / 'x').resolve() + + +def test_read_only_blocks_redirect(): + with tempfile.TemporaryDirectory() as td: + out = Path(td) / 'out' + out.mkdir() + k = WorkspacePolicyKernel( + out, + shell_default_mode='read_only', + ) + with pytest.raises(WorkspacePolicyError): + k.assert_shell_command_allowed('echo x > file.txt') + + +def test_workspace_write_allows_redirect_but_blocks_network(): + with tempfile.TemporaryDirectory() as td: + out = Path(td) / 'out' + out.mkdir() + k = WorkspacePolicyKernel( + out, + shell_default_mode='workspace_write', + shell_network_enabled=False, + ) + k.assert_shell_command_allowed('echo x > file.txt') + with pytest.raises(WorkspacePolicyError): + k.assert_shell_command_allowed('curl https://example.com') + + +def test_artifact_manager_spill(tmp_path): + from ms_agent.utils.artifact_manager import ArtifactManager + + am = ArtifactManager(tmp_path, max_combined_bytes=32) + big = 'a' * 100 + packed = am.pack_text_result( + tool_name='t', + call_id='c1', + stdout=big, + stderr='', + ) + assert packed.get('truncated') is True + assert 'artifact_path' in packed From ebbc72a2108c2ec09c8700ce7b5998a05be6f238 Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 13 Apr 2026 14:55:04 +0800 Subject: [PATCH 30/40] chore(projects): align configs with filesystem and workspace search tools Replace removed file_system tools (list_files, delete_file_or_dir) with workspace_search (glob/grep) and/or code_executor (shell, file_operation). Update deep_research prompts and callbacks for read_file offset/limit and edit_file. fin_research: aggregator adds python_env shell/file_operation; collector exposes shell and file_operation in sandbox; file_system keeps read/write/edit. code_genesis: prompts use glob_files/shell for listing; orchestrator_callback uses os.makedirs instead of removed create_directory(). singularity registers workspace_search. Made-with: Cursor --- projects/code_genesis/architect.yaml | 2 +- projects/code_genesis/coding.yaml | 11 ++++++----- projects/code_genesis/refine.yaml | 7 ++++--- projects/code_genesis/workflow/coding.py | 4 +--- .../v2/callbacks/reporter_callback.py | 10 +++++----- .../v2/callbacks/researcher_callback.py | 12 ++++++------ .../v2/prompts/reporter/en/gpt5.txt | 2 +- .../v2/prompts/researcher/en/gpt5.txt | 12 ++++++------ projects/deep_research/v2/reporter.yaml | 5 +++-- projects/deep_research/v2/researcher.yaml | 7 +++---- projects/deep_research/v2/searcher.yaml | 3 ++- projects/fin_research/aggregator.yaml | 17 +++++++++++++---- .../callbacks/orchestrator_callback.py | 3 ++- projects/fin_research/collector.yaml | 11 ++++++++--- projects/singularity_cinema/agent.yaml | 6 ++---- 15 files changed, 63 insertions(+), 49 deletions(-) diff --git a/projects/code_genesis/architect.yaml b/projects/code_genesis/architect.yaml index 988091875..79c65d535 100644 --- a/projects/code_genesis/architect.yaml +++ b/projects/code_genesis/architect.yaml @@ -37,7 +37,7 @@ prompt: Your Steps: 1. Read topic and user story to obtain the original requirements and user stories. 2. If topic contains documentation (e.g., usage of external frameworks, project PRDs, etc.), you must read them to understand user requirements. - * When reading external documents, prioritize using `read_abbreviation_file` to reduce token usage. + * When reading external documents, prioritize using `read_file` with `abbreviate: true` to reduce token usage. 3. Design the technology selection. If the user has specified technology choices, prioritize using the user's choices and supplement any missing parts. * [Historical Error]: Failure due to mixing different technical frameworks within the same language in different files. Your framework information should prevent this issue. - Example: Mixing CommonJS and ES6 in package.json and xx.js causing runtime failure; you should specify syntax rules in framework.txt. diff --git a/projects/code_genesis/coding.yaml b/projects/code_genesis/coding.yaml index 1508a5de1..b9f398697 100644 --- a/projects/code_genesis/coding.yaml +++ b/projects/code_genesis/coding.yaml @@ -34,7 +34,7 @@ prompt: - `postcss.config.js` enables `tailwindcss` and `autoprefixer` 3. CRITICAL: Before reading ANY file: - * FIRST use `list_files` to check which files actually exist in the project + * FIRST use `workspace_search---glob_files` (e.g. pattern `**/*`, path `.`) to list paths under the project, or `code_executor---shell_executor` with a read-only `ls`/`find` command if you prefer the shell. * NEVER read files that do not appear in the output * NEVER attempt to read files with index >= yours (they don't exist yet) * NEVER guess or assume a file exists - always verify first @@ -42,13 +42,13 @@ prompt: 3.1 CRITICAL (no hallucination about files): * Do not fully trust `protocol.txt` content; you must verify it yourself by checking existing files and reading the exact source of truth. * Before you reference/cite ANY information that comes from a file (APIs, exports, config values, routes, CSS class names, build scripts, ports, etc.), you MUST: - 1) Confirm the file exists via `list_files`, AND + 1) Confirm the file exists via `glob_files` (or shell listing), AND 2) Read the relevant part of that exact file via `read_file`. * You are NOT allowed to infer file contents. * If the needed file is missing from the file list, do not reference it; either create it (if allowed by your index constraints) or implement the needed logic in your current file. 4. Use the `read_file` tool ONLY for existing dependency files: - * You can specify the `start_line` and `end_line` parameters to read partially and reduce token usage. + * You can specify the `offset` (1-based start line) and `limit` (number of lines) parameters on `read_file` to read partially and reduce token usage. * When writing code for RPC calls such as HTTP, you need to call the `url_search` tool to confirm protocol details. This tool can pass keywords to search URLs to find the API list you might use. Do not call non-existent API interfaces. 5. Write code, do not use `edit_file` to write non-existent files! @@ -103,7 +103,7 @@ prompt: 8. When fixing issues and updating files: Call the `edit_file` tool or `write_file` tool, after fixing issues there's no need to check by yourself, the lsp tool will check and report issues. - 9. You can use the shell_executor to debug problems: + 9. You can use `code_executor---shell_executor` to debug problems: Example: shell_executor(command='python -c "from module import MyClass; print(vars(MyClass))"') @@ -113,13 +113,14 @@ prompt: 2. [Secondary] Use as few tokens as possible. tools: + workspace_search: + mcp: false file_system: mcp: false allow_read_all_files: true include: - read_file - edit_file - - list_files - write_file edit_file_config: diff_model: morph-v3-fast diff --git a/projects/code_genesis/refine.yaml b/projects/code_genesis/refine.yaml index 2795dc654..8dd78adcd 100644 --- a/projects/code_genesis/refine.yaml +++ b/projects/code_genesis/refine.yaml @@ -46,9 +46,9 @@ prompt: a. Write test cases to verify correctness 3. Use your tools effectively - * Use `read_file` to read the original file. When using `read_file`, specify `start_line` and `end_line` to reduce token usage + * Use `read_file` to read the original file. When using `read_file`, specify `offset` (1-based start line) and `limit` (line count) to reduce token usage * If an HTTP API is unclear, use `workflow/api_search` to locate the implementation - * Use `search_file_content` to search the project for keywords you care about + * To scan the project for keywords without a dedicated search tool, use `read_file` with `abbreviate: true` on candidate files, or read likely paths with `offset`/`limit` * Shell commands must not use system directories (except `/dev/null`) 4. When run successfully, you can use EdgeOne Pages MCP tools to deploy the project. Available tools from `edgeone-pages-mcp` include: @@ -91,12 +91,13 @@ tools: - file_operation - reset_executor - get_executor_info + workspace_search: + mcp: false file_system: mcp: false include: - read_file - write_file - - list_files - edit_file edit_file_config: diff_model: morph-v3-fast diff --git a/projects/code_genesis/workflow/coding.py b/projects/code_genesis/workflow/coding.py index e5b696373..c6cfc398f 100644 --- a/projects/code_genesis/workflow/coding.py +++ b/projects/code_genesis/workflow/coding.py @@ -163,9 +163,7 @@ def find_all_read_files(): for message in messages: if message.tool_calls: for tool_call in message.tool_calls: - if 'read_file' in tool_call[ - 'tool_name'] or 'read_abbreviation_file' in tool_call[ - 'tool_name']: + if 'read_file' in tool_call['tool_name']: arguments = tool_call['arguments'] if isinstance(arguments, str): try: diff --git a/projects/deep_research/v2/callbacks/reporter_callback.py b/projects/deep_research/v2/callbacks/reporter_callback.py index 7bfb5bee3..6613e4194 100644 --- a/projects/deep_research/v2/callbacks/reporter_callback.py +++ b/projects/deep_research/v2/callbacks/reporter_callback.py @@ -176,9 +176,9 @@ class ReporterCallback(Callback): '- 请严格遵守系统指令中的要求,不要遗漏、忽略任何合理的规则。\n' '- 审查要点包括事实准确性、逻辑一致性、用户核心问题的覆盖度、引用与论据的对齐关系、引用格式问题、内容完整性等等。' '修改须有明确依据(如事实冗余、逻辑混乱、证据不一致、格式出错等),不要为了"润色"而改动结构/质量良好的内容。\n' - '- 读取报告内容一次后形成判断,后续核查优先使用 search_file_content 或带 start_line / end_line 的 read_file,不要反复全量读取同一文件。' + '- 读取报告内容一次后形成判断,后续核查优先使用 read_file 的 offset/limit 或 abbreviate,不要反复全量读取同一文件。' '在读取文件前先检查对话历史中是否已包含该文件的内容,避免重复读取。\n' - '- 优先使用定点修改(search_file_content -> replace_file_contents / replace_file_lines),仅在必要时才读取全文。' + '- 优先使用定点修改(read_file 定位片段 -> edit_file 精确替换),仅在必要时才读取全文。' '仅在定点修改完全无法解决时使用 write_file,且**必须完整保留所有有价值的内容**,严禁使用占位符、省略标记、引用其他内容等方式替代正文。\n' '- 质量较高无需修改的部分直接跳过。如果[Reporter 工作总结]中无异常且审查确认全文质量良好,直接进入结论阶段即可。\n\n' '**需避免的常见错误:**\n' @@ -210,11 +210,11 @@ class ReporterCallback(Callback): 'Edits must have clear justification (e.g., factual redundancy, logical confusion, evidence inconsistency, ' 'formatting errors, etc.) — do not alter well-structured, high-quality content merely for "polishing."\n' '- Read the report content ONCE to form your assessment. For subsequent ' - 'checks, prefer `search_file_content` or `read_file` with `start_line`/`end_line`. ' + 'checks, prefer `read_file` with `offset`/`limit` or `abbreviate`. ' 'Do not re-read the entire file repeatedly. Check your conversation history before ' 'reading any file to avoid redundant reads.\n' - '- Prefer targeted fixes (`search_file_content` -> `replace_file_contents` / ' - '`replace_file_lines`); only read the full text when necessary. ' + '- Prefer targeted fixes (`read_file` to locate a span, then `edit_file` with exact ' + '`old_string`/`new_string`); only read the full text when necessary. ' 'Use `write_file` only when targeted fixes are completely insufficient, ' 'and you **must preserve ALL valuable content in full** — never use placeholders, ' 'ellipsis markers, or references to other content as substitutes for actual text.\n' diff --git a/projects/deep_research/v2/callbacks/researcher_callback.py b/projects/deep_research/v2/callbacks/researcher_callback.py index 4796e2151..f09e743d5 100644 --- a/projects/deep_research/v2/callbacks/researcher_callback.py +++ b/projects/deep_research/v2/callbacks/researcher_callback.py @@ -52,9 +52,9 @@ class ResearcherCallback(Callback): ('外部检查发现:{filename} 的内容存在质量问题——{reason}。\n' '请仔细确认上述质量问题是否属实、是否还有更多问题,并立即采取行动修复。\n' '**重要提醒**:如果质量问题属实,你必须按照以下原则进行修复:\n' - '1. 优先通过有针对性的局部修改完成修复。请使用 file_system---search_file_content 定位问题段落,' - '然后使用 file_system---replace_file_contents 和 file_system---replace_file_lines 进行针对性修复。' - '需要时可以使用 file_system---read_file (with start_line/end_line) 验证上下文是否一致。\n' + '1. 优先通过有针对性的局部修改完成修复。请使用 file_system---read_file(可用 offset/limit 或 abbreviate)定位问题段落,' + '然后使用 file_system---edit_file(old_string/new_string 必须与原文完全一致)进行针对性修复。' + '需要时可以使用 file_system---read_file(with offset/limit)验证上下文是否一致。\n' '2. 如果确认无法通过1完成修复,可以使用 file_system---write_file 全量重写报告,但请注意以下可能的质量违规:\n' '- 用省略号或缩略标记替代正文,如"(同之前,略)"、"此处省略"、"篇幅所限不再展开"、' '"……以下类似"、"内容已截断"、"Content truncated for brevity"等;\n' @@ -74,9 +74,9 @@ class ResearcherCallback(Callback): 'Please carefully verify whether these issues are valid and whether additional problems exist, ' 'then immediately take action to fix them.\n' '**IMPORTANT**: If the quality issues are confirmed, you must follow these principles to fix them:\n' - '1. PREFER targeted, localized fixes. Use file_system---search_file_content to locate the problematic sections, ' - 'then use file_system---replace_file_contents and file_system---replace_file_lines to apply precise corrections. ' - 'use file_system---read_file (with start_line/end_line) to verify surrounding context when needed.\n' + '1. PREFER targeted, localized fixes. Use file_system---read_file (with offset/limit or abbreviate) to locate the problematic sections, ' + 'then use file_system---edit_file (exact old_string/new_string) to apply precise corrections. ' + 'Use file_system---read_file (with offset/limit) to verify surrounding context when needed.\n' '2. If you confirm that targeted fixes alone cannot resolve the issues, you may use file_system---write_file ' 'to fully rewrite the report, but beware of the following quality violations:\n' '- Replacing body text with ellipsis or brevity markers, e.g., "(same as before, omitted)", ' diff --git a/projects/deep_research/v2/prompts/reporter/en/gpt5.txt b/projects/deep_research/v2/prompts/reporter/en/gpt5.txt index fc276252a..19faef800 100644 --- a/projects/deep_research/v2/prompts/reporter/en/gpt5.txt +++ b/projects/deep_research/v2/prompts/reporter/en/gpt5.txt @@ -55,7 +55,7 @@ Stopping conditions (stop if any one is satisfied): - Based on the reflection results and recorded conflicts, you MUST rewrite the final markdown report content and save it to reports/report.md. - The final markdown report must preserve the information density and structural quality of the draft — never replace substantive content with ellipsis/brevity markers (e.g., "omitted here", "content truncated"), pointers to external files (e.g., "details are in chapter_2.md"), or hollow reference-only placeholders (e.g., "see [1]"). - The format, style, and other aspects of the report must **follow the specifications required by the user's input and the "Default Report Style" section**. The report must include citations to reference sources. - - **Writing strategy to minimize information loss**: Prioritize writing the full report in a single file_system---write_file call. Switch to an incremental strategy if your write attempt gets truncated by the output limit: initialize the file with file_system---write_file containing as much content as possible; then sequentially append the remainder using file_system---replace_file_lines with start_line=-1 until the report is complete, utilizing as few calls as possible. + - **Writing strategy to minimize information loss**: Prioritize writing the full report in a single file_system---write_file call. Switch to an incremental strategy if your write attempt gets truncated by the output limit: initialize the file with file_system---write_file containing as much content as possible; then read the end of reports/report.md and repeatedly use file_system---edit_file to replace a unique trailing anchor with anchor+next_chunk (append by editing the tail), re-reading the tail before each append until the report is complete, using as few calls as practical. - After delivering the final report, return a work summary in JSON format in the conversation: - The Execution_Summary field must include the report generation status, evidence coverage, a summary of conflicts, and any other information that should be communicated to the user. - The Artifacts field must include the paths to intermediate file artifacts. diff --git a/projects/deep_research/v2/prompts/researcher/en/gpt5.txt b/projects/deep_research/v2/prompts/researcher/en/gpt5.txt index 8891d8313..765e4b631 100644 --- a/projects/deep_research/v2/prompts/researcher/en/gpt5.txt +++ b/projects/deep_research/v2/prompts/researcher/en/gpt5.txt @@ -15,7 +15,7 @@ Action protocol: Before outputting the final result, every iteration MUST invoke - When the research can only move forward by conducting synthesis based on the collected materials—such as framework design, cross-validation, scenario analysis, data analysis, etc.—you MUST proactively complete these tasks using the available tools. - Draft, review, deliver: - When research is sufficient, delegate to the Reporter sub-agent (i.e., agent_tools---reporter_tool) to generate the research report. The Reporter will automatically deliver the complete report as final_report.md. - - Then you MUST review the report for quality and accuracy. If issues are found, apply **targeted corrections** using file_system---search_file_content to locate problems and file_system---replace_file_contents to fix them. Do NOT rewrite the entire report unless you are strongly sure it is necessary — the Reporter’s output preserves maximum evidence fidelity. + - Then you MUST review the report for quality and accuracy. If issues are found, apply **targeted corrections** using file_system---read_file (with `offset`/`limit` or `abbreviate` as needed) to locate the passage, then file_system---edit_file with an exact `old_string`/`new_string` pair to fix it. Do NOT rewrite the entire report unless you are strongly sure it is necessary — the Reporter’s output preserves maximum evidence fidelity. # Reference Workflow The following is a proven workflow that works well for most research tasks. @@ -58,11 +58,11 @@ Stopping conditions (stop if you are confident to proceed to the next phase): - **Do not over-edit.** Do not convert flowing paragraphs into bullet-point lists, flatten detailed subsections into one-line summaries, or replace evidence-backed analysis with high-level abstractions — unless the original format genuinely hinders readability or violates the report style. - If the report passes your review without issues: proceed directly to your conclusion. Do NOT rewrite it "for polish." - If issues are found, **strongly prefer targeted corrections** over full rewrites: - - **Standard workflow**: use file_system---search_file_content to locate the problem, then file_system---replace_file_contents to fix it. This is the safest and most precise approach. - - Precision reminder: Punctuation mismatches (e.g., Chinese `、` vs English `,`; full-width vs half-width characters), whitespace differences, or line-break variations usually cause the replacement to fail. - - Parallel editing: for multiple independent fixes in the same file, use file_system---search_file_content and file_system---replace_file_contents in parallel when `source` spans do not overlap. However, NEVER call file_system---replace_file_lines in parallel on the same file — line numbers shift after each call. - - **Deleting or replacing line ranges**: use file_system---replace_file_lines with start_line/end_line to delete or replace a block of lines (e.g., removing an entire section). Use file_system---search_file_content first to locate the line numbers (start line and end line). - - **Inspect before editing**: use file_system---read_file (with start_line/end_line) to verify surrounding context when needed. + - **Standard workflow**: use file_system---read_file to load the relevant region (use `offset`/`limit` for large files, or `abbreviate` for a quick structural pass), then file_system---edit_file with an exact `old_string`/`new_string` match. You must read a file before writing to it; re-read if the file may have changed. + - Precision reminder: Punctuation mismatches (e.g., Chinese `、` vs English `,`; full-width vs half-width characters), whitespace differences, or line-break variations usually cause the replacement to fail — copy the span verbatim from read_file output. + - Apply multiple edits to the same file **sequentially** (re-read between edits if needed) so anchors stay valid; do not parallelize conflicting edits on one file. + - **Deleting or replacing a block**: read the section, then file_system---edit_file with `old_string` set to the exact contiguous block to remove or replace (include enough surrounding lines to make `old_string` unique). + - **Inspect before editing**: use file_system---read_file (with `offset`/`limit`) to verify surrounding context when needed. - **Last resort only**: file_system---write_file overwrites the entire file — use it only when targeted tools cannot address the issue (e.g., extensive structural reorganization). You must reproduce ALL content valuable to the user. - WARNING: Full report rewrites may carry high risk of content loss. Do not over-compress the report. Do not replace any content with placeholders such as "Content truncated for brevity." or "This section is stored in xxx file." - Finally show your conclusions for the entire task in the conversation. diff --git a/projects/deep_research/v2/reporter.yaml b/projects/deep_research/v2/reporter.yaml index c55fd109b..aa19d1114 100644 --- a/projects/deep_research/v2/reporter.yaml +++ b/projects/deep_research/v2/reporter.yaml @@ -29,13 +29,14 @@ prompt: family: gpt5 tools: + workspace_search: + mcp: false file_system: mcp: false include: - write_file - read_file - - list_files - - replace_file_lines + - edit_file evidence_store: mcp: false evidence_dir: evidence diff --git a/projects/deep_research/v2/researcher.yaml b/projects/deep_research/v2/researcher.yaml index 50c2ece5d..47eb0ab12 100644 --- a/projects/deep_research/v2/researcher.yaml +++ b/projects/deep_research/v2/researcher.yaml @@ -28,15 +28,14 @@ prompt: tools: + workspace_search: + mcp: false file_system: mcp: false include: - write_file - read_file - - list_files - - search_file_content - - replace_file_contents - - replace_file_lines + - edit_file code_executor: mcp: false implementation: python_env diff --git a/projects/deep_research/v2/searcher.yaml b/projects/deep_research/v2/searcher.yaml index b9b19f08b..0e015a39d 100644 --- a/projects/deep_research/v2/searcher.yaml +++ b/projects/deep_research/v2/searcher.yaml @@ -28,12 +28,13 @@ prompt: tools: + workspace_search: + mcp: false file_system: mcp: false include: - write_file - read_file - - list_files web_search: mcp: false engines: diff --git a/projects/fin_research/aggregator.yaml b/projects/fin_research/aggregator.yaml index 2efcabfe8..70c7bf449 100644 --- a/projects/fin_research/aggregator.yaml +++ b/projects/fin_research/aggregator.yaml @@ -84,8 +84,8 @@ prompt: - You may interpret the meaning of embedded images from their surrounding context \ in the financial data analysis report. - You should note that most of the charts generated during the analysis process \ - are stored in the default working directory. You may use the tools under the file_system \ - server to browse these files and select meaningful images — based on their filenames — to embed into the report. + are stored in the default working directory. Use `workspace_search---glob_files` (e.g. pattern `**/*.{png,jpg,svg}`) \ + or `file_system---read_file` to inspect filenames and embed meaningful images into the report. - Avoid inserting any images whose titles or filenames do not convey meaningful information. - Avoid inserting any images that are from websites or other external sources. - The image paths should be relative paths to the the default working directory. @@ -127,13 +127,22 @@ prompt: tools: + workspace_search: + mcp: false + code_executor: + mcp: false + implementation: python_env + exclude: + - notebook_executor + - python_executor + - reset_executor + - get_executor_info file_system: mcp: false include: - read_file - write_file - - list_files - - delete_file_or_dir + - edit_file spec_loader: mcp: false plugins: diff --git a/projects/fin_research/callbacks/orchestrator_callback.py b/projects/fin_research/callbacks/orchestrator_callback.py index a649504ab..d9b507dce 100644 --- a/projects/fin_research/callbacks/orchestrator_callback.py +++ b/projects/fin_research/callbacks/orchestrator_callback.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import os from typing import List from file_parser import extract_code_blocks @@ -31,7 +32,7 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): if messages[-1].tool_calls or messages[-1].role == 'tool': return - await self.file_system.create_directory() + os.makedirs(self.file_system.output_dir, exist_ok=True) content = '\n'.join([m.content for m in messages[2:]]) all_files, _ = extract_code_blocks(content) results = [] diff --git a/projects/fin_research/collector.yaml b/projects/fin_research/collector.yaml index fde543744..25827516e 100644 --- a/projects/fin_research/collector.yaml +++ b/projects/fin_research/collector.yaml @@ -132,10 +132,12 @@ tools: network_enabled: True # Enable network access in sandbox (default: false for security) tools_config: notebook_executor: {} + shell_executor: {} + file_operation: {} exclude: - python_executor - - shell_executor - - file_operation + - reset_executor + - get_executor_info financial_data_fetcher: mcp: false source_type: hybrid @@ -145,11 +147,14 @@ tools: max_requests_per_second: 1 # Maximum requests per second min_request_interval: 1 # Minimum interval between requests (seconds) max_concurrent: 1 # Maximum concurrent requests (reduced to work with thread semaphore) + workspace_search: + mcp: false file_system: mcp: false include: - write_file - - delete_file_or_dir + - read_file + - edit_file handler: time_handler diff --git a/projects/singularity_cinema/agent.yaml b/projects/singularity_cinema/agent.yaml index dc1756486..8b374ed41 100644 --- a/projects/singularity_cinema/agent.yaml +++ b/projects/singularity_cinema/agent.yaml @@ -275,15 +275,13 @@ fonts: - Microsoft YaHei tools: + workspace_search: + mcp: false file_system: mcp: false allow_read_all_files: true exclude: - edit_file - - list_files - - search_file_content - - search_file_name - - replace_file_lines memory: diversity: From 94ba0c354549ae2758636d440671af834dc81574 Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 13 Apr 2026 15:58:56 +0800 Subject: [PATCH 31/40] refactor(tools): merge grep/glob into FileSystemTool Remove WorkspaceSearchTool and register grep and glob on the file_system server alongside read/write/edit. Add read/edit/write include aliases and optional grep_head_limit, glob_max_files, and grep_timeout_s on file_system. Update project YAML and prompts to drop workspace_search blocks; document the mapping in shell-grep-glob-workspace-policy.md. Add tests for include aliases and grep/glob filtering. Made-with: Cursor --- ms_agent/tools/filesystem_tool.py | 454 +++++++++++++++++++- ms_agent/tools/tool_manager.py | 6 - ms_agent/tools/workspace_search_tool.py | 475 --------------------- projects/code_genesis/coding.yaml | 14 +- projects/code_genesis/refine.yaml | 4 +- projects/deep_research/v2/reporter.yaml | 4 +- projects/deep_research/v2/researcher.yaml | 4 +- projects/deep_research/v2/searcher.yaml | 4 +- projects/fin_research/aggregator.yaml | 6 +- projects/fin_research/collector.yaml | 4 +- projects/singularity_cinema/agent.yaml | 2 - shell-grep-glob-workspace-policy.md | 2 +- tests/utils/test_filesystem_tool_config.py | 54 +++ 13 files changed, 528 insertions(+), 505 deletions(-) delete mode 100644 ms_agent/tools/workspace_search_tool.py create mode 100644 tests/utils/test_filesystem_tool_config.py diff --git a/ms_agent/tools/filesystem_tool.py b/ms_agent/tools/filesystem_tool.py index df01d0261..73cb0a442 100644 --- a/ms_agent/tools/filesystem_tool.py +++ b/ms_agent/tools/filesystem_tool.py @@ -1,17 +1,35 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import asyncio import base64 +import fnmatch +import json import os +import re +import shutil from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Dict, List, Optional -import json from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger +from ms_agent.utils.artifact_manager import ArtifactManager from ms_agent.utils.constants import DEFAULT_INDEX_DIR, DEFAULT_OUTPUT_DIR +from ms_agent.utils.workspace_policy import WorkspacePolicyError, WorkspacePolicyKernel logger = get_logger() +_FS_TOOL_ALIASES = {'read': 'read_file', 'edit': 'edit_file', 'write': 'write_file'} + +_TEXT_SUFFIXES = { + '.py', '.md', '.txt', '.yaml', '.yml', '.json', '.toml', '.cfg', '.ini', + '.sh', '.bash', '.js', '.ts', '.tsx', '.jsx', '.css', '.html', '.xml', + '.rs', '.go', '.java', '.c', '.h', '.cpp', '.hpp', '.cs', '.rb', '.php', + '.sql', '.vue', '.svelte', '.m', '.swift', '.kt', '.gradle', '.properties', + '.env', '.gitignore', '.dockerignore', 'Dockerfile', +} + class FileSystemTool(ToolBase): """A file system operation tool""" @@ -41,6 +59,14 @@ class FileSystemTool(ToolBase): def __init__(self, config, **kwargs): super().__init__(config) self.exclude_func(getattr(config.tools, 'file_system', None)) + if self.include_functions: + self.include_functions = [ + _FS_TOOL_ALIASES.get(n, n) for n in self.include_functions + ] + if self.exclude_functions: + self.exclude_functions = [ + _FS_TOOL_ALIASES.get(n, n) for n in self.exclude_functions + ] self.output_dir = getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR) self.trust_remote_code = kwargs.get('trust_remote_code', False) self.allow_read_all_files = getattr( @@ -58,6 +84,44 @@ def __init__(self, config, **kwargs): # {real_path: {"mtime": float, "offset": int|None, "limit": int|None}} self._read_cache: dict[str, dict] = {} + fs_cfg = getattr(config.tools, 'file_system', None) + self._grep_timeout = int(getattr(fs_cfg, 'grep_timeout_s', 120) or 120) + self._default_grep_head = int( + getattr(fs_cfg, 'grep_head_limit', 250) or 250) + self._glob_max_files = int(getattr(fs_cfg, 'glob_max_files', 100) or 100) + + wp = getattr(getattr(config, 'tools', None), 'workspace_policy', None) + extra = list(getattr(wp, 'allow_roots', []) or []) if wp else [] + deny = list(getattr(wp, 'deny_globs', []) or []) if wp else [] + + shell_cfg = getattr( + getattr(config.tools, 'code_executor', None), 'shell', None) + shell_mode = getattr(shell_cfg, 'default_mode', + 'workspace_write') if shell_cfg else 'workspace_write' + net = bool(getattr(shell_cfg, 'network_enabled', False) + ) if shell_cfg else False + max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192) + ) if shell_cfg else 8192 + + _out_p = Path(self.output_dir).expanduser().resolve() + try: + _out_p.mkdir(parents=True, exist_ok=True) + except OSError: + pass + self._fs_policy = WorkspacePolicyKernel( + _out_p, + extra_allow_roots=extra, + deny_globs=deny if deny else None, + shell_default_mode=str(shell_mode), + shell_network_enabled=net, + max_command_chars=max_cmd, + ) + max_kb = 256 + if shell_cfg and getattr(shell_cfg, 'max_output_kb', None): + max_kb = int(shell_cfg.max_output_kb) + self._fs_artifacts = ArtifactManager( + _out_p, max_combined_bytes=max_kb * 1024) + async def connect(self): logger.warning_once( '[IMPORTANT]FileSystemTool is not implemented with sandbox, please consider other similar ' @@ -176,6 +240,77 @@ async def _get_tools_inner(self): 'required': ['path', 'old_string', 'new_string'], 'additionalProperties': False }), + Tool( + tool_name='grep', + server_name='file_system', + description=( + 'Search file contents under the workspace using ripgrep when available, ' + 'otherwise a safe Python scan. Paths must stay under the configured output/workspace roots. ' + 'Read-only.' + ), + parameters={ + 'type': 'object', + 'properties': { + 'pattern': { + 'type': 'string', + 'description': 'Regular expression (Rust regex if rg is used).', + }, + 'path': { + 'type': 'string', + 'description': + 'Directory or file to search (relative to output_dir if not absolute). Default ".".', + }, + 'glob': { + 'type': 'string', + 'description': 'Optional glob filter for files, e.g. "*.py"', + }, + 'output_mode': { + 'type': 'string', + 'enum': ['content', 'files_with_matches', 'count'], + 'description': + 'content: matching lines; files_with_matches: paths only; count: per-file counts', + }, + 'head_limit': { + 'type': 'integer', + 'description': 'Max lines (content) or paths/count entries to return', + }, + 'offset': { + 'type': 'integer', + 'description': 'Skip first N lines/entries after collect', + }, + 'case_insensitive': { + 'type': 'boolean', + 'description': 'Case-insensitive search', + }, + }, + 'required': ['pattern'], + 'additionalProperties': False, + }, + ), + Tool( + tool_name='glob', + server_name='file_system', + description=( + 'List files under a workspace directory matching a glob pattern ' + '(e.g. "**/*.py", "*.md"). Read-only; results are capped.' + ), + parameters={ + 'type': 'object', + 'properties': { + 'pattern': { + 'type': 'string', + 'description': 'Glob pattern relative to path', + }, + 'path': { + 'type': 'string', + 'description': + 'Base directory (relative to output_dir if not absolute).', + }, + }, + 'required': ['pattern'], + 'additionalProperties': False, + }, + ), ] } @@ -185,6 +320,268 @@ async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await getattr(self, tool_name)(**tool_args) + async def grep( + self, + pattern: str, + path: str = '.', + glob: Optional[str] = None, + output_mode: str = 'files_with_matches', + head_limit: Optional[int] = None, + offset: Optional[int] = None, + case_insensitive: bool = False, + ) -> str: + call_id = f'grep-{pattern[:40]}' + head_limit = (head_limit if head_limit is not None else + self._default_grep_head) + offset = offset or 0 + path = path or '.' + try: + root = self._fs_policy.resolve_under_roots(path) + except WorkspacePolicyError as e: + return json.dumps({'success': False, 'error': str(e)}, indent=2) + + lines: List[str] = [] + try: + rg = shutil.which('rg') + if rg and root.is_file(): + lines = await self._grep_rg_file(rg, pattern, root, + case_insensitive, output_mode, + head_limit, offset, glob) + elif rg and root.is_dir(): + lines = await self._grep_rg_dir(rg, pattern, root, + case_insensitive, output_mode, + head_limit, offset, glob) + else: + lines = self._grep_python( + pattern, + root, + glob, + output_mode, + head_limit, + offset, + case_insensitive, + ) + except Exception as e: + logger.warning('grep failed: %s', e, exc_info=True) + return json.dumps({'success': False, 'error': str(e)}, indent=2) + + text = '\n'.join(lines) + packed = self._fs_artifacts.pack_text_result( + tool_name='grep', + call_id=call_id, + stdout=text, + stderr='', + extra={ + 'success': True, + 'output_mode': output_mode, + 'num_lines': len(lines), + }, + ) + return json.dumps(packed, ensure_ascii=False, indent=2, default=str) + + async def _grep_rg_file( + self, + rg: str, + pattern: str, + file_path: Path, + case_insensitive: bool, + output_mode: str, + head_limit: int, + offset: int, + glob_pat: Optional[str], + ) -> List[str]: + args = [rg, '--no-heading', '--color', 'never'] + if case_insensitive: + args.append('-i') + if glob_pat: + args.extend(['--glob', glob_pat]) + if output_mode == 'files_with_matches': + args.extend(['-l', pattern, str(file_path)]) + elif output_mode == 'count': + args.extend(['-c', pattern, str(file_path)]) + else: + args.extend(['-n', pattern, str(file_path)]) + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=str(self._fs_policy.workspace_root), + ) + out_b, err_b = await asyncio.wait_for(proc.communicate(), + timeout=self._grep_timeout) + out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') + err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') + if proc.returncode not in (0, 1): + raise RuntimeError(err or f'rg exited {proc.returncode}') + lines = [ln for ln in out.split('\n') if ln] if out else [] + return _apply_offset_limit(lines, offset, head_limit) + + async def _grep_rg_dir( + self, + rg: str, + pattern: str, + root: Path, + case_insensitive: bool, + output_mode: str, + head_limit: int, + offset: int, + glob_pat: Optional[str], + ) -> List[str]: + args = [rg, '--no-heading', '--color', 'never'] + if case_insensitive: + args.append('-i') + if glob_pat: + args.extend(['--glob', glob_pat]) + if output_mode == 'files_with_matches': + args.extend(['-l', pattern, str(root)]) + elif output_mode == 'count': + args.extend(['--count-matches', pattern, str(root)]) + else: + args.extend(['-n', pattern, str(root)]) + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=str(self._fs_policy.workspace_root), + ) + out_b, err_b = await asyncio.wait_for(proc.communicate(), + timeout=self._grep_timeout) + out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') + err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') + if proc.returncode not in (0, 1): + raise RuntimeError(err or f'rg exited {proc.returncode}') + lines = [ln for ln in out.split('\n') if ln] if out else [] + return _apply_offset_limit(lines, offset, head_limit) + + def _grep_python( + self, + pattern: str, + root: Path, + glob_pat: Optional[str], + output_mode: str, + head_limit: int, + offset: int, + case_insensitive: bool, + ) -> List[str]: + flags = re.IGNORECASE if case_insensitive else 0 + try: + rx = re.compile(pattern, flags) + except re.error as e: + return [f'[error] invalid regex: {e}'] + lines_out: List[str] = [] + counts: Dict[str, int] = {} + + def consider_file(fp: Path) -> bool: + if glob_pat: + rel = str(fp.relative_to(root)) if root.is_dir() else fp.name + if not fnmatch.fnmatch(fp.name, glob_pat) and not fnmatch.fnmatch( + rel, glob_pat): + return False + suf = fp.suffix.lower() + if suf not in _TEXT_SUFFIXES and fp.suffix == '': + if fp.name not in ('Dockerfile', 'Makefile', 'README'): + return False + return fp.is_file() + + files: List[Path] = [] + if root.is_file(): + files = [root] + else: + for fp in _walk_files_limited(root, self._fs_policy.deny_globs, + 50_000): + if consider_file(fp): + files.append(fp) + + for fp in files: + try: + text = fp.read_text(encoding='utf-8', errors='replace') + except OSError: + continue + rel = str(fp.relative_to(self._fs_policy.workspace_root) + ) if _is_relative(fp, self._fs_policy.workspace_root) else str( + fp) + if output_mode == 'files_with_matches': + if rx.search(text): + lines_out.append(rel) + elif output_mode == 'count': + n = len(rx.findall(text)) + if n: + counts[rel] = n + else: + for i, line in enumerate(text.splitlines(), start=1): + if rx.search(line): + lines_out.append(f'{rel}:{i}:{line}') + if len(lines_out) >= head_limit + offset + 5000: + break + + if output_mode == 'count': + lines_out = [f'{k}:{v}' for k, v in sorted(counts.items())] + return _apply_offset_limit(lines_out, offset, head_limit) + + async def glob(self, pattern: str, path: str = '') -> str: + call_id = f'glob-{pattern[:40]}' + try: + base = self._fs_policy.resolve_under_roots(path or '.') + except WorkspacePolicyError as e: + return json.dumps({'success': False, 'error': str(e)}, indent=2) + + if not base.is_dir(): + return json.dumps( + { + 'success': False, + 'error': f'Not a directory: {path}', + }, + indent=2, + ) + + matches: List[str] = [] + truncated = False + deny = self._fs_policy.deny_globs + + try: + for p in sorted(base.glob(pattern)): + if not p.is_file(): + continue + rp = p.resolve() + if not self._fs_policy.path_is_allowed(rp): + continue + if _is_denied_path(rp, base, deny): + continue + rel = str(p.relative_to(self._fs_policy.workspace_root) + ) if _is_relative(p, self._fs_policy.workspace_root + ) else str(p) + matches.append(rel) + if len(matches) >= self._glob_max_files: + truncated = True + break + except ValueError: + return json.dumps( + { + 'success': False, + 'error': 'Invalid glob pattern', + }, + indent=2, + ) + + text = json.dumps( + { + 'success': True, + 'num_files': len(matches), + 'filenames': matches, + 'truncated': truncated, + }, + ensure_ascii=False, + indent=2, + ) + packed = self._fs_artifacts.pack_text_result( + tool_name='glob', + call_id=call_id, + stdout=text, + stderr='', + extra={'success': True}, + ) + return json.dumps(packed, ensure_ascii=False, indent=2, default=str) + def _check_staleness(self, real_path: str) -> str | None: """Return an error string if the file has not been read or has changed since last read. Returns None if the write is safe to proceed. @@ -577,3 +974,58 @@ async def edit_file(self, return f'Edit file <{path}> successfully ({replaced} occurrence(s) replaced).' except Exception as e: return f'Edit file <{path}> failed, error: ' + str(e) + + +def _apply_offset_limit(lines: List[str], offset: int, + head_limit: int) -> List[str]: + if offset: + lines = lines[offset:] + if head_limit and head_limit > 0: + lines = lines[:head_limit] + return lines + + +def _is_relative(path: Path, base: Path) -> bool: + try: + path.relative_to(base) + return True + except ValueError: + return False + + +def _is_denied_path(path: Path, root: Path, deny: tuple[str, ...]) -> bool: + if not deny: + return False + try: + rel = path.relative_to(root).as_posix() + except ValueError: + rel = path.as_posix() + for pat in deny: + if fnmatch.fnmatch(rel, pat): + return True + return False + + +def _walk_files_limited(root: Path, deny: tuple[str, ...], + max_files: int) -> List[Path]: + out: List[Path] = [] + for dirpath, dirnames, filenames in os.walk( + root, topdown=True, followlinks=False): + dp = Path(dirpath) + pruned = [] + for d in list(dirnames): + child = dp / d + try: + rel = child.relative_to(root).as_posix() + except ValueError: + rel = child.as_posix() + skip = any(fnmatch.fnmatch(rel, p) for p in deny) + if skip: + continue + pruned.append(d) + dirnames[:] = pruned + for name in filenames: + out.append(dp / name) + if len(out) >= max_files: + return out + return out diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 3c8f21339..a51290444 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -88,12 +88,6 @@ def __init__(self, self.extra_tools.append(TodoListTool(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'web_search'): self.extra_tools.append(WebSearchTool(config)) - ws = getattr(getattr(config, 'tools', None), 'workspace_search', None) - _ws_enabled = True if ws is None else bool(getattr(ws, 'enabled', True)) - if _ws_enabled: - from ms_agent.tools.workspace_search_tool import WorkspaceSearchTool - - self.extra_tools.append(WorkspaceSearchTool(config)) self.tool_call_timeout = getattr(config, 'tool_call_timeout', TOOL_CALL_TIMEOUT) local_dir = self.config.local_dir if hasattr(self.config, diff --git a/ms_agent/tools/workspace_search_tool.py b/ms_agent/tools/workspace_search_tool.py deleted file mode 100644 index c5559f8f9..000000000 --- a/ms_agent/tools/workspace_search_tool.py +++ /dev/null @@ -1,475 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Read-only workspace search: grep (rg or Python fallback) and glob.""" - -from __future__ import annotations - -import asyncio -import fnmatch -import json -import os -import re -import shutil -from pathlib import Path -from typing import Any, Dict, List, Optional - -from ms_agent.llm.utils import Tool -from ms_agent.tools.base import ToolBase -from ms_agent.utils import get_logger -from ms_agent.utils.artifact_manager import ArtifactManager -from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR -from ms_agent.utils.workspace_policy import WorkspacePolicyError, WorkspacePolicyKernel - -logger = get_logger() - -_TEXT_SUFFIXES = { - '.py', '.md', '.txt', '.yaml', '.yml', '.json', '.toml', '.cfg', '.ini', - '.sh', '.bash', '.js', '.ts', '.tsx', '.jsx', '.css', '.html', '.xml', - '.rs', '.go', '.java', '.c', '.h', '.cpp', '.hpp', '.cs', '.rb', '.php', - '.sql', '.vue', '.svelte', '.m', '.swift', '.kt', '.gradle', '.properties', - '.env', '.gitignore', '.dockerignore', 'Dockerfile', -} - - -class WorkspaceSearchTool(ToolBase): - """Grep and glob under output_dir (+ optional extra roots) with shared policy.""" - - def __init__(self, config, **kwargs): - super().__init__(config) - self.output_dir = Path( - getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR)).expanduser().resolve() - self.output_dir.mkdir(parents=True, exist_ok=True) - - wp = getattr(getattr(config, 'tools', None), 'workspace_policy', None) - extra = [] - deny: list[str] = [] - if wp is not None: - extra = list(getattr(wp, 'allow_roots', []) or []) - deny = list(getattr(wp, 'deny_globs', []) or []) - else: - deny = [] - ws = getattr(getattr(config, 'tools', None), 'workspace_search', None) - self._default_head = int(getattr(ws, 'default_head_limit', 250) or 250) - self._glob_max = int(getattr(ws, 'max_files', 100) or 100) - self._grep_timeout = int(getattr(ws, 'grep_timeout_s', 120) or 120) - - shell_cfg = getattr( - getattr(config.tools, 'code_executor', None), 'shell', None) - shell_mode = getattr(shell_cfg, 'default_mode', - 'workspace_write') if shell_cfg else 'workspace_write' - net = bool(getattr(shell_cfg, 'network_enabled', False) - ) if shell_cfg else False - max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192) - ) if shell_cfg else 8192 - - self._policy = WorkspacePolicyKernel( - self.output_dir, - extra_allow_roots=extra, - deny_globs=deny if deny else None, - shell_default_mode=str(shell_mode), - shell_network_enabled=net, - max_command_chars=max_cmd, - ) - max_kb = 256 - if shell_cfg and getattr(shell_cfg, 'max_output_kb', None): - max_kb = int(shell_cfg.max_output_kb) - self._artifacts = ArtifactManager( - self.output_dir, max_combined_bytes=max_kb * 1024) - - self.exclude_func(ws) - - async def connect(self) -> None: - return - - async def _get_tools_inner(self) -> Dict[str, Any]: - return { - 'workspace_search': [ - Tool( - tool_name='grep_files', - server_name='workspace_search', - description=( - 'Search file contents under the workspace using ripgrep when available, ' - 'otherwise a safe Python scan. Paths must stay under the configured output/workspace roots. ' - 'Read-only.' - ), - parameters={ - 'type': 'object', - 'properties': { - 'pattern': { - 'type': 'string', - 'description': 'Regular expression (Rust regex if rg is used).', - }, - 'path': { - 'type': 'string', - 'description': - 'Directory or file to search (relative to output_dir if not absolute). Default ".".', - }, - 'glob': { - 'type': 'string', - 'description': 'Optional glob filter for files, e.g. "*.py"', - }, - 'output_mode': { - 'type': 'string', - 'enum': ['content', 'files_with_matches', 'count'], - 'description': 'content: matching lines; files_with_matches: paths only; count: per-file counts', - }, - 'head_limit': { - 'type': 'integer', - 'description': 'Max lines (content) or paths/count entries to return', - }, - 'offset': { - 'type': 'integer', - 'description': 'Skip first N lines/entries after collect', - }, - 'case_insensitive': { - 'type': 'boolean', - 'description': 'Case-insensitive search', - }, - }, - 'required': ['pattern'], - 'additionalProperties': False, - }, - ), - Tool( - tool_name='glob_files', - server_name='workspace_search', - description=( - 'List files under a workspace directory matching a glob pattern ' - '(e.g. "**/*.py", "*.md"). Read-only; results are capped.' - ), - parameters={ - 'type': 'object', - 'properties': { - 'pattern': { - 'type': 'string', - 'description': 'Glob pattern relative to path', - }, - 'path': { - 'type': 'string', - 'description': 'Base directory (relative to output_dir if not absolute).', - }, - }, - 'required': ['pattern'], - 'additionalProperties': False, - }, - ), - ] - } - - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: - return await getattr(self, tool_name)(**tool_args) - - async def grep_files( - self, - pattern: str, - path: str = '.', - glob: Optional[str] = None, - output_mode: str = 'files_with_matches', - head_limit: Optional[int] = None, - offset: Optional[int] = None, - case_insensitive: bool = False, - ) -> str: - call_id = f'grep-{pattern[:40]}' - head_limit = head_limit if head_limit is not None else self._default_head - offset = offset or 0 - path = path or '.' - try: - root = self._policy.resolve_under_roots(path) - except WorkspacePolicyError as e: - return json.dumps({'success': False, 'error': str(e)}, indent=2) - - lines: List[str] = [] - try: - rg = shutil.which('rg') - if rg and root.is_file(): - lines = await self._rg_file(rg, pattern, root, case_insensitive, - output_mode, head_limit, offset, - glob) - elif rg and root.is_dir(): - lines = await self._rg_dir(rg, pattern, root, case_insensitive, - output_mode, head_limit, offset, - glob) - else: - lines = self._python_grep( - pattern, - root, - glob, - output_mode, - head_limit, - offset, - case_insensitive, - ) - except Exception as e: - logger.warning('grep_files failed: %s', e, exc_info=True) - return json.dumps({'success': False, 'error': str(e)}, indent=2) - - text = '\n'.join(lines) - packed = self._artifacts.pack_text_result( - tool_name='grep_files', - call_id=call_id, - stdout=text, - stderr='', - extra={ - 'success': True, - 'output_mode': output_mode, - 'num_lines': len(lines), - }, - ) - return json.dumps(packed, ensure_ascii=False, indent=2, default=str) - - async def _rg_file( - self, - rg: str, - pattern: str, - file_path: Path, - case_insensitive: bool, - output_mode: str, - head_limit: int, - offset: int, - glob: Optional[str], - ) -> List[str]: - args = [rg, '--no-heading', '--color', 'never'] - if case_insensitive: - args.append('-i') - if glob: - args.extend(['--glob', glob]) - if output_mode == 'files_with_matches': - args.extend(['-l', pattern, str(file_path)]) - elif output_mode == 'count': - args.extend(['-c', pattern, str(file_path)]) - else: - args.extend(['-n', pattern, str(file_path)]) - proc = await asyncio.create_subprocess_exec( - *args, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - cwd=str(self._policy.workspace_root), - ) - out_b, err_b = await asyncio.wait_for(proc.communicate(), - timeout=self._grep_timeout) - out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') - err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') - if proc.returncode not in (0, 1): - raise RuntimeError(err or f'rg exited {proc.returncode}') - lines = [ln for ln in out.split('\n') if ln] if out else [] - return _apply_offset_limit(lines, offset, head_limit) - - async def _rg_dir( - self, - rg: str, - pattern: str, - root: Path, - case_insensitive: bool, - output_mode: str, - head_limit: int, - offset: int, - glob: Optional[str], - ) -> List[str]: - args = [rg, '--no-heading', '--color', 'never'] - if case_insensitive: - args.append('-i') - if glob: - args.extend(['--glob', glob]) - if output_mode == 'files_with_matches': - args.extend(['-l', pattern, str(root)]) - elif output_mode == 'count': - args.extend(['--count-matches', pattern, str(root)]) - else: - args.extend(['-n', pattern, str(root)]) - proc = await asyncio.create_subprocess_exec( - *args, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - cwd=str(self._policy.workspace_root), - ) - out_b, err_b = await asyncio.wait_for(proc.communicate(), - timeout=self._grep_timeout) - out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') - err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') - if proc.returncode not in (0, 1): - raise RuntimeError(err or f'rg exited {proc.returncode}') - lines = [ln for ln in out.split('\n') if ln] if out else [] - return _apply_offset_limit(lines, offset, head_limit) - - def _python_grep( - self, - pattern: str, - root: Path, - glob_pat: Optional[str], - output_mode: str, - head_limit: int, - offset: int, - case_insensitive: bool, - ) -> List[str]: - flags = re.IGNORECASE if case_insensitive else 0 - try: - rx = re.compile(pattern, flags) - except re.error as e: - return [f'[error] invalid regex: {e}'] - lines_out: List[str] = [] - counts: Dict[str, int] = {} - - def consider_file(fp: Path) -> bool: - if glob_pat: - rel = str(fp.relative_to(root)) if root.is_dir() else fp.name - if not fnmatch.fnmatch(fp.name, glob_pat) and not fnmatch.fnmatch( - rel, glob_pat): - return False - suf = fp.suffix.lower() - if suf not in _TEXT_SUFFIXES and fp.suffix == '': - if fp.name not in ('Dockerfile', 'Makefile', 'README'): - return False - return fp.is_file() - - files: List[Path] = [] - if root.is_file(): - files = [root] - else: - for fp in _walk_files_limited(root, self._policy.deny_globs, 50_000): - if consider_file(fp): - files.append(fp) - - for fp in files: - try: - text = fp.read_text(encoding='utf-8', errors='replace') - except OSError: - continue - rel = str(fp.relative_to(self._policy.workspace_root)) if _is_relative( - fp, self._policy.workspace_root) else str(fp) - if output_mode == 'files_with_matches': - if rx.search(text): - lines_out.append(rel) - elif output_mode == 'count': - n = len(rx.findall(text)) - if n: - counts[rel] = n - else: - for i, line in enumerate(text.splitlines(), start=1): - if rx.search(line): - lines_out.append(f'{rel}:{i}:{line}') - if len(lines_out) >= head_limit + offset + 5000: - break - - if output_mode == 'count': - lines_out = [f'{k}:{v}' for k, v in sorted(counts.items())] - return _apply_offset_limit(lines_out, offset, head_limit) - - async def glob_files(self, pattern: str, path: str = '') -> str: - call_id = f'glob-{pattern[:40]}' - try: - base = self._policy.resolve_under_roots(path or '.') - except WorkspacePolicyError as e: - return json.dumps({'success': False, 'error': str(e)}, indent=2) - - if not base.is_dir(): - return json.dumps( - { - 'success': False, - 'error': f'Not a directory: {path}', - }, - indent=2, - ) - - matches: List[str] = [] - truncated = False - deny = self._policy.deny_globs - - # Prefer pathlib.glob from base - try: - for p in sorted(base.glob(pattern)): - if not p.is_file(): - continue - rp = p.resolve() - if not self._policy.path_is_allowed(rp): - continue - if _is_denied_path(rp, base, deny): - continue - rel = str(p.relative_to(self._policy.workspace_root)) if _is_relative( - p, self._policy.workspace_root) else str(p) - matches.append(rel) - if len(matches) >= self._glob_max: - truncated = True - break - except ValueError: - # invalid pattern - return json.dumps( - { - 'success': False, - 'error': 'Invalid glob pattern', - }, - indent=2, - ) - - text = json.dumps( - { - 'success': True, - 'num_files': len(matches), - 'filenames': matches, - 'truncated': truncated, - }, - ensure_ascii=False, - indent=2, - ) - packed = self._artifacts.pack_text_result( - tool_name='glob_files', - call_id=call_id, - stdout=text, - stderr='', - extra={'success': True}, - ) - return json.dumps(packed, ensure_ascii=False, indent=2, default=str) - - -def _apply_offset_limit(lines: List[str], offset: int, - head_limit: int) -> List[str]: - if offset: - lines = lines[offset:] - if head_limit and head_limit > 0: - lines = lines[:head_limit] - return lines - - -def _is_relative(path: Path, base: Path) -> bool: - try: - path.relative_to(base) - return True - except ValueError: - return False - - -def _is_denied_path(path: Path, root: Path, deny: tuple[str, ...]) -> bool: - if not deny: - return False - try: - rel = path.relative_to(root).as_posix() - except ValueError: - rel = path.as_posix() - for pat in deny: - if fnmatch.fnmatch(rel, pat): - return True - return False - - -def _walk_files_limited(root: Path, deny: tuple[str, ...], - max_files: int) -> List[Path]: - out: List[Path] = [] - for dirpath, dirnames, filenames in os.walk( - root, topdown=True, followlinks=False): - dp = Path(dirpath) - pruned = [] - for d in list(dirnames): - child = dp / d - try: - rel = child.relative_to(root).as_posix() - except ValueError: - rel = child.as_posix() - skip = any(fnmatch.fnmatch(rel, p) for p in deny) - if skip: - continue - pruned.append(d) - dirnames[:] = pruned - for name in filenames: - out.append(dp / name) - if len(out) >= max_files: - return out - return out diff --git a/projects/code_genesis/coding.yaml b/projects/code_genesis/coding.yaml index b9f398697..cc99b728f 100644 --- a/projects/code_genesis/coding.yaml +++ b/projects/code_genesis/coding.yaml @@ -34,7 +34,7 @@ prompt: - `postcss.config.js` enables `tailwindcss` and `autoprefixer` 3. CRITICAL: Before reading ANY file: - * FIRST use `workspace_search---glob_files` (e.g. pattern `**/*`, path `.`) to list paths under the project, or `code_executor---shell_executor` with a read-only `ls`/`find` command if you prefer the shell. + * FIRST use `file_system---glob` (e.g. pattern `**/*`, path `.`) to list paths under the project, or `code_executor---shell_executor` with a read-only `ls`/`find` command if you prefer the shell. * NEVER read files that do not appear in the output * NEVER attempt to read files with index >= yours (they don't exist yet) * NEVER guess or assume a file exists - always verify first @@ -42,7 +42,7 @@ prompt: 3.1 CRITICAL (no hallucination about files): * Do not fully trust `protocol.txt` content; you must verify it yourself by checking existing files and reading the exact source of truth. * Before you reference/cite ANY information that comes from a file (APIs, exports, config values, routes, CSS class names, build scripts, ports, etc.), you MUST: - 1) Confirm the file exists via `glob_files` (or shell listing), AND + 1) Confirm the file exists via `file_system---glob` (or shell listing), AND 2) Read the relevant part of that exact file via `read_file`. * You are NOT allowed to infer file contents. * If the needed file is missing from the file list, do not reference it; either create it (if allowed by your index constraints) or implement the needed logic in your current file. @@ -113,15 +113,15 @@ prompt: 2. [Secondary] Use as few tokens as possible. tools: - workspace_search: - mcp: false file_system: mcp: false allow_read_all_files: true include: - - read_file - - edit_file - - write_file + - read + - edit + - write + - grep + - glob edit_file_config: diff_model: morph-v3-fast api_key: diff --git a/projects/code_genesis/refine.yaml b/projects/code_genesis/refine.yaml index 8dd78adcd..8cea4e407 100644 --- a/projects/code_genesis/refine.yaml +++ b/projects/code_genesis/refine.yaml @@ -91,14 +91,14 @@ tools: - file_operation - reset_executor - get_executor_info - workspace_search: - mcp: false file_system: mcp: false include: - read_file - write_file - edit_file + - grep + - glob edit_file_config: diff_model: morph-v3-fast api_key: diff --git a/projects/deep_research/v2/reporter.yaml b/projects/deep_research/v2/reporter.yaml index aa19d1114..91dbb303b 100644 --- a/projects/deep_research/v2/reporter.yaml +++ b/projects/deep_research/v2/reporter.yaml @@ -29,14 +29,14 @@ prompt: family: gpt5 tools: - workspace_search: - mcp: false file_system: mcp: false include: - write_file - read_file - edit_file + - grep + - glob evidence_store: mcp: false evidence_dir: evidence diff --git a/projects/deep_research/v2/researcher.yaml b/projects/deep_research/v2/researcher.yaml index 47eb0ab12..5c1155f9b 100644 --- a/projects/deep_research/v2/researcher.yaml +++ b/projects/deep_research/v2/researcher.yaml @@ -28,14 +28,14 @@ prompt: tools: - workspace_search: - mcp: false file_system: mcp: false include: - write_file - read_file - edit_file + - grep + - glob code_executor: mcp: false implementation: python_env diff --git a/projects/deep_research/v2/searcher.yaml b/projects/deep_research/v2/searcher.yaml index 0e015a39d..208d4aae9 100644 --- a/projects/deep_research/v2/searcher.yaml +++ b/projects/deep_research/v2/searcher.yaml @@ -28,13 +28,13 @@ prompt: tools: - workspace_search: - mcp: false file_system: mcp: false include: - write_file - read_file + - grep + - glob web_search: mcp: false engines: diff --git a/projects/fin_research/aggregator.yaml b/projects/fin_research/aggregator.yaml index 70c7bf449..15cc8452e 100644 --- a/projects/fin_research/aggregator.yaml +++ b/projects/fin_research/aggregator.yaml @@ -84,7 +84,7 @@ prompt: - You may interpret the meaning of embedded images from their surrounding context \ in the financial data analysis report. - You should note that most of the charts generated during the analysis process \ - are stored in the default working directory. Use `workspace_search---glob_files` (e.g. pattern `**/*.{png,jpg,svg}`) \ + are stored in the default working directory. Use `file_system---glob` (e.g. pattern `**/*.{png,jpg,svg}`) \ or `file_system---read_file` to inspect filenames and embed meaningful images into the report. - Avoid inserting any images whose titles or filenames do not convey meaningful information. - Avoid inserting any images that are from websites or other external sources. @@ -127,8 +127,6 @@ prompt: tools: - workspace_search: - mcp: false code_executor: mcp: false implementation: python_env @@ -143,6 +141,8 @@ tools: - read_file - write_file - edit_file + - grep + - glob spec_loader: mcp: false plugins: diff --git a/projects/fin_research/collector.yaml b/projects/fin_research/collector.yaml index 25827516e..7fd1967a1 100644 --- a/projects/fin_research/collector.yaml +++ b/projects/fin_research/collector.yaml @@ -147,14 +147,14 @@ tools: max_requests_per_second: 1 # Maximum requests per second min_request_interval: 1 # Minimum interval between requests (seconds) max_concurrent: 1 # Maximum concurrent requests (reduced to work with thread semaphore) - workspace_search: - mcp: false file_system: mcp: false include: - write_file - read_file - edit_file + - grep + - glob handler: time_handler diff --git a/projects/singularity_cinema/agent.yaml b/projects/singularity_cinema/agent.yaml index 8b374ed41..14fe6809b 100644 --- a/projects/singularity_cinema/agent.yaml +++ b/projects/singularity_cinema/agent.yaml @@ -275,8 +275,6 @@ fonts: - Microsoft YaHei tools: - workspace_search: - mcp: false file_system: mcp: false allow_read_all_files: true diff --git a/shell-grep-glob-workspace-policy.md b/shell-grep-glob-workspace-policy.md index 9e9f86bbd..ac4e3f912 100644 --- a/shell-grep-glob-workspace-policy.md +++ b/shell-grep-glob-workspace-policy.md @@ -217,7 +217,7 @@ tools: | ArtifactManager | `ms_agent/utils/artifact_manager.py` | | TaskManager | `ms_agent/utils/task_manager.py` | | Shell 策略 / 产物 / 后台 | `ms_agent/tools/code/local_code_executor.py`(`set_task_manager`、`shell_executor`) | -| Grep / Glob | `ms_agent/tools/workspace_search_tool.py`(默认注册;`tools.workspace_search.enabled: false` 可关闭) | +| Grep / Glob | `ms_agent/tools/filesystem_tool.py` 中 `grep` / `glob` 工具(与 `read_file` / `edit_file` / `write_file` 同属 `file_system` server;用 `tools.file_system.include` / `exclude` 控制)。可选键:`grep_timeout_s`、`grep_head_limit`、`glob_max_files`;`include` 短名 `read` / `edit` / `write` 分别等价 `read_file` / `edit_file` / `write_file`。 | | `__call_id` 注入 shell | `ms_agent/tools/tool_manager.py` | | TaskManager 与通知 | `ms_agent/agent/llm_agent.py`(`prepare_tools` / `cleanup_tools` / `_append_task_notifications`) | | 单测 | `tests/utils/test_workspace_policy.py` | diff --git a/tests/utils/test_filesystem_tool_config.py b/tests/utils/test_filesystem_tool_config.py new file mode 100644 index 000000000..5bc8310c9 --- /dev/null +++ b/tests/utils/test_filesystem_tool_config.py @@ -0,0 +1,54 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""FileSystemTool config: include aliases and grep/glob registration.""" + +import asyncio +import tempfile + +from omegaconf import OmegaConf + +from ms_agent.tools.filesystem_tool import FileSystemTool + + +def test_include_short_aliases_expand_to_canonical_names(): + async def _run(): + with tempfile.TemporaryDirectory() as td: + cfg = OmegaConf.create({ + 'output_dir': td, + 'tools': { + 'file_system': { + 'mcp': False, + 'include': ['read', 'write', 'glob'], + }, + }, + }) + fs = FileSystemTool(cfg) + tools = await fs.get_tools() + names = [t['tool_name'] for t in tools['file_system']] + assert 'read_file' in names + assert 'write_file' in names + assert 'glob' in names + assert 'grep' not in names + assert 'read' not in names + assert 'write' not in names + + asyncio.run(_run()) + + +def test_grep_glob_listed_with_full_names(): + async def _run(): + with tempfile.TemporaryDirectory() as td: + cfg = OmegaConf.create({ + 'output_dir': td, + 'tools': { + 'file_system': { + 'mcp': False, + 'include': ['grep', 'glob'], + }, + }, + }) + fs = FileSystemTool(cfg) + tools = await fs.get_tools() + names = [t['tool_name'] for t in tools['file_system']] + assert names == ['grep', 'glob'] + + asyncio.run(_run()) From ec5fb1a3a4591b39807b4c0dc160d8b59c1bae8e Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 20 Apr 2026 16:52:44 +0800 Subject: [PATCH 32/40] feat(search): add Tavily engine, extract fetcher, and dr_bench wiring Add Tavily HTTP client, search/extract schema, WebSearchTool integration, optional large-result spill, researcher/searcher Tavily YAML presets, and run_benchmark env hooks for RESEARCHER_CONFIG / BENCH paths. Made-with: Cursor --- ms_agent/tools/search/search_base.py | 2 + ms_agent/tools/search/tavily/__init__.py | 4 + ms_agent/tools/search/tavily/fetcher.py | 91 ++++ ms_agent/tools/search/tavily/http.py | 50 +++ ms_agent/tools/search/tavily/schema.py | 121 ++++++ ms_agent/tools/search/tavily/search.py | 217 ++++++++++ ms_agent/tools/search/web_search_spill.py | 292 +++++++++++++ ms_agent/tools/search/websearch_tool.py | 389 ++++++++++++++++-- projects/deep_research/.env.example | 2 + .../deep_research/v2/researcher.tavily.yaml | 172 ++++++++ projects/deep_research/v2/run_benchmark.sh | 17 +- .../deep_research/v2/searcher.tavily.yaml | 105 +++++ 12 files changed, 1427 insertions(+), 35 deletions(-) create mode 100644 ms_agent/tools/search/tavily/__init__.py create mode 100644 ms_agent/tools/search/tavily/fetcher.py create mode 100644 ms_agent/tools/search/tavily/http.py create mode 100644 ms_agent/tools/search/tavily/schema.py create mode 100644 ms_agent/tools/search/tavily/search.py create mode 100644 ms_agent/tools/search/web_search_spill.py create mode 100644 projects/deep_research/v2/researcher.tavily.yaml create mode 100644 projects/deep_research/v2/searcher.tavily.yaml diff --git a/ms_agent/tools/search/search_base.py b/ms_agent/tools/search/search_base.py index bb6952729..8a10c9d20 100644 --- a/ms_agent/tools/search/search_base.py +++ b/ms_agent/tools/search/search_base.py @@ -17,6 +17,7 @@ class SearchEngineType(enum.Enum): EXA = 'exa' SERPAPI = 'serpapi' ARXIV = 'arxiv' + TAVILY = 'tavily' # Mapping from engine type to tool name @@ -24,6 +25,7 @@ class SearchEngineType(enum.Enum): 'exa': 'exa_search', 'serpapi': 'serpapi_search', 'arxiv': 'arxiv_search', + 'tavily': 'tavily_search', } diff --git a/ms_agent/tools/search/tavily/__init__.py b/ms_agent/tools/search/tavily/__init__.py new file mode 100644 index 000000000..ef3e913f2 --- /dev/null +++ b/ms_agent/tools/search/tavily/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from ms_agent.tools.search.tavily.search import TavilySearch + +__all__ = ['TavilySearch'] diff --git a/ms_agent/tools/search/tavily/fetcher.py b/ms_agent/tools/search/tavily/fetcher.py new file mode 100644 index 000000000..4082907a2 --- /dev/null +++ b/ms_agent/tools/search/tavily/fetcher.py @@ -0,0 +1,91 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tavily Extract API as ContentFetcher (replaces Jina for fetch_page / URL fetch).""" +import os +import time +from typing import Any, Dict, Optional, Tuple + +from ms_agent.tools.search.tavily.http import post_json +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +TAVILY_EXTRACT_URL = 'https://api.tavily.com/extract' + + +class TavilyExtractFetcher: + """ + Fetch page text via Tavily POST /extract. + + Uses ``extract_depth=advanced`` and ``format=markdown`` by default for + richest structured text (tables, etc., when available). + """ + + def __init__( + self, + api_key: Optional[str] = None, + *, + extract_depth: str = 'advanced', + format: str = 'markdown', + timeout: float = 45.0, + chunks_per_source: int = 3, + include_images: bool = False, + include_favicon: bool = False, + include_usage: bool = False, + ): + key = api_key or os.getenv('TAVILY_API_KEY') + if not key: + raise ValueError( + 'TAVILY_API_KEY required for tavily_extract fetcher') + self._api_key = key + self._extract_depth = extract_depth + self._format = format + self._timeout = max(1.0, min(60.0, float(timeout))) + self._chunks_per_source = max(1, min(5, int(chunks_per_source))) + self._include_images = include_images + self._include_favicon = include_favicon + self._include_usage = include_usage + + def fetch(self, url: str, query: Optional[str] = None) -> Tuple[str, Dict[str, Any]]: + """ + Extract one URL. Optional ``query`` enables chunk reranking (more relevant raw_content). + """ + body: Dict[str, Any] = { + 'api_key': self._api_key, + 'urls': [url], + 'extract_depth': self._extract_depth, + 'format': self._format, + 'timeout': self._timeout, + 'include_images': self._include_images, + 'include_favicon': self._include_favicon, + 'include_usage': self._include_usage, + } + if query: + body['query'] = query + body['chunks_per_source'] = self._chunks_per_source + + try: + data = post_json(TAVILY_EXTRACT_URL, body, timeout=self._timeout + 30.0) + except Exception as e: + logger.warning(f'Tavily extract failed for {url[:80]}: {e}') + return '', { + 'fetcher': 'tavily_extract', + 'error': str(e), + 'fetched_at': time.strftime('%Y-%m-%dT%H:%M:%S'), + } + + results = data.get('results') or [] + text = '' + if results: + text = (results[0].get('raw_content') or '').strip() + meta: Dict[str, Any] = { + 'fetcher': 'tavily_extract', + 'fetched_at': time.strftime('%Y-%m-%dT%H:%M:%S'), + 'tavily_response_time': data.get('response_time'), + 'tavily_usage': data.get('usage'), + 'tavily_request_id': data.get('request_id'), + } + failed = data.get('failed_results') or [] + if failed and not text: + err = failed[0].get('error', 'unknown') + meta['error'] = err + return text, meta diff --git a/ms_agent/tools/search/tavily/http.py b/ms_agent/tools/search/tavily/http.py new file mode 100644 index 000000000..d4916d271 --- /dev/null +++ b/ms_agent/tools/search/tavily/http.py @@ -0,0 +1,50 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Minimal HTTP JSON client for Tavily REST API (stdlib only).""" +import json +from typing import Any, Dict +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + + +def post_json( + url: str, + body: Dict[str, Any], + *, + timeout: float = 120.0, +) -> Dict[str, Any]: + """ + POST JSON and parse JSON response. + + Raises: + RuntimeError: on HTTP errors or invalid JSON (includes Tavily error body). + """ + data = json.dumps(body, ensure_ascii=False).encode('utf-8') + req = Request( + url, + data=data, + method='POST', + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json', + }, + ) + try: + with urlopen(req, timeout=timeout) as resp: + raw = resp.read().decode('utf-8', errors='replace') + if not raw.strip(): + return {} + return json.loads(raw) + except HTTPError as e: + err_body = '' + try: + err_body = e.read().decode('utf-8', errors='replace') + except Exception: + pass + try: + detail = json.loads(err_body) if err_body else {} + except json.JSONDecodeError: + detail = {'raw': err_body} + raise RuntimeError( + f'Tavily HTTP {e.code}: {detail}') from e + except URLError as e: + raise RuntimeError(f'Tavily network error: {e}') from e diff --git a/ms_agent/tools/search/tavily/schema.py b/ms_agent/tools/search/tavily/schema.py new file mode 100644 index 000000000..75f3f0aed --- /dev/null +++ b/ms_agent/tools/search/tavily/schema.py @@ -0,0 +1,121 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class TavilySearchRequest: + """Tavily POST /search body. See https://docs.tavily.com/documentation/api-reference/endpoint/search""" + + query: str + max_results: int = 10 + search_depth: str = 'advanced' + chunks_per_source: int = 3 + topic: str = 'general' + time_range: Optional[str] = None + start_date: Optional[str] = None + end_date: Optional[str] = None + # include_answer: false | true | basic | advanced + include_answer: Any = 'advanced' + # include_raw_content: false | true | markdown | text — use markdown for richest text + include_raw_content: Any = 'markdown' + include_images: bool = False + include_image_descriptions: bool = False + include_favicon: bool = False + include_domains: List[str] = field(default_factory=list) + exclude_domains: List[str] = field(default_factory=list) + country: Optional[str] = None + auto_parameters: bool = False + exact_match: bool = False + include_usage: bool = False + safe_search: bool = False + + def to_api_body(self, api_key: str) -> Dict[str, Any]: + n = max(0, min(20, int(self.max_results))) + body: Dict[str, Any] = { + 'api_key': api_key, + 'query': self.query, + 'max_results': n, + 'search_depth': self.search_depth, + 'topic': self.topic, + 'include_answer': self.include_answer, + 'include_raw_content': self.include_raw_content, + 'include_images': self.include_images, + 'include_image_descriptions': self.include_image_descriptions, + 'include_favicon': self.include_favicon, + 'auto_parameters': self.auto_parameters, + 'exact_match': self.exact_match, + 'include_usage': self.include_usage, + 'safe_search': self.safe_search, + } + # chunks_per_source only meaningful for advanced (per Tavily docs) + if self.search_depth == 'advanced': + body['chunks_per_source'] = max(1, min(3, int(self.chunks_per_source))) + if self.time_range: + body['time_range'] = self.time_range + if self.start_date: + body['start_date'] = self.start_date + if self.end_date: + body['end_date'] = self.end_date + if self.include_domains: + body['include_domains'] = list(self.include_domains)[:300] + if self.exclude_domains: + body['exclude_domains'] = list(self.exclude_domains)[:150] + if self.country: + body['country'] = self.country + return body + + +@dataclass +class TavilySearchResult: + """Parsed Tavily /search JSON.""" + + query: str + arguments: Dict[str, Any] + response: Dict[str, Any] + + def to_list(self) -> List[Dict[str, Any]]: + """Normalize to WebSearchTool pipeline dicts (prefill content when raw_content present).""" + if not self.response: + return [] + rows: List[Dict[str, Any]] = [] + for r in self.response.get('results') or []: + url = r.get('url') or '' + title = r.get('title') or '' + snippet = (r.get('content') or '').strip() + raw = (r.get('raw_content') or '').strip() + # Prefer full page text for downstream summarization; fallback to snippets + body = raw if raw else snippet + rows.append({ + 'url': url, + 'id': url, + 'title': title, + 'highlights': None, + 'highlight_scores': None, + 'summary': snippet, + 'markdown': raw if raw else None, + # Pipeline uses these keys: + 'content': body, + 'fetch_success': bool(raw), + 'score': r.get('score'), + 'tavily_images': r.get('images') or [], + 'favicon': r.get('favicon'), + }) + return rows + + def extra_response_fields(self) -> Dict[str, Any]: + """Top-level fields to merge into web_search JSON output.""" + if not self.response: + return {} + out: Dict[str, Any] = {} + if self.response.get('answer'): + out['tavily_answer'] = self.response['answer'] + if self.response.get('images'): + out['tavily_images'] = self.response['images'] + if self.response.get('response_time') is not None: + out['tavily_response_time'] = self.response['response_time'] + if self.response.get('usage'): + out['tavily_usage'] = self.response['usage'] + if self.response.get('request_id'): + out['tavily_request_id'] = self.response['request_id'] + return out diff --git a/ms_agent/tools/search/tavily/search.py b/ms_agent/tools/search/tavily/search.py new file mode 100644 index 000000000..b4b7d3f3b --- /dev/null +++ b/ms_agent/tools/search/tavily/search.py @@ -0,0 +1,217 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import os +from typing import TYPE_CHECKING, Any, Dict, Optional + +from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType +from ms_agent.tools.search.tavily.http import post_json +from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult +from ms_agent.utils.logger import get_logger + +if TYPE_CHECKING: + from ms_agent.llm.utils import Tool + +logger = get_logger() + +TAVILY_SEARCH_URL = 'https://api.tavily.com/search' + + +class TavilySearch(SearchEngine): + """ + Tavily Search API — optimized for LLM agents. + + Defaults favor maximum usable text: ``search_depth=advanced``, + ``include_raw_content=markdown``, ``include_answer=advanced``, + ``chunks_per_source=3`` (capped by Tavily). + """ + + engine_type = SearchEngineType.TAVILY + + def __init__( + self, + api_key: Optional[str] = None, + request_timeout: float = 120.0, + ): + key = api_key or os.getenv('TAVILY_API_KEY') + if not key: + raise ValueError( + 'TAVILY_API_KEY must be set in environment or web_search.tavily_api_key' + ) + self._api_key = key + self._request_timeout = float(request_timeout) + + def search(self, search_request: TavilySearchRequest) -> TavilySearchResult: + body = search_request.to_api_body(self._api_key) + try: + data = post_json( + TAVILY_SEARCH_URL, body, timeout=self._request_timeout) + except Exception as e: + raise RuntimeError(f'Tavily search failed: {e}') from e + safe_args = {k: v for k, v in body.items() if k != 'api_key'} + safe_args['api_key'] = '' + return TavilySearchResult( + query=search_request.query, + arguments=safe_args, + response=data or {}, + ) + + @classmethod + def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': + from ms_agent.llm.utils import Tool + return Tool( + tool_name=cls.get_tool_name(), + server_name=server_name, + description=( + 'Search the web using Tavily (built for AI agents). ' + 'Returns ranked results with optional full-page markdown via ' + '`include_raw_content`. Use `search_depth` advanced for best ' + 'relevance and richer `content` chunks (higher API credit use).'), + parameters={ + 'type': 'object', + 'properties': { + 'query': { + 'type': 'string', + 'description': 'Search query.', + }, + 'num_results': { + 'type': 'integer', + 'minimum': 1, + 'maximum': 20, + 'description': 'Max results (maps to Tavily max_results). Default 10.', + }, + 'search_depth': { + 'type': 'string', + 'enum': ['advanced', 'basic', 'fast', 'ultra-fast'], + 'description': + ('advanced: best quality, 2 credits; ' + 'basic/fast/ultra-fast: 1 credit (see Tavily docs).'), + }, + 'topic': { + 'type': 'string', + 'enum': ['general', 'news', 'finance'], + 'description': + 'Search category (`news` / `finance` for focused verticals).', + }, + 'time_range': { + 'type': 'string', + 'description': + ('Filter by recency: day, week, month, year or d,w,m,y.'), + }, + 'start_date': { + 'type': 'string', + 'description': 'Results after YYYY-MM-DD.', + }, + 'end_date': { + 'type': 'string', + 'description': 'Results before YYYY-MM-DD.', + }, + 'include_answer': { + 'type': 'string', + 'enum': ['false', 'true', 'basic', 'advanced'], + 'description': + ('LLM answer: true/basic for short, advanced for detailed. ' + 'Use false to skip.'), + }, + 'include_raw_content': { + 'type': 'string', + 'enum': + ['false', 'true', 'markdown', 'text'], + 'description': + ('full page text: markdown (recommended) or text; ' + 'false to skip raw content.'), + }, + 'chunks_per_source': { + 'type': 'integer', + 'minimum': 1, + 'maximum': 3, + 'description': + ('Relevant chunks per URL when search_depth=advanced. ' + 'Each chunk up to ~500 chars in `content` field.'), + }, + 'include_domains': { + 'type': 'array', + 'items': { + 'type': 'string' + }, + 'description': 'Only include these domains (max 300).', + }, + 'exclude_domains': { + 'type': 'array', + 'items': { + 'type': 'string' + }, + 'description': 'Exclude domains (max 150).', + }, + 'country': { + 'type': 'string', + 'description': + ('Boost results from country (e.g. united states). ' + 'See Tavily docs for enum.'), + }, + 'exact_match': { + 'type': 'boolean', + 'description': + 'Only results with exact quoted phrases in query.', + }, + }, + 'required': ['query'], + }, + ) + + @classmethod + def build_request_from_args(cls, **kwargs: Any) -> TavilySearchRequest: + """Build from merged tool args + YAML defaults (see WebSearchTool).""" + + def _boolish(name: str, default: Any) -> Any: + if name not in kwargs: + return default + v = kwargs[name] + if isinstance(v, str) and v.lower() in ('false', 'true'): + return v.lower() == 'true' + return v + + num = kwargs.get('num_results', kwargs.get('max_results', 10)) + try: + num = int(num) + except (TypeError, ValueError): + num = 10 + + try: + cps = int(kwargs.get('chunks_per_source', 3)) + except (TypeError, ValueError): + cps = 3 + + inc_ans = kwargs.get('include_answer', 'advanced') + if isinstance(inc_ans, str) and inc_ans.lower() == 'false': + inc_ans = False + elif isinstance(inc_ans, str) and inc_ans.lower() == 'true': + inc_ans = True + + inc_raw = kwargs.get('include_raw_content', 'markdown') + if isinstance(inc_raw, str) and inc_raw.lower() == 'false': + inc_raw = False + elif isinstance(inc_raw, str) and inc_raw.lower() == 'true': + inc_raw = 'markdown' + + return TavilySearchRequest( + query=kwargs['query'], + max_results=num, + search_depth=str(kwargs.get('search_depth', 'advanced')), + chunks_per_source=cps, + topic=str(kwargs.get('topic', 'general')), + time_range=kwargs.get('time_range'), + start_date=kwargs.get('start_date'), + end_date=kwargs.get('end_date'), + include_answer=inc_ans, + include_raw_content=inc_raw, + include_images=bool(_boolish('include_images', False)), + include_image_descriptions=bool( + _boolish('include_image_descriptions', False)), + include_favicon=bool(_boolish('include_favicon', False)), + include_domains=list(kwargs.get('include_domains') or []), + exclude_domains=list(kwargs.get('exclude_domains') or []), + country=kwargs.get('country'), + auto_parameters=bool(_boolish('auto_parameters', False)), + exact_match=bool(_boolish('exact_match', False)), + include_usage=bool(_boolish('include_usage', False)), + safe_search=bool(_boolish('safe_search', False)), + ) diff --git a/ms_agent/tools/search/web_search_spill.py b/ms_agent/tools/search/web_search_spill.py new file mode 100644 index 000000000..9c8da73c3 --- /dev/null +++ b/ms_agent/tools/search/web_search_spill.py @@ -0,0 +1,292 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Offload oversized web_search tool payloads to disk so LLM context stays bounded. + +When +---- +Estimated inline character volume (per-result ``content``, ``summary``, ``abstract``, +and ``chunks[*].content``) exceeds ``spill_max_inline_chars``. + +Where / lifecycle +----------------- +``{output_dir}/{spill_subdir}/{run_key}/`` — same lifecycle as the task +``output_dir`` (delete the run workdir to reclaim space; no automatic pruning). + +Naming +------ +``run_key = {UTC compact}_{call_id_or_random}`` — unique, sortable, filesystem-safe. + +Files +----- +* ``manifest.json`` — index (query, engine, URLs, paths, sizes, previews). +* ``bodies/{i:03d}.md`` — UTF-8 full text for spilled rows (sections for content / + abstract / chunks). + +Return payload +-------------- +JSON gains ``spill`` with ``digest`` (instructions + quick index) and paths relative +to ``output_dir`` so ``read_file`` can open them. +""" +from __future__ import annotations + +import copy +import json +import os +import re +import time +import uuid +from typing import Any, Dict, List, Tuple + + +def _item_inline_chars(it: Dict[str, Any]) -> int: + n = 0 + for k in ('content', 'summary', 'abstract'): + v = it.get(k) + if isinstance(v, str): + n += len(v) + chunks = it.get('chunks') + if isinstance(chunks, list): + for c in chunks: + if isinstance(c, dict): + s = c.get('content') + if isinstance(s, str): + n += len(s) + return n + + +def _total_inline_chars(items: List[Dict[str, Any]]) -> int: + return sum(_item_inline_chars(x) for x in items) + + +def _preview(text: str, max_chars: int) -> str: + t = (text or '').strip() + if len(t) <= max_chars: + return t + return t[:max_chars].rstrip() + '\n…' + + +def _safe_run_key(call_id: str) -> str: + ts = time.strftime('%Y%m%dT%H%M%SZ', time.gmtime()) + tail = (call_id or '').strip() + tail = re.sub(r'[^a-zA-Z0-9._-]+', '_', tail)[:24] + if not tail: + tail = uuid.uuid4().hex[:12] + return f'{ts}_{tail}' + + +def _build_spill_markdown(item: Dict[str, Any]) -> str: + """Assemble full text for one result row.""" + lines: List[str] = [] + url = item.get('url', '') + title = item.get('title', '') + lines.append(f'# {title or "(no title)"}\n') + lines.append(f'**URL:** {url}\n') + summary = item.get('summary') + if isinstance(summary, str) and summary.strip(): + lines.append('\n## Summary (search snippet)\n\n') + lines.append(summary) + lines.append('\n') + content = item.get('content') + if isinstance(content, str) and content.strip(): + lines.append('\n## Content\n\n') + lines.append(content) + lines.append('\n') + abstract = item.get('abstract') + if isinstance(abstract, str) and abstract.strip(): + lines.append('\n## Abstract\n\n') + lines.append(abstract) + lines.append('\n') + chunks = item.get('chunks') + if isinstance(chunks, list) and chunks: + lines.append('\n## Chunks\n\n') + for c in chunks: + if not isinstance(c, dict): + continue + cid = c.get('chunk_id', '') + body = c.get('content', '') + lines.append(f'### chunk `{cid}`\n\n') + if isinstance(body, str): + lines.append(body) + lines.append('\n') + return ''.join(lines) + + +def _shrink_item_after_spill(item: Dict[str, Any], + spill_preview_chars: int) -> Dict[str, Any]: + """Replace heavy fields with short previews + pointers.""" + out = dict(item) + note = ( + 'Full text spilled to disk; see content_path / manifest_path in parent ' + 'JSON spill block. Use read_file on content_path for this row.') + sm = out.get('summary') + if isinstance(sm, str) and sm.strip(): + out['summary'] = _preview(sm, spill_preview_chars) + out.setdefault('content_note', note) + main = (out.get('content') or '') + if isinstance(main, str) and main.strip(): + out['content'] = _preview(main, spill_preview_chars) + out['content_note'] = note + ab = out.get('abstract') + if isinstance(ab, str) and ab.strip(): + out['abstract'] = _preview(ab, min(800, spill_preview_chars)) + ch = out.get('chunks') + if isinstance(ch, list) and ch: + out['chunks'] = [{ + 'chunk_id': + c.get('chunk_id', ''), + 'content': + _preview(str(c.get('content', '')), min(400, spill_preview_chars)), + } for c in ch if isinstance(c, dict)] + out['chunks_note'] = 'Full chunk bodies are in the spilled markdown file.' + return out + + +def maybe_spill_web_search_payload( + *, + output_dir: str, + spill_subdir: str, + spill_max_inline_chars: int, + spill_preview_chars: int, + query: str, + engine: str, + results: List[Dict[str, Any]], + call_id: str, +) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """ + If total inline chars exceed threshold, spill largest rows first until under. + + Returns: + (possibly_mutated_results, spill_meta_dict) + spill_meta_dict is empty if no spill occurred. + """ + if not output_dir or not results: + return results, {} + + work = copy.deepcopy(results) + total = _total_inline_chars(work) + if total <= spill_max_inline_chars: + return results, {} + + run_key = _safe_run_key(call_id) + root = os.path.abspath(os.path.join(output_dir, spill_subdir, run_key)) + bodies_dir = os.path.join(root, 'bodies') + os.makedirs(bodies_dir, exist_ok=True) + + spilled_indices: List[int] = [] + manifest_rows: List[Dict[str, Any]] = [] + + def order_by_size() -> List[int]: + sizes = [(i, _item_inline_chars(work[i])) for i in range(len(work))] + sizes.sort(key=lambda x: x[1], reverse=True) + return [i for i, sz in sizes if sz > 0] + + while _total_inline_chars(work) > spill_max_inline_chars: + order = order_by_size() + if not order: + break + idx = order[0] + item = work[idx] + if _item_inline_chars(item) == 0: + break + full_md = _build_spill_markdown(item) + rel_body = os.path.join(spill_subdir, run_key, 'bodies', + f'{idx:03d}.md').replace('\\', '/') + abs_body = os.path.normpath( + os.path.join(output_dir, rel_body.replace('/', os.sep))) + os.makedirs(os.path.dirname(abs_body), exist_ok=True) + header = ( + f'\n') + with open(abs_body, 'w', encoding='utf-8') as bf: + bf.write(header + full_md) + + spilled_indices.append(idx) + before_chars = _item_inline_chars(item) + work[idx] = _shrink_item_after_spill(item, spill_preview_chars) + work[idx]['content_spilled'] = True + work[idx]['content_path'] = rel_body + work[idx]['content_chars_spilled'] = before_chars + + preview_src = ( + item.get('content') or item.get('summary') or item.get('abstract') + or '')[:4000] + manifest_rows.append({ + 'index': + idx, + 'url': + item.get('url', ''), + 'title': + item.get('title', ''), + 'body_file': + f'bodies/{idx:03d}.md', + 'content_path': + rel_body, + 'chars_spilled': + before_chars, + 'preview': + _preview(preview_src, min(500, spill_preview_chars)), + }) + + manifest: Dict[str, Any] = { + 'version': + 1, + 'created_at_utc': + time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()), + 'query': + query, + 'engine': + engine, + 'run_key': + run_key, + 'lifecycle': + ('Ephemeral: lives under this task output_dir; delete the task directory ' + 'to remove. ms-agent does not auto-prune.'), + 'inline_chars_before': + total, + 'inline_chars_after': + _total_inline_chars(work), + 'spill_threshold_chars': + spill_max_inline_chars, + 'spilled_row_indices': + spilled_indices, + 'rows': + manifest_rows, + } + rel_manifest = os.path.join(spill_subdir, run_key, 'manifest.json').replace( + '\\', '/') + abs_manifest = os.path.normpath( + os.path.join(output_dir, rel_manifest.replace('/', os.sep))) + with open(abs_manifest, 'w', encoding='utf-8') as mf: + json.dump(manifest, mf, ensure_ascii=False, indent=2) + + lines = [ + 'Large web_search payload was written to disk under this task output_dir.', + f'- **Manifest (map of rows → files, sizes)**: `{rel_manifest}`', + f'- **Bodies**: `{spill_subdir}/{run_key}/bodies/`', + 'Read **manifest.json** first, then **read_file** on specific ' + '`bodies/NNN.md` files as needed.', + '', + '**Quick index**', + ] + for row in manifest_rows: + lines.append( + f'{row["index"]}. {row.get("title") or "(no title)"} — ' + f'`{row["content_path"]}` ({row.get("chars_spilled", 0)} chars)') + digest = '\n'.join(lines) + + spill_meta = { + 'spilled': + True, + 'run_key': + run_key, + 'artifact_dir': + f'{spill_subdir}/{run_key}'.replace('\\', '/'), + 'manifest_path': + rel_manifest, + 'digest': + digest, + 'inline_chars_before_spill': + total, + 'inline_chars_after_spill': + _total_inline_chars(work), + } + return work, spill_meta diff --git a/ms_agent/tools/search/websearch_tool.py b/ms_agent/tools/search/websearch_tool.py index 8dd80f86b..16d6005d2 100644 --- a/ms_agent/tools/search/websearch_tool.py +++ b/ms_agent/tools/search/websearch_tool.py @@ -10,11 +10,13 @@ from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase -from ms_agent.tools.jina_reader import JinaReaderConfig, fetch_single_text +from ms_agent.tools.jina_reader import (JinaReaderConfig, + fetch_single_text_with_meta) from ms_agent.tools.search.content_optimizer import (ContentOptimizer, ContentOptimizerConfig, SearchResultReranker) from ms_agent.tools.search.search_base import ENGINE_TOOL_NAMES, SearchEngine +from ms_agent.tools.search.web_search_spill import maybe_spill_web_search_payload from ms_agent.utils.logger import get_logger from ms_agent.utils.thread_util import DaemonThreadPoolExecutor @@ -23,6 +25,29 @@ MAX_FETCH_CHARS = int(os.getenv('MAX_FETCH_CHARS', 100000)) +def default_per_url_fetch_timeout_s( + fetch_timeout: float, + fetch_retries: int, + direct_fetch_timeout: float, + playwright_timeout_ms: int, +) -> float: + """ + Default asyncio cap for a single URL fetch (Jina → optional direct → optional PW). + + Must exceed a worst-case Jina-only path (``fetch_timeout`` × (retries+1) attempts + plus backoff headroom) before tier-2/3, otherwise logs show all TIMEOUT and no + ``fetch_done``. Clamped to keep runaway tabs bounded. + """ + ft = max(5.0, float(fetch_timeout)) + retries = max(0, int(fetch_retries)) + # Up to (retries+1) attempts each up to ``ft``; 1.35 leaves slack for urllib backoff. + jina_budget = ft * float(retries + 1) * 1.35 + tail = max(10.0, float(direct_fetch_timeout)) + ( + float(playwright_timeout_ms) / 1000.0) + 30.0 + raw = jina_budget + tail + return max(210.0, min(720.0, raw)) + + def _json_dumps(data: Any) -> str: import json return json.dumps(data, ensure_ascii=False, indent=2) @@ -148,10 +173,11 @@ def fetch( url: str, max_chars: Optional[int] = MAX_FETCH_CHARS ) -> Tuple[str, Dict[str, Any]]: - content = fetch_single_text(url, self.config) + content, source_meta = fetch_single_text_with_meta(url, self.config) metadata: Dict[str, Any] = { 'fetcher': 'jina_reader', 'fetched_at': time.strftime('%Y-%m-%dT%H:%M:%S'), + **source_meta, } if max_chars: @@ -171,10 +197,31 @@ def get_content_fetcher(fetcher_type: str = 'jina_reader', """Factory function to get content fetcher by type.""" if fetcher_type == 'jina_reader': config = JinaReaderConfig( - timeout=kwargs.get('timeout', 30.0), + timeout=kwargs.get('timeout', 45.0), retries=kwargs.get('retries', 3), + direct_fetch_fallback=bool(kwargs.get('direct_fetch_fallback', True)), + direct_fetch_timeout=float(kwargs.get('direct_fetch_timeout', 15.0)), + playwright_fetch_fallback=bool( + kwargs.get('playwright_fetch_fallback', True)), + playwright_retry_min_chars=int( + kwargs.get('playwright_retry_min_chars', 400) or 400), + playwright_timeout_ms=int( + kwargs.get('playwright_timeout_ms', 30_000) or 30_000), + playwright_settle_ms=int(kwargs.get('playwright_settle_ms', 350)), ) return JinaContentFetcher(config) + if fetcher_type == 'tavily_extract': + from ms_agent.tools.search.tavily.fetcher import TavilyExtractFetcher + return TavilyExtractFetcher( + api_key=kwargs.get('tavily_api_key'), + extract_depth=str(kwargs.get('tavily_extract_depth', 'advanced')), + format=str(kwargs.get('tavily_extract_format', 'markdown')), + timeout=float(kwargs.get('timeout', 45.0)), + chunks_per_source=int(kwargs.get('tavily_extract_chunks_per_source', 3)), + include_images=bool(kwargs.get('tavily_extract_include_images', False)), + include_favicon=bool(kwargs.get('tavily_extract_include_favicon', False)), + include_usage=bool(kwargs.get('tavily_extract_include_usage', False)), + ) # Future: add more fetchers # elif fetcher_type == 'docling': # return DoclingContentFetcher(**kwargs) @@ -182,7 +229,23 @@ def get_content_fetcher(fetcher_type: str = 'jina_reader', logger.warning( f"Unknown fetcher type '{fetcher_type}', falling back to jina_reader" ) - return JinaContentFetcher() + return JinaContentFetcher( + JinaReaderConfig( + timeout=kwargs.get('timeout', 45.0), + retries=kwargs.get('retries', 3), + direct_fetch_fallback=bool(kwargs.get('direct_fetch_fallback', + True)), + direct_fetch_timeout=float( + kwargs.get('direct_fetch_timeout', 15.0)), + playwright_fetch_fallback=bool( + kwargs.get('playwright_fetch_fallback', True)), + playwright_retry_min_chars=int( + kwargs.get('playwright_retry_min_chars', 400) or 400), + playwright_timeout_ms=int( + kwargs.get('playwright_timeout_ms', 30_000) or 30_000), + playwright_settle_ms=int( + kwargs.get('playwright_settle_ms', 350)), + )) def get_search_engine_class(engine_type: str) -> Type[SearchEngine]: @@ -206,6 +269,9 @@ def get_search_engine_class(engine_type: str) -> Type[SearchEngine]: elif engine_type == 'arxiv': from ms_agent.tools.search.arxiv import ArxivSearch return ArxivSearch + elif engine_type == 'tavily': + from ms_agent.tools.search.tavily import TavilySearch + return TavilySearch else: logger.warning( f"Unknown search engine '{engine_type}', falling back to arxiv") @@ -247,6 +313,12 @@ def get_search_engine(engine_type: str, elif engine_type == 'arxiv': from ms_agent.tools.search.arxiv import ArxivSearch return ArxivSearch() + elif engine_type == 'tavily': + from ms_agent.tools.search.tavily import TavilySearch + return TavilySearch( + api_key=api_key or os.getenv('TAVILY_API_KEY'), + request_timeout=float(kwargs.get('request_timeout', 120.0)), + ) else: logger.warning( f"Unknown search engine '{engine_type}', falling back to arxiv") @@ -271,7 +343,7 @@ def build_search_request(engine_type: str, class WebSearchTool(ToolBase): """ Unified web search tool for agents. It can search the web and fetch page content. - - Search via multiple engines (Exa, SerpAPI, Arxiv) + - Search via multiple engines (Exa, SerpAPI, Arxiv, Tavily) - Dynamic tool definitions based on configured engines - Auto-fetch and parse page content - Configurable content fetcher (jina_reader, docling, etc.) @@ -297,12 +369,14 @@ class WebSearchTool(ToolBase): exa_api_key: $EXA_API_KEY serpapi_api_key: $SERPAPI_API_KEY fetch_content: true + # Optional: asyncio deadline per URL (omit = auto from fetch_timeout/retries). + # per_url_fetch_timeout: 0 """ SERVER_NAME = 'web_search' # Registry of supported search engines - SUPPORTED_ENGINES = ('exa', 'serpapi', 'arxiv') + SUPPORTED_ENGINES = ('exa', 'serpapi', 'arxiv', 'tavily') # Process-wide (class-level) usage tracking for summarization calls. # This is intentionally separate from LLMAgent usage totals. @@ -415,8 +489,31 @@ def __init__(self, config, **kwargs): 'serpapi': (getattr(tool_cfg, 'serpapi_api_key', None) or os.getenv('SERPAPI_API_KEY')) if tool_cfg else os.getenv('SERPAPI_API_KEY'), + 'tavily': (getattr(tool_cfg, 'tavily_api_key', None) + or os.getenv('TAVILY_API_KEY')) if tool_cfg else + os.getenv('TAVILY_API_KEY'), } + # Tavily search defaults from optional `tavily:` sub-block in YAML + self._tavily_defaults: Dict[str, Any] = {} + if tool_cfg is not None: + tv = getattr(tool_cfg, 'tavily', None) + if tv is not None: + try: + from omegaconf import OmegaConf + if OmegaConf.is_config(tv): + self._tavily_defaults = dict( + OmegaConf.to_container(tv, resolve=True)) + elif isinstance(tv, dict): + self._tavily_defaults = dict(tv) + except Exception: + if isinstance(tv, dict): + self._tavily_defaults = dict(tv) + + self._tavily_request_timeout = float( + getattr(tool_cfg, 'tavily_request_timeout', 120.0) + or 120.0) if tool_cfg else 120.0 + # SerpApi provider (google, bing, baidu) self._serpapi_provider = getattr(tool_cfg, 'serpapi_provider', 'google') if tool_cfg else 'google' @@ -429,9 +526,31 @@ def __init__(self, config, **kwargs): self._fetcher_type = getattr( tool_cfg, 'fetcher', 'jina_reader') if tool_cfg else 'jina_reader' self._fetch_timeout = float( - getattr(tool_cfg, 'fetch_timeout', 30) or 30) if tool_cfg else 30.0 + getattr(tool_cfg, 'fetch_timeout', 45) or 45) if tool_cfg else 45.0 self._fetch_retries = int(getattr(tool_cfg, 'fetch_retries', 3) or 3) if tool_cfg else 3 + self._jina_direct_fetch_fallback = bool( + getattr(tool_cfg, 'jina_direct_fetch_fallback', True) + ) if tool_cfg else True + if tool_cfg is not None and hasattr(tool_cfg, 'jina_direct_fetch_timeout'): + self._jina_direct_fetch_timeout = float( + tool_cfg.jina_direct_fetch_timeout) + else: + self._jina_direct_fetch_timeout = 15.0 + self._jina_playwright_fetch_fallback = bool( + getattr(tool_cfg, 'jina_playwright_fetch_fallback', True) + ) if tool_cfg else True + self._jina_playwright_retry_min_chars = int( + getattr(tool_cfg, 'jina_playwright_retry_min_chars', 400) or 400 + ) if tool_cfg else 400 + self._jina_playwright_timeout_ms = int( + getattr(tool_cfg, 'jina_playwright_timeout_ms', 30000) or 30000 + ) if tool_cfg else 30000 + if tool_cfg is not None and hasattr(tool_cfg, 'jina_playwright_settle_ms'): + self._jina_playwright_settle_ms = int( + tool_cfg.jina_playwright_settle_ms) + else: + self._jina_playwright_settle_ms = 350 self._fetch_content_default = bool( getattr(tool_cfg, 'fetch_content', True)) if tool_cfg else True @@ -448,6 +567,20 @@ def __init__(self, config, **kwargs): self._max_concurrent_fetch = int( getattr(tool_cfg, 'max_concurrent_fetch', 3) or 3) if tool_cfg else 3 + # Hard cap (seconds) per URL for asyncio.wait_for around run_in_executor. + # When hit, this URL gets empty content + fetch_error; other URLs in the + # same web_search call keep their already-fetched bodies. Set 0 to disable + # the asyncio cap (underlying urllib/Jina timeouts still apply). + if tool_cfg is not None and hasattr(tool_cfg, 'per_url_fetch_timeout'): + self._per_url_fetch_timeout_s = float( + tool_cfg.per_url_fetch_timeout) + else: + self._per_url_fetch_timeout_s = default_per_url_fetch_timeout_s( + self._fetch_timeout, + self._fetch_retries, + self._jina_direct_fetch_timeout, + self._jina_playwright_timeout_ms, + ) self._max_concurrent_summarization = int( getattr(tool_cfg, 'max_concurrent_summarization', 5) or 5) if tool_cfg else 5 @@ -475,6 +608,19 @@ def __init__(self, config, **kwargs): getattr(tool_cfg, 'summarization_timeout', 90.0) or 90.0) if tool_cfg else 90.0 + # Large payload spill (write bodies to disk; keep JSON small) + self._spill_enabled = bool( + getattr(tool_cfg, 'spill_large_results', True)) if tool_cfg else True + self._spill_max_inline_chars = int( + getattr(tool_cfg, 'spill_max_inline_chars', 120000) + or 120000) if tool_cfg else 120000 + self._spill_subdir = str( + getattr(tool_cfg, 'spill_subdir', 'web_search_artifacts') + or 'web_search_artifacts') if tool_cfg else 'web_search_artifacts' + self._spill_preview_chars = int( + getattr(tool_cfg, 'spill_preview_chars', 600) + or 600) if tool_cfg else 600 + # Reranking config self._enable_rerank = bool(getattr(tool_cfg, 'enable_rerank', False)) if tool_cfg else False @@ -519,6 +665,11 @@ async def connect(self) -> None: api_key=self._api_keys.get('serpapi'), provider=self._serpapi_provider, ) + elif engine_type == 'tavily': + self._engines[engine_type] = engine_cls( + api_key=self._api_keys.get('tavily'), + request_timeout=self._tavily_request_timeout, + ) else: # arxiv self._engines[engine_type] = engine_cls() @@ -530,11 +681,34 @@ async def connect(self) -> None: if not self._engines: raise RuntimeError('No search engines could be initialized') - self._content_fetcher = get_content_fetcher( - self._fetcher_type, - timeout=self._fetch_timeout, - retries=self._fetch_retries, - ) + wcfg = getattr(getattr(self.config, 'tools', None), 'web_search', None) + _fk: Dict[str, Any] = { + 'timeout': self._fetch_timeout, + 'retries': self._fetch_retries, + 'tavily_api_key': self._api_keys.get('tavily'), + 'direct_fetch_fallback': self._jina_direct_fetch_fallback, + 'direct_fetch_timeout': self._jina_direct_fetch_timeout, + 'playwright_fetch_fallback': self._jina_playwright_fetch_fallback, + 'playwright_retry_min_chars': self._jina_playwright_retry_min_chars, + 'playwright_timeout_ms': self._jina_playwright_timeout_ms, + 'playwright_settle_ms': self._jina_playwright_settle_ms, + } + if wcfg is not None: + _fk.update({ + 'tavily_extract_depth': + getattr(wcfg, 'tavily_extract_depth', 'advanced'), + 'tavily_extract_format': + getattr(wcfg, 'tavily_extract_format', 'markdown'), + 'tavily_extract_chunks_per_source': + int(getattr(wcfg, 'tavily_extract_chunks_per_source', 3) or 3), + 'tavily_extract_include_images': + bool(getattr(wcfg, 'tavily_extract_include_images', False)), + 'tavily_extract_include_favicon': + bool(getattr(wcfg, 'tavily_extract_include_favicon', False)), + 'tavily_extract_include_usage': + bool(getattr(wcfg, 'tavily_extract_include_usage', False)), + }) + self._content_fetcher = get_content_fetcher(self._fetcher_type, **_fk) # Use daemon threads: tool-call timeouts can cancel the awaiting coroutine, # but not the underlying sync network calls running in executor threads. self._executor = DaemonThreadPoolExecutor( @@ -727,6 +901,59 @@ async def _fetch_content_async(self, url: str) -> Dict[str, Any]: return await loop.run_in_executor(self._executor, self._fetch_content_sync, url) + def _url_log_preview(self, url: str, max_len: int = 220) -> str: + u = (url or '').strip() + if len(u) <= max_len: + return u + return u[:max_len] + '...' + + async def _fetch_content_async_bounded(self, url: str) -> Dict[str, Any]: + """ + Fetch one URL with optional asyncio deadline and progress logs. + + Executor threads are not cancelled on timeout; the model still receives + partial results for other URLs in the same batch. + """ + u = (url or '').strip() + preview = self._url_log_preview(u) + t0 = time.perf_counter() + cap = float(self._per_url_fetch_timeout_s) + logger.info('[web_search] fetch start url=%s', preview) + try: + if cap > 0: + out = await asyncio.wait_for( + self._fetch_content_async(u), + timeout=cap, + ) + else: + out = await self._fetch_content_async(u) + except asyncio.TimeoutError: + elapsed = time.perf_counter() - t0 + logger.warning( + '[web_search] fetch TIMEOUT url=%s elapsed=%.1fs cap=%.1fs — ' + 'this URL is dropped for this response; others are unchanged', + preview, elapsed, cap) + return { + 'url': u, + 'content': '', + 'fetch_success': False, + 'fetcher': 'web_search', + 'fetch_error': f'per_url_fetch_timeout ({cap:g}s)', + 'fetch_timed_out': True, + } + elapsed = time.perf_counter() - t0 + src = (out or {}).get('content_source') or (out or {}).get( + 'fetcher', '') or '' + ok = bool((out or {}).get('fetch_success')) + logger.info( + '[web_search] fetch done url=%s elapsed=%.2fs ok=%s source=%s', + preview, elapsed, ok, src) + return out if out is not None else { + 'url': u, + 'content': '', + 'fetch_success': False, + } + async def _fetch_multiple_async(self, urls: List[str]) -> List[Dict[str, Any]]: """Fetch multiple URLs concurrently with semaphore.""" @@ -734,23 +961,41 @@ async def _fetch_multiple_async(self, async def _bounded_fetch(url: str) -> Dict[str, Any]: async with semaphore: - return await self._fetch_content_async(url) + return await self._fetch_content_async_bounded(url) tasks = [_bounded_fetch(url) for url in urls] return await asyncio.gather(*tasks) - def _do_search(self, engine_type: str, engine: SearchEngine, - engine_cls: Type[SearchEngine], - tool_args: Dict[str, Any]) -> List[Dict[str, Any]]: - """Perform search using the specified engine and return raw results.""" + def _do_search( + self, engine_type: str, engine: SearchEngine, + engine_cls: Type[SearchEngine], + tool_args: Dict[str, Any] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """Perform search; returns (result rows, extra top-level metadata e.g. Tavily).""" try: - # Build request using engine's method - request = engine_cls.build_request_from_args(**tool_args) + merged = dict(tool_args) + if engine_type == 'tavily' and getattr(self, '_tavily_defaults', + None): + merged = {**self._tavily_defaults, **merged} + # Keys only for engine / fetcher YAML, not TavilySearchRequest + for _k in ('request_timeout', 'tavily_extract_depth', + 'tavily_extract_format', + 'tavily_extract_chunks_per_source', + 'tavily_extract_include_images', + 'tavily_extract_include_favicon', + 'tavily_extract_include_usage'): + merged.pop(_k, None) + request = engine_cls.build_request_from_args(**merged) result = engine.search(request) - return result.to_list() + rows = result.to_list() + extra: Dict[str, Any] = {} + from ms_agent.tools.search.tavily.schema import TavilySearchResult + if isinstance(result, TavilySearchResult): + extra = result.extra_response_fields() + return rows, extra except Exception as e: logger.error(f'Search failed ({engine_type}): {e}') - return [] + return [], {} async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> str: @@ -774,6 +1019,8 @@ async def _execute_search(self, engine_type: str, 'message': 'Query is required.' }) + call_id_for_spill = str(tool_args.pop('__call_id', '') or '') + # Get fetch_content preference, default to configured value fetch_content = tool_args.pop('fetch_content', self._fetch_content_default) @@ -798,20 +1045,22 @@ async def _execute_search(self, engine_type: str, # Perform search loop = asyncio.get_event_loop() - search_results = await loop.run_in_executor(self._executor, - self._do_search, - engine_type, engine, - engine_cls, tool_args) + search_results, tavily_extra = await loop.run_in_executor( + self._executor, self._do_search, engine_type, engine, engine_cls, + tool_args) if not search_results: - return _json_dumps({ + out_empty: Dict[str, Any] = { 'status': 'ok', 'query': query, 'engine': engine_type, 'count': 0, 'results': [], 'message': 'No search results found.', - }) + } + if tavily_extra: + out_empty.update(tavily_extra) + return _json_dumps(out_empty) original_count = len(search_results) @@ -827,13 +1076,25 @@ async def _execute_search(self, engine_type: str, f'Reranked {original_count} results to top {len(search_results)} ' f'for query: {query[:50]}...') - # Step 3: Fetch content for (filtered) results + # Step 3: Fetch content for (filtered) results (skip URLs already filled e.g. Tavily raw_content) + fetch_attempts = 0 + fetch_timeouts = 0 if fetch_content and self._content_fetcher: search_results = SearchResultReranker.deduplicate_by_url( search_results) - urls = [r.get('url') for r in search_results if r.get('url')] + urls: List[str] = [] + for r in search_results: + u = r.get('url') + if not u: + continue + if r.get('fetch_success') and (r.get('content') or '').strip(): + continue + urls.append(u) if urls: + fetch_attempts = len(urls) fetch_results = await self._fetch_multiple_async(urls) + fetch_timeouts = sum( + 1 for r in fetch_results if r.get('fetch_timed_out')) # Merge search metadata with fetched content url_to_fetch = {r['url']: r for r in fetch_results} @@ -844,6 +1105,16 @@ async def _execute_search(self, engine_type: str, sr['content'] = fetched.get('content', '') sr['fetch_success'] = fetched.get( 'fetch_success', False) + if fetched.get('fetch_error'): + sr['fetch_error'] = fetched['fetch_error'] + else: + sr.pop('fetch_error', None) + if fetched.get('fetch_timed_out'): + sr['fetch_timed_out'] = True + else: + sr.pop('fetch_timed_out', None) + if fetched.get('content_source'): + sr['content_source'] = fetched['content_source'] if fetched.get('published_at' ) and not sr.get('published_date'): sr['published_at'] = fetched['published_at'] @@ -990,6 +1261,8 @@ async def _execute_search(self, engine_type: str, sr.get('arxiv_id', '') or '', # arXiv short id 'abs_url': sr.get('id', '') or '', # entry_id (abstract page) + 'pdf_url': + sr.get('pdf_url', '') or '', 'abstract': sr.get('summary', '') or '', 'authors': @@ -1005,6 +1278,12 @@ async def _execute_search(self, engine_type: str, if fetch_content: item['content'] = sr.get('content', '') item['fetch_success'] = sr.get('fetch_success', False) + if sr.get('fetch_error'): + item['fetch_error'] = sr.get('fetch_error') + if sr.get('fetch_timed_out'): + item['fetch_timed_out'] = True + if sr.get('content_source'): + item['content_source'] = sr.get('content_source') # Add summarization metadata if applicable if sr.get('content_summarized'): item['content_summarized'] = True @@ -1017,6 +1296,14 @@ async def _execute_search(self, engine_type: str, # Include snippet if available for non-arxiv engines item['summary'] = sr.get('summary', '') + if engine_type == 'tavily': + if sr.get('score') is not None: + item['score'] = sr.get('score') + if sr.get('tavily_images'): + item['images'] = sr.get('tavily_images') + if sr.get('favicon'): + item['favicon'] = sr.get('favicon') + # Add item to results for all engines output_results.append(item) @@ -1028,6 +1315,17 @@ async def _execute_search(self, engine_type: str, 'count': len(output_results), 'results': output_results, } + if fetch_content and self._content_fetcher: + response['fetch_stats'] = { + 'per_url_timeout_s': + self._per_url_fetch_timeout_s, + 'urls_fetched_this_call': + fetch_attempts, + 'urls_timed_out': + fetch_timeouts, + } + if tavily_extra: + response.update(tavily_extra) # Add optimization info if self._enable_rerank or self._enable_summarization: @@ -1070,6 +1368,29 @@ async def _execute_search(self, engine_type: str, 'summarization_usage_process_total' ] = WebSearchTool.get_global_summarization_usage() # yapf: disable + if self._spill_enabled: + od = getattr(self, 'output_dir', None) or getattr( + getattr(self, 'config', None), 'output_dir', '') or '' + if od: + try: + new_results, spill_meta = maybe_spill_web_search_payload( + output_dir=od, + spill_subdir=self._spill_subdir, + spill_max_inline_chars=self._spill_max_inline_chars, + spill_preview_chars=self._spill_preview_chars, + query=query, + engine=engine_type, + results=response['results'], + call_id=call_id_for_spill, + ) + if spill_meta: + response['results'] = new_results + response['spill'] = spill_meta + except Exception as e: + logger.warning( + f'web_search spill failed (returning full inline JSON): {e}' + ) + return _json_dumps(response) async def fetch_page(self, url: str) -> str: @@ -1080,7 +1401,7 @@ async def fetch_page(self, url: str) -> str: 'message': 'URL is required.' }) - result = await self._fetch_content_async(url.strip()) + result = await self._fetch_content_async_bounded(url.strip()) return _json_dumps({ 'status': @@ -1093,6 +1414,10 @@ async def fetch_page(self, url: str) -> str: result.get('published_at', ''), 'fetch_success': result.get('fetch_success', False), + 'fetch_error': + result.get('fetch_error', ''), + 'fetch_timed_out': + bool(result.get('fetch_timed_out')), 'chunks': result.get('chunks') if self._enable_chunking else None, }) @@ -1133,3 +1458,7 @@ async def arxiv_search(self, **kwargs) -> str: async def serpapi_search(self, **kwargs) -> str: """Search using SerpApi engine.""" return await self._execute_search('serpapi', kwargs) + + async def tavily_search(self, **kwargs) -> str: + """Search using Tavily engine.""" + return await self._execute_search('tavily', kwargs) diff --git a/projects/deep_research/.env.example b/projects/deep_research/.env.example index 59ec0d03b..8f0552b68 100644 --- a/projects/deep_research/.env.example +++ b/projects/deep_research/.env.example @@ -1,5 +1,7 @@ EXA_API_KEY=xxx SERPAPI_API_KEY=xxx +# Tavily Search / Extract (optional; set when using engines: [tavily] or fetcher: tavily_extract) +TAVILY_API_KEY=tvly-xxx OPENAI_API_KEY=xxx OPENAI_BASE_URL=https://your-openai-compatible-endpoint/v1 diff --git a/projects/deep_research/v2/researcher.tavily.yaml b/projects/deep_research/v2/researcher.tavily.yaml new file mode 100644 index 000000000..e5423c4fc --- /dev/null +++ b/projects/deep_research/v2/researcher.tavily.yaml @@ -0,0 +1,172 @@ +llm: + service: openai + model: gpt-5.2-2025-12-11 + openai_api_key: + openai_base_url: + + +generation_config: + stream: true + stream_options: + include_usage: true + # Enable explicit prefix caching (auto-detects provider from openai_base_url) + force_prefix_cache: false + # Supports role names: system, user, assistant, tool, last_message + prefix_cache_roles: [system, user, assistant, tool] + # extra_body: + # enable_thinking: true + # show_reasoning: true + # reasoning_output: stdout + reasoning_effort: medium + + +tag: deep-research-researcher + + +prompt: + root: prompts/ + agent: researcher + lang: en + family: thinking + + +tools: + file_system: + mcp: false + include: + - write_file + - read_file + - list_files + - search_file_content + - replace_file_contents + - replace_file_lines + code_executor: + mcp: false + implementation: python_env + notebook_timeout: 120 + include: + - notebook_executor + todo_list: + mcp: false + auto_render_md: true + include: + - todo_write + - todo_read + evidence_store: + mcp: false + evidence_dir: evidence + include: + - load_index + - get_note + - list_notes + - write_analysis + - get_analysis + - list_analyses + agent_tools: + mcp: false + enable_stats: true + run_in_thread: true + run_in_process: true + max_workers: 4 + definitions: + - tool_name: searcher_tool + description: > + Invoke the Searcher sub-agent to perform an in-depth research task on a specific topic. + Searcher is capable of autonomously executing a research loop until sufficient evidence is collected and a research report is produced (search -> parse -> evidence discovery & storage -> progressive search -> ...). + Returns a JSON result containing: task completion status, core findings, issues or limitations encountered, research report body, and evidence storage locations. + config_path: searcher.tavily.yaml + parameters: + type: object + properties: + request: + type: string + description: > + A JSON-formatted research task description that should include: + - The corresponding task ID from the TODO list (required) + - Specific research objectives + - Questions to be answered + - Constraints (time range, source preferences, etc., optional) + - Stopping conditions (optional) + - Other requirements (optional) + Recommended format: + { + "task_id": "...", + "research_objectives": "...", + "questions_to_answer": "...", + "constraints": "...", + "stopping_conditions": "...", + "other_requirements": "...", + } + required: [request] + additionalProperties: false + trust_remote_code: true + output_mode: final_message + max_output_chars: 200000 + - tool_name: reporter_tool + description: > + Invoke the Reporter sub-agent to generate a research report based on collected evidence. + Reporter reads the stored evidence cards and executes a complex workflow for research report writing. + The completed report is automatically saved to `final_report.md` in the output directory by the system. + Returns a JSON result containing: execution summary and + intermediate artifact file paths (the full report body is NOT included in the return value — + read `final_report.md` directly to access the report content). + config_path: reporter.yaml + parameters: + type: object + properties: + request: + type: string + description: > + A JSON-formatted report generation instruction that should include: + - Report topic and target audience + - Complete background description and task description + - Core questions to be covered + - Writing requirements (style, structure, length, language, etc.) + - Any other requirements + Recommended format: + { + "report_topic_and_audience": "...", + "background": "...", + "task_description": "...", + "writing_requirements": "...", + "other_requirements": "...", + } + required: [request] + additionalProperties: false + trust_remote_code: true + output_mode: final_message + max_output_chars: 200000 + plugins: + - tools/evidence_tool.py + + +callbacks: + - callbacks/researcher_callback + +# Self-reflection checks before allowing the researcher to stop. +self_reflection: + enabled: true + max_retries: 3 + compression_check: + enabled: true + min_retention_ratio: 0.5 + report_selection: + enabled: true + min_retention_ratio: 0.5 + report_cleanup: + enabled: false + quality_check: + enabled: true + model: qwen3.5-flash + openai_base_url: + openai_api_key: + +handler: time_handler + +code_file: researcher + +max_chat_round: 45 + +tool_call_timeout: 2600 + +output_dir: ./output diff --git a/projects/deep_research/v2/run_benchmark.sh b/projects/deep_research/v2/run_benchmark.sh index c4bd326fb..d4c147f84 100755 --- a/projects/deep_research/v2/run_benchmark.sh +++ b/projects/deep_research/v2/run_benchmark.sh @@ -30,6 +30,7 @@ else echo -e "${RED}Error: Neither 'python' nor 'python3' is available in PATH.${NC}" exit 1 fi +PYTHON_BIN="/Users/luyan/software/miniconda3/bin/python" # When stdout is redirected (e.g., nohup > file), Python is block-buffered by default. # Force unbuffered output so progress lines like "[xx] OK" show up in logs promptly. @@ -91,19 +92,22 @@ if [ -z "$DR_BENCH_ROOT" ]; then echo -e "${YELLOW}Using default benchmark query...${NC}" echo "" - # Run a simple benchmark query - QUERY="Provide a comprehensive survey of recent advances in large language models (LLMs), covering key developments in the last 12 months including architecture innovations, training techniques, and real-world applications." - OUTPUT_DIR="output/deep_research/benchmark_run" + # Run a simple benchmark query (override with BENCH_QUERY for smoke tests) + DEFAULT_QUERY="Provide a comprehensive survey of recent advances in large language models (LLMs), covering key developments in the last 12 months including architecture innovations, training techniques, and real-world applications." + QUERY="${BENCH_QUERY:-$DEFAULT_QUERY}" + OUTPUT_DIR="${BENCH_OUTPUT_DIR:-output/deep_research/benchmark_run}" echo -e "${GREEN}Running benchmark with query:${NC}" echo " \"$QUERY\"" echo "" echo -e "${GREEN}Output directory: $OUTPUT_DIR${NC}" + RESEARCHER_CONFIG="${RESEARCHER_CONFIG:-projects/deep_research/v2/researcher.yaml}" + echo -e "${GREEN}Researcher config: $RESEARCHER_CONFIG${NC}" echo "" - # Run the benchmark + # Run the benchmark (override RESEARCHER_CONFIG / BENCH_OUTPUT_DIR as needed) PYTHONPATH=. "$PYTHON_BIN" -u ms_agent/cli/cli.py run \ - --config projects/deep_research/v2/researcher.yaml \ + --config "$RESEARCHER_CONFIG" \ --query "$QUERY" \ --trust_remote_code true \ --output_dir "$OUTPUT_DIR" @@ -165,6 +169,8 @@ else echo " Work root: $WORK_ROOT" echo " Workers: $WORKERS" echo " Limit: $LIMIT (0 = no limit)" + RESEARCHER_CONFIG="${RESEARCHER_CONFIG:-projects/deep_research/v2/researcher.yaml}" + echo " Researcher config: $RESEARCHER_CONFIG" echo "" # Run the full benchmark @@ -172,6 +178,7 @@ else --query_file "$QUERY_FILE" \ --output_jsonl "$OUTPUT_JSONL" \ --model_name "$MODEL_NAME" \ + --config "$RESEARCHER_CONFIG" \ --work_root "$WORK_ROOT" \ --limit "$LIMIT" \ --workers "$WORKERS" \ diff --git a/projects/deep_research/v2/searcher.tavily.yaml b/projects/deep_research/v2/searcher.tavily.yaml new file mode 100644 index 000000000..e8d632745 --- /dev/null +++ b/projects/deep_research/v2/searcher.tavily.yaml @@ -0,0 +1,105 @@ +# Tavily-only web search (no Exa / no Jina Reader). +# Use: cp projects/deep_research/v2/searcher.tavily.yaml projects/deep_research/v2/searcher.yaml +# Or point researcher.yaml searcher_tool config_path to this file. +# Requires TAVILY_API_KEY in .env. + +llm: + service: openai + model: qwen3.5-plus + openai_api_key: + openai_base_url: + + +generation_config: + stream: true + stream_options: + include_usage: true + force_prefix_cache: true + prefix_cache_roles: [system, user, assistant, tool] + extra_body: + enable_thinking: false + + +tag: deep-research + + +prompt: + root: prompts/ + agent: searcher + lang: en + family: gpt5 + + +tools: + file_system: + mcp: false + include: + - write_file + - read_file + - list_files + web_search: + mcp: false + engines: + - tavily + tavily_api_key: + max_results: 10 + fetcher: tavily_extract + fetch_content: true + fetch_timeout: 60 + fetch_retries: 3 + per_url_fetch_timeout: 400 + _max_concurrent_fetch: 5 + tavily: + search_depth: advanced + include_raw_content: markdown + include_answer: advanced + chunks_per_source: 3 + max_results: 10 + tavily_request_timeout: 120 + tavily_extract_depth: advanced + tavily_extract_format: markdown + enable_chunking: false + # Off by default: Tavily raw_content is enough; summarizer may hit provider filters. + enable_summarization: false + summarizer_model: qwen3.5-flash + summarizer_base_url: + summarizer_api_key: + max_content_chars: 200000 + summarizer_max_workers: 15 + summarization_timeout: 360 + _max_concurrent_summarization: 15 + # Optional: spill oversized tool JSON to output_dir/web_search_artifacts/… + # spill_large_results: true + # spill_max_inline_chars: 120000 + # spill_preview_chars: 600 + # spill_subdir: web_search_artifacts + evidence_store: + mcp: false + evidence_dir: evidence + chunks_dir: chunks + enable_chunk_storage: false + include: + - write_note + - list_notes + - get_note + - load_index + - search_notes + - delete_note + plugins: + - tools/evidence_tool.py + + +handler: time_handler + +callbacks: + - callbacks/searcher_callback + +max_chat_round: 30 + +round_reminder: + enabled: true + remind_at_round: 28 + +tool_call_timeout: 2600 + +output_dir: ./output From cb75af31c6fdc4463e4a61fc2fcde8a467a884ea Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 20 Apr 2026 16:59:59 +0800 Subject: [PATCH 33/40] fix(jina): align reader with websearch (meta fetch + playwright fallback) WebSearchTool imports fetch_single_text_with_meta; add tiered fetch helpers and optional Playwright fallback module used by jina_reader. Made-with: Cursor --- ms_agent/tools/fetch_playwright_fallback.py | 191 ++++++++++++++++++++ ms_agent/tools/jina_reader.py | 173 +++++++++++++++++- 2 files changed, 355 insertions(+), 9 deletions(-) create mode 100644 ms_agent/tools/fetch_playwright_fallback.py diff --git a/ms_agent/tools/fetch_playwright_fallback.py b/ms_agent/tools/fetch_playwright_fallback.py new file mode 100644 index 000000000..0ff89a91d --- /dev/null +++ b/ms_agent/tools/fetch_playwright_fallback.py @@ -0,0 +1,191 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +Optional headless Chromium fetch for URLs where Jina + direct HTTP yield empty +or obviously client-rendered shells. + +Requires: ``pip install playwright`` and ``playwright install chromium``. +If Playwright is not installed, helpers return empty string without raising. + +Performance: Playwright ``Browser`` must be used from the creating thread only. +We keep **one browser per thread** (e.g. each ``ThreadPoolExecutor`` worker) and +reuse it across URLs instead of launching Chromium for every fetch. +""" +from __future__ import annotations + +import atexit +import os +import re +import threading +from typing import Dict, List, Tuple + +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +_MAX_INNER_TEXT_CHARS = 100_000 + +_tls = threading.local() +_registry_lock = threading.Lock() +# thread_id -> (sync_playwright handle, Browser) for atexit cleanup +_pw_by_thread: Dict[int, Tuple[object, object]] = {} + + +def _chromium_launch_args() -> List[str]: + args: List[str] = [ + '--disable-extensions', + '--blink-settings=imagesEnabled=false', + ] + if os.getenv('MS_AGENT_PLAYWRIGHT_NO_SANDBOX', '').lower() in ( + '1', + 'true', + 'yes', + ): + args.extend(('--no-sandbox', '--disable-setuid-sandbox')) + return args + + +def _invalidate_thread_playwright_unlocked() -> None: + """Drop thread-local Playwright; caller must hold no registry lock if mutating _pw_by_thread.""" + tid = threading.get_ident() + pw = getattr(_tls, 'pw', None) + br = getattr(_tls, 'browser', None) + _tls.pw = None + _tls.browser = None + with _registry_lock: + _pw_by_thread.pop(tid, None) + if br is not None: + try: + br.close() + except Exception: + pass + if pw is not None: + try: + pw.stop() + except Exception: + pass + + +def _atexit_close_all_playwright() -> None: + with _registry_lock: + items = list(_pw_by_thread.values()) + _pw_by_thread.clear() + for pw, browser in items: + try: + browser.close() + except Exception: + pass + try: + pw.stop() + except Exception: + pass + + +atexit.register(_atexit_close_all_playwright) + + +def _thread_browser() -> object: + """Return a Chromium ``Browser`` for this thread, creating it lazily.""" + br = getattr(_tls, 'browser', None) + if br is not None: + try: + if br.is_connected(): + return br + except Exception: + pass + _invalidate_thread_playwright_unlocked() + br = None + + try: + from playwright.sync_api import sync_playwright + except ImportError: + logger.debug( + 'playwright is not installed; skip headless fetch. ' + 'Install with: pip install playwright && playwright install chromium') + raise RuntimeError('playwright not installed') from None + + pw = sync_playwright().start() + browser = pw.chromium.launch( + headless=True, + args=_chromium_launch_args(), + ) + _tls.pw = pw + _tls.browser = browser + with _registry_lock: + _pw_by_thread[threading.get_ident()] = (pw, browser) + return browser + + +def try_playwright_inner_text( + url: str, + timeout_ms: int, + *, + settle_ms: int = 350, +) -> str: + """ + Load URL in headless Chromium and return ``document.body.innerText``. + + Reuses one browser per thread. Returns empty string on missing dependency, + timeout, or navigation error. + """ + if not url.startswith(('http://', 'https://')): + return '' + settle_ms = max(0, int(settle_ms)) + try: + from playwright.sync_api import sync_playwright # noqa: F401 + except ImportError: + logger.debug( + 'playwright is not installed; skip headless fetch. ' + 'Install with: pip install playwright && playwright install chromium') + return '' + + text = '' + try: + browser = _thread_browser() + page = browser.new_page() + try: + page.set_default_timeout(timeout_ms) + page.goto(url, wait_until='domcontentloaded', timeout=timeout_ms) + if settle_ms: + page.wait_for_timeout(settle_ms) + raw = page.evaluate( + """() => { + const b = document.body; + if (!b) return ''; + return b.innerText || ''; + }""" + ) + if isinstance(raw, str): + text = raw[:_MAX_INNER_TEXT_CHARS] + finally: + try: + page.close() + except Exception: + pass + except RuntimeError: + return '' + except Exception as e: + logger.debug(f'Playwright fetch failed for {url[:80]!r}: {e}') + try: + _invalidate_thread_playwright_unlocked() + except Exception: + pass + return '' + + return text + + +def looks_like_spa_shell_html(raw_html: str) -> bool: + """Heuristic: HTML suggests JS-only app or empty mount root.""" + if not raw_html or len(raw_html) < 80: + return False + low = raw_html.lower() + if any( + x in low + for x in ('enable javascript', 'javascript is required', + 'you need to enable javascript')): + return True + if re.search(r']+\bid=["\']root["\'][^>]*>\s*', low): + return True + if re.search(r']+\bid=["\']app["\'][^>]*>\s*', low): + return True + return False diff --git a/ms_agent/tools/jina_reader.py b/ms_agent/tools/jina_reader.py index d8962f617..b3f663971 100644 --- a/ms_agent/tools/jina_reader.py +++ b/ms_agent/tools/jina_reader.py @@ -1,13 +1,21 @@ import asyncio +import html as html_module import random +import re import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from urllib.error import HTTPError, URLError -from urllib.parse import quote +from urllib.parse import quote, urlparse from urllib.request import Request, urlopen +from ms_agent.tools.fetch_playwright_fallback import (looks_like_spa_shell_html, + try_playwright_inner_text) +from ms_agent.utils.logger import get_logger + +logger = get_logger() + DEFAULT_HEADERS: Dict[str, str] = { 'User-Agent': 'Mozilla/5.0 (compatible; ms-agent/1.0; +https://example.com)', @@ -15,16 +23,36 @@ 'Accept-Language': 'en-US,en;q=0.9', } +# Cap body size for direct HTTP fallback (same order of magnitude as MAX_FETCH_CHARS). +_MAX_DIRECT_RESPONSE_BYTES = 10 * 1024 * 1024 + +_DIRECT_FETCH_HEADERS: Dict[str, str] = { + 'User-Agent': DEFAULT_HEADERS['User-Agent'], + 'Accept': + 'text/html,application/xhtml+xml,application/xml;q=0.9,text/plain;q=0.8,*/*;q=0.7', + 'Accept-Language': DEFAULT_HEADERS['Accept-Language'], +} + @dataclass class JinaReaderConfig: base_endpoint: str = 'https://r.jina.ai/' - timeout: float = 30.0 + timeout: float = 45.0 retries: int = 3 backoff_base: float = 0.8 backoff_max: float = 8.0 headers: Dict[str, str] = field(default_factory=lambda: DEFAULT_HEADERS.copy()) + # When Jina Reader returns empty after retries, try HTTP GET on the target URL. + direct_fetch_fallback: bool = True + # Tier 2 (urllib): shorter than Jina timeout — fail fast on slow origins. + direct_fetch_timeout: float = 15.0 + # Tier 3: headless Chromium when direct body is empty/short or looks like a JS shell. + playwright_fetch_fallback: bool = True + playwright_retry_min_chars: int = 400 + playwright_timeout_ms: int = 30_000 + # After domcontentloaded, brief wait for client hydration (lower = faster). + playwright_settle_ms: int = 350 def _build_reader_url(target_url: str, base_endpoint: str) -> str: @@ -50,10 +78,85 @@ def _postprocess_text(raw_text: str) -> str: return text.strip() -def fetch_single_text(url: str, config: JinaReaderConfig) -> str: +def _is_direct_http_allowed(url: str) -> bool: + try: + parsed = urlparse(url) + if parsed.scheme not in ('http', 'https'): + return False + if not parsed.netloc: + return False + return True + except Exception: + return False + + +def _html_to_plaintext(html: str) -> str: + """Best-effort HTML → text without extra dependencies.""" + text = re.sub(r'(?is)]*>.*?', ' ', html) + text = re.sub(r'(?is)]*>.*?', ' ', text) + text = re.sub(r'(?is)]*>.*?', ' ', text) + text = re.sub( + r'(?i)', + '\n', + text, + ) + text = re.sub(r'', '\n', text, flags=re.IGNORECASE) + text = re.sub(r'<[^>]+>', ' ', text) + text = html_module.unescape(text) + text = re.sub(r'[ \t\f\v]+', ' ', text) + text = re.sub(r'\n{3,}', '\n\n', text) + return text.strip() + + +# Snippet size for SPA heuristics (avoid holding multi‑MB strings in memory). +_DIRECT_HTML_HEURISTIC_CAP = 120_000 + + +def _fetch_direct_http_pair(url: str, timeout: float) -> Tuple[str, str]: """ - Synchronous fetch of a single URL via Jina Reader with retry/backoff and postprocessing. + Fetch the target URL over HTTP(S) without Jina. + + Returns: + (plaintext, raw_html_snippet) — ``raw_html_snippet`` is non-empty only when + the response was treated as HTML (used for shell / length heuristics). """ + if not _is_direct_http_allowed(url): + return '', '' + try: + req = Request(url, headers=_DIRECT_FETCH_HEADERS) + with urlopen(req, timeout=timeout) as resp: + raw = resp.read(_MAX_DIRECT_RESPONSE_BYTES + 1) + if len(raw) > _MAX_DIRECT_RESPONSE_BYTES: + raw = raw[:_MAX_DIRECT_RESPONSE_BYTES] + charset = resp.headers.get_content_charset() or 'utf-8' + content_type = (resp.headers.get('Content-Type') or '').lower() + content_type_main = content_type.split(';')[0].strip() + text = raw.decode(charset, errors='replace') + if 'html' in content_type_main or text.lstrip().lower().startswith( + ' bool: + """Whether tier-3 headless fetch is worth attempting.""" + p = plain.strip() + if raw_html: + if looks_like_spa_shell_html(raw_html): + return True + if len(p) < min_chars: + return True + return False + return not bool(p) + + +def _fetch_via_jina(url: str, config: JinaReaderConfig) -> str: + """Jina Reader only; returns empty string on failure.""" request_url = _build_reader_url(url, config.base_endpoint) attempt = 0 while True: @@ -62,10 +165,8 @@ def fetch_single_text(url: str, config: JinaReaderConfig) -> str: req = Request(request_url, headers=config.headers) with urlopen(req, timeout=config.timeout) as resp: data = resp.read() - return _postprocess_text( - data.decode('utf-8', errors='replace')) + return data.decode('utf-8', errors='replace') except HTTPError as e: - # Retry on 429 and 5xx, otherwise fail fast status = getattr(e, 'code', None) if status in (429, 500, 502, 503, 504) and attempt <= config.retries: @@ -84,7 +185,6 @@ def fetch_single_text(url: str, config: JinaReaderConfig) -> str: continue return '' except Exception: - # Unknown error; do not loop excessively if attempt <= config.retries: sleep_s = min(config.backoff_max, config.backoff_base * (2**(attempt - 1))) @@ -94,6 +194,61 @@ def fetch_single_text(url: str, config: JinaReaderConfig) -> str: return '' +def fetch_single_text_with_meta(url: str, + config: JinaReaderConfig) -> Tuple[str, Dict[str, Any]]: + """ + Tiered fetch: Jina Reader → direct HTTP → optional Playwright (empty / short / SPA shell). + + Returns: + (text, meta) where ``meta['content_source']`` is one of: + ``jina_reader`` | ``direct_http_fallback`` | ``playwright_fallback`` | ``none``. + """ + jina_raw = _fetch_via_jina(url, config) + jina_text = _postprocess_text(jina_raw) + if jina_text: + return jina_text, {'content_source': 'jina_reader'} + if not config.direct_fetch_fallback: + return '', {'content_source': 'none'} + d_timeout = (float(config.timeout) if float(config.direct_fetch_timeout or 0) + <= 0 else float(config.direct_fetch_timeout)) + direct_plain, raw_html = _fetch_direct_http_pair(url, d_timeout) + direct_text = _postprocess_text(direct_plain) + + try_playwright = ( + bool(config.playwright_fetch_fallback) and _is_direct_http_allowed(url) + and _should_try_playwright_after_direct(direct_text, raw_html, + config.playwright_retry_min_chars)) + + if try_playwright: + pw_text = _postprocess_text( + try_playwright_inner_text( + url, + int(config.playwright_timeout_ms), + settle_ms=int(config.playwright_settle_ms), + )) + if pw_text.strip(): + logger.info( + 'Using headless Chromium fallback after Jina/direct HTTP ' + f'(url prefix): {url[:80]}') + return pw_text, {'content_source': 'playwright_fallback'} + + if direct_text: + logger.info( + 'Jina Reader returned no body for URL; using direct HTTP fallback ' + f'(url prefix): {url[:80]}') + return direct_text, {'content_source': 'direct_http_fallback'} + return '', {'content_source': 'none'} + + +def fetch_single_text(url: str, config: JinaReaderConfig) -> str: + """ + Synchronous fetch of a single URL via Jina Reader with retry/backoff, + then optional direct HTTP fallback when Jina yields empty. + """ + text, _meta = fetch_single_text_with_meta(url, config) + return text + + async def fetch_texts_via_jina( urls: List[str], config: Optional[JinaReaderConfig] = None, From 18a56832a7647255c8ae10dc9be7614a5ab2ba2c Mon Sep 17 00:00:00 2001 From: suluyan Date: Tue, 21 Apr 2026 18:00:47 +0800 Subject: [PATCH 34/40] chore(tools): use origin/feat/git filesystem_tool on bench/tavily-0413 Replace bench-specific filesystem_tool with feat/git version; accept behavior differences vs prior tavily bench worktree. Made-with: Cursor --- ms_agent/tools/filesystem_tool.py | 1219 +++++++++-------------------- 1 file changed, 360 insertions(+), 859 deletions(-) diff --git a/ms_agent/tools/filesystem_tool.py b/ms_agent/tools/filesystem_tool.py index 0440941f3..df01d0261 100644 --- a/ms_agent/tools/filesystem_tool.py +++ b/ms_agent/tools/filesystem_tool.py @@ -1,19 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import fnmatch +import base64 import os -import re -import shutil from concurrent.futures import ThreadPoolExecutor, as_completed -from pathlib import Path -from typing import Optional import json from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool from ms_agent.tools.base import ToolBase -from ms_agent.utils import MAX_CONTINUE_RUNS, get_logger, retry +from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_INDEX_DIR, DEFAULT_OUTPUT_DIR -from openai import OpenAI logger = get_logger() @@ -21,12 +16,13 @@ class FileSystemTool(ToolBase): """A file system operation tool""" - # Directories to exclude from file operations - EXCLUDED_DIRS = { - 'node_modules', 'dist', '.git', '__pycache__', '.venv', 'venv' + MAX_READ_BYTES = 10 * 1024 * 1024 # 10 MB per file + IMAGE_EXTENSIONS = frozenset({'png', 'jpg', 'jpeg', 'gif', 'webp'}) + # Curly quote → straight quote mapping for fuzzy matching + CURLY_QUOTE_MAP = { + '\u2018': "'", '\u2019': "'", # ' ' + '\u201c': '"', '\u201d': '"', # " " } - # File prefixes to exclude - EXCLUDED_FILE_PREFIXES = ('.', '..', '__pycache__') SYSTEM_FOR_ABBREVIATIONS = """你是一个帮我简化文件信息并返回缩略的机器人,你需要根据输入文件内容来生成压缩过的文件内容。 @@ -46,13 +42,6 @@ def __init__(self, config, **kwargs): super().__init__(config) self.exclude_func(getattr(config.tools, 'file_system', None)) self.output_dir = getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR) - if self.exclude_functions and 'edit_file' not in self.exclude_functions \ - or self.include_functions and 'edit_file' in self.include_functions: - self.edit_file_config = getattr(config.tools.file_system, - 'edit_file_config', None) - self.edit_client = OpenAI( - api_key=self.edit_file_config.api_key, - base_url=self.edit_file_config.base_url) self.trust_remote_code = kwargs.get('trust_remote_code', False) self.allow_read_all_files = getattr( getattr(config.tools, 'file_system', {}), 'allow_read_all_files', @@ -66,6 +55,8 @@ def __init__(self, config, **kwargs): self.system = self.SYSTEM_FOR_ABBREVIATIONS if hasattr(self.config.tools.file_system, 'system_for_abbreviations'): self.system = self.config.tools.file_system.system_for_abbreviations + # {real_path: {"mtime": float, "offset": int|None, "limit": int|None}} + self._read_cache: dict[str, dict] = {} async def connect(self): logger.warning_once( @@ -75,367 +66,117 @@ async def connect(self): async def _get_tools_inner(self): tools = { 'file_system': [ - Tool( - tool_name='create_directory', - server_name='file_system', - description='Create a directory', - parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': - 'string', - 'description': - 'The relative path of the directory to create', - } - }, - 'required': ['path'], - 'additionalProperties': False - }), Tool( tool_name='write_file', server_name='file_system', - description='Write content into a file', + description=( + 'Write content to a file. Creates the file if it does not exist, ' + 'or overwrites it if it does.\n\n' + 'Usage:\n' + '- Prefer `edit_file` for modifying existing files — it only changes the relevant section.\n' + '- Use this tool to create new files or perform a complete rewrite.\n' + '- Parent directories are created automatically if they do not exist.' + ), parameters={ 'type': 'object', 'properties': { 'path': { 'type': 'string', - 'description': 'The relative path of the file', + 'description': 'The relative path of the file to write', }, 'content': { 'type': 'string', - 'description': 'The content of the file', + 'description': 'The full content to write into the file', }, }, 'required': ['path', 'content'], 'additionalProperties': False }), Tool( - tool_name='read_abbreviation_file', + tool_name='read_file', server_name='file_system', - description= - 'Read the abbreviation content of file(s). If the information is not enough, ' - 'read the original file by `read_file`', + description=( + 'Read the content of one or more files.\n\n' + '- `paths`: list of relative file paths to read.\n' + '- For image files (png/jpg/jpeg/gif/webp), returns base64-encoded content.\n' + '- `offset`: line number to start reading from (1-based). ' + 'Only effective when paths has exactly one element. Omit to read from the beginning.\n' + '- `limit`: number of lines to read. ' + 'Only effective when paths has exactly one element. Omit to read to the end.\n' + '- `abbreviate`: if true, use an LLM to return a condensed summary of each file ' + 'instead of the raw content. Cached after first call. ' + 'Use this for a quick structural overview; read the full file if more detail is needed.' + ), parameters={ 'type': 'object', 'properties': { 'paths': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, + 'type': 'array', + 'items': {'type': 'string'}, 'description': - 'List of relative file path(s) to read, format: {"paths": ["file1", "file2"]}"]}', + 'List of relative file path(s) to read', }, - }, - 'required': ['paths'], - 'additionalProperties': False - }), - Tool( - tool_name='read_file', - server_name='file_system', - description= - 'Read the content of file(s). When reading a single file, optionally specify line range.', - parameters={ - 'type': 'object', - 'properties': { - 'paths': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, + 'offset': { + 'type': 'integer', 'description': - 'List of relative file path(s) to read, format: {"paths": ["file1", "file2"]}"]}', + 'Line number to start reading from (1-based). ' + 'Only provide if the file is too large to read at once.', }, - 'start_line': { - 'type': - 'integer', + 'limit': { + 'type': 'integer', 'description': - 'Start line number (1-based, inclusive). Only effective when paths has exactly one ' - 'element. 0 or omit to read from beginning.', + 'Number of lines to read. ' + 'Only provide if the file is too large to read at once.', }, - 'end_line': { - 'type': - 'integer', + 'abbreviate': { + 'type': 'boolean', 'description': - 'End line number (1-based, inclusive). Only effective when paths has exactly one ' - 'element. Omit to read to the end.', + 'If true, return an LLM-generated summary instead of raw content. ' + 'Useful for large files or quick structural overview.', }, }, 'required': ['paths'], 'additionalProperties': False }), - Tool( - tool_name='list_files', - server_name='file_system', - description='List all files in a directory', - parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': - 'string', - 'description': - "The path to list files, if path is None or '' or not given, " - 'the root dir will be used as path.', - } - }, - 'required': [], - 'additionalProperties': False - }), - Tool( - tool_name='delete_file_or_dir', - server_name='file_system', - description='Delete one file or one directory', - parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': 'string', - 'description': 'The relative path to delete', - } - }, - 'required': ['path'], - 'additionalProperties': False - }), Tool( tool_name='edit_file', server_name='file_system', - description= - ('Use this tool to make an edit to an existing file.\n\n' - 'This will be read by a less intelligent model, which will quickly apply the edit. ' - 'You should make it clear what the edit is, while also minimizing the unchanged code you write.\n' - 'When writing the edit, you should specify each edit in sequence, with the special comment ' - '// ... existing code ... to represent unchanged code in between edited lines.\n\n' - 'For example:\n\n// ... existing code ...\nFIRST_EDIT\n// ... existing code ...\n' - 'SECOND_EDIT\n// ... existing code ...\nTHIRD_EDIT\n// ... existing code ...\n\n' - 'You should still bias towards repeating as few lines of the original file ' - 'as possible to convey the change.\n' - 'But, each edit should contain minimally sufficient context of unchanged lines ' - "around the code you're editing to resolve ambiguity.\n" - 'DO NOT omit spans of pre-existing code (or comments) without using the ' - '// ... existing code ... comment to indicate its absence. ' - 'If you omit the existing code comment, the model may inadvertently delete these lines.\n' - 'If you plan on deleting a section, you must provide context before and after to delete it. ' - 'If the initial code is ```code \\n Block 1 \\n Block 2 \\n Block 3 \\n code```, ' - 'and you want to remove Block 2, you would output ' - '```// ... existing code ... \\n Block 1 \\n Block 3 \\n // ... existing code ...```.\n' - 'Make sure it is clear what the edit should be, and where it should be applied.\n' - 'Make edits to a file in a single edit_file call ' - 'instead of multiple edit_file calls to the same file. ' - 'The apply model can handle many distinct edits at once.' - ), + description=( + 'Edit an existing file by replacing an exact string with new content.\n\n' + 'You must provide the exact text to find (`old_string`) and the replacement (`new_string`).\n' + '`old_string` must match the file content EXACTLY — including whitespace and line breaks.\n' + 'If `old_string` appears multiple times and `replace_all` is false, the call will fail ' + 'with the match count so you can add more context to make it unique.\n\n' + 'Special case — `old_string=""`:\n' + '- File does not exist: creates the file with `new_string` as its content.\n' + '- File exists and is empty: fills it with `new_string`.\n' + '- File exists and has content: returns an error. Use `write_file` for a full rewrite.' + ), parameters={ 'type': 'object', 'properties': { 'path': { 'type': 'string', - 'description': - 'Path of the target file to modify.' + 'description': 'The relative path of the file to edit.', }, - 'instructions': { - 'type': - 'string', - 'description': - ('A single sentence instruction describing ' - 'what you are going to do for the sketched edit. ' - 'This is used to assist the less intelligent model in applying the edit. ' - 'Use the first person to describe what you are going to do. ' - 'Use it to disambiguate uncertainty in the edit.' - ) - }, - 'code_edit': { - 'type': - 'string', - 'description': - ('Specify ONLY the precise lines of code that you wish to edit. ' - 'NEVER specify or write out unchanged code. ' - 'Instead, represent all unchanged code using the comment of the language ' - "you're editing in - example: // ... existing code ..." - ) - } - }, - 'required': ['path', 'instructions', 'code_edit'] - }), - Tool( - tool_name='search_file_content', - server_name='file_system', - description= - 'Search for content in files. By default, searches for the exact literal text. ' - 'Set is_regex=true to use regex pattern matching instead. ' - 'Returns matching files with line numbers and surrounding context.', - parameters={ - 'type': 'object', - 'properties': { - 'content': { - 'type': - 'string', - 'description': - 'The text to search for (literal by default, or a regex pattern if is_regex=true).', - }, - 'parent_path': { - 'type': - 'string', - 'description': - 'The relative parent path to search in (optional, defaults to root)', - }, - 'file_pattern': { - 'type': - 'string', - 'description': - 'Wildcard pattern for file names, e.g., "*.py", "*.js", "test_*.py" ' - '(default: "*" for all files)', - }, - 'context_lines': { - 'type': - 'integer', - 'description': - 'Number of lines before and after the match to include (default: 2)', - }, - 'is_regex': { - 'type': - 'boolean', - 'description': - 'If true, treat content as a regex pattern. If false (default), ' - 'search for the exact literal text. Characters like [, ], (, ), ., *, $ ' - 'are matched literally when is_regex is false.', - }, - 'max_matches': { - 'type': - 'integer', - 'description': - 'Maximum number of matches to return (default: 50). ' - 'If more matches exist, the total count is reported but only ' - 'the first max_matches results are shown.', - }, - }, - 'required': ['content'], - 'additionalProperties': False - }), - Tool( - tool_name='search_file_name', - server_name='file_system', - description= - 'Search for files by name using regex pattern matching. ' - 'Supports both regex patterns and simple substring matching. ' - 'If the file parameter is a valid regex pattern, it will be used for regex matching; ' - 'otherwise, falls back to substring matching. ' - 'The parent_path can also be a regex pattern to filter directories.', - parameters={ - 'type': 'object', - 'properties': { - 'file': { - 'type': - 'string', - 'description': - 'The filename pattern to search for (supports regex, e.g., r"\\.js$" for .js files, ' - 'or "service" for substring match).', - }, - 'parent_path': { - 'type': - 'string', - 'description': - 'The relative parent path to search in (supports regex for directory filtering, ' - 'e.g., r"backend.*" to match backend-related directories). ' - 'Defaults to root if not specified.', - }, - }, - 'required': ['file'], - 'additionalProperties': False - }), - Tool( - tool_name='replace_file_lines', - server_name='file_system', - description= - 'Replace specific line ranges in a file. Supports inserting at beginning ' - '(start_line=0) or end (start_line=-1). Line numbers are 1-based and inclusive on both ends.\n\n' - 'IMPORTANT — Line-number shift after each call. Every replacement changes the total line count, ' - 'which invalidates ALL line numbers after the replaced range. If you need to make multiple replacements in the same file:\n' - '- Option A (recommended): Work from BOTTOM to TOP — edit the largest line numbers first so earlier line numbers remain valid.\n' - '- Option B: Re-search after each replacement to get updated line numbers before the next replacement.\n' - '- Option C: Pre-calculate the cumulative offset — each replacement shifts subsequent lines by (new_content_lines - replaced_lines).\n' - 'NEVER call this tool multiple times in parallel on the same file — the concurrent line-number ' - 'shifts will corrupt the file. Always call sequentially.\n', - parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': - 'string', - 'description': - 'The relative path of the file to modify', - }, - 'content': { + 'old_string': { 'type': 'string', - 'description': - 'The new content to insert/replace', - }, - 'start_line': { - 'type': - 'integer', - 'description': - 'Start line number (1-based, inclusive). Use 0 to insert at beginning, ' - '-1 to append at end', + 'description': 'The exact string to find and replace.', }, - 'end_line': { - 'type': - 'integer', - 'description': - 'End line number (1-based, inclusive). Required unless start_line is 0 or -1', - }, - }, - 'required': ['path', 'content', 'start_line'], - 'additionalProperties': False - }), - Tool( - tool_name='replace_file_contents', - server_name='file_system', - description= - 'Replace exact content in a file without using line numbers. ' - 'You must provide:\n' - '[Required]path: The relative path of modified file.\n' - '[Required]source: The old content to be replaced.\n' - '[Required]target: The new content to replace the `source`.\n' - '[Required]occurrence: Which occurrence to replace (1-based).\n' - 'Do not miss any of these arguments!\n\n' - 'IMPORTANT:\n' - '- `source` must match the file content EXACTLY — including punctuation style ' - '(e.g., Chinese "、" vs English ","), whitespace, line breaks, and Unicode characters.', - parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': - 'string', - 'description': - 'The relative path of the file to modify', - }, - 'source': { - 'type': - 'string', - 'description': - 'The exact content to find and replace. Must match the file content ' - 'EXACTLY including all whitespace, punctuation, and line breaks. ', - }, - 'target': { + 'new_string': { 'type': 'string', - 'description': - 'The new content to replace with', + 'description': 'The string to replace it with.', }, - 'occurrence': { - 'type': - 'integer', + 'replace_all': { + 'type': 'boolean', 'description': - 'Which occurrence to replace (1-based). Default is 1 (first occurrence). ' - 'Use -1 to replace all occurrences.', + 'If true, replace all occurrences. Default is false (replace only the first).', }, }, - 'required': ['path', 'source', 'target', 'occurrence'], + 'required': ['path', 'old_string', 'new_string'], 'additionalProperties': False }), + ] } return tools @@ -444,25 +185,69 @@ async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await getattr(self, tool_name)(**tool_args) - async def create_directory(self, path: Optional[str] = None) -> str: - """Create a directory - - Args: - path(`str`): The relative directory path to create, a prefix dir will be automatically concatenated. - - Returns: - or error message. + def _check_staleness(self, real_path: str) -> str | None: + """Return an error string if the file has not been read or has changed since last read. + Returns None if the write is safe to proceed. + Only applies to existing files — new file creation is always allowed. """ - try: - if not path: - path = self.output_dir - else: - path = os.path.join(self.output_dir, path) - os.makedirs(path, exist_ok=True) - return f'Directory: <{path or "root path"}> was created.' - except Exception as e: - return f'Create directory <{path or "root path"}> failed, error: ' + str( - e) + if not os.path.exists(real_path): + return None # new file, no staleness concern + cached = self._read_cache.get(real_path) + if cached is None: + return ( + 'Error: File has not been read yet. ' + 'Read it first before writing to it.' + ) + current_mtime = os.path.getmtime(real_path) + if current_mtime > cached['mtime']: + return ( + 'Error: File has been modified since last read. ' + 'Read it again before writing to it.' + ) + return None + + def _normalize_quotes(self, s: str) -> str: + for curly, straight in self.CURLY_QUOTE_MAP.items(): + s = s.replace(curly, straight) + return s + + def _preserve_quote_style(self, old_string: str, actual_old: str, new_string: str) -> str: + """If old_string matched via quote normalization, apply the same curly quotes to new_string.""" + if old_string == actual_old: + return new_string + has_double = any(c in actual_old for c in '\u201c\u201d') + has_single = any(c in actual_old for c in '\u2018\u2019') + result = new_string + if has_double: + out, chars = [], list(result) + for i, ch in enumerate(chars): + if ch == '"': + prev = chars[i - 1] if i > 0 else None + opening = prev is None or prev in ' \t\n\r([{' + out.append('\u201c' if opening else '\u201d') + else: + out.append(ch) + result = ''.join(out) + if has_single: + out, chars = [], list(result) + for i, ch in enumerate(chars): + if ch == "'": + prev = chars[i - 1] if i > 0 else None + nxt = chars[i + 1] if i < len(chars) - 1 else None + # apostrophe in contraction → right single quote + if prev and nxt and prev.isalpha() and nxt.isalpha(): + out.append('\u2019') + else: + opening = prev is None or prev in ' \t\n\r([{' + out.append('\u2018' if opening else '\u2019') + else: + out.append(ch) + result = ''.join(out) + return result + + @staticmethod + def _strip_trailing_whitespace(s: str) -> str: + return '\n'.join(line.rstrip() for line in s.split('\n')) async def write_file(self, path: str, content: str): """Write content to a file. @@ -478,175 +263,22 @@ async def write_file(self, path: str, content: str): if not os.path.exists(self.output_dir): os.makedirs(self.output_dir, exist_ok=True) original_path = path # Preserve original path for error messages - path = self.get_real_path(path) - if path is None: + real_path = self.get_real_path(path) + if real_path is None: return f'<{original_path}> is out of the valid project path: {self.output_dir}' - dirname = os.path.dirname(path) + err = self._check_staleness(real_path) + if err: + return err + dirname = os.path.dirname(real_path) if dirname: - os.makedirs( - os.path.join(self.output_dir, dirname), exist_ok=True) - with open(os.path.join(self.output_dir, path), 'w') as f: + os.makedirs(dirname, exist_ok=True) + with open(real_path, 'w', encoding='utf-8') as f: f.write(content) + self._read_cache.pop(real_path, None) return f'Save file <{path}> successfully.' except Exception as e: return f'Write file <{path}> failed, error: ' + str(e) - async def replace_file_contents(self, - path: str, - source: str = None, - target: str = None, - occurrence: int = 1): - """Replace exact content in a file without using line numbers. - - This method is safer for parallel operations as it doesn't rely on line numbers - that might change when multiple agents modify the same file concurrently. - - Args: - path(str): The relative file path to modify - source(str): The exact content to find and replace (must match exactly including whitespace) - target(str): The new content to replace with - occurrence(int): Which occurrence to replace (1-based). Use -1 to replace all occurrences. - Default is 1 (first occurrence). - - Returns: - Success or error message. - """ - try: - if not source: - return f'Error: You MUST provide the `source` parameter to be replaced with the `target`, but got {source}.' - if target is None: - return f'Error: You MUST provide the `target` parameter to replace the `source`, but got {target}.' - target_path_real = self.get_real_path(path) - if target_path_real is None: - return f'<{path}> is out of the valid project path: {self.output_dir}' - - # Read file content - if not os.path.exists(target_path_real): - return f'Error: File <{path}> does not exist' - - with open(target_path_real, 'r', encoding='utf-8') as f: - file_content = f.read() - - # Check if source exists - if source not in file_content: - return ( - f'Error: Could not find the exact content to replace in <{path}>. ' - f'Make sure the content matches exactly including all whitespace.' - ) - - # Count occurrences - count = file_content.count(source) - - # Replace based on occurrence parameter - if occurrence == -1: - # Replace all occurrences - updated_content = file_content.replace(source, target) - operation_msg = f'Replaced all {count} occurrence(s)' - elif occurrence < 1: - return f'Error: occurrence must be >= 1 or -1 (for all), got {occurrence}' - elif occurrence > count: - return f'Error: occurrence {occurrence} exceeds total occurrences ({count}) of the content' - else: - # Replace specific occurrence - parts = file_content.split(source, occurrence) - if len(parts) <= occurrence: - return f'Error: Could not find occurrence {occurrence} of the content' - # Rejoin: first (occurrence-1) parts with source, then target, then the rest - updated_content = source.join( - parts[:occurrence]) + target + source.join( - parts[occurrence:]) - operation_msg = f'Replaced occurrence {occurrence} of {count}' - - # Write back to file - with open(target_path_real, 'w', encoding='utf-8') as f: - f.write(updated_content) - - return f'{operation_msg} in file <{path}> successfully.' - - except Exception as e: - return f'Replace content in file <{path}> failed, error: ' + str(e) - - async def replace_file_lines(self, - path: str, - content: str, - start_line: int, - end_line: int = None): - """Replace specific line ranges in a file. - - Args: - path(str): The relative file path to modify, a prefix dir will be automatically concatenated. - content(str): The new content to insert/replace - start_line(int): Start line number (1-based, inclusive). Use 0 to insert at beginning, -1 to append at end - end_line(int): End line number (1-based, inclusive). Optional for start_line=0 or -1 - - Returns: - Success or error message. - """ - try: - target_path_real = self.get_real_path(path) - if target_path_real is None: - return f'<{path}> is out of the valid project path: {self.output_dir}' - file_path = target_path_real - # Read existing file content - if os.path.exists(file_path): - with open(file_path, 'r', encoding='utf-8') as f: - lines = f.readlines() - else: - # If file doesn't exist, create it - dirname = os.path.dirname(file_path) - if dirname: - os.makedirs(dirname, exist_ok=True) - lines = [] - - total_lines = len(lines) - - # Ensure content ends with newline if it doesn't already - if content and not content.endswith('\n'): - content += '\n' - - # Handle special cases - if start_line == 0: - # Insert at beginning - new_lines = [content] + lines - operation = 'Inserted at beginning' - elif start_line == -1: - # Append at end - new_lines = lines + [content] - operation = 'Appended at end' - else: - # Replace range (1-based, inclusive) - if end_line is None: - return 'Error: end_line is required when start_line is not 0 or -1' - - if start_line < 1 or start_line > total_lines + 1: - return f'Error: start_line {start_line} is out of range (file has {total_lines} lines)' - - if end_line < start_line: - return f'Error: end_line {end_line} must be >= start_line {start_line}' - - # Convert to 0-based indices - start_idx = start_line - 1 - # end_line is inclusive (1-based), so we keep lines from end_line onwards (0-based) - end_idx = end_line - # Lines to keep start from index end_line (which is the line after end_line in 1-based) - - new_lines = lines[:start_idx] + [content] + lines[end_idx:] - operation = f'Replaced lines {start_line}-{end_line}' - - # Write back to file - with open(file_path, 'w', encoding='utf-8') as f: - f.writelines(new_lines) - - target = '\n'.join(new_lines).split('\n') - return ( - f'{operation} in file <{path}> completed successfully. The updated file now has {len(target)} lines. ' - 'WARNING: All line numbers after the replaced range may have shifted. ' - 'If you need to make another line-based replacement in this file, keep this in mind.' - ) - - except Exception as e: - return f'Replace lines in file <{path}> failed, error: ' + str(e) - def get_real_path(self, path): # Check if path is absolute or already starts with output_dir if os.path.isabs(path): @@ -670,7 +302,121 @@ def get_real_path(self, path): else: return target_path_real - async def read_abbreviation_file(self, paths: list[str]): + async def read_file(self, + paths: list[str], + offset: int = None, + limit: int = None, + abbreviate: bool = False): + """Read the content of file(s). + + Args: + paths: List of relative file path(s) to read. + offset: Line number to start reading from (1-based). Only effective for a single file. + limit: Number of lines to read. Only effective for a single file. + abbreviate: If True, return an LLM-generated summary instead of raw content. + + Returns: + Dictionary mapping file path(s) to their content or error messages. + """ + if abbreviate: + return await self._read_files_abbreviated(paths) + + results = {} + use_line_range = len(paths) == 1 and (offset is not None + or limit is not None) + + for path in paths: + try: + target_path_real = self.get_real_path(path) + if target_path_real is None: + results[path] = ( + f'Access denied: Reading file <{path}> outside output directory is not allowed. ' + f'Set allow_read_all_files=true in config to enable.') + continue + + ext = os.path.splitext(path)[1].lstrip('.').lower() + + # --- Image files --- + if ext in self.IMAGE_EXTENSIONS: + with open(target_path_real, 'rb') as f: + raw = f.read() + media_type = f'image/{ext}' if ext != 'jpg' else 'image/jpeg' + results[path] = { + 'type': 'image', + 'media_type': media_type, + 'base64': base64.b64encode(raw).decode('ascii'), + } + continue + + # --- Text files --- + file_size = os.path.getsize(target_path_real) + if file_size > self.MAX_READ_BYTES and not use_line_range: + results[path] = ( + f'Error: File <{path}> is too large ({file_size} bytes). ' + f'Use offset and limit to read specific portions.') + continue + + # Dedup: return stub if file unchanged since last read + mtime = os.path.getmtime(target_path_real) + cached = self._read_cache.get(target_path_real) + if (cached + and cached['mtime'] == mtime + and cached['offset'] == offset + and cached['limit'] == limit): + results[path] = { + 'type': 'file_unchanged', + 'message': 'File has not changed since last read.', + } + continue + + with open(target_path_real, 'rb') as f: + raw_bytes = f.read() + + try: + content = raw_bytes.decode('utf-8') + except UnicodeDecodeError: + results[path] = ( + f'Error: File <{path}> appears to be binary. ' + f'Only text and image files are supported.') + continue + + # Normalize line endings + content = content.replace('\r\n', '\n') + lines = content.splitlines(keepends=True) + total_lines = len(lines) + + if use_line_range: + actual_start = max(1, offset) if offset is not None else 1 + actual_end = min(actual_start + limit - 1, total_lines) if limit is not None else total_lines + + if actual_start > total_lines: + results[path] = f'Error: offset {offset} exceeds file length ({total_lines} lines)' + continue + selected = lines[actual_start - 1:actual_end] + start_lineno = actual_start + else: + selected = lines + start_lineno = 1 + + results[path] = ''.join( + f'{start_lineno + i}\t{line}' + for i, line in enumerate(selected) + ) + + # Update dedup cache + self._read_cache[target_path_real] = { + 'mtime': mtime, + 'offset': offset, + 'limit': limit, + } + + except FileNotFoundError: + results[path] = f'Read file <{path}> failed: FileNotFound' + except Exception as e: + results[path] = f'Read file <{path}> failed, error: ' + str(e) + return json.dumps(results, indent=2, ensure_ascii=False) + + async def _read_files_abbreviated(self, paths: list[str]) -> str: results = {} def process_file(path): @@ -681,20 +427,18 @@ def process_file(path): index_file = os.path.join(self.index_dir, path.strip(os.sep)) if os.path.exists(index_file): - with open(index_file, 'r', encoding='utf-8') as f: - return path, f.read() + src_mtime = os.path.getmtime(target_path_real) + idx_mtime = os.path.getmtime(index_file) + if idx_mtime >= src_mtime: + with open(index_file, 'r', encoding='utf-8') as f: + return path, f.read() - # Read file content with open(target_path_real, 'r', encoding='utf-8') as f: content = f.read() - # Use LLM to generate abbreviation messages = [ Message(role='system', content=self.system), - Message( - role='user', - content='The content to be abbreviated:\n\n' - + content), + Message(role='user', content='The content to be abbreviated:\n\n' + content), ] response = self.llm.generate(messages=messages, stream=False) os.makedirs(os.path.dirname(index_file), exist_ok=True) @@ -706,373 +450,130 @@ def process_file(path): except Exception as e: return path, f'Process file <{path}> failed, error: ' + str(e) - # Use thread pool for parallel LLM API calls with ThreadPoolExecutor(max_workers=4) as executor: - future_to_path = { - executor.submit(process_file, p): p - for p in paths - } + future_to_path = {executor.submit(process_file, p): p for p in paths} for future in as_completed(future_to_path): path, result = future.result() results[path] = result return json.dumps(results, indent=2, ensure_ascii=False) - async def read_file(self, - paths: list[str], - start_line: int = 0, - end_line: int = None): - """Read the content of file(s). + async def edit_file(self, + path: str = None, + old_string: str = None, + new_string: str = None, + replace_all: bool = False): + """Edit a file by replacing an exact string with new content. Args: - paths(`list[str]`): List of relative file path(s) to read, a prefix dir will be automatically concatenated. - start_line(int): Start line number (1-based, inclusive). Only effective when paths has exactly one element. - 0 means from the beginning. - end_line(int): End line number (1-based, inclusive). Only effective when paths has exactly one element. - None means to the end. + path: The relative file path to edit. + old_string: The exact string to find and replace. + new_string: The replacement string. + replace_all: If True, replace all occurrences. Default replaces only the first. Returns: - Dictionary mapping file path(s) to their content or error messages. + Success or error message. """ - results = {} - # Line range is only effective when reading a single file - use_line_range = len(paths) == 1 and (start_line > 0 - or end_line is not None) - - for path in paths: - try: - target_path_real = self.get_real_path(path) - if target_path_real is None: - results[path] = ( - f'Access denied: Reading file <{path}> outside output directory is not allowed. ' - f'Set allow_read_all_files=true in config to enable.') - continue - - with open(target_path_real, 'r') as f: - if use_line_range: - # Read specific line range - lines = f.readlines() - total_lines = len(lines) - - # Validate and adjust line numbers (1-based) - actual_start = max(1, - start_line) if start_line > 0 else 1 - actual_end = min( - end_line, total_lines - ) if end_line is not None else total_lines - - if actual_start > total_lines: - results[ - path] = f'Error: start_line {start_line} exceeds file length ({total_lines} lines)' - elif actual_start > actual_end: - results[ - path] = f'Error: start_line {actual_start} > end_line {actual_end}' - else: - # Convert to 0-based index, end_line is inclusive - selected_lines = lines[actual_start - 1:actual_end] - results[path] = ''.join(selected_lines) - else: - # Read entire file - results[path] = f.read() - except FileNotFoundError: - results[path] = f'Read file <{path}> failed: FileNotFound' - except Exception as e: - results[path] = f'Read file <{path}> failed, error: ' + str(e) - return json.dumps(results, indent=2, ensure_ascii=False) - - async def delete_file_or_dir(self, path: str): - """Delete a file or a directory. + try: + if old_string is None: + return 'Error: `old_string` is required.' + if new_string is None: + return 'Error: `new_string` is required.' - Args: - path(str): The file or directory to delete, a prefix dir will be automatically concatenated. + target_path_real = self.get_real_path(path) + if target_path_real is None: + return f'<{path}> is out of the valid project path: {self.output_dir}' - Returns: - boolean - """ - abs_path = os.path.join(self.output_dir, path) - if os.path.exists(abs_path): - try: - if os.path.isfile(abs_path): - os.remove(abs_path) - else: - shutil.rmtree(abs_path) - return f'Path deleted: <{path}>' - except Exception as e: - return f'Delete file <{path}> failed, error: ' + str(e) - else: - return f'Path not found: {path}' + # --- Special case: old_string="" --- + if old_string == '': + if not os.path.exists(target_path_real): + # Create new file + os.makedirs(os.path.dirname(target_path_real), exist_ok=True) + with open(target_path_real, 'w', encoding='utf-8') as f: + f.write(new_string) + return f'Created file <{path}> successfully.' + with open(target_path_real, 'rb') as f: + existing = f.read() + try: + existing_text = existing.decode('utf-8') + except UnicodeDecodeError: + return f'Error: File <{path}> appears to be binary and cannot be edited as text.' + if existing_text.strip() != '': + return ( + 'Error: `old_string` is empty but the file already has content. ' + 'Use `write_file` for a full rewrite, or provide an `old_string` anchor to insert content.' + ) + with open(target_path_real, 'w', encoding='utf-8') as f: + f.write(new_string) + self._read_cache.pop(target_path_real, None) + return f'Edit file <{path}> successfully (filled empty file).' - async def search_file_name(self, file: str = '', parent_path: str = ''): - """Search for files by name using regex pattern matching. + if not os.path.exists(target_path_real): + return f'Error: File <{path}> does not exist.' - Args: - file(str): File name pattern (supports regex). If it's a valid regex pattern, - it will be used for regex matching; otherwise, falls back to substring matching. - parent_path(str): Parent path pattern (supports regex for filtering directories). - Can be a simple path or a regex pattern to match directory paths. + err = self._check_staleness(target_path_real) + if err: + return err - Returns: - String containing all matched file paths - """ - parent_path = parent_path or '' - target_path_real = self.get_real_path(parent_path) - if target_path_real is None: - return f'<{parent_path}> is out of the valid project path: {self.output_dir}' - _parent_path = target_path_real - assert os.path.isdir( - _parent_path - ), f'Parent path <{parent_path}> does not exist, it should be a inner relative path of the project folder.' - - # Try to compile file pattern as regex - file_use_regex = False - file_pattern = None - if file: + with open(target_path_real, 'rb') as f: + raw = f.read() try: - file_pattern = re.compile(file) - file_use_regex = True - except re.error: - file_use_regex = False - - # Try to compile parent_path filter as regex (optional) - path_use_regex = False - path_pattern = None - if parent_path: - try: - path_pattern = re.compile(parent_path) - path_use_regex = True - except re.error: - path_use_regex = False - - all_found_files = [] - for root, dirs, files in os.walk(_parent_path): - if path_use_regex and parent_path: - relative_root = os.path.relpath(root, self.output_dir) - if not path_pattern.search(relative_root): - continue - - for filename in files: - if file: - if file_use_regex: - is_match = file_pattern.search(filename) is not None - else: - is_match = file in filename - else: - is_match = True # No filter, match all files - - if is_match: - file_path = os.path.join(root, filename) - relative_path = os.path.relpath(file_path, self.output_dir) - all_found_files.append(relative_path) - - if not all_found_files: - return f'No files found matching pattern <{file or "*"}> in <{parent_path or "root"}>' + content = raw.decode('utf-8') + except UnicodeDecodeError: + return f'Error: File <{path}> appears to be binary and cannot be edited as text.' + + # Normalize line endings for matching + content = content.replace('\r\n', '\n') + old_string = old_string.replace('\r\n', '\n') + + # --- Fallback 1: exact match --- + actual_old = old_string if old_string in content else None + + # --- Fallback 2: quote normalization --- + if actual_old is None: + norm_old = self._normalize_quotes(old_string) + norm_content = self._normalize_quotes(content) + idx = norm_content.find(norm_old) + if idx != -1: + actual_old = content[idx:idx + len(old_string)] + + if actual_old is None: + return ( + f'Error: `old_string` not found in <{path}>. ' + f'Make sure it matches the file content exactly including whitespace.' + ) - all_found_files = '\n'.join(all_found_files) - return f'Found {len(all_found_files.splitlines())} file(s) matching <{file or "*"}>:\n{all_found_files}' + count = content.count(actual_old) + if count > 1 and not replace_all: + return ( + f'Error: Found {count} occurrences of `old_string` in <{path}>. ' + f'Add more surrounding context to make it unique, or set replace_all=true.' + ) - async def search_file_content(self, - content: str = None, - parent_path: str = '.', - file_pattern: str = '*', - context_lines: int = 2, - is_regex: bool = False, - max_matches: int = 50): - """Search for content in files using thread pool. + # Apply quote style preservation to new_string + actual_new = self._preserve_quote_style(old_string, actual_old, new_string) - Args: - content(str): The text to search for (literal by default, regex if is_regex=True) - parent_path(str): The relative parent path to search in - file_pattern(str): Wildcard pattern for file names (default: '*' for all files) - context_lines(int): Number of lines before and after the match to include (default: 2) - is_regex(bool): If True, treat content as a regex pattern; otherwise literal match (default: False) - max_matches(int): Maximum number of matches to return (default: 50) + # --- Fallback 3: smart delete — strip trailing newline when deleting --- + if actual_new == '' and not actual_old.endswith('\n') and actual_old + '\n' in content: + actual_old = actual_old + '\n' - Returns: - String containing all matches with file path, line number, and context - """ - if parent_path.startswith('.' + os.sep): - parent_path = parent_path[len('.' + os.sep):] - if parent_path == '.': - parent_path = '' - target_path_real = self.get_real_path(parent_path) - if target_path_real is None: - return f'<{parent_path}> is out of the valid project path: {self.output_dir}' - _parent_path = target_path_real - assert os.path.isdir( - _parent_path - ), f'Parent path <{parent_path}> does not exist, it should be a inner relative path of the project folder.' - - if not content: - return 'Error: content parameter is required for search' - - use_regex = False - pattern = None - if is_regex: - try: - pattern = re.compile(content) - use_regex = True - except re.error: - return f'Error: "{content}" is not a valid regex pattern.' - else: - use_regex = False + # Strip trailing whitespace from new_string (skip markdown files) + is_markdown = path.lower().endswith(('.md', '.mdx')) + if not is_markdown: + actual_new = self._strip_trailing_whitespace(actual_new) - # Collect all files matching the pattern - files_to_search = [] - for root, dirs, files in os.walk(_parent_path): - try: - test_dir = str(Path(root).relative_to(self.output_dir)) - except ValueError: - test_dir = str(root) - if test_dir == '.': - test_dir = '' - if any(excluded_dir in root - for excluded_dir in self.EXCLUDED_DIRS): - continue - for filename in files: - # Skip excluded files - if filename.startswith( - self.EXCLUDED_FILE_PREFIXES) or test_dir.startswith( - self.EXCLUDED_FILE_PREFIXES): - continue - # Match file pattern - if fnmatch.fnmatch(filename, file_pattern): - files_to_search.append(os.path.join(root, filename)) - - if not files_to_search: - return f'No files matching pattern <{file_pattern}> found in <{parent_path or "root"}>' - - # Function to search in a single file - def search_in_file(file_path): - matches = [] - with open(file_path, 'r', encoding='utf-8') as f: - lines = f.readlines() - for line_num, line in enumerate(lines, start=1): - if use_regex: - is_match = pattern.search(line) is not None - else: - is_match = content in line - - if is_match: - start_line = max(0, line_num - context_lines - 1) - end_line = min(len(lines), line_num + context_lines) - - context = [] - for i in range(start_line, end_line): - prefix = '> ' if i == line_num - 1 else ' ' - context.append( - f'{prefix}{i + 1:4d} | {lines[i].rstrip()}') - - relative_path = os.path.relpath( - file_path, self.output_dir) - matches.append({ - 'file': relative_path, - 'line': line_num, - 'context': '\n'.join(context) - }) - return matches - - # Use thread pool to search files in parallel - all_matches = [] - with ThreadPoolExecutor(max_workers=8) as executor: - future_to_file = { - executor.submit(search_in_file, f): f - for f in files_to_search - } - for future in as_completed(future_to_file): - matches = future.result() - all_matches.extend(matches) - - if not all_matches: - return f'No matches found for <{content}> in files matching <{file_pattern}>' - - all_matches.sort(key=lambda m: (m['file'], m['line'])) - - total_found = len(all_matches) - truncated = total_found > max_matches - if truncated: - all_matches = all_matches[:max_matches] - - # Format results - if truncated: - result_lines = [ - f'Found {total_found} match(es) for "{content}" ' - f'(showing first {max_matches}; refine your search ' - f'or increase max_matches for more):\n' - ] - else: - result_lines = [ - f'Found {total_found} match(es) for "{content}":\n' - ] - for match in all_matches: - result_lines.append( - f"File: {match['file']}, Line: {match['line']}") - result_lines.append(match['context']) - result_lines.append('') - - return '\n'.join(result_lines) - - async def list_files(self, path: str = None): - """List all files in a directory. + if replace_all: + updated = content.replace(actual_old, actual_new) + else: + updated = content.replace(actual_old, actual_new, 1) - Args: - path: The relative path to traverse, a prefix dir will be automatically concatenated. + with open(target_path_real, 'w', encoding='utf-8') as f: + f.write(updated) - Returns: - The file names concatenated as a string - """ - file_paths = [] - if not path or path == '.': - path = self.output_dir - else: - path = os.path.join(self.output_dir, path) - if path.startswith('.' + os.sep): - path = path[len('.' + os.sep):] - try: - for root, dirs, files in os.walk(path): - try: - test_dir = str(Path(root).relative_to(self.output_dir)) - except ValueError: - test_dir = str(root) - if test_dir == '.': - test_dir = '' - for file in files: - # Skip excluded directories and files - root_exclude = any(excluded_dir in root - for excluded_dir in self.EXCLUDED_DIRS) - if root_exclude or file.startswith( - self.EXCLUDED_FILE_PREFIXES - ) or test_dir.startswith(self.EXCLUDED_FILE_PREFIXES): - continue - absolute_path = os.path.join(root, file) - relative_path = os.path.relpath(absolute_path, path) - file_paths.append(relative_path) - return '\n'.join(file_paths) or f'No files in path: {path}' - except Exception as e: - return f'List files of <{path or "root path"}> failed, error: ' + str( - e) + # Invalidate dedup cache for this file + self._read_cache.pop(target_path_real, None) - @retry(max_attempts=MAX_CONTINUE_RUNS, delay=1.0) - async def edit_file(self, - path: str = None, - instructions: str = None, - code_edit: str = None): - try: - with open(os.path.join(self.output_dir, path), 'r') as f: - initial_code = f.read() - response = self.edit_client.chat.completions.create( - model=self.edit_file_config.diff_model, - messages=[{ - 'role': - 'user', - 'content': - (f'{instructions}\n' - f'{initial_code}\n' - f'{code_edit}') - }]) - merged_code = response.choices[0].message.content - - with open(os.path.join(self.output_dir, path), 'w') as f: - f.write(merged_code) - return f'Edit file <{path}> successfully.' + replaced = count if replace_all else 1 + return f'Edit file <{path}> successfully ({replaced} occurrence(s) replaced).' except Exception as e: return f'Edit file <{path}> failed, error: ' + str(e) From ecf350b4a55ff533cc46b4bf8ef49d39c5bfdc06 Mon Sep 17 00:00:00 2001 From: suluyan Date: Wed, 22 Apr 2026 14:53:43 +0800 Subject: [PATCH 35/40] fix(dr v2): align tavily yaml file_system include with grep/glob/edit_file Made-with: Cursor --- projects/deep_research/v2/researcher.tavily.yaml | 7 +++---- projects/deep_research/v2/searcher.tavily.yaml | 4 +++- projects/deep_research/v2/searcher.yaml | 1 + 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/projects/deep_research/v2/researcher.tavily.yaml b/projects/deep_research/v2/researcher.tavily.yaml index e5423c4fc..e2bdd654e 100644 --- a/projects/deep_research/v2/researcher.tavily.yaml +++ b/projects/deep_research/v2/researcher.tavily.yaml @@ -36,10 +36,9 @@ tools: include: - write_file - read_file - - list_files - - search_file_content - - replace_file_contents - - replace_file_lines + - edit_file + - grep + - glob code_executor: mcp: false implementation: python_env diff --git a/projects/deep_research/v2/searcher.tavily.yaml b/projects/deep_research/v2/searcher.tavily.yaml index e8d632745..b24887dae 100644 --- a/projects/deep_research/v2/searcher.tavily.yaml +++ b/projects/deep_research/v2/searcher.tavily.yaml @@ -36,7 +36,9 @@ tools: include: - write_file - read_file - - list_files + - edit_file + - grep + - glob web_search: mcp: false engines: diff --git a/projects/deep_research/v2/searcher.yaml b/projects/deep_research/v2/searcher.yaml index c4fcdbde9..d3f4817a8 100644 --- a/projects/deep_research/v2/searcher.yaml +++ b/projects/deep_research/v2/searcher.yaml @@ -35,6 +35,7 @@ tools: include: - write_file - read_file + - edit_file - grep - glob web_search: From a5d4267c121d42bffb52404073aa57f7300b30ac Mon Sep 17 00:00:00 2001 From: suluyan Date: Thu, 23 Apr 2026 15:42:14 +0800 Subject: [PATCH 36/40] fix(filesystem): drop read-cache staleness gate for writes Align edit/write with disk-backed validation (Claude Code style): remove _check_staleness and post-write cache pops that caused redundant read_file round-trips and noisy errors. test: use tool_manager.extra_tools in rollback read_cache smoke (matches LLMAgent.rollback). Made-with: Cursor --- ms_agent/tools/filesystem_tool.py | 42 +++++------------------------- tests/utils/test_snapshot_smoke.py | 2 +- 2 files changed, 8 insertions(+), 36 deletions(-) diff --git a/ms_agent/tools/filesystem_tool.py b/ms_agent/tools/filesystem_tool.py index 73cb0a442..b6b2678f2 100644 --- a/ms_agent/tools/filesystem_tool.py +++ b/ms_agent/tools/filesystem_tool.py @@ -81,7 +81,7 @@ def __init__(self, config, **kwargs): self.system = self.SYSTEM_FOR_ABBREVIATIONS if hasattr(self.config.tools.file_system, 'system_for_abbreviations'): self.system = self.config.tools.file_system.system_for_abbreviations - # {real_path: {"mtime": float, "offset": int|None, "limit": int|None}} + # read_file dedup only: {real_path: {"mtime", "offset", "limit"}} self._read_cache: dict[str, dict] = {} fs_cfg = getattr(config.tools, 'file_system', None) @@ -139,7 +139,8 @@ async def _get_tools_inner(self): 'Usage:\n' '- Prefer `edit_file` for modifying existing files — it only changes the relevant section.\n' '- Use this tool to create new files or perform a complete rewrite.\n' - '- Parent directories are created automatically if they do not exist.' + '- Parent directories are created automatically if they do not exist.\n\n' + 'No prior `read_file` is required; the path is resolved and the given `content` is written as-is.' ), parameters={ 'type': 'object', @@ -207,6 +208,10 @@ async def _get_tools_inner(self): server_name='file_system', description=( 'Edit an existing file by replacing an exact string with new content.\n\n' + 'The tool reads the file from disk and checks that `old_string` appears in the current ' + 'contents before applying the edit — you do not need a prior `read_file` call for ' + 'staleness. Still use `read_file` or `grep` when you need the exact snippet in the ' + 'conversation so you can form a correct `old_string`.\n\n' 'You must provide the exact text to find (`old_string`) and the replacement (`new_string`).\n' '`old_string` must match the file content EXACTLY — including whitespace and line breaks.\n' 'If `old_string` appears multiple times and `replace_all` is false, the call will fail ' @@ -582,27 +587,6 @@ async def glob(self, pattern: str, path: str = '') -> str: ) return json.dumps(packed, ensure_ascii=False, indent=2, default=str) - def _check_staleness(self, real_path: str) -> str | None: - """Return an error string if the file has not been read or has changed since last read. - Returns None if the write is safe to proceed. - Only applies to existing files — new file creation is always allowed. - """ - if not os.path.exists(real_path): - return None # new file, no staleness concern - cached = self._read_cache.get(real_path) - if cached is None: - return ( - 'Error: File has not been read yet. ' - 'Read it first before writing to it.' - ) - current_mtime = os.path.getmtime(real_path) - if current_mtime > cached['mtime']: - return ( - 'Error: File has been modified since last read. ' - 'Read it again before writing to it.' - ) - return None - def _normalize_quotes(self, s: str) -> str: for curly, straight in self.CURLY_QUOTE_MAP.items(): s = s.replace(curly, straight) @@ -663,15 +647,11 @@ async def write_file(self, path: str, content: str): real_path = self.get_real_path(path) if real_path is None: return f'<{original_path}> is out of the valid project path: {self.output_dir}' - err = self._check_staleness(real_path) - if err: - return err dirname = os.path.dirname(real_path) if dirname: os.makedirs(dirname, exist_ok=True) with open(real_path, 'w', encoding='utf-8') as f: f.write(content) - self._read_cache.pop(real_path, None) return f'Save file <{path}> successfully.' except Exception as e: return f'Write file <{path}> failed, error: ' + str(e) @@ -902,16 +882,11 @@ async def edit_file(self, ) with open(target_path_real, 'w', encoding='utf-8') as f: f.write(new_string) - self._read_cache.pop(target_path_real, None) return f'Edit file <{path}> successfully (filled empty file).' if not os.path.exists(target_path_real): return f'Error: File <{path}> does not exist.' - err = self._check_staleness(target_path_real) - if err: - return err - with open(target_path_real, 'rb') as f: raw = f.read() try: @@ -967,9 +942,6 @@ async def edit_file(self, with open(target_path_real, 'w', encoding='utf-8') as f: f.write(updated) - # Invalidate dedup cache for this file - self._read_cache.pop(target_path_real, None) - replaced = count if replace_all else 1 return f'Edit file <{path}> successfully ({replaced} occurrence(s) replaced).' except Exception as e: diff --git a/tests/utils/test_snapshot_smoke.py b/tests/utils/test_snapshot_smoke.py index c544135f2..17bbc2e76 100644 --- a/tests/utils/test_snapshot_smoke.py +++ b/tests/utils/test_snapshot_smoke.py @@ -267,7 +267,7 @@ def test_rollback_clears_read_cache(self): fake_tool = MagicMock() fake_tool._read_cache = {'some/path': {'mtime': 123}} fake_manager = MagicMock() - fake_manager.tools = {'fs': fake_tool} + fake_manager.extra_tools = [fake_tool] agent.tool_manager = fake_manager agent.rollback(h) From 00c8e8678befa5a8809145e64a3ebfc6c61120b8 Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 27 Apr 2026 11:49:21 +0800 Subject: [PATCH 37/40] fix(deep-research): harden tools, snapshots, and reporter/searcher configs - Subagent snapshot defaults and snapshot repo hook bypass - FileSystemTool read_file path alias; grep newline guard - Evidence write_note optional title; report commit_outline coercion and report_generator load_index - Reporter todo_list; Tavily-only searcher yaml; exp_nosnap configs - Searcher JSON parse resilience in callback Made-with: Cursor --- ms_agent/agent/llm_agent.py | 36 ++- ms_agent/tools/agent_tool.py | 15 +- ms_agent/tools/filesystem_tool.py | 76 +++++- ms_agent/utils/snapshot.py | 21 +- .../v2/callbacks/searcher_callback.py | 99 ++++++-- .../deep_research/v2/reporter.exp_nosnap.yaml | 93 +++++++ projects/deep_research/v2/reporter.yaml | 6 + .../v2/researcher.exp_nosnap.yaml | 174 +++++++++++++ .../deep_research/v2/searcher.exp_nosnap.yaml | 110 +++++++++ projects/deep_research/v2/searcher.yaml | 40 ++- .../deep_research/v2/tools/evidence_tool.py | 29 ++- .../deep_research/v2/tools/report_tool.py | 229 +++++++++++++++--- 12 files changed, 840 insertions(+), 88 deletions(-) create mode 100644 projects/deep_research/v2/reporter.exp_nosnap.yaml create mode 100644 projects/deep_research/v2/researcher.exp_nosnap.yaml create mode 100644 projects/deep_research/v2/searcher.exp_nosnap.yaml diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 75adcf7bb..c8adf5740 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -26,7 +26,6 @@ from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger from ms_agent.utils.snapshot import take_snapshot -from ms_agent.utils.task_manager import TaskManager from omegaconf import DictConfig, OmegaConf from ..config.config import Config, ConfigLifecycleHandler @@ -34,6 +33,8 @@ logger = get_logger() +_MISSING_ENABLE_SNAPSHOTS = object() + class LLMAgent(Agent): """ @@ -87,6 +88,36 @@ class LLMAgent(Agent): DEFAULT_MAX_CHAT_ROUND = 20 + @staticmethod + def _coerce_enable_snapshots_value(value: Any) -> bool: + if isinstance(value, str): + return value.strip().lower() in ('1', 'true', 'yes', 'on') + return bool(value) + + @staticmethod + def resolve_enable_snapshots(config: Any) -> bool: + """Resolve whether to take automatic pre-task snapshots. + + Tool-spawned sub-agents (``ms_agent_subagent`` in config) default to + ``False``; all other agents default to ``True``. An explicit + ``enable_snapshots`` in config always wins (including string forms + like ``\"false\"`` coerced to boolean). + """ + if OmegaConf.is_config(config): + raw = OmegaConf.select(config, 'enable_snapshots', + default=_MISSING_ENABLE_SNAPSHOTS) + if raw is not _MISSING_ENABLE_SNAPSHOTS and raw is not None: + return LLMAgent._coerce_enable_snapshots_value(raw) + sub = bool( + OmegaConf.select(config, 'ms_agent_subagent', default=False)) + return not sub + if isinstance(config, dict): + if 'enable_snapshots' in config and config['enable_snapshots'] is not None: + return LLMAgent._coerce_enable_snapshots_value( + config['enable_snapshots']) + return not bool(config.get('ms_agent_subagent')) + return True + TOTAL_PROMPT_TOKENS = 0 TOTAL_COMPLETION_TOKENS = 0 TOTAL_CACHED_TOKENS = 0 @@ -114,7 +145,6 @@ def __init__( self.knowledge_search: Optional[SirchmunkSearch] = None self.llm: Optional[LLM] = None self.runtime: Optional[Runtime] = None - self.task_manager: Optional[TaskManager] = None self.max_chat_round: int = 0 self.load_cache = kwargs.get('load_cache', False) self.config.load_cache = self.load_cache @@ -500,7 +530,7 @@ def register_callback_from_config(self): async def on_task_begin(self, messages: List[Message]): self.log_output(f'Agent {self.tag} task beginning.') - if getattr(self.config, 'enable_snapshots', True): + if self.resolve_enable_snapshots(self.config): _user_content = next( ((getattr(m, 'content', '') or '')[:80] for m in messages if getattr(m, 'role', '') == 'user'), diff --git a/ms_agent/tools/agent_tool.py b/ms_agent/tools/agent_tool.py index 5f9ff7d20..02a57b8d5 100644 --- a/ms_agent/tools/agent_tool.py +++ b/ms_agent/tools/agent_tool.py @@ -75,9 +75,17 @@ def _message_from_data(data: Any) -> Message: def _build_sub_agent(spec: _AgentToolSpec, default_trust_remote_code: bool): if spec.inline_config is not None: - config_override = OmegaConf.create(spec.inline_config) + container = _to_container(spec.inline_config) + base_override = OmegaConf.create(container) if isinstance( + container, dict) else OmegaConf.create({}) else: - config_override = None + base_override = OmegaConf.create({}) + # Sub-agents default snapshots off in LLMAgent unless enable_snapshots is set + # on the merged agent config (e.g. in the sub-agent YAML). + config_override = OmegaConf.merge( + base_override, + OmegaConf.create({'ms_agent_subagent': True}), + ) trust_remote_code = spec.trust_remote_code if trust_remote_code is None: @@ -736,6 +744,9 @@ async def _run_one(i: int, task: dict) -> str: system = task.get('system', '') query = task.get('query', '') task_config = dict(base_config) if isinstance(base_config, dict) else {} + # Avoid inheriting the parent agent's snapshot preference into each + # split sub-task; sub-agents use ms_agent_subagent defaults instead. + task_config.pop('enable_snapshots', None) if 'prompt' not in task_config or not isinstance(task_config.get('prompt'), dict): task_config['prompt'] = {} task_config['prompt']['system'] = system diff --git a/ms_agent/tools/filesystem_tool.py b/ms_agent/tools/filesystem_tool.py index b6b2678f2..caa58c58b 100644 --- a/ms_agent/tools/filesystem_tool.py +++ b/ms_agent/tools/filesystem_tool.py @@ -162,7 +162,8 @@ async def _get_tools_inner(self): server_name='file_system', description=( 'Read the content of one or more files.\n\n' - '- `paths`: list of relative file paths to read.\n' + '- `paths`: list of relative file paths to read (preferred).\n' + '- `path`: single relative file path (alias when the model passes one file).\n' '- For image files (png/jpg/jpeg/gif/webp), returns base64-encoded content.\n' '- `offset`: line number to start reading from (1-based). ' 'Only effective when paths has exactly one element. Omit to read from the beginning.\n' @@ -179,7 +180,13 @@ async def _get_tools_inner(self): 'type': 'array', 'items': {'type': 'string'}, 'description': - 'List of relative file path(s) to read', + 'List of relative file path(s) to read. ' + 'Use this OR `path` (single file).', + }, + 'path': { + 'type': 'string', + 'description': + 'Single relative file path to read (alias for `paths` of length 1).', }, 'offset': { 'type': 'integer', @@ -200,7 +207,7 @@ async def _get_tools_inner(self): 'Useful for large files or quick structural overview.', }, }, - 'required': ['paths'], + 'required': [], 'additionalProperties': False }), Tool( @@ -345,6 +352,28 @@ async def grep( except WorkspacePolicyError as e: return json.dumps({'success': False, 'error': str(e)}, indent=2) + if pattern is None or (isinstance(pattern, str) and not pattern.strip()): + return json.dumps( + { + 'success': False, + 'error': 'grep requires a non-empty pattern string.', + }, + indent=2, + ) + if isinstance(pattern, str) and ('\n' in pattern or '\r' in pattern): + return json.dumps( + { + 'success': False, + 'error': ( + 'grep pattern must not contain raw newline characters; ' + 'ripgrep rejects them unless multiline mode is enabled server-side. ' + 'Use several single-line patterns, escape newlines as needed, ' + 'or search with read_file for fixed multi-line text.' + ), + }, + indent=2, + ) + lines: List[str] = [] try: rg = shutil.which('rg') @@ -367,7 +396,15 @@ async def grep( case_insensitive, ) except Exception as e: - logger.warning('grep failed: %s', e, exc_info=True) + err = str(e) + # Expected user/tooling failures (bad regex, rg rules) — log without traceback noise. + _quiet = ( + 'rg:' in err + or 'exited' in err.lower() + or 'regex' in err.lower() + or 'pattern' in err.lower() + ) + logger.warning('grep failed: %s', e, exc_info=not _quiet) return json.dumps({'success': False, 'error': str(e)}, indent=2) text = '\n'.join(lines) @@ -679,8 +716,24 @@ def get_real_path(self, path): else: return target_path_real + def _normalize_read_paths(self, paths, path) -> List[str]: + """Accept `paths` (list or lone string) or legacy `path` kwarg.""" + out: List[str] = [] + if paths is not None: + if isinstance(paths, str) and paths.strip(): + out = [paths.strip()] + elif isinstance(paths, list): + out = [ + p.strip() for p in paths + if isinstance(p, str) and p.strip() + ] + if not out and path is not None and isinstance(path, str) and path.strip(): + out = [path.strip()] + return out + async def read_file(self, - paths: list[str], + paths: Optional[List[str]] = None, + path: Optional[str] = None, offset: int = None, limit: int = None, abbreviate: bool = False): @@ -688,6 +741,7 @@ async def read_file(self, Args: paths: List of relative file path(s) to read. + path: Single path (alias); used when `paths` is omitted. offset: Line number to start reading from (1-based). Only effective for a single file. limit: Number of lines to read. Only effective for a single file. abbreviate: If True, return an LLM-generated summary instead of raw content. @@ -695,6 +749,18 @@ async def read_file(self, Returns: Dictionary mapping file path(s) to their content or error messages. """ + paths = self._normalize_read_paths(paths, path) + if not paths: + return json.dumps( + { + 'success': False, + 'error': ( + 'read_file requires `paths` (list of strings) or `path` (single string). ' + 'Example: {"paths": ["a.md"]} or {"path": "a.md"}.' + ), + }, + indent=2, + ) if abbreviate: return await self._read_files_abbreviated(paths) diff --git a/ms_agent/utils/snapshot.py b/ms_agent/utils/snapshot.py index 35d3fe26b..2014b49da 100644 --- a/ms_agent/utils/snapshot.py +++ b/ms_agent/utils/snapshot.py @@ -42,6 +42,23 @@ def _snapshot_git_dir(output_dir: str) -> str: return os.path.join(output_dir, _SNAPSHOT_DIR_NAME) +def _configure_snapshot_repo_for_automation(work_tree: str, git_dir: str) -> None: + """Disable hook execution for the nested snapshot repo. + + Without this, Git can inherit ``init.templateDir`` / global ``core.hooksPath`` + (e.g. lefthook), so ``git commit`` runs hooks and races under concurrency + (``cannot lock ref 'HEAD'`` / hook failures). ``os.devnull`` is the portable + Git-supported way to disable hooks (POSIX ``/dev/null``, Windows ``nul``). + """ + try: + _git(['config', 'core.hooksPath', os.devnull], + work_tree=work_tree, + git_dir=git_dir, + check=False) + except Exception: + pass + + def _ensure_repo(output_dir: str) -> str: """Initialize the snapshot repo if it doesn't exist. Returns git_dir.""" git_dir = _snapshot_git_dir(output_dir) @@ -60,6 +77,8 @@ def _ensure_repo(output_dir: str) -> str: exclude_file = os.path.join(info_dir, 'exclude') with open(exclude_file, 'a', encoding='utf-8') as f: f.write(f'\n{_SNAPSHOT_DIR_NAME}/\n') + # Always (re)apply: repos created before this fix may still inherit hooks. + _configure_snapshot_repo_for_automation(output_dir, git_dir) return git_dir @@ -115,7 +134,7 @@ def take_snapshot(output_dir: str, message: str, # Truncate message to keep commit subject readable subject = message.strip().replace('\n', ' ')[:120] - result = _git(['commit', '-m', subject], + result = _git(['commit', '--no-verify', '-m', subject], work_tree=output_dir, git_dir=git_dir) commit_hash = None diff --git a/projects/deep_research/v2/callbacks/searcher_callback.py b/projects/deep_research/v2/callbacks/searcher_callback.py index e48d35880..735a2d47a 100644 --- a/projects/deep_research/v2/callbacks/searcher_callback.py +++ b/projects/deep_research/v2/callbacks/searcher_callback.py @@ -14,6 +14,42 @@ logger = get_logger() +def _parse_search_result_json(text: str) -> Optional[Any]: + """Parse searcher final JSON from raw assistant text. + + Accepts plain JSON, fenced ```json blocks, or a JSON object embedded in + surrounding prose (first balanced object via :func:`json.JSONDecoder.raw_decode`). + """ + if text is None: + return None + if not isinstance(text, str): + return None + text = text.strip() + if not text: + return None + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError): + pass + m = re.search(r'```(?:json)?\s*\r?\n(.*?)```', text, flags=re.DOTALL | re.IGNORECASE) + if m: + block = m.group(1).strip() + if block: + try: + return json.loads(block) + except (json.JSONDecodeError, TypeError): + pass + dec = json.JSONDecoder() + for i, ch in enumerate(text): + if ch not in '{[': + continue + try: + return dec.raw_decode(text[i:])[0] + except json.JSONDecodeError: + continue + return None + + class SearcherCallback(Callback): """ Callback for Searcher agent. @@ -221,33 +257,44 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): try: # Prefer JSON file if possible; fallback to markdown otherwise. - if isinstance(content, str): - parsed_json = json.loads(content) - else: + if isinstance(content, (dict, list)): parsed_json = content - try: - with open(json_path, 'x', encoding='utf-8') as f: - json.dump( - parsed_json, f, ensure_ascii=False, indent=2) - logger.info( - f'Searcher: Search result saved to {json_path}') - except FileExistsError: - logger.info( - f'Search result already exists at {json_path}') - except (json.JSONDecodeError, TypeError): - logger.warning( - 'Failed to parse search result as JSON, saving as markdown' - ) - text = content if isinstance(content, - str) else str(content) - try: - with open(md_path, 'x', encoding='utf-8') as f: - f.write(text) - logger.info( - f'Searcher: Search result saved to {md_path}') - except FileExistsError: - logger.info( - f'Search result already exists at {md_path}') + elif isinstance(content, str): + try: + parsed_json = json.loads(content) + except (json.JSONDecodeError, TypeError): + parsed_json = _parse_search_result_json(content) + if parsed_json is not None: + logger.info( + 'Searcher: parsed JSON from fenced or embedded payload' + ) + else: + parsed_json = _parse_search_result_json(str(content)) + + if parsed_json is not None: + try: + with open(json_path, 'x', encoding='utf-8') as f: + json.dump( + parsed_json, f, ensure_ascii=False, indent=2) + logger.info( + f'Searcher: Search result saved to {json_path}') + except FileExistsError: + logger.info( + f'Search result already exists at {json_path}') + else: + logger.warning( + 'Failed to parse search result as JSON, saving as markdown' + ) + text = content if isinstance(content, + str) else str(content) + try: + with open(md_path, 'x', encoding='utf-8') as f: + f.write(text) + logger.info( + f'Searcher: Search result saved to {md_path}') + except FileExistsError: + logger.info( + f'Search result already exists at {md_path}') except Exception as e: logger.warning( f'Unexpected error when saving search result: {e}') diff --git a/projects/deep_research/v2/reporter.exp_nosnap.yaml b/projects/deep_research/v2/reporter.exp_nosnap.yaml new file mode 100644 index 000000000..d1c67ccdf --- /dev/null +++ b/projects/deep_research/v2/reporter.exp_nosnap.yaml @@ -0,0 +1,93 @@ +llm: + service: openai + model: qwen3.5-plus + openai_api_key: + openai_base_url: + + +generation_config: + stream: true + stream_options: + include_usage: true + # Enable explicit prefix caching (auto-detects provider from openai_base_url) + force_prefix_cache: true + # Supports role names: system, user, assistant, tool, last_message + prefix_cache_roles: [system, user, assistant, tool] + extra_body: + enable_thinking: false + # show_reasoning: true + # reasoning_output: stdout + + +tag: deep-research + + +prompt: + root: prompts/ + agent: reporter + lang: en + family: gpt5 + +tools: + file_system: + mcp: false + include: + - write_file + - read_file + - edit_file + - grep + - glob + todo_list: + mcp: false + auto_render_md: true + include: + - todo_write + - todo_read + evidence_store: + mcp: false + evidence_dir: evidence + include: + - load_index + - get_note + - list_notes + - get_analysis + - list_analyses + report_generator: + mcp: false + reports_dir: reports + plugins: + - tools/evidence_tool.py + - tools/report_tool.py + + +handler: time_handler + +code_file: reporter + +callbacks: + - callbacks/reporter_callback + +max_chat_round: 36 + +# Round-aware reminder injected via ReporterCallback.on_generate_response. +round_reminder: + enabled: true + remind_at_round: 34 + +self_reflection: + enabled: true + max_retries: 3 + min_retention_ratio: 0.6 + post_report_guidance_enabled: true + quality_check: + enabled: true + model: qwen3.5-flash + openai_base_url: + openai_api_key: + +tool_call_timeout: 300 + +output_dir: ./output + +# bench experiment: disable snapshot for reporter sub-agent +enable_snapshots: false diff --git a/projects/deep_research/v2/reporter.yaml b/projects/deep_research/v2/reporter.yaml index a004e1c27..84183cbf7 100644 --- a/projects/deep_research/v2/reporter.yaml +++ b/projects/deep_research/v2/reporter.yaml @@ -37,6 +37,12 @@ tools: - edit_file - grep - glob + todo_list: + mcp: false + auto_render_md: true + include: + - todo_write + - todo_read evidence_store: mcp: false evidence_dir: evidence diff --git a/projects/deep_research/v2/researcher.exp_nosnap.yaml b/projects/deep_research/v2/researcher.exp_nosnap.yaml new file mode 100644 index 000000000..307539e74 --- /dev/null +++ b/projects/deep_research/v2/researcher.exp_nosnap.yaml @@ -0,0 +1,174 @@ +llm: + service: openai + model: gpt-5.2-2025-12-11 + openai_api_key: + openai_base_url: + + +generation_config: + stream: true + stream_options: + include_usage: true + # Enable explicit prefix caching (auto-detects provider from openai_base_url) + force_prefix_cache: false + # Supports role names: system, user, assistant, tool, last_message + prefix_cache_roles: [system, user, assistant, tool] + # extra_body: + # enable_thinking: true + # show_reasoning: true + # reasoning_output: stdout + reasoning_effort: medium + + +tag: deep-research-researcher + + +prompt: + root: prompts/ + agent: researcher + lang: en + family: thinking + + +tools: + file_system: + mcp: false + include: + - write_file + - read_file + - edit_file + - grep + - glob + code_executor: + mcp: false + implementation: python_env + notebook_timeout: 120 + include: + - notebook_executor + todo_list: + mcp: false + auto_render_md: true + include: + - todo_write + - todo_read + evidence_store: + mcp: false + evidence_dir: evidence + include: + - load_index + - get_note + - list_notes + - write_analysis + - get_analysis + - list_analyses + agent_tools: + mcp: false + enable_stats: true + run_in_thread: true + run_in_process: true + max_workers: 4 + definitions: + - tool_name: searcher_tool + description: > + Invoke the Searcher sub-agent to perform an in-depth research task on a specific topic. + Searcher is capable of autonomously executing a research loop until sufficient evidence is collected and a research report is produced (search -> parse -> evidence discovery & storage -> progressive search -> ...). + Returns a JSON result containing: task completion status, core findings, issues or limitations encountered, research report body, and evidence storage locations. + config_path: searcher.exp_nosnap.yaml + parameters: + type: object + properties: + request: + type: string + description: > + A JSON-formatted research task description that should include: + - The corresponding task ID from the TODO list (required) + - Specific research objectives + - Questions to be answered + - Constraints (time range, source preferences, etc., optional) + - Stopping conditions (optional) + - Other requirements (optional) + Recommended format: + { + "task_id": "...", + "research_objectives": "...", + "questions_to_answer": "...", + "constraints": "...", + "stopping_conditions": "...", + "other_requirements": "...", + } + required: [request] + additionalProperties: false + trust_remote_code: true + output_mode: final_message + max_output_chars: 200000 + - tool_name: reporter_tool + description: > + Invoke the Reporter sub-agent to generate a research report based on collected evidence. + Reporter reads the stored evidence cards and executes a complex workflow for research report writing. + The completed report is automatically saved to `final_report.md` in the output directory by the system. + Returns a JSON result containing: execution summary and + intermediate artifact file paths (the full report body is NOT included in the return value — + read `final_report.md` directly to access the report content). + config_path: reporter.exp_nosnap.yaml + parameters: + type: object + properties: + request: + type: string + description: > + A JSON-formatted report generation instruction that should include: + - Report topic and target audience + - Complete background description and task description + - Core questions to be covered + - Writing requirements (style, structure, length, language, etc.) + - Any other requirements + Recommended format: + { + "report_topic_and_audience": "...", + "background": "...", + "task_description": "...", + "writing_requirements": "...", + "other_requirements": "...", + } + required: [request] + additionalProperties: false + trust_remote_code: true + output_mode: final_message + max_output_chars: 200000 + plugins: + - tools/evidence_tool.py + + +callbacks: + - callbacks/researcher_callback + +# Self-reflection checks before allowing the researcher to stop. +self_reflection: + enabled: true + max_retries: 3 + compression_check: + enabled: true + min_retention_ratio: 0.5 + report_selection: + enabled: true + min_retention_ratio: 0.5 + report_cleanup: + enabled: false + quality_check: + enabled: true + model: qwen3.5-flash + openai_base_url: + openai_api_key: + +handler: time_handler + +code_file: researcher + +max_chat_round: 45 + +tool_call_timeout: 2600 + +output_dir: ./output + +# bench experiment: disable .ms_agent_snapshots git snapshot +enable_snapshots: false diff --git a/projects/deep_research/v2/searcher.exp_nosnap.yaml b/projects/deep_research/v2/searcher.exp_nosnap.yaml new file mode 100644 index 000000000..4858fb281 --- /dev/null +++ b/projects/deep_research/v2/searcher.exp_nosnap.yaml @@ -0,0 +1,110 @@ +# Tavily-only web search (no Exa / no Jina Reader). +# Use: cp projects/deep_research/v2/searcher.tavily.yaml projects/deep_research/v2/searcher.yaml +# Or point researcher.yaml searcher_tool config_path to this file. +# Requires TAVILY_API_KEY in .env. + +llm: + service: openai + model: qwen3.5-plus + openai_api_key: + openai_base_url: + + +generation_config: + stream: true + stream_options: + include_usage: true + force_prefix_cache: true + prefix_cache_roles: [system, user, assistant, tool] + extra_body: + enable_thinking: false + + +tag: deep-research + + +prompt: + root: prompts/ + agent: searcher + lang: en + family: gpt5 + + +tools: + file_system: + mcp: false + include: + - write_file + - read_file + - edit_file + - grep + - glob + web_search: + mcp: false + engines: + - tavily + tavily_api_key: + max_results: 10 + fetcher: tavily_extract + fetch_content: true + fetch_timeout: 60 + fetch_retries: 3 + per_url_fetch_timeout: 400 + _max_concurrent_fetch: 5 + tavily: + search_depth: advanced + include_raw_content: markdown + include_answer: advanced + chunks_per_source: 3 + max_results: 10 + tavily_request_timeout: 120 + tavily_extract_depth: advanced + tavily_extract_format: markdown + enable_chunking: false + # Off by default: Tavily raw_content is enough; summarizer may hit provider filters. + enable_summarization: false + summarizer_model: qwen3.5-flash + summarizer_base_url: + summarizer_api_key: + max_content_chars: 200000 + summarizer_max_workers: 15 + summarization_timeout: 360 + _max_concurrent_summarization: 15 + # Optional: spill oversized tool JSON to output_dir/web_search_artifacts/… + # spill_large_results: true + # spill_max_inline_chars: 120000 + # spill_preview_chars: 600 + # spill_subdir: web_search_artifacts + evidence_store: + mcp: false + evidence_dir: evidence + chunks_dir: chunks + enable_chunk_storage: false + include: + - write_note + - list_notes + - get_note + - load_index + - search_notes + - delete_note + plugins: + - tools/evidence_tool.py + + +handler: time_handler + +callbacks: + - callbacks/searcher_callback + +max_chat_round: 30 + +round_reminder: + enabled: true + remind_at_round: 28 + +tool_call_timeout: 2600 + +output_dir: ./output + +# bench experiment: disable snapshot for searcher sub-agent +enable_snapshots: false diff --git a/projects/deep_research/v2/searcher.yaml b/projects/deep_research/v2/searcher.yaml index d3f4817a8..b24887dae 100644 --- a/projects/deep_research/v2/searcher.yaml +++ b/projects/deep_research/v2/searcher.yaml @@ -1,3 +1,8 @@ +# Tavily-only web search (no Exa / no Jina Reader). +# Use: cp projects/deep_research/v2/searcher.tavily.yaml projects/deep_research/v2/searcher.yaml +# Or point researcher.yaml searcher_tool config_path to this file. +# Requires TAVILY_API_KEY in .env. + llm: service: openai model: qwen3.5-plus @@ -9,14 +14,10 @@ generation_config: stream: true stream_options: include_usage: true - # Enable explicit prefix caching (auto-detects provider from openai_base_url) force_prefix_cache: true - # Supports role names: system, user, assistant, tool, last_message prefix_cache_roles: [system, user, assistant, tool] extra_body: enable_thinking: false - # show_reasoning: true - # reasoning_output: stdout tag: deep-research @@ -41,16 +42,27 @@ tools: web_search: mcp: false engines: - - exa - - arxiv - api_key: - max_results: 5 - fetcher: jina_reader + - tavily + tavily_api_key: + max_results: 10 + fetcher: tavily_extract fetch_content: true fetch_timeout: 60 + fetch_retries: 3 + per_url_fetch_timeout: 400 _max_concurrent_fetch: 5 + tavily: + search_depth: advanced + include_raw_content: markdown + include_answer: advanced + chunks_per_source: 3 + max_results: 10 + tavily_request_timeout: 120 + tavily_extract_depth: advanced + tavily_extract_format: markdown enable_chunking: false - enable_summarization: true + # Off by default: Tavily raw_content is enough; summarizer may hit provider filters. + enable_summarization: false summarizer_model: qwen3.5-flash summarizer_base_url: summarizer_api_key: @@ -58,6 +70,11 @@ tools: summarizer_max_workers: 15 summarization_timeout: 360 _max_concurrent_summarization: 15 + # Optional: spill oversized tool JSON to output_dir/web_search_artifacts/… + # spill_large_results: true + # spill_max_inline_chars: 120000 + # spill_preview_chars: 600 + # spill_subdir: web_search_artifacts evidence_store: mcp: false evidence_dir: evidence @@ -81,11 +98,10 @@ callbacks: max_chat_round: 30 -# Round-aware reminder injected via SearcherCallback.on_generate_response. round_reminder: enabled: true remind_at_round: 28 -tool_call_timeout: 300 +tool_call_timeout: 2600 output_dir: ./output diff --git a/projects/deep_research/v2/tools/evidence_tool.py b/projects/deep_research/v2/tools/evidence_tool.py index 064ab6a87..1379ce511 100644 --- a/projects/deep_research/v2/tools/evidence_tool.py +++ b/projects/deep_research/v2/tools/evidence_tool.py @@ -379,7 +379,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'string', 'description': - 'Brief title describing this evidence (e.g., "Tesla Q3 revenue growth").', + ('Brief title describing this evidence (e.g., "Tesla Q3 revenue growth"). ' + 'Optional: if omitted, a title is derived from the first line of `content`.'), }, 'content': { 'type': @@ -464,10 +465,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'Optional: Confidence/quality score (0-100).', }, }, - 'required': [ - 'title', 'content', 'sources', 'summary', - 'task_id', 'tags' - ], + 'required': ['content'], 'additionalProperties': False, }, @@ -849,8 +847,8 @@ def _load_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: async def write_note( self, - title: str, content: str, + title: Optional[str] = None, contradicts: Optional[str] = None, sources: Optional[List[Dict[str, Any]]] = None, summary: Optional[str] = None, @@ -863,12 +861,27 @@ async def write_note( _ensure_dir(paths['notes_dir']) _ensure_dir(paths['lock_dir']) + content = (content or '').strip() + if not content: + return _json_dumps({ + 'status': 'error', + 'message': 'write_note requires non-empty content.', + }) + + if title is None or not str(title).strip(): + first_line = content.split('\n', 1)[0].strip() + if len(first_line) > 120: + first_line = first_line[:117] + '...' + title_resolved = first_line or 'Evidence note' + else: + title_resolved = str(title).strip() + # Generate ID and build note note_id = _generate_note_id() note: Dict[str, Any] = { 'note_id': note_id, - 'title': title.strip(), - 'content': content.strip(), + 'title': title_resolved, + 'content': content, 'created_at': _now_iso(), } diff --git a/projects/deep_research/v2/tools/report_tool.py b/projects/deep_research/v2/tools/report_tool.py index 96fd51994..819da108c 100644 --- a/projects/deep_research/v2/tools/report_tool.py +++ b/projects/deep_research/v2/tools/report_tool.py @@ -43,6 +43,39 @@ def _write_text(path: str, content: str) -> None: f.write(content) +def _coerce_chapters_argument(chapters: Any) -> tuple[List[Dict[str, Any]], Optional[str]]: + """Normalize `chapters` from the model (list, JSON string, or nested strings).""" + if chapters is None: + return [], ( + 'commit_outline requires `chapters` (array of chapter objects, ' + 'or a JSON string of that array).') + raw: Any = chapters + if isinstance(raw, str): + try: + raw = json.loads(raw.strip()) + except json.JSONDecodeError as e: + return [], ( + 'commit_outline `chapters` must be a JSON array of objects, ' + f'or a JSON string of that array: {e}') + if not isinstance(raw, list): + return [], ( + f'commit_outline `chapters` must be a list, got {type(chapters).__name__}.') + out: List[Dict[str, Any]] = [] + for i, ch in enumerate(raw): + if isinstance(ch, str): + try: + ch = json.loads(ch.strip()) + except json.JSONDecodeError: + return [], ( + f'commit_outline chapters[{i}] must be an object; ' + 'string entry is not valid JSON for an object.') + if not isinstance(ch, dict): + return [], ( + f'commit_outline chapters[{i}] must be an object, got {type(ch).__name__}.') + out.append(ch) + return out, None + + def _render_outline_md(outline: Dict[str, Any]) -> str: """Render outline as Markdown.""" lines = [f"# {outline.get('title', 'Report Outline')}", ''] @@ -178,6 +211,37 @@ def _paths(self) -> Dict[str, str]: os.path.join(self.output_dir, self._lock_subdir), } + def _filter_candidate_evidence( + self, + paths: Dict[str, str], + candidate: Any, + ) -> tuple[List[str], List[str]]: + """Keep only note ids that have ``note_{id}.md`` under evidence notes dir. + + Returns: + (kept_ids_in_order, dropped_ids) — duplicates in input are skipped + after the first occurrence. + """ + if not isinstance(candidate, list): + return [], [] + notes_dir = paths['evidence_notes_dir'] + kept: List[str] = [] + dropped: List[str] = [] + seen: set[str] = set() + for raw in candidate: + nid = str(raw).strip() if raw is not None else '' + if not nid: + continue + if nid in seen: + continue + seen.add(nid) + note_path = os.path.join(notes_dir, f'note_{nid}.md') + if os.path.exists(note_path): + kept.append(nid) + else: + dropped.append(nid) + return kept, dropped + async def _get_tools_inner(self) -> Dict[str, Any]: tools: Dict[str, List[Tool]] = { self.SERVER_NAME: [ @@ -500,6 +564,21 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'additionalProperties': False, }, ), + Tool( + tool_name='load_index', + server_name=self.SERVER_NAME, + description=( + 'Load the full evidence index (notes + analyses metadata). ' + 'Same data as evidence_store---load_index; provided here so calls ' + 'mistakenly prefixed with report_generator--- still work.' + ), + parameters={ + 'type': 'object', + 'properties': {}, + 'required': [], + 'additionalProperties': False, + }, + ), ] } return tools @@ -536,6 +615,25 @@ def _load_evidence_index(self, paths: Dict[str, str]) -> Dict[str, Any]: return {'notes': {}} return data + def _load_full_evidence_index(self, paths: Dict[str, str]) -> Dict[str, Any]: + """Load evidence/index.json with the same defaults as EvidenceTool.""" + data = _safe_read_json(paths['evidence_index']) + if data is None or not isinstance(data, dict): + return { + 'schema_version': 2, + 'updated_at': _now_iso(), + 'notes': {}, + 'analyses': {}, + } + if 'notes' not in data or not isinstance(data.get('notes'), dict): + data['notes'] = {} + if 'analyses' not in data or not isinstance(data.get('analyses'), dict): + data['analyses'] = {} + legacy = data.get('conclusions') + if isinstance(legacy, dict) and legacy and not data.get('analyses'): + data['analyses'] = legacy + return data + def _load_note_content(self, paths: Dict[str, str], note_id: str) -> Optional[Dict[str, Any]]: """Load a single note's full content from markdown file.""" @@ -565,27 +663,53 @@ def _save_conflict(self, paths: Dict[str, str], conflict['updated_at'] = _now_iso() _write_text(paths['conflict_json'], _json_dumps(conflict)) + async def load_index(self) -> str: + """Return evidence index JSON (alias for mistaken report_generator---load_index).""" + paths = self._paths() + _ensure_dir(paths['lock_dir']) + with file_lock(paths['lock_dir'], 'evidence_index'): + index = self._load_full_evidence_index(paths) + notes = index.get('notes', {}) + analyses = index.get('analyses', {}) + return _json_dumps({ + 'status': 'ok', + 'updated_at': index.get('updated_at', ''), + 'total_notes': len(notes), + 'total_analyses': len(analyses), + 'notes': notes, + 'analyses': analyses, + }) + async def commit_outline( self, title: str, - chapters: List[Dict[str, Any]], + chapters: Any, ) -> str: """Generate report outline with chapter structure.""" paths = self._paths() _ensure_dir(paths['chapters_dir']) _ensure_dir(paths['lock_dir']) + chapters_list, chapters_err = _coerce_chapters_argument(chapters) + if chapters_err: + return _json_dumps({'status': 'error', 'message': chapters_err}) + # Load evidence index to validate coverage evidence_index = self._load_evidence_index(paths) all_note_ids = set(evidence_index.get('notes', {}).keys()) - # Build outline + # Build outline (only bind note ids that exist on disk) outline_chapters = [] covered_evidence = set() + invalid_candidate_by_chapter: Dict[str, List[str]] = {} - for idx, ch in enumerate(chapters, start=1): - candidate = ch.get('candidate_evidence', []) - covered_evidence.update(candidate) + for idx, ch in enumerate(chapters_list, start=1): + candidate_raw = ch.get('candidate_evidence', []) + kept, dropped = self._filter_candidate_evidence( + paths, candidate_raw) + if dropped: + invalid_candidate_by_chapter[str(idx)] = dropped + covered_evidence.update(kept) outline_chapters.append({ 'chapter_id': @@ -597,7 +721,7 @@ async def commit_outline( 'sections_description': ch.get('sections_description', ''), 'candidate_evidence': - candidate, + kept, 'status': 'pending', }) @@ -632,6 +756,14 @@ async def commit_outline( if coverage_warning: result['warning'] = coverage_warning + if invalid_candidate_by_chapter: + result['invalid_candidate_evidence'] = invalid_candidate_by_chapter + result['invalid_candidate_evidence_note'] = ( + 'These note ids were removed from candidate_evidence because ' + 'no matching evidence/notes/note_.md file exists. ' + 'Use list_notes or prior write_note responses to pick valid ids.' + ) + return _json_dumps(result) async def prepare_chapter_bundle( @@ -667,13 +799,39 @@ async def prepare_chapter_bundle( 'message': f'Chapter {chapter_id} not found.' }) + cand_kept, cand_dropped = self._filter_candidate_evidence( + paths, chapter.get('candidate_evidence', [])) + rel_kept, rel_dropped = self._filter_candidate_evidence( + paths, relevant_evidence or []) + # Load evidence content evidence_index = self._load_evidence_index(paths) notes_meta = evidence_index.get('notes', {}) + _known_sorted = sorted(notes_meta.keys()) + _sample = _known_sorted[:48] + _note_id_hint = ( + 'Known note ids in evidence index (sample): ' + + (', '.join(_sample) if _sample else '(none)') + ) + if len(_known_sorted) > len(_sample): + _note_id_hint += f' … (+{len(_known_sorted) - len(_sample)} more)' + + def _missing_note_entry(note_id: str, meta: Dict[str, Any]) -> Dict[str, Any]: + return { + 'note_id': note_id, + 'error': f'Note {note_id} not found', + 'title': meta.get('title', ''), + 'summary': meta.get('summary', ''), + 'hint': ( + f'{_note_id_hint}. ' + 'Align outline candidate_evidence with filenames under evidence_store, ' + 'or list existing notes before referencing ids.' + ), + } notes_content = [] seen_note_ids = set() - for note_id in chapter.get('candidate_evidence', []): + for note_id in cand_kept: meta = notes_meta.get(note_id, {}) note_data = self._load_note_content(paths, note_id) @@ -697,17 +855,11 @@ async def prepare_chapter_bundle( meta.get('tags', note_data.get('tags', [])), }) else: - # Note not found, include minimal info - notes_content.append({ - 'note_id': note_id, - 'error': f'Note {note_id} not found', - 'title': meta.get('title', ''), - 'summary': meta.get('summary', ''), - }) + notes_content.append(_missing_note_entry(note_id, meta)) seen_note_ids.add(note_id) - for note_id in relevant_evidence: + for note_id in rel_kept: if note_id not in seen_note_ids: meta = notes_meta.get(note_id, {}) note_data = self._load_note_content(paths, note_id) @@ -733,18 +885,11 @@ async def prepare_chapter_bundle( meta.get('tags', note_data.get('tags', [])), }) else: - notes_content.append({ - 'note_id': note_id, - 'error': f'Note {note_id} not found', - 'title': meta.get('title', ''), - 'summary': meta.get('summary', ''), - }) + notes_content.append(_missing_note_entry(note_id, meta)) - # Build meta + # Build meta (only ids that resolved to on-disk notes for this bundle) candidate_evidence = list( - dict.fromkeys( - list(chapter.get('candidate_evidence', [])) - + list(relevant_evidence or []))) + dict.fromkeys(list(cand_kept) + list(rel_kept))) meta = { 'chapter_id': chapter_id, 'chapter_title': chapter['title'], @@ -767,7 +912,7 @@ async def prepare_chapter_bundle( with file_lock(paths['lock_dir'], 'report_outline'): self._save_outline(paths, outline) - return _json_dumps({ + out_bundle: Dict[str, Any] = { 'status': 'ok', 'chapter_id': @@ -782,7 +927,19 @@ async def prepare_chapter_bundle( os.path.relpath(meta_path, self.output_dir), 'notes_content': notes_content, - }) + } + skipped: Dict[str, List[str]] = {} + if cand_dropped: + skipped['outline_candidate_evidence'] = cand_dropped + if rel_dropped: + skipped['relevant_evidence_argument'] = rel_dropped + if skipped: + out_bundle['skipped_invalid_note_ids'] = skipped + out_bundle['skipped_invalid_note_ids_note'] = ( + 'These ids were ignored (no evidence/notes/note_.md). ' + 'Update outline candidate_evidence via update_outline if needed.' + ) + return _json_dumps(out_bundle) async def commit_chapter( self, @@ -922,6 +1079,7 @@ async def update_outline( }) chapter_found = False + invalid_candidate_removed: List[str] = [] for ch in outline.get('chapters', []): if ch['chapter_id'] == chapter_id: if 'title' in updates: @@ -932,8 +1090,10 @@ async def update_outline( ch['sections_description'] = updates[ 'sections_description'] if 'candidate_evidence' in updates: - ch['candidate_evidence'] = updates[ - 'candidate_evidence'] + kept, dropped = self._filter_candidate_evidence( + paths, updates['candidate_evidence']) + ch['candidate_evidence'] = kept + invalid_candidate_removed = dropped chapter_found = True break @@ -947,11 +1107,18 @@ async def update_outline( self._save_outline(paths, outline) - return _json_dumps({ + out: Dict[str, Any] = { 'status': 'ok', 'chapter_id': chapter_id, 'updates_applied': list(updates.keys()), - }) + } + if invalid_candidate_removed: + out['invalid_candidate_evidence_removed'] = invalid_candidate_removed + out['invalid_candidate_evidence_note'] = ( + 'These ids were removed from candidate_evidence because no ' + 'matching evidence/notes/note_.md exists.' + ) + return _json_dumps(out) async def assemble_draft( self, From 67668e9c9a8af3f299bf606d3a34cff461bcfe79 Mon Sep 17 00:00:00 2001 From: suluyan Date: Tue, 28 Apr 2026 13:54:12 +0800 Subject: [PATCH 38/40] fix lint --- .gitignore | 8 +- .pre-commit-config.yaml | 26 +- .../nanobot_integration/test_mcp_tools.py | 40 +- ms-agent-skills/scripts/check_ms_agent.py | 69 +- ms_agent/agent/base.py | 44 +- ms_agent/agent/code_agent.py | 15 +- ms_agent/agent/llm_agent.py | 272 ++-- ms_agent/agent/loader.py | 60 +- ms_agent/agent/runtime.py | 1 - ms_agent/app/doc_research.py | 2 +- ms_agent/app/fin_research.py | 6 +- ms_agent/callbacks/base.py | 10 +- ms_agent/callbacks/input_callback.py | 3 +- ms_agent/capabilities/__init__.py | 9 +- ms_agent/capabilities/async_task.py | 17 +- ms_agent/capabilities/mcp_server.py | 41 +- ms_agent/capabilities/registry.py | 12 +- .../capabilities/wrappers/agent_delegate.py | 101 +- .../capabilities/wrappers/deep_research.py | 152 +-- ms_agent/capabilities/wrappers/filesystem.py | 96 +- .../capabilities/wrappers/lsp_code_server.py | 83 +- ms_agent/capabilities/wrappers/web_search.py | 68 +- ms_agent/cli/app.py | 37 +- ms_agent/cli/cli.py | 7 +- ms_agent/cli/run.py | 110 +- ms_agent/cli/ui.py | 78 +- ms_agent/config/config.py | 89 +- ms_agent/config/env.py | 4 +- ms_agent/llm/anthropic_llm.py | 164 ++- ms_agent/llm/dashscope_llm.py | 17 +- ms_agent/llm/deepseek_llm.py | 51 +- ms_agent/llm/llm.py | 20 +- ms_agent/llm/modelscope_llm.py | 16 +- ms_agent/llm/openai.py | 158 +-- ms_agent/llm/openai_llm.py | 395 +++--- ms_agent/llm/utils.py | 26 +- ms_agent/memory/base.py | 6 +- ms_agent/memory/condenser/code_condenser.py | 36 +- .../memory/condenser/context_compressor.py | 35 +- ms_agent/memory/condenser/refine_condenser.py | 60 +- ms_agent/memory/default_memory.py | 320 ++--- ms_agent/memory/diversity.py | 27 +- ms_agent/memory/memory_manager.py | 13 +- ms_agent/memory/utils.py | 4 +- ms_agent/prompting/file_resolver.py | 29 +- ms_agent/rag/base.py | 6 +- ms_agent/rag/extraction.py | 4 +- ms_agent/rag/extraction_manager.py | 67 +- ms_agent/rag/llama_index_rag.py | 157 ++- ms_agent/rag/schema.py | 1 + ms_agent/retriever/hybrid_retriever.py | 76 +- ms_agent/sandbox/sandbox.py | 91 +- ms_agent/skill/auto_skills.py | 8 +- ms_agent/skill/container.py | 418 +++---- ms_agent/skill/loader.py | 22 +- ms_agent/skill/schema.py | 130 +- ms_agent/skill/spec.py | 14 +- ms_agent/tools/agent_tool.py | 350 +++--- ms_agent/tools/audio_generator/audio_gen.py | 21 +- ms_agent/tools/audio_generator/edge_tts.py | 15 +- ms_agent/tools/base.py | 15 +- ms_agent/tools/code/code_executor.py | 490 +++----- ms_agent/tools/code/local_code_executor.py | 462 +++---- ms_agent/tools/code/sandbox_manager.py | 43 +- ms_agent/tools/code_server/lsp_code_server.py | 408 +++--- ms_agent/tools/docling/chunker.py | 40 +- ms_agent/tools/docling/doc_loader.py | 17 +- ms_agent/tools/docling/doc_postprocess.py | 4 +- ms_agent/tools/docling/patches.py | 72 +- ms_agent/tools/fetch_playwright_fallback.py | 24 +- ms_agent/tools/filesystem_tool.py | 253 ++-- ms_agent/tools/findata/__init__.py | 3 +- ms_agent/tools/findata/akshare_source.py | 296 ++--- ms_agent/tools/findata/baostock_source.py | 173 +-- ms_agent/tools/findata/data_source_base.py | 25 +- ms_agent/tools/findata/findata_fetcher.py | 672 ++++------ ms_agent/tools/findata/hybrid_source.py | 83 +- .../tools/image_generator/ds_image_gen.py | 33 +- .../tools/image_generator/google_image_gen.py | 7 +- ms_agent/tools/image_generator/image_gen.py | 32 +- .../tools/image_generator/ms_image_gen.py | 49 +- ms_agent/tools/jina_reader.py | 75 +- ms_agent/tools/mcp_client.py | 128 +- ms_agent/tools/mineru/pdf_parser.py | 33 +- ms_agent/tools/search/arxiv/__init__.py | 3 +- ms_agent/tools/search/arxiv/schema.py | 110 +- ms_agent/tools/search/arxiv/search.py | 109 +- ms_agent/tools/search/content_optimizer.py | 108 +- ms_agent/tools/search/exa/schema.py | 53 +- ms_agent/tools/search/exa/search.py | 103 +- ms_agent/tools/search/localsearch_tool.py | 76 +- ms_agent/tools/search/search_base.py | 37 +- ms_agent/tools/search/search_request.py | 2 +- ms_agent/tools/search/serpapi/__init__.py | 3 +- ms_agent/tools/search/serpapi/schema.py | 31 +- ms_agent/tools/search/serpapi/search.py | 50 +- ms_agent/tools/search/sirchmunk_search.py | 195 ++- ms_agent/tools/search/tavily/fetcher.py | 4 +- ms_agent/tools/search/tavily/http.py | 4 +- ms_agent/tools/search/tavily/schema.py | 32 +- ms_agent/tools/search/tavily/search.py | 59 +- ms_agent/tools/search/web_search_spill.py | 133 +- ms_agent/tools/search/websearch_tool.py | 755 +++++------ ms_agent/tools/search_engine.py | 37 +- ms_agent/tools/task_control_tool.py | 26 +- ms_agent/tools/todolist_tool.py | 167 +-- ms_agent/tools/tool_manager.py | 98 +- .../tools/video_generator/ds_video_gen.py | 46 +- ms_agent/tools/video_generator/video_gen.py | 30 +- ms_agent/utils/__init__.py | 3 +- ms_agent/utils/artifact_manager.py | 46 +- ms_agent/utils/constants.py | 8 +- ms_agent/utils/llm_utils.py | 32 +- ms_agent/utils/logger.py | 6 +- ms_agent/utils/parser_utils.py | 205 ++- ms_agent/utils/patcher.py | 7 +- ms_agent/utils/push_to_hub.py | 193 ++- ms_agent/utils/rate_limiter.py | 83 +- ms_agent/utils/snapshot.py | 49 +- ms_agent/utils/stats.py | 32 +- ms_agent/utils/stream_writer.py | 77 +- ms_agent/utils/task_manager.py | 4 +- ms_agent/utils/thread_util.py | 37 +- ms_agent/utils/tokenizer_util.py | 8 +- ms_agent/utils/utils.py | 109 +- ms_agent/utils/workspace_policy.py | 30 +- ms_agent/workflow/base.py | 17 +- ms_agent/workflow/chain_workflow.py | 10 +- ms_agent/workflow/dag_workflow.py | 18 +- ms_agent/workflow/deep_research/__init__.py | 12 +- ms_agent/workflow/deep_research/principle.py | 39 +- .../workflow/deep_research/research_utils.py | 17 +- .../deep_research/research_workflow.py | 4 +- .../deep_research/research_workflow_beta.py | 33 +- ms_agent/workflow/loader.py | 25 +- .../code_genesis/tools/build_sandbox_image.py | 97 ++ .../code_genesis/tools/build_sandbox_image.sh | 63 +- projects/code_genesis/workflow/api_search.py | 35 +- projects/code_genesis/workflow/architect.py | 1 - projects/code_genesis/workflow/coding.py | 222 ++-- projects/code_genesis/workflow/file_design.py | 16 +- projects/code_genesis/workflow/file_order.py | 14 +- projects/code_genesis/workflow/install.py | 4 +- projects/code_genesis/workflow/refine.py | 55 +- projects/code_genesis/workflow/user_story.py | 1 - projects/deep_research/run.py | 58 +- .../v2/callbacks/quality_checker.py | 143 +-- .../v2/callbacks/reporter_callback.py | 12 +- .../v2/callbacks/researcher_callback.py | 6 +- .../v2/callbacks/searcher_callback.py | 104 +- .../deep_research/v2/eval/dr_bench_runner.py | 190 ++- projects/deep_research/v2/reporter.py | 21 +- projects/deep_research/v2/researcher.py | 22 +- projects/deep_research/v2/time_handler.py | 12 +- .../deep_research/v2/tools/evidence_tool.py | 667 +++++----- .../deep_research/v2/tools/report_tool.py | 799 +++++------- projects/fin_research/aggregator.py | 51 +- .../callbacks/aggregator_callback.py | 9 +- .../callbacks/analyst_callback.py | 61 +- .../callbacks/collector_callback.py | 46 +- .../fin_research/callbacks/file_parser.py | 4 +- .../callbacks/orchestrator_callback.py | 9 +- projects/fin_research/searcher.py | 67 +- projects/fin_research/time_handler.py | 12 +- .../fin_research/tools/principle_skill.py | 162 +-- projects/fin_research/tools/spec_loader.py | 278 ++--- .../singularity_cinema/compose_video/agent.py | 258 ++-- .../create_background/agent.py | 35 +- .../generate_animation/agent.py | 18 +- .../generate_animation/generate_manim_code.py | 45 +- .../generate_remotion_code.py | 64 +- .../generate_audio/agent.py | 27 +- .../generate_illustration_prompts/agent.py | 73 +- .../generate_images/agent.py | 99 +- .../generate_script/agent.py | 10 +- .../generate_subtitle/agent.py | 89 +- .../generate_video/agent.py | 43 +- .../generate_video_prompts/agent.py | 60 +- .../singularity_cinema/parse_images/agent.py | 48 +- .../render_animation/agent.py | 18 +- .../render_animation/render_manim.py | 336 +++-- .../render_animation/render_remotion.py | 318 ++--- projects/singularity_cinema/segment/agent.py | 45 +- setup.py | 66 +- shell-grep-glob-workspace-policy.md | 225 ---- webui/backend/agent_runner.py | 1107 +++++++---------- webui/backend/api.py | 175 +-- webui/backend/config_manager.py | 63 +- webui/backend/deep_research_eventizer.py | 104 +- webui/backend/deep_research_worker.py | 162 ++- webui/backend/deep_research_worker_manager.py | 84 +- webui/backend/main.py | 24 +- webui/backend/project_discovery.py | 49 +- webui/backend/session_manager.py | 40 +- webui/backend/shared.py | 4 +- webui/backend/websocket_handler.py | 194 ++- 196 files changed, 7310 insertions(+), 10843 deletions(-) create mode 100644 projects/code_genesis/tools/build_sandbox_image.py delete mode 100644 shell-grep-glob-workspace-policy.md diff --git a/.gitignore b/.gitignore index 30dfa8d1f..2cef2d174 100644 --- a/.gitignore +++ b/.gitignore @@ -32,7 +32,13 @@ wheels/ /temp **/tmp/ .env* -.claude-trace/ +.claude* +# Local Colima/Lima state when using CODE_GENESIS_COLIMA_IN_REPO=1 +.colima/ +.xdg-cache/ +.xdg-config/ +.xdg-data/ +scripts/colima_proxy.local.env /apps/agentfabric/tmp/ MANIFEST diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 00657312b..4aaa131d7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,23 +2,21 @@ repos: - - repo: https://github.com/pycqa/flake8.git - rev: 4.0.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.6 + hooks: + - id: ruff-format + exclude: ^(thirdparty/|examples/|projects/fin_research/examples/|projects/agent_skills/|tests/) + - id: ruff + args: [--fix, --select, I] + exclude: ^(thirdparty/|examples/|projects/fin_research/examples/|projects/agent_skills/|tests/) + - repo: https://github.com/pycqa/flake8 + rev: 7.0.0 hooks: - id: flake8 exclude: ^(thirdparty/|examples/|tests/|projects/agent_skills/|projects/fin_research/examples/|ms_agent/utils/prompts\.py) - - repo: https://github.com/PyCQA/isort.git - rev: 4.3.21 - hooks: - - id: isort - exclude: ^(examples/|projects/fin_research/examples/|projects/agent_skills/|tests/) - - repo: https://github.com/pre-commit/mirrors-yapf.git - rev: v0.30.0 - hooks: - - id: yapf - exclude: ^(thirdparty/|examples/|projects/fin_research/examples/|projects/agent_skills/|tests/) - - repo: https://github.com/pre-commit/pre-commit-hooks.git - rev: v3.1.0 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 hooks: - id: trailing-whitespace exclude: ^(thirdparty/|tests/|projects/fin_research/examples/|projects/agent_skills/) diff --git a/examples/nanobot_integration/test_mcp_tools.py b/examples/nanobot_integration/test_mcp_tools.py index 65c9cad7d..c59313b30 100644 --- a/examples/nanobot_integration/test_mcp_tools.py +++ b/examples/nanobot_integration/test_mcp_tools.py @@ -37,9 +37,9 @@ async def connect(stack): async def list_tools(session): """List all available MCP tools.""" result = await session.list_tools() - print(f'\n{"=" * 60}') + print(f'\n{'=' * 60}') print(f' Available MCP Tools ({len(result.tools)} total)') - print(f'{"=" * 60}\n') + print(f'{'=' * 60}\n') for tool in result.tools: desc = (tool.description or '')[:70] print(f' {tool.name:35s} {desc}') @@ -72,9 +72,9 @@ async def call_tool(session, name: str, args: dict): async def test_filesystem(session): """Test filesystem tools: write a file, then replace contents.""" - print(f'\n{"=" * 60}') + print(f'\n{'=' * 60}') print(' TEST: Filesystem Tools') - print(f'{"=" * 60}') + print(f'{'=' * 60}') with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: f.write('def hello():\n print("Hello, World!")\n\nhello()\n') @@ -117,9 +117,9 @@ async def test_filesystem(session): async def test_deep_research_async(session): """Test async deep research tools (submit/check/get without actually running).""" - print(f'\n{"=" * 60}') + print(f'\n{'=' * 60}') print(' TEST: Deep Research Async Tools') - print(f'{"=" * 60}') + print(f'{'=' * 60}') # submit_research_task should return immediately with a task_id # (it will fail to find the config in CI, but the response format is testable) @@ -128,7 +128,7 @@ async def test_deep_research_async(session): }) if 'error' in result: - print(f'\n submit_research_task returned error (expected if no config): {result["error"]}') + print(f'\n submit_research_task returned error (expected if no config): {result['error']}') print(' Testing check/get with a fake task_id...') result = await call_tool(session, 'check_research_progress', { @@ -151,21 +151,21 @@ async def test_deep_research_async(session): progress = await call_tool(session, 'check_research_progress', { 'task_id': task_id, }) - print(f' Progress check: status={progress["status"]}') + print(f' Progress check: status={progress['status']}') report = await call_tool(session, 'get_research_report', { 'task_id': task_id, }) - print(f' Report check: status={report["status"]}') + print(f' Report check: status={report['status']}') print('\n DEEP RESEARCH ASYNC TESTS: PASSED') async def test_web_search(session): """Test web_search tool with arxiv (no API key required).""" - print(f'\n{"=" * 60}') + print(f'\n{'=' * 60}') print(' TEST: Web Search') - print(f'{"=" * 60}') + print(f'{'=' * 60}') result = await call_tool(session, 'web_search', { 'query': 'large language model agent framework', @@ -174,24 +174,24 @@ async def test_web_search(session): }) if 'error' in result: - print(f'\n web_search returned error: {result["error"]}') + print(f'\n web_search returned error: {result['error']}') print(' (This may happen if arxiv is unreachable)') else: - assert result['status'] == 'ok', f'Unexpected status: {result["status"]}' + assert result['status'] == 'ok', f'Unexpected status: {result['status']}' assert result['engine'] == 'arxiv' - print(f'\n Returned {result["count"]} results:') + print(f'\n Returned {result['count']} results:') for i, r in enumerate(result.get('results', []), 1): - print(f' {i}. {r.get("title", "No title")[:60]}') - print(f' {r.get("url", "")}') + print(f' {i}. {r.get('title', 'No title')[:60]}') + print(f' {r.get('url', '')}') print('\n WEB SEARCH TESTS: PASSED') async def test_agent_delegate(session): """Test agent delegate tools (async pattern only, to avoid blocking).""" - print(f'\n{"=" * 60}') + print(f'\n{'=' * 60}') print(' TEST: Agent Delegate (Async)') - print(f'{"=" * 60}') + print(f'{'=' * 60}') # Test check/get/cancel with unknown task_id (safe, no LLM needed) result = await call_tool(session, 'check_agent_task', { @@ -251,9 +251,9 @@ async def main(): if args.test in ('ad', 'all'): await test_agent_delegate(session) - print(f'\n{"=" * 60}') + print(f'\n{'=' * 60}') print(' ALL TESTS PASSED') - print(f'{"=" * 60}\n') + print(f'{'=' * 60}\n') if __name__ == '__main__': diff --git a/ms-agent-skills/scripts/check_ms_agent.py b/ms-agent-skills/scripts/check_ms_agent.py index 64e4668fa..93aea705c 100644 --- a/ms-agent-skills/scripts/check_ms_agent.py +++ b/ms-agent-skills/scripts/check_ms_agent.py @@ -1,14 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import subprocess import sys -import json - def check_import() -> dict: """Check that ms_agent is importable.""" try: import ms_agent # noqa: F401 + version = getattr(ms_agent, '__version__', 'unknown') return {'importable': True, 'version': version} except ImportError as e: @@ -23,29 +23,28 @@ def check_capabilities() -> dict: """ try: from ms_agent.capabilities import create_registry + registry = create_registry() caps = registry.list_all() return { - 'registry_ok': - True, - 'count': - len(caps), - 'capabilities': [{ - 'name': c.name, - 'granularity': c.granularity, - 'summary': c.summary, - 'tags': c.tags, - } for c in caps], + 'registry_ok': True, + 'count': len(caps), + 'capabilities': [ + { + 'name': c.name, + 'granularity': c.granularity, + 'summary': c.summary, + 'tags': c.tags, + } + for c in caps + ], } except ImportError: # ms_agent may not be on sys.path (e.g. dev mode without pip install). # Fall back to subprocess check which uses ``-m`` resolution. try: result = subprocess.run( - [ - sys.executable, '-m', 'ms_agent.capabilities.mcp_server', - '--check' - ], + [sys.executable, '-m', 'ms_agent.capabilities.mcp_server', '--check'], capture_output=True, text=True, timeout=30, @@ -56,15 +55,11 @@ def check_capabilities() -> dict: 'registry_ok': True, 'count': len(data.get('capabilities', [])), 'capabilities': data.get('capabilities', []), - 'note': - 'verified via subprocess (package not on sys.path)', + 'note': 'verified via subprocess (package not on sys.path)', } except Exception: pass - return { - 'registry_ok': False, - 'error': 'ms_agent.capabilities not importable' - } + return {'registry_ok': False, 'error': 'ms_agent.capabilities not importable'} except Exception as e: return {'registry_ok': False, 'error': str(e)} @@ -73,10 +68,7 @@ def check_mcp_server() -> dict: """Check that the MCP server can start in --check mode.""" try: result = subprocess.run( - [ - sys.executable, '-m', 'ms_agent.capabilities.mcp_server', - '--check' - ], + [sys.executable, '-m', 'ms_agent.capabilities.mcp_server', '--check'], capture_output=True, text=True, timeout=30, @@ -87,10 +79,7 @@ def check_mcp_server() -> dict: else: return {'mcp_server_ok': False, 'error': result.stderr.strip()} except subprocess.TimeoutExpired: - return { - 'mcp_server_ok': False, - 'error': 'MCP server --check timed out' - } + return {'mcp_server_ok': False, 'error': 'MCP server --check timed out'} except Exception as e: return {'mcp_server_ok': False, 'error': str(e)} @@ -99,6 +88,7 @@ def check_mcp_package() -> dict: """Check that the mcp Python package is installed.""" try: import mcp # noqa: F401 + version = getattr(mcp, '__version__', 'unknown') return {'installed': True, 'version': version} except ImportError: @@ -119,7 +109,8 @@ def main() -> None: all_ok = ( report['ms_agent'].get('importable', False) and report['capabilities'].get('registry_ok', False) - and report['mcp_server'].get('mcp_server_ok', False)) + and report['mcp_server'].get('mcp_server_ok', False) + ) report['overall_status'] = 'ok' if all_ok else 'issues_found' print(json.dumps(report, indent=2, ensure_ascii=False)) @@ -127,21 +118,13 @@ def main() -> None: if not all_ok: print('\n--- Issues ---', file=sys.stderr) if not report['ms_agent'].get('importable'): - print( - ' ms-agent is not installed. Run: pip install ms-agent', - file=sys.stderr) + print(' ms-agent is not installed. Run: pip install ms-agent', file=sys.stderr) if not report['mcp_package'].get('installed'): - print( - ' mcp package is not installed. Run: pip install mcp', - file=sys.stderr) + print(' mcp package is not installed. Run: pip install mcp', file=sys.stderr) if not report['capabilities'].get('registry_ok'): - print( - f" Registry error: {report['capabilities'].get('error')}", - file=sys.stderr) + print(f" Registry error: {report['capabilities'].get('error')}", file=sys.stderr) if not report['mcp_server'].get('mcp_server_ok'): - print( - f" MCP server error: {report['mcp_server'].get('error')}", - file=sys.stderr) + print(f" MCP server error: {report['mcp_server'].get('error')}", file=sys.stderr) sys.exit(1) diff --git a/ms_agent/agent/base.py b/ms_agent/agent/base.py index cb78d5ce2..7908a53dd 100644 --- a/ms_agent/agent/base.py +++ b/ms_agent/agent/base.py @@ -3,10 +3,11 @@ from abc import ABC, abstractmethod from typing import Any, AsyncGenerator, List, Tuple, Union +from omegaconf import DictConfig + from ms_agent.llm import Message from ms_agent.utils import read_history, save_history from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_RETRY_COUNT -from omegaconf import DictConfig class Agent(ABC): @@ -18,36 +19,31 @@ class Agent(ABC): retry_count = int(os.environ.get('AGENT_RETRY_COUNT', DEFAULT_RETRY_COUNT)) - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + """ + Base class for all agents. Provides core functionality such as configuration loading, + lifecycle handling via external code, and defining the interface for agent execution. + + The agent can be initialized either with a config object directly or by loading from a config directory or ID. + If external code (e.g., custom handlers) is involved, the agent must be explicitly trusted via + `trust_remote_code=True`. + + Base class for all agents. Make sure your custom agents are derived from this class. + Args: + config (DictConfig): Pre-loaded configuration object. + tag (str): A custom tag for identifying this agent run. + trust_remote_code (bool): Whether to allow loading of external code (e.g., custom handler modules). """ - Base class for all agents. Provides core functionality such as configuration loading, - lifecycle handling via external code, and defining the interface for agent execution. - - The agent can be initialized either with a config object directly or by loading from a config directory or ID. - If external code (e.g., custom handlers) is involved, the agent must be explicitly trusted via - `trust_remote_code=True`. - - Base class for all agents. Make sure your custom agents are derived from this class. - Args: - config (DictConfig): Pre-loaded configuration object. - tag (str): A custom tag for identifying this agent run. - trust_remote_code (bool): Whether to allow loading of external code (e.g., custom handler modules). - """ self.config = config self.tag = tag self.trust_remote_code = trust_remote_code self.config.tag = tag self.config.trust_remote_code = trust_remote_code - self.output_dir = getattr(self.config, 'output_dir', - DEFAULT_OUTPUT_DIR) + self.output_dir = getattr(self.config, 'output_dir', DEFAULT_OUTPUT_DIR) @abstractmethod async def run( - self, inputs: Union[str, List[Message]], **kwargs + self, inputs: Union[str, List[Message]], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: """ Main method to execute the agent. @@ -65,8 +61,7 @@ async def run( """ raise NotImplementedError() - def read_history(self, messages: Any, - **kwargs) -> Tuple[DictConfig, List[Message]]: + def read_history(self, messages: Any, **kwargs) -> Tuple[DictConfig, List[Message]]: return read_history(self.output_dir, self.tag) def save_history(self, messages: Any, **kwargs): @@ -77,6 +72,7 @@ def save_history(self, messages: Any, **kwargs): def list_snapshots(self) -> list: """Return snapshots for this agent's output_dir, most recent first.""" from ms_agent.utils.snapshot import list_snapshots + return list_snapshots(self.output_dir) def rollback(self, commit_hash: str) -> bool: diff --git a/ms_agent/agent/code_agent.py b/ms_agent/agent/code_agent.py index 33b2b63a4..2200712ec 100644 --- a/ms_agent/agent/code_agent.py +++ b/ms_agent/agent/code_agent.py @@ -1,9 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import Any, List, Union -from ms_agent.llm import Message from omegaconf import DictConfig +from ms_agent.llm import Message + from .base import Agent @@ -12,16 +13,11 @@ class CodeAgent(Agent): AGENT_NAME = 'CodeAgent' - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.load_cache = kwargs.get('load_cache', False) - async def run(self, inputs: Union[str, List[Message]], - **kwargs) -> List[Message]: + async def run(self, inputs: Union[str, List[Message]], **kwargs) -> List[Message]: """Run the external code. Default implementation here does nothing. Args: @@ -42,6 +38,5 @@ async def run(self, inputs: Union[str, List[Message]], self.save_history(messages, **kwargs) return messages - async def execute_code(self, inputs: Union[str, List[Message]], - **kwargs) -> List[Message]: + async def execute_code(self, inputs: Union[str, List[Message]], **kwargs) -> List[Message]: return inputs diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 14adfa8b3..379f439c3 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -2,6 +2,7 @@ import asyncio import importlib import inspect +import json import os.path import sys import threading @@ -10,7 +11,8 @@ from copy import deepcopy from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union -import json +from omegaconf import DictConfig, OmegaConf + from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping from ms_agent.llm.llm import LLM @@ -21,11 +23,10 @@ from ms_agent.rag.utils import rag_mapping from ms_agent.tools import ToolManager from ms_agent.utils import async_retry, read_history, save_history -from ms_agent.utils.task_manager import TaskManager from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger from ms_agent.utils.snapshot import take_snapshot -from omegaconf import DictConfig, OmegaConf +from ms_agent.utils.task_manager import TaskManager from ..config.config import Config, ConfigLifecycleHandler from .base import Agent @@ -103,17 +104,14 @@ def resolve_enable_snapshots(config: Any) -> bool: like ``\"false\"`` coerced to boolean). """ if OmegaConf.is_config(config): - raw = OmegaConf.select(config, 'enable_snapshots', - default=_MISSING_ENABLE_SNAPSHOTS) + raw = OmegaConf.select(config, 'enable_snapshots', default=_MISSING_ENABLE_SNAPSHOTS) if raw is not _MISSING_ENABLE_SNAPSHOTS and raw is not None: return LLMAgent._coerce_enable_snapshots_value(raw) - sub = bool( - OmegaConf.select(config, 'ms_agent_subagent', default=False)) + sub = bool(OmegaConf.select(config, 'ms_agent_subagent', default=False)) return not sub if isinstance(config, dict): if 'enable_snapshots' in config and config['enable_snapshots'] is not None: - return LLMAgent._coerce_enable_snapshots_value( - config['enable_snapshots']) + return LLMAgent._coerce_enable_snapshots_value(config['enable_snapshots']) return not bool(config.get('ms_agent_subagent')) return True @@ -131,8 +129,7 @@ def __init__( **kwargs, ): if not hasattr(config, 'llm'): - default_yaml = os.path.join( - os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') + default_yaml = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') llm_config = Config.from_task(default_yaml) config = OmegaConf.merge(llm_config, config) super().__init__(config, tag, trust_remote_code) @@ -147,8 +144,7 @@ def __init__( self.load_cache = kwargs.get('load_cache', False) self.config.load_cache = self.load_cache self.mcp_server_file = kwargs.get('mcp_server_file', None) - self.mcp_config: Dict[str, Any] = self.parse_mcp_servers( - kwargs.get('mcp_config', {})) + self.mcp_config: Dict[str, Any] = self.parse_mcp_servers(kwargs.get('mcp_config', {})) self.mcp_client = kwargs.get('mcp_client', None) self.config_handler = self.register_config_handler() @@ -198,31 +194,25 @@ def _ensure_auto_skills(self) -> bool: from ms_agent.utils.docker_utils import is_docker_daemon_running if not is_docker_daemon_running(): - logger.warning( - 'Docker not running, disabling sandbox for skills') + logger.warning('Docker not running, disabling sandbox for skills') use_sandbox = False # Build retrieve args retrieve_args = {} if hasattr(skills_config, 'retrieve_args'): - retrieve_args = OmegaConf.to_container( - skills_config.retrieve_args) + retrieve_args = OmegaConf.to_container(skills_config.retrieve_args) self._auto_skills = AutoSkills( skills=skills_path, llm=self.llm, - enable_retrieve=getattr(skills_config, 'enable_retrieve', - None), + enable_retrieve=getattr(skills_config, 'enable_retrieve', None), retrieve_args=retrieve_args, - max_candidate_skills=getattr(skills_config, - 'max_candidate_skills', 10), + max_candidate_skills=getattr(skills_config, 'max_candidate_skills', 10), max_retries=getattr(skills_config, 'max_retries', 3), work_dir=getattr(skills_config, 'work_dir', None), use_sandbox=use_sandbox, ) - logger.info( - f'AutoSkills initialized with {len(self._auto_skills.all_skills)} skills' - ) + logger.info(f'AutoSkills initialized with {len(self._auto_skills.all_skills)} skills') self._auto_skills_initialized = True return True @@ -301,9 +291,7 @@ async def execute_skills(self, query: str, execution_input=None): return None skills_config = self._get_skills_config() - stop_on_failure = ( - getattr(skills_config, 'stop_on_failure', True) - if skills_config else True) + stop_on_failure = getattr(skills_config, 'stop_on_failure', True) if skills_config else True result = await self._auto_skills.run( query=query, @@ -327,8 +315,7 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: # Handle chat-only response if dag_result.chat_response: - messages.append( - Message(role='assistant', content=dag_result.chat_response)) + messages.append(Message(role='assistant', content=dag_result.chat_response)) return messages # Handle incomplete skills @@ -359,9 +346,7 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: if output.output_files: content += f'**Generated files:** {list(output.output_files.values())}\n\n' - content += ( - f'Total execution time: {exec_result.total_duration_ms:.2f}ms' - ) + content += f'Total execution time: {exec_result.total_duration_ms:.2f}ms' else: content = 'Skill execution completed with errors.\n\n' for skill_id, result in exec_result.results.items(): @@ -387,14 +372,14 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: def rollback(self, commit_hash: str) -> bool: """Restore output_dir to snapshot and truncate message history.""" from ms_agent.utils.snapshot import restore_snapshot + ok, message_count = restore_snapshot(self.output_dir, commit_hash) if not ok: return False # Truncate saved history to the message count at snapshot time _, saved_messages = read_history(self.output_dir, self.tag) if saved_messages and message_count < len(saved_messages): - save_history(self.output_dir, self.tag, self.config, - saved_messages[:message_count]) + save_history(self.output_dir, self.tag, self.config, saved_messages[:message_count]) # Clear read cache on FileSystemTool so stale entries don't block edits if self.tool_manager is not None: for tool in self.tool_manager.extra_tools: @@ -422,8 +407,7 @@ def parse_mcp_servers(self, mcp_config: Dict[str, Any]) -> Dict[str, Any]: Dict[str, Any]: Merged configuration including file-based overrides. """ mcp_config = mcp_config or {} - if self.mcp_server_file is not None and os.path.isfile( - self.mcp_server_file): + if self.mcp_server_file is not None and os.path.isfile(self.mcp_server_file): with open(self.mcp_server_file, 'r') as f: config = json.load(f) config.update(mcp_config) @@ -455,27 +439,19 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: f'[External Code]A Config Lifecycle handler ' f'registered in the config: {handler_file}. ' f'\nThis is external code, if you trust this workflow, ' - f'please specify `--trust_remote_code true`') - assert ( - local_dir is not None - ), 'Using external py files, but local_dir cannot be found.' + f'please specify `--trust_remote_code true`' + ) + assert local_dir is not None, 'Using external py files, but local_dir cannot be found.' if local_dir not in sys.path: sys.path.insert(0, local_dir) handler_module = importlib.import_module(handler_file) - module_classes = { - name: cls - for name, cls in inspect.getmembers(handler_module, - inspect.isclass) - } + module_classes = {name: cls for name, cls in inspect.getmembers(handler_module, inspect.isclass)} handler = None for name, handler_cls in module_classes.items(): - if (handler_cls.__bases__[0] is ConfigLifecycleHandler - and handler_cls.__module__ == handler_file): + if handler_cls.__bases__[0] is ConfigLifecycleHandler and handler_cls.__module__ == handler_file: handler = handler_cls() - assert ( - handler is not None - ), f'Config Lifecycle handler class cannot be found in {handler_file}' + assert handler is not None, f'Config Lifecycle handler class cannot be found in {handler_file}' return handler return None @@ -486,15 +462,12 @@ def register_callback_from_config(self): Raises: AssertionError: If untrusted external code is referenced without permission. """ - local_dir = self.config.local_dir if hasattr(self.config, - 'local_dir') else None + local_dir = self.config.local_dir if hasattr(self.config, 'local_dir') else None if hasattr(self.config, 'callbacks'): callbacks = self.config.callbacks or [] for _callback in callbacks: subdir = os.path.dirname(_callback) - assert ( - local_dir is not None - ), 'Using external py files, but local_dir cannot be found.' + assert local_dir is not None, 'Using external py files, but local_dir cannot be found.' if subdir: subdir = os.path.join(local_dir, str(subdir)) _callback = os.path.basename(_callback) @@ -512,26 +485,19 @@ def register_callback_from_config(self): if _callback.endswith('.py'): _callback = _callback[:-3] callback_file = importlib.import_module(_callback) - module_classes = { - name: cls - for name, cls in inspect.getmembers( - callback_file, inspect.isclass) - } + module_classes = {name: cls for name, cls in inspect.getmembers(callback_file, inspect.isclass)} for name, cls in module_classes.items(): # Find cls which base class is `Callback` - if issubclass( - cls, Callback) and cls.__module__ == _callback: + if issubclass(cls, Callback) and cls.__module__ == _callback: self.callbacks.append(cls(self.config)) # noqa else: - self.callbacks.append(callbacks_mapping[_callback]( - self.config)) + self.callbacks.append(callbacks_mapping[_callback](self.config)) async def on_task_begin(self, messages: List[Message]): self.log_output(f'Agent {self.tag} task beginning.') if self.resolve_enable_snapshots(self.config): _user_content = next( - ((getattr(m, 'content', '') or '')[:80] - for m in messages if getattr(m, 'role', '') == 'user'), + ((getattr(m, 'content', '') or '')[:80] for m in messages if getattr(m, 'role', '') == 'user'), '', ) take_snapshot( @@ -567,8 +533,7 @@ async def loop_callback(self, point, messages: List[Message]): for callback in self.callbacks: await getattr(callback, point)(self.runtime, messages) - async def parallel_tool_call(self, - messages: List[Message]) -> List[Message]: + async def parallel_tool_call(self, messages: List[Message]) -> List[Message]: """ Execute multiple tool calls in parallel and append results to the message list. @@ -578,11 +543,9 @@ async def parallel_tool_call(self, Returns: List[Message]: Updated message list including tool responses. """ - tool_call_result = await self.tool_manager.parallel_call_tool( - messages[-1].tool_calls) + tool_call_result = await self.tool_manager.parallel_call_tool(messages[-1].tool_calls) assert len(tool_call_result) == len(messages[-1].tool_calls) - for tool_call_result, tool_call_query in zip(tool_call_result, - messages[-1].tool_calls): + for tool_call_result, tool_call_query in zip(tool_call_result, messages[-1].tool_calls): tool_call_result_format = ToolResult.from_raw(tool_call_result) _new_message = Message( role='tool', @@ -624,8 +587,7 @@ async def cleanup_tools(self): @property def stream(self): - generation_config = getattr(self.config, 'generation_config', - DictConfig({})) + generation_config = getattr(self.config, 'generation_config', DictConfig({})) return getattr(generation_config, 'stream', False) @property @@ -636,8 +598,7 @@ def show_reasoning(self) -> bool: - This only affects local console output. - Reasoning is carried by `Message.reasoning_content` (if the backend provides it). """ - generation_config = getattr(self.config, 'generation_config', - DictConfig({})) + generation_config = getattr(self.config, 'generation_config', DictConfig({})) return bool(getattr(generation_config, 'show_reasoning', False)) @property @@ -648,8 +609,7 @@ def reasoning_output(self) -> str: - "stderr" (default): keep stdout clean for assistant final text - "stdout": interleave reasoning with assistant output on stdout """ - generation_config = getattr(self.config, 'generation_config', - DictConfig({})) + generation_config = getattr(self.config, 'generation_config', DictConfig({})) return str(getattr(generation_config, 'reasoning_output', 'stdout')) _THINKING_SEP = '─' * 40 @@ -690,19 +650,16 @@ def _write_thinking_footer(self): @property def system(self): - return getattr( - getattr(self.config, 'prompt', DictConfig({})), 'system', None) + return getattr(getattr(self.config, 'prompt', DictConfig({})), 'system', None) @property def query(self): - query = getattr( - getattr(self.config, 'prompt', DictConfig({})), 'query', None) + query = getattr(getattr(self.config, 'prompt', DictConfig({})), 'query', None) if not query: query = input('>>>') return query - async def create_messages( - self, messages: Union[List[Message], str]) -> List[Message]: + async def create_messages(self, messages: Union[List[Message], str]) -> List[Message]: """ Convert input into a standardized list of messages. @@ -714,18 +671,15 @@ async def create_messages( """ if isinstance(messages, list): system = self.system - if (system is not None and messages[0].role == 'system' - and system != messages[0].content): + if system is not None and messages[0].role == 'system' and system != messages[0].content: # Replace the existing system messages[0].content = system else: - assert isinstance( - messages, str - ), f'inputs can be either a list or a string, but current is {type(messages)}' + assert isinstance(messages, str), ( + f'inputs can be either a list or a string, but current is {type(messages)}' + ) messages = [ - Message( - role='system', - content=self.system or LLMAgent.DEFAULT_SYSTEM), + Message(role='system', content=self.system or LLMAgent.DEFAULT_SYSTEM), Message(role='user', content=messages or self.query), ] return messages @@ -746,8 +700,7 @@ async def do_rag(self, messages: List[Message]): if self.rag is not None: user_message.content = await self.rag.query(query) - async def do_skill(self, - messages: List[Message]) -> Optional[List[Message]]: + async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: """ Process skill-related query if applicable. @@ -762,9 +715,7 @@ async def do_skill(self, None if no skill processing or fallback to standard agent """ # Extract user query from normalized messages - query = ( - messages[1].content - if len(messages) > 1 and messages[1].role == 'user' else None) + query = messages[1].content if len(messages) > 1 and messages[1].role == 'user' else None if not query: return None @@ -778,9 +729,7 @@ async def do_skill(self, try: skills_config = self._get_skills_config() - auto_execute = ( - getattr(skills_config, 'auto_execute', True) - if skills_config else True) + auto_execute = getattr(skills_config, 'auto_execute', True) if skills_config else True if auto_execute: dag_result = await self.execute_skills(query) @@ -788,8 +737,7 @@ async def do_skill(self, dag_result = await self.get_skill_dag(query) if dag_result: - skill_messages = self._format_skill_result_as_messages( - dag_result) + skill_messages = self._format_skill_result_as_messages(dag_result) for msg in skill_messages: messages.append(msg) return messages @@ -799,8 +747,7 @@ async def do_skill(self, return None except Exception as e: - logger.warning( - f'Skill execution failed: {e}, falling back to standard agent') + logger.warning(f'Skill execution failed: {e}, falling back to standard agent') self._skill_mode_active = False return None @@ -814,11 +761,10 @@ async def load_memory(self): if hasattr(self.config, 'memory'): for mem_instance_type, _memory in self.config.memory.items(): assert mem_instance_type in memory_mapping, ( - f'{mem_instance_type} not in memory_mapping, ' - f'which supports: {list(memory_mapping.keys())}') + f'{mem_instance_type} not in memory_mapping, which supports: {list(memory_mapping.keys())}' + ) - shared_memory = await SharedMemoryManager.get_shared_memory( - self.config, mem_instance_type) + shared_memory = await SharedMemoryManager.get_shared_memory(self.config, mem_instance_type) self.memory_tools.append(shared_memory) async def prepare_rag(self): @@ -827,8 +773,8 @@ async def prepare_rag(self): rag = self.config.rag if rag is not None: assert rag.name in rag_mapping, ( - f'{rag.name} not in rag_mapping, ' - f'which supports: {list(rag_mapping.keys())}') + f'{rag.name} not in rag_mapping, which supports: {list(rag_mapping.keys())}' + ) self.rag: RAG = rag_mapping(rag.name)(self.config) async def condense_memory(self, messages: List[Message]) -> List[Message]: @@ -875,8 +821,7 @@ def log_output(self, content: Union[str, list]): for _line in line.split('\\n'): logger.info(f'[{self.tag}] {_line}') - def handle_new_response(self, messages: List[Message], - response_message: Message): + def handle_new_response(self, messages: List[Message], response_message: Message): assert response_message is not None, 'No response message generated from LLM.' if response_message.tool_calls: self.log_output('[tool_calling]:') @@ -884,18 +829,15 @@ def handle_new_response(self, messages: List[Message], tool_call = deepcopy(tool_call) if isinstance(tool_call['arguments'], str): try: - tool_call['arguments'] = json.loads( - tool_call['arguments']) + tool_call['arguments'] = json.loads(tool_call['arguments']) except json.decoder.JSONDecodeError: pass - self.log_output( - json.dumps(tool_call, ensure_ascii=False, indent=4)) + self.log_output(json.dumps(tool_call, ensure_ascii=False, indent=4)) if messages[-1] is not response_message: messages.append(response_message) - if (messages[-1].role == 'assistant' and not messages[-1].content - and response_message.tool_calls): + if messages[-1].role == 'assistant' and not messages[-1].content and response_message.tool_calls: messages[-1].content = 'Let me do a tool calling.' def _append_task_notifications(self, messages: List[Message]) -> List[Message]: @@ -910,9 +852,7 @@ def _append_task_notifications(self, messages: List[Message]) -> List[Message]: return messages @async_retry(max_attempts=Agent.retry_count, delay=1.0) - async def step( - self, messages: List[Message] - ) -> AsyncGenerator[List[Message], Any]: # type: ignore + async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], Any]: # type: ignore """ Execute a single step in the agent's interaction loop. @@ -950,20 +890,17 @@ async def step( _response_message = None _printed_reasoning_header = False _printed_reasoning_footer = False - for _response_message in self.llm.generate( - messages, tools=tools): + for _response_message in self.llm.generate(messages, tools=tools): if is_first: messages.append(_response_message) is_first = False if self.show_reasoning: - reasoning_text = ( - getattr(_response_message, 'reasoning_content', '') - or '') + reasoning_text = getattr(_response_message, 'reasoning_content', '') or '' # Some providers may reset / shorten content across chunks. if len(reasoning_text) < len(_reasoning): _reasoning = '' - new_reasoning = reasoning_text[len(_reasoning):] + new_reasoning = reasoning_text[len(_reasoning) :] if new_reasoning: if not _printed_reasoning_header: self._write_thinking_header() @@ -971,7 +908,7 @@ async def step( self._write_reasoning(new_reasoning, dim=True) _reasoning = reasoning_text - new_content = _response_message.content[len(_content):] + new_content = _response_message.content[len(_content) :] if new_content: if _printed_reasoning_header and not _printed_reasoning_footer: self._write_thinking_footer() @@ -986,8 +923,7 @@ async def step( # Handle reasoning summaries that arrive after content if self.show_reasoning and _response_message is not None: - final_reasoning = getattr(_response_message, - 'reasoning_content', '') or '' + final_reasoning = getattr(_response_message, 'reasoning_content', '') or '' if final_reasoning and not _printed_reasoning_header: self._write_thinking_header() self._write_reasoning(final_reasoning, dim=True) @@ -997,9 +933,7 @@ async def step( else: _response_message = self.llm.generate(messages, tools=tools) if self.show_reasoning: - reasoning_text = ( - getattr(_response_message, 'reasoning_content', '') - or '') + reasoning_text = getattr(_response_message, 'reasoning_content', '') or '' if reasoning_text: self._write_thinking_header() self._write_reasoning(reasoning_text, dim=True) @@ -1027,8 +961,7 @@ async def step( prompt_tokens = _response_message.prompt_tokens completion_tokens = _response_message.completion_tokens cached_tokens = getattr(_response_message, 'cached_tokens', 0) or 0 - cache_creation_input_tokens = ( - getattr(_response_message, 'cache_creation_input_tokens', 0) or 0) + cache_creation_input_tokens = getattr(_response_message, 'cache_creation_input_tokens', 0) or 0 async with LLMAgent.TOKEN_LOCK: LLMAgent.TOTAL_PROMPT_TOKENS += prompt_tokens @@ -1037,17 +970,14 @@ async def step( LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS += cache_creation_input_tokens # tokens in the current step - self.log_output( - f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}' - ) + self.log_output(f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}') if cached_tokens or cache_creation_input_tokens: - self.log_output( - f'[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}' - ) + self.log_output(f'[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}') # total tokens for the process so far self.log_output( f'[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, ' - f'total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}') + f'total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}' + ) if LLMAgent.TOTAL_CACHED_TOKENS or LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS: self.log_output( f'[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, ' @@ -1064,8 +994,7 @@ def prepare_runtime(self): """Initialize the runtime context.""" self.runtime: Runtime = Runtime(llm=self.llm) - def read_history(self, messages: List[Message], - **kwargs) -> Tuple[DictConfig, Runtime, List[Message]]: + def read_history(self, messages: List[Message], **kwargs) -> Tuple[DictConfig, Runtime, List[Message]]: """ Load previous chat history from disk if available. @@ -1108,10 +1037,8 @@ def get_user_id(self, default_user_id=DEFAULT_USER) -> Optional[str]: return user_id def _get_step_memory_info(self, memory_config: DictConfig): - user_id, agent_id, run_id, memory_type = get_memory_meta_safe( - memory_config, 'add_after_step') - if all(value is None - for value in [user_id, agent_id, run_id, memory_type]): + user_id, agent_id, run_id, memory_type = get_memory_meta_safe(memory_config, 'add_after_step') + if all(value is None for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) return user_id, agent_id, run_id, memory_type @@ -1122,8 +1049,7 @@ def _get_run_memory_info(self, memory_config: DictConfig): 'add_after_task', default_user_id=getattr(memory_config, 'user_id', None), ) - if all(value is None - for value in [user_id, agent_id, run_id, memory_type]): + if all(value is None for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) agent_id = agent_id or self.tag @@ -1134,18 +1060,14 @@ async def add_memory(self, messages: List[Message], add_type, **kwargs): if hasattr(self.config, 'memory') and self.config.memory: tools_num = len(self.memory_tools) if self.memory_tools else 0 - for idx, (mem_instance_type, - memory_config) in enumerate(self.config.memory.items()): + for idx, (mem_instance_type, memory_config) in enumerate(self.config.memory.items()): if add_type == 'add_after_task': - user_id, agent_id, run_id, memory_type = self._get_run_memory_info( - memory_config) + user_id, agent_id, run_id, memory_type = self._get_run_memory_info(memory_config) else: - user_id, agent_id, run_id, memory_type = self._get_step_memory_info( - memory_config) + user_id, agent_id, run_id, memory_type = self._get_step_memory_info(memory_config) if idx < tools_num: - if any(v is not None - for v in [user_id, agent_id, run_id, memory_type]): + if any(v is not None for v in [user_id, agent_id, run_id, memory_type]): await self.memory_tools[idx].add( messages, user_id=user_id, @@ -1174,11 +1096,9 @@ def save_history(self, messages: List[Message], **kwargs): config: DictConfig = deepcopy(self.config) config.runtime = self.runtime.to_dict() - save_history( - self.output_dir, task=self.tag, config=config, messages=messages) + save_history(self.output_dir, task=self.tag, config=config, messages=messages) - async def run_loop(self, messages: Union[List[Message], str], - **kwargs) -> AsyncGenerator[Any, Any]: + async def run_loop(self, messages: Union[List[Message], str], **kwargs) -> AsyncGenerator[Any, Any]: """ Run the agent, mainly contains a llm calling and tool calling loop. @@ -1191,8 +1111,7 @@ async def run_loop(self, messages: Union[List[Message], str], List[Message]: A list of message objects representing the agent's response or interaction history. """ try: - self.max_chat_round = getattr(self.config, 'max_chat_round', - LLMAgent.DEFAULT_MAX_CHAT_ROUND) + self.max_chat_round = getattr(self.config, 'max_chat_round', LLMAgent.DEFAULT_MAX_CHAT_ROUND) self.register_callback_from_config() self.prepare_llm() self.prepare_runtime() @@ -1242,8 +1161,7 @@ async def run_loop(self, messages: Union[List[Message], str], yield messages self.runtime.round += 1 # save memory and history - await self.add_memory( - messages, add_type='add_after_step', **kwargs) + await self.add_memory(messages, add_type='add_after_step', **kwargs) self.save_history(messages) # +1 means the next round the assistant may give a conclusion @@ -1252,10 +1170,10 @@ async def run_loop(self, messages: Union[List[Message], str], messages.append( Message( role='assistant', - content= - f'Task {messages[1].content} was cutted off, because ' + content=f'Task {messages[1].content} was cutted off, because ' f'max round({self.max_chat_round}) exceeded.', - )) + ) + ) self.runtime.should_stop = True yield messages @@ -1265,9 +1183,7 @@ async def run_loop(self, messages: Union[List[Message], str], yield messages def _add_memory(): - asyncio.run( - self.add_memory( - messages, add_type='add_after_task', **kwargs)) + asyncio.run(self.add_memory(messages, add_type='add_after_task', **kwargs)) loop = asyncio.get_running_loop() loop.run_in_executor(None, _add_memory) @@ -1276,23 +1192,19 @@ def _add_memory(): logger.warning(traceback.format_exc()) if hasattr(self.config, 'help'): - logger.error( - f'[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}' - ) + logger.error(f'[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}') raise e async def run( - self, messages: Union[List[Message], str], **kwargs + self, messages: Union[List[Message], str], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: stream = kwargs.get('stream', False) with self.config_context(): if stream: - OmegaConf.update( - self.config, 'generation_config.stream', True, merge=True) + OmegaConf.update(self.config, 'generation_config.stream', True, merge=True) async def stream_generator(): - async for _chunk in self.run_loop( - messages=messages, **kwargs): + async for _chunk in self.run_loop(messages=messages, **kwargs): yield _chunk return stream_generator() diff --git a/ms_agent/agent/loader.py b/ms_agent/agent/loader.py index 21b1687a7..7dd5f7a02 100644 --- a/ms_agent/agent/loader.py +++ b/ms_agent/agent/loader.py @@ -5,27 +5,30 @@ import sys from typing import Dict, Optional +from omegaconf import DictConfig, OmegaConf + from ms_agent.config.config import Config from ms_agent.utils.constants import DEFAULT_AGENT_FILE, DEFAULT_TAG -from omegaconf import DictConfig, OmegaConf from .base import Agent class AgentLoader: - @classmethod - def build(cls, - config_dir_or_id: Optional[str] = None, - config: Optional[DictConfig] = None, - env: Optional[Dict[str, str]] = None, - tag: Optional[str] = None, - trust_remote_code: bool = False, - **kwargs) -> Agent: + def build( + cls, + config_dir_or_id: Optional[str] = None, + config: Optional[DictConfig] = None, + env: Optional[Dict[str, str]] = None, + tag: Optional[str] = None, + trust_remote_code: bool = False, + **kwargs, + ) -> Agent: agent_config: Optional[DictConfig] = None if config_dir_or_id is not None: if not os.path.exists(config_dir_or_id): from modelscope import snapshot_download + config_dir_or_id = snapshot_download(config_dir_or_id) agent_config: DictConfig = Config.from_task(config_dir_or_id, env) if config is not None: @@ -40,35 +43,30 @@ def build(cls, agent_tag = tag agent_config.tag = agent_tag agent_config.trust_remote_code = trust_remote_code - if getattr(agent_config, 'local_dir', - None) is None and config_dir_or_id is not None: + if getattr(agent_config, 'local_dir', None) is None and config_dir_or_id is not None: agent_config.local_dir = config_dir_or_id - from .llm_agent import LLMAgent from .code_agent import CodeAgent + from .llm_agent import LLMAgent + agent_type = LLMAgent.AGENT_NAME if 'code_file' in kwargs: code_file = kwargs.pop('code_file') elif agent_config is not None: - agent_type = getattr(agent_config, 'type', - '').lower() or agent_type.lower() + agent_type = getattr(agent_config, 'type', '').lower() or agent_type.lower() code_file = getattr(agent_config, 'code_file', None) else: assert getattr(agent_config, 'local_dir', None) is not None - code_file = os.path.join( - getattr(agent_config, 'local_dir', ''), DEFAULT_AGENT_FILE) + code_file = os.path.join(getattr(agent_config, 'local_dir', ''), DEFAULT_AGENT_FILE) if code_file is not None: - agent_instance = cls._load_external_code(agent_config, code_file, - **kwargs) + agent_instance = cls._load_external_code(agent_config, code_file, **kwargs) else: assert agent_config is not None if agent_type == LLMAgent.AGENT_NAME.lower(): - agent_instance = LLMAgent(agent_config, agent_tag, - trust_remote_code, **kwargs) + agent_instance = LLMAgent(agent_config, agent_tag, trust_remote_code, **kwargs) elif agent_type == CodeAgent.AGENT_NAME.lower(): - agent_instance = CodeAgent(agent_config, agent_tag, - trust_remote_code, **kwargs) + agent_instance = CodeAgent(agent_config, agent_tag, trust_remote_code, **kwargs) else: raise ValueError(f'Unknown agent type: {agent_type}') return agent_instance @@ -79,7 +77,8 @@ def _load_external_code(cls, config, code_file, **kwargs) -> 'Agent': assert config.trust_remote_code, ( f'[External Code]A code file is required to run in the LLMAgent: {code_file}' f'\nThis is external code, if you trust this code file, ' - f'please specify `--trust_remote_code true`') + f'please specify `--trust_remote_code true`' + ) subdir = os.path.dirname(code_file) code_file = os.path.basename(code_file) local_dir = config.local_dir @@ -97,20 +96,11 @@ def _load_external_code(cls, config, code_file, **kwargs) -> 'Agent': if code_file in sys.modules: del sys.modules[code_file] code_module = importlib.import_module(code_file) - module_classes = { - name: agent_cls - for name, agent_cls in inspect.getmembers(code_module, - inspect.isclass) - } + module_classes = {name: agent_cls for name, agent_cls in inspect.getmembers(code_module, inspect.isclass)} agent_instance = None for name, agent_cls in module_classes.items(): - if Agent in agent_cls.__mro__[ - 1:] and agent_cls.__module__ == code_file: - agent_instance = agent_cls( - config, - config.tag, - trust_remote_code=config.trust_remote_code, - **kwargs) + if Agent in agent_cls.__mro__[1:] and agent_cls.__module__ == code_file: + agent_instance = agent_cls(config, config.tag, trust_remote_code=config.trust_remote_code, **kwargs) break assert agent_instance is not None, f'Cannot find a proper agent class in the external code file: {code_file}' if subdir_inserted: diff --git a/ms_agent/agent/runtime.py b/ms_agent/agent/runtime.py index 55a0dbf9e..508eeaf0e 100644 --- a/ms_agent/agent/runtime.py +++ b/ms_agent/agent/runtime.py @@ -7,7 +7,6 @@ @dataclass class Runtime: - should_stop: bool = False llm: LLM = None diff --git a/ms_agent/app/doc_research.py b/ms_agent/app/doc_research.py index 8c55f99c4..9cfdcf8bf 100644 --- a/ms_agent/app/doc_research.py +++ b/ms_agent/app/doc_research.py @@ -2076,7 +2076,7 @@ def initialize_page(request: gr.Request): session_status_html = f"""
📊 会话状态: {'已加载历史数据' if any(session_data.values()) else '新会话'} - {f'| 最后更新: {session_data.get("timestamp", "未知")}' if session_data.get("timestamp") else ''} + {f'| 最后更新: {session_data.get('timestamp', '未知')}' if session_data.get('timestamp') else ''}
""" if any(session_data.values()) else """
diff --git a/ms_agent/app/fin_research.py b/ms_agent/app/fin_research.py index 1767d3e70..a8344803e 100644 --- a/ms_agent/app/fin_research.py +++ b/ms_agent/app/fin_research.py @@ -443,14 +443,14 @@ def build_fin_prompt( sections.append(f'Market / region focus: {markets.strip()}') if focus_areas: sections.append( - f'Priority analytical pillars: {", ".join(focus_areas)}') + f'Priority analytical pillars: {', '.join(focus_areas)}') if macro_view: sections.append(f'Macro sensitivity preference: {macro_view}') if extra_notes.strip(): sections.append(f'Additional analyst notes:\n{extra_notes.strip()}') instructions = [ - f'Desired deliverable style: {deliverable_style or "Balanced"}', + f'Desired deliverable style: {deliverable_style or 'Balanced'}', f'Analytical depth target (1-5): {analysis_depth}' ] if output_language: @@ -1137,7 +1137,7 @@ def format_result_summary(workdir: str, include_sentiment: bool, f'- 工作目录: {workdir}', ] if focus_areas: - lines.append(f'- 关注领域: {", ".join(focus_areas)}') + lines.append(f'- 关注领域: {', '.join(focus_areas)}') lines.append('请查阅过程报告及最终综合报告。') return '\n'.join(lines) diff --git a/ms_agent/callbacks/base.py b/ms_agent/callbacks/base.py index 849fe7069..f4509e15b 100644 --- a/ms_agent/callbacks/base.py +++ b/ms_agent/callbacks/base.py @@ -1,18 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import List +from omegaconf import DictConfig + from ms_agent.agent.runtime import Runtime from ms_agent.llm.utils import Message -from omegaconf import DictConfig class Callback: - def __init__(self, config: DictConfig): self.config = config - async def on_task_begin(self, runtime: Runtime, - messages: List[Message]) -> None: + async def on_task_begin(self, runtime: Runtime, messages: List[Message]) -> None: """Called when a task begins. Args: @@ -24,8 +23,7 @@ async def on_task_begin(self, runtime: Runtime, """ pass - async def on_generate_response(self, runtime: Runtime, - messages: List[Message]): + async def on_generate_response(self, runtime: Runtime, messages: List[Message]): """Called before LLM generates response. Args: diff --git a/ms_agent/callbacks/input_callback.py b/ms_agent/callbacks/input_callback.py index e44db1e31..a5bffe998 100644 --- a/ms_agent/callbacks/input_callback.py +++ b/ms_agent/callbacks/input_callback.py @@ -1,11 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import List +from omegaconf import DictConfig + from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/ms_agent/capabilities/__init__.py b/ms_agent/capabilities/__init__.py index 73feb864e..bded8c19a 100644 --- a/ms_agent/capabilities/__init__.py +++ b/ms_agent/capabilities/__init__.py @@ -23,6 +23,7 @@ """ from __future__ import annotations + from typing import Any from ms_agent.capabilities.descriptor import CapabilityDescriptor @@ -39,13 +40,7 @@ def create_registry(config: Any = None) -> CapabilityRegistry: """ registry = CapabilityRegistry() - from ms_agent.capabilities.wrappers import ( - agent_delegate, - deep_research, - filesystem, - lsp_code_server, - web_search, - ) + from ms_agent.capabilities.wrappers import agent_delegate, deep_research, filesystem, lsp_code_server, web_search filesystem.register_all(registry, config) lsp_code_server.register_all(registry, config) diff --git a/ms_agent/capabilities/async_task.py b/ms_agent/capabilities/async_task.py index 74ac90fbe..c13db3a7b 100644 --- a/ms_agent/capabilities/async_task.py +++ b/ms_agent/capabilities/async_task.py @@ -145,10 +145,7 @@ def check( task = self._tasks.get(task_id) if task is None: known = [t.task_id for t in self._tasks.values()] - return { - 'error': f'Unknown task_id: {task_id}', - 'known_tasks': known - } + return {'error': f'Unknown task_id: {task_id}', 'known_tasks': known} info: dict[str, Any] = { 'task_id': task.task_id, @@ -165,8 +162,7 @@ def check( try: info.update(progress_fn(task)) except Exception: - logger.debug( - 'progress_fn raised for task %s', task_id, exc_info=True) + logger.debug('progress_fn raised for task %s', task_id, exc_info=True) return info @@ -182,11 +178,7 @@ def get_result(self, task_id: str) -> dict[str, Any]: 'message': 'Task is still in progress.', } if task.status == 'failed': - return { - 'task_id': task_id, - 'status': 'failed', - 'error': task.error - } + return {'task_id': task_id, 'status': 'failed', 'error': task.error} if task.status == 'cancelled': return {'task_id': task_id, 'status': 'cancelled'} return { @@ -202,8 +194,7 @@ async def cancel(self, task_id: str) -> dict[str, Any]: return {'error': f'Unknown task_id: {task_id}'} if task.status != 'running': return { - 'error': - f'Task {task_id} is not running (status: {task.status})', + 'error': f'Task {task_id} is not running (status: {task.status})', } # Cancel the asyncio task first diff --git a/ms_agent/capabilities/mcp_server.py b/ms_agent/capabilities/mcp_server.py index 650dc33cf..06314a019 100644 --- a/ms_agent/capabilities/mcp_server.py +++ b/ms_agent/capabilities/mcp_server.py @@ -1,11 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import argparse +import json import logging import os import sys -import json from dotenv import find_dotenv, load_dotenv + from ms_agent.capabilities import create_registry logger = logging.getLogger(__name__) @@ -49,13 +50,15 @@ def _print_check() -> None: registry = create_registry() caps = registry.list_all() info = { - 'status': - 'ok', - 'capabilities': [{ - 'name': c.name, - 'granularity': c.granularity, - 'summary': c.summary, - } for c in caps], + 'status': 'ok', + 'capabilities': [ + { + 'name': c.name, + 'granularity': c.granularity, + 'summary': c.summary, + } + for c in caps + ], } print(json.dumps(info, indent=2)) @@ -116,7 +119,8 @@ def main() -> None: """ parser = argparse.ArgumentParser( - description='ms-agent MCP Capability Server', ) + description='ms-agent MCP Capability Server', + ) parser.add_argument( '--check', action='store_true', @@ -151,8 +155,7 @@ def main() -> None: from mcp.server.fastmcp import FastMCP except ImportError: print( - 'ERROR: The "mcp" package is required. Install it with:\n' - ' pip install mcp\n', + 'ERROR: The "mcp" package is required. Install it with:\n pip install mcp\n', file=sys.stderr, ) sys.exit(1) @@ -162,8 +165,9 @@ def main() -> None: server = FastMCP( 'ms-agent-capabilities', - instructions=('ms-agent Capability Gateway. Provides deep research, ' - 'LSP code validation, and advanced file-editing tools.'), + instructions=( + 'ms-agent Capability Gateway. Provides deep research, LSP code validation, and advanced file-editing tools.' + ), ) for cap in registry.list_all(): @@ -204,18 +208,13 @@ def _build_handler(registry, cap, workspace: str): for pname, pschema in properties.items(): py_type = type_map.get(pschema.get('type', 'string'), str) if pname in required_params: - params.append( - inspect.Parameter( - pname, inspect.Parameter.KEYWORD_ONLY, annotation=py_type)) + params.append(inspect.Parameter(pname, inspect.Parameter.KEYWORD_ONLY, annotation=py_type)) else: opt_type = typing.Optional[py_type] default = pschema.get('default') params.append( - inspect.Parameter( - pname, - inspect.Parameter.KEYWORD_ONLY, - default=default, - annotation=opt_type)) + inspect.Parameter(pname, inspect.Parameter.KEYWORD_ONLY, default=default, annotation=opt_type) + ) annotations[pname] = params[-1].annotation cap_name = cap.name diff --git a/ms_agent/capabilities/registry.py b/ms_agent/capabilities/registry.py index 5d0dbd4b2..bc1f87516 100644 --- a/ms_agent/capabilities/registry.py +++ b/ms_agent/capabilities/registry.py @@ -21,8 +21,7 @@ def __init__(self) -> None: self._descriptors: dict[str, CapabilityDescriptor] = {} self._handlers: dict[str, Handler] = {} - def register(self, descriptor: CapabilityDescriptor, - handler: Handler) -> None: + def register(self, descriptor: CapabilityDescriptor, handler: Handler) -> None: if descriptor.name in self._descriptors: logger.warning('Overwriting capability %s', descriptor.name) self._descriptors[descriptor.name] = descriptor @@ -45,8 +44,7 @@ def discover( results = self.list_all() if granularity is not None: - levels = [granularity] if isinstance(granularity, - str) else granularity + levels = [granularity] if isinstance(granularity, str) else granularity results = [c for c in results if c.granularity in levels] if tags: @@ -56,14 +54,12 @@ def discover( if query: q = query.lower() results = [ - c for c in results if q in c.name.lower() - or q in c.summary.lower() or q in c.description.lower() + c for c in results if q in c.name.lower() or q in c.summary.lower() or q in c.description.lower() ] return results - async def invoke(self, name: str, args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: + async def invoke(self, name: str, args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """Invoke a registered capability by name.""" if name not in self._handlers: return {'error': f'Unknown capability: {name}'} diff --git a/ms_agent/capabilities/wrappers/agent_delegate.py b/ms_agent/capabilities/wrappers/agent_delegate.py index 0c0cc2bf7..e7a89905a 100644 --- a/ms_agent/capabilities/wrappers/agent_delegate.py +++ b/ms_agent/capabilities/wrappers/agent_delegate.py @@ -24,13 +24,13 @@ 'description': 'Custom system prompt for the agent (optional)', }, 'tools': { - 'type': - 'string', - 'description': - ('Comma-separated basic tool component names to enable, e.g. ' - '"web_search,file_system,todo_list". Alias "filesystem" is accepted ' - 'for backward compatibility. Leave empty to use the default agent ' - 'config tools.'), + 'type': 'string', + 'description': ( + 'Comma-separated basic tool component names to enable, e.g. ' + '"web_search,file_system,todo_list". Alias "filesystem" is accepted ' + 'for backward compatibility. Leave empty to use the default agent ' + 'config tools.' + ), }, 'max_rounds': { 'type': 'integer', @@ -69,13 +69,13 @@ name='delegate_task', version='0.1.0', granularity='project', - summary=('Delegate a task to an LLM agent that can use tools. ' - 'Blocks until the agent completes.'), + summary=('Delegate a task to an LLM agent that can use tools. Blocks until the agent completes.'), description=( 'Creates an LLMAgent with the given configuration, runs it on the ' 'provided query, and returns the final response text. The agent ' 'can use tools (web search, filesystem, etc.) to accomplish the ' - 'task. WARNING: this call blocks and may take minutes.'), + 'task. WARNING: this call blocks and may take minutes.' + ), input_schema={ 'type': 'object', 'properties': _DELEGATE_INPUT_PROPERTIES, @@ -84,12 +84,8 @@ output_schema={ 'type': 'object', 'properties': { - 'status': { - 'type': 'string' - }, - 'response': { - 'type': 'string' - }, + 'status': {'type': 'string'}, + 'response': {'type': 'string'}, }, }, tags=['agent', 'delegate', 'llm', 'sync'], @@ -100,11 +96,12 @@ name='submit_agent_task', version='0.1.0', granularity='project', - summary=('Submit an agent task to run in the background. ' - 'Returns a task_id immediately.'), - description=('Starts an LLMAgent in the background and returns a task_id. ' - 'Use check_agent_task(task_id) to poll progress and ' - 'get_agent_result(task_id) to retrieve the final response.'), + summary=('Submit an agent task to run in the background. Returns a task_id immediately.'), + description=( + 'Starts an LLMAgent in the background and returns a task_id. ' + 'Use check_agent_task(task_id) to poll progress and ' + 'get_agent_result(task_id) to retrieve the final response.' + ), input_schema={ 'type': 'object', 'properties': _DELEGATE_INPUT_PROPERTIES, @@ -113,12 +110,8 @@ output_schema={ 'type': 'object', 'properties': { - 'task_id': { - 'type': 'string' - }, - 'status': { - 'type': 'string' - }, + 'task_id': {'type': 'string'}, + 'status': {'type': 'string'}, }, }, tags=['agent', 'delegate', 'llm', 'async', 'submit'], @@ -130,8 +123,9 @@ version='0.1.0', granularity='tool', summary='Check progress of a background agent task.', - description=('Polls the status of an agent task previously submitted via ' - 'submit_agent_task. Returns the current status.'), + description=( + 'Polls the status of an agent task previously submitted via submit_agent_task. Returns the current status.' + ), input_schema={ 'type': 'object', 'properties': { @@ -151,8 +145,10 @@ version='0.1.0', granularity='tool', summary='Get the result of a completed agent task.', - description=('Retrieves the final response from a completed agent task. ' - 'If the task is still running, returns a status message.'), + description=( + 'Retrieves the final response from a completed agent task. ' + 'If the task is still running, returns a status message.' + ), input_schema={ 'type': 'object', 'properties': { @@ -209,8 +205,7 @@ def _build_basic_tools_config(tools_list: list[str] | None) -> dict[str, Any]: for raw_name in tools_list: tool_name = _BASIC_TOOL_ALIASES.get(raw_name) if tool_name is None: - logger.warning('Ignoring unsupported delegate tool name: %s', - raw_name) + logger.warning('Ignoring unsupported delegate tool name: %s', raw_name) continue if tool_name in tools_cfg: continue @@ -250,9 +245,7 @@ def _build_agent_config( # return. OmegaConf.merge(default, ours) lets our value win. safe_cbs: list[str] = [] if hasattr(config, 'callbacks') and config.callbacks: - safe_cbs = [ - c for c in config.callbacks if c not in ('input_callback', ) - ] + safe_cbs = [c for c in config.callbacks if c not in ('input_callback',)] OmegaConf.update(config, 'callbacks', safe_cbs, merge=False) OmegaConf.update(config, 'save_history', False, merge=True) @@ -264,8 +257,7 @@ def _build_agent_config( # Preserve explicit config_path tool settings when already present. if hasattr(existing_tools, tool_name): continue - OmegaConf.update( - config, f'tools.{tool_name}', tool_cfg, merge=True) + OmegaConf.update(config, f'tools.{tool_name}', tool_cfg, merge=True) return config @@ -297,8 +289,7 @@ async def _run_agent( """Create, run, and clean up an LLMAgent. Returns the response text.""" from ms_agent.agent.llm_agent import LLMAgent - config = _build_agent_config(config_path, system_prompt, tools_list, - max_rounds) + config = _build_agent_config(config_path, system_prompt, tools_list, max_rounds) agent = LLMAgent(config=config, tag='delegate') try: @@ -322,8 +313,7 @@ async def _run_agent( logger.debug('Error during agent tool cleanup', exc_info=True) -async def _handle_delegate_task(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: +async def _handle_delegate_task(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """Synchronous agent delegation -- blocks until the agent finishes.""" query = (args.get('query') or '').strip() if not query: @@ -355,8 +345,7 @@ async def _background_agent(task: AsyncTask) -> dict[str, Any]: return {'response': response} -async def _handle_submit_agent_task(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: +async def _handle_submit_agent_task(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """Submit an agent task to run in the background.""" query = (args.get('query') or '').strip() if not query: @@ -374,32 +363,27 @@ async def _handle_submit_agent_task(args: dict[str, Any], }, ) return { - 'task_id': - task.task_id, - 'status': - 'running', - 'message': - (f'Agent task {task.task_id} started. ' - f'Use check_agent_task(task_id="{task.task_id}") to poll status.'), + 'task_id': task.task_id, + 'status': 'running', + 'message': ( + f'Agent task {task.task_id} started. Use check_agent_task(task_id="{task.task_id}") to poll status.' + ), } -async def _handle_check_agent_task(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: +async def _handle_check_agent_task(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """Check progress of a background agent task.""" return _manager.check(args['task_id']) -async def _handle_get_agent_result(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: +async def _handle_get_agent_result(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """Get the result of a completed agent task.""" task_id = args['task_id'] max_chars = args.get('max_chars', 50000) result = _manager.get_result(task_id) # Truncate response if needed - if result.get('status') == 'completed' and isinstance( - result.get('result'), dict): + if result.get('status') == 'completed' and isinstance(result.get('result'), dict): response = result['result'].get('response', '') truncated = len(response) > max_chars if truncated: @@ -411,8 +395,7 @@ async def _handle_get_agent_result(args: dict[str, Any], return result -async def _handle_cancel_agent_task(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: +async def _handle_cancel_agent_task(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """Cancel a running agent task.""" return await _manager.cancel(args['task_id']) diff --git a/ms_agent/capabilities/wrappers/deep_research.py b/ms_agent/capabilities/wrappers/deep_research.py index 68de160bb..10672777f 100644 --- a/ms_agent/capabilities/wrappers/deep_research.py +++ b/ms_agent/capabilities/wrappers/deep_research.py @@ -18,14 +18,17 @@ name='submit_research_task', version='0.1.0', granularity='project', - summary=('Submit a deep research task that runs in the background. ' - 'Returns a task_id immediately -- use check_research_progress ' - 'and get_research_report to poll results.'), + summary=( + 'Submit a deep research task that runs in the background. ' + 'Returns a task_id immediately -- use check_research_progress ' + 'and get_research_report to poll results.' + ), description=( 'Launches the deep_research v2 pipeline as a background subprocess. ' 'The calling agent is NOT blocked and can continue other work. ' 'Use check_research_progress(task_id) to poll status, and ' - 'get_research_report(task_id) to retrieve the final report.'), + 'get_research_report(task_id) to retrieve the final report.' + ), input_schema={ 'type': 'object', 'properties': { @@ -34,16 +37,12 @@ 'description': 'The research question or topic to investigate', }, 'config_path': { - 'type': - 'string', - 'description': ('Path to researcher.yaml config. ' - 'Defaults to the bundled v2 config.'), + 'type': 'string', + 'description': ('Path to researcher.yaml config. Defaults to the bundled v2 config.'), }, 'output_dir': { - 'type': - 'string', - 'description': - 'Directory for research outputs (auto-generated if omitted)', + 'type': 'string', + 'description': 'Directory for research outputs (auto-generated if omitted)', }, }, 'required': ['query'], @@ -51,15 +50,9 @@ output_schema={ 'type': 'object', 'properties': { - 'task_id': { - 'type': 'string' - }, - 'status': { - 'type': 'string' - }, - 'output_dir': { - 'type': 'string' - }, + 'task_id': {'type': 'string'}, + 'status': {'type': 'string'}, + 'output_dir': {'type': 'string'}, }, }, tags=['research', 'search', 'report', 'async', 'submit'], @@ -70,12 +63,14 @@ name='check_research_progress', version='0.1.0', granularity='tool', - summary=('Check the progress of a running deep research task. ' - 'Returns status, evidence count, and latest activity.'), + summary=( + 'Check the progress of a running deep research task. Returns status, evidence count, and latest activity.' + ), description=( 'Polls the status of a research task previously submitted via ' 'submit_research_task. Inspects the output directory to report ' - 'how many evidence notes and analyses have been collected so far.'), + 'how many evidence notes and analyses have been collected so far.' + ), input_schema={ 'type': 'object', 'properties': { @@ -94,12 +89,15 @@ name='get_research_report', version='0.1.0', granularity='tool', - summary=('Retrieve the final report from a completed deep research task. ' - 'Returns the report content or an error if not yet complete.'), + summary=( + 'Retrieve the final report from a completed deep research task. ' + 'Returns the report content or an error if not yet complete.' + ), description=( 'Reads the final research report produced by a completed task. ' 'If the task is still running, returns a message to wait. ' - 'If completed, returns the full report markdown content.'), + 'If completed, returns the full report markdown content.' + ), input_schema={ 'type': 'object', 'properties': { @@ -126,12 +124,14 @@ granularity='project', summary=( 'Run deep research synchronously (BLOCKS until complete, 20-60 min). ' - 'Prefer submit_research_task for non-blocking usage.'), + 'Prefer submit_research_task for non-blocking usage.' + ), description=( 'Synchronous version that blocks until the research is complete. ' 'WARNING: This can take 20-60 minutes. Most MCP clients will ' 'timeout. Use submit_research_task + check_research_progress + ' - 'get_research_report for non-blocking async operation.'), + 'get_research_report for non-blocking async operation.' + ), input_schema={ 'type': 'object', 'properties': { @@ -139,14 +139,8 @@ 'type': 'string', 'description': 'The research question or topic to investigate', }, - 'config_path': { - 'type': 'string', - 'description': 'Path to researcher.yaml' - }, - 'output_dir': { - 'type': 'string', - 'description': 'Output directory' - }, + 'config_path': {'type': 'string', 'description': 'Path to researcher.yaml'}, + 'output_dir': {'type': 'string', 'description': 'Output directory'}, }, 'required': ['query'], }, @@ -159,14 +153,12 @@ def _find_default_config() -> str | None: """Locate the bundled deep_research v2 researcher.yaml.""" candidates = [ - os.path.join( - os.path.dirname(__file__), '..', '..', '..', 'projects', - 'deep_research', 'v2', 'researcher.yaml'), + os.path.join(os.path.dirname(__file__), '..', '..', '..', 'projects', 'deep_research', 'v2', 'researcher.yaml'), ] try: from importlib import resources as importlib_resources - trav = importlib_resources.files('ms_agent').joinpath( - 'projects', 'deep_research', 'v2', 'researcher.yaml') + + trav = importlib_resources.files('ms_agent').joinpath('projects', 'deep_research', 'v2', 'researcher.yaml') candidates.insert(0, str(trav)) except Exception: pass @@ -208,12 +200,8 @@ def _count_evidence(output_dir: str) -> dict[str, int]: notes_dir = os.path.join(evidence_dir, 'notes') analyses_dir = os.path.join(evidence_dir, 'analyses') return { - 'notes': - len(list(Path(notes_dir).glob('*.md'))) - if os.path.isdir(notes_dir) else 0, - 'analyses': - len(list(Path(analyses_dir).glob('*.md'))) - if os.path.isdir(analyses_dir) else 0, + 'notes': len(list(Path(notes_dir).glob('*.md'))) if os.path.isdir(notes_dir) else 0, + 'analyses': len(list(Path(analyses_dir).glob('*.md'))) if os.path.isdir(analyses_dir) else 0, } @@ -230,8 +218,7 @@ def _research_progress_fn(task: AsyncTask) -> dict[str, Any]: 'report_available': bool(report_path), } if task.status == 'completed': - info['report_path'] = task.metadata.get('report_path', - '') or report_path + info['report_path'] = task.metadata.get('report_path', '') or report_path return info @@ -271,8 +258,7 @@ async def _background_research(task: AsyncTask) -> dict[str, Any]: raise RuntimeError(stderr.decode('utf-8', errors='replace')[-2000:]) -async def _handle_submit(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: +async def _handle_submit(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """Submit a research task to run in the background.""" query: str = args['query'] config_path = args.get('config_path', '') or _find_default_config() or '' @@ -297,28 +283,23 @@ async def _handle_submit(args: dict[str, Any], ) return { - 'task_id': - task.task_id, - 'status': - 'running', - 'output_dir': - output_dir, - 'message': - (f'Research task {task.task_id} started. ' - f'Use check_research_progress(task_id="{task.task_id}") to poll status.' - ), + 'task_id': task.task_id, + 'status': 'running', + 'output_dir': output_dir, + 'message': ( + f'Research task {task.task_id} started. ' + f'Use check_research_progress(task_id="{task.task_id}") to poll status.' + ), } -async def _handle_check_progress(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: +async def _handle_check_progress(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """Check the progress of a running research task.""" task_id: str = args['task_id'] return _manager.check(task_id, progress_fn=_research_progress_fn) -async def _handle_get_report(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: +async def _handle_get_report(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """Retrieve the final report from a completed task.""" task_id: str = args['task_id'] max_chars: int = args.get('max_chars', 50000) @@ -330,15 +311,14 @@ async def _handle_get_report(args: dict[str, Any], if task.status == 'running': evidence = _count_evidence(task.metadata.get('output_dir', '')) return { - 'task_id': - task_id, - 'status': - 'running', - 'message': - ('Research is still in progress. ' - f'Evidence collected so far: {evidence["notes"]} notes, ' - f'{evidence["analyses"]} analyses. ' - 'Please check again later.'), + 'task_id': task_id, + 'status': 'running', + 'message': ( + 'Research is still in progress. ' + f'Evidence collected so far: {evidence["notes"]} notes, ' + f'{evidence["analyses"]} analyses. ' + 'Please check again later.' + ), } if task.status == 'failed': @@ -349,8 +329,7 @@ async def _handle_get_report(args: dict[str, Any], } output_dir = task.metadata.get('output_dir', '') - report_path = task.metadata.get('report_path', - '') or _find_report(output_dir) + report_path = task.metadata.get('report_path', '') or _find_report(output_dir) if not report_path or not os.path.isfile(report_path): return { 'task_id': task_id, @@ -375,18 +354,14 @@ async def _handle_get_report(args: dict[str, Any], } -async def _handle_deep_research_sync(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: +async def _handle_deep_research_sync(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """Launch deep_research synchronously (blocks until complete).""" query: str = args['query'] config_path = args.get('config_path', '') or _find_default_config() or '' output_dir = args.get('output_dir', '') if not config_path or not os.path.isfile(config_path): - return { - 'status': 'failed', - 'error': f'Config not found: {config_path}' - } + return {'status': 'failed', 'error': f'Config not found: {config_path}'} if not output_dir: ts = time.strftime('%Y%m%d_%H%M%S') @@ -405,16 +380,12 @@ async def _handle_deep_research_sync(args: dict[str, Any], await proc.wait() report_path = _find_report(output_dir) if proc.returncode == 0: - return { - 'status': 'completed', - 'output_dir': output_dir, - 'report_path': report_path - } + return {'status': 'completed', 'output_dir': output_dir, 'report_path': report_path} else: return { 'status': 'failed', 'output_dir': output_dir, - 'error': stderr.decode('utf-8', errors='replace')[-2000:] + 'error': stderr.decode('utf-8', errors='replace')[-2000:], } except Exception as e: return {'status': 'failed', 'error': str(e)} @@ -427,5 +398,4 @@ def register_all(registry: CapabilityRegistry, config: Any = None) -> None: registry.register(CHECK_PROGRESS_DESCRIPTOR, _handle_check_progress) registry.register(GET_REPORT_DESCRIPTOR, _handle_get_report) # Sync (for direct Python API or long-timeout scenarios) - registry.register(DEEP_RESEARCH_SYNC_DESCRIPTOR, - _handle_deep_research_sync) + registry.register(DEEP_RESEARCH_SYNC_DESCRIPTOR, _handle_deep_research_sync) diff --git a/ms_agent/capabilities/wrappers/filesystem.py b/ms_agent/capabilities/wrappers/filesystem.py index 5becb5615..fcf4b944e 100644 --- a/ms_agent/capabilities/wrappers/filesystem.py +++ b/ms_agent/capabilities/wrappers/filesystem.py @@ -9,44 +9,40 @@ name='replace_file_contents', version='0.1.0', granularity='tool', - summary=('Replace exact content in a file without line numbers. ' - 'Concurrent-safe: matches by content instead of line numbers, ' - 'so parallel edits on the same file do not conflict.'), - description= - ('Performs an exact-string replacement inside a file. The caller supplies ' - 'the verbatim `source` text to find and the `target` text to replace it with. ' - 'An `occurrence` parameter controls which match to replace (1-based) or ' - '-1 for all. Because it relies on content matching rather than line numbers, ' - 'it is safe to use from multiple agents editing the same file concurrently.' - ), + summary=( + 'Replace exact content in a file without line numbers. ' + 'Concurrent-safe: matches by content instead of line numbers, ' + 'so parallel edits on the same file do not conflict.' + ), + description=( + 'Performs an exact-string replacement inside a file. The caller supplies ' + 'the verbatim `source` text to find and the `target` text to replace it with. ' + 'An `occurrence` parameter controls which match to replace (1-based) or ' + '-1 for all. Because it relies on content matching rather than line numbers, ' + 'it is safe to use from multiple agents editing the same file concurrently.' + ), input_schema={ 'type': 'object', 'properties': { 'path': { - 'type': - 'string', - 'description': - 'Path to the file to modify (relative to workspace or absolute)', + 'type': 'string', + 'description': 'Path to the file to modify (relative to workspace or absolute)', }, 'source': { - 'type': - 'string', - 'description': - ('Exact content to find. Must match the file content verbatim ' - 'including whitespace, punctuation, and line breaks.'), + 'type': 'string', + 'description': ( + 'Exact content to find. Must match the file content verbatim ' + 'including whitespace, punctuation, and line breaks.' + ), }, 'target': { 'type': 'string', 'description': 'New content to replace the source with', }, 'occurrence': { - 'type': - 'integer', - 'description': - ('Which occurrence to replace (1-based). ' - 'Use -1 to replace all occurrences. Default: 1'), - 'default': - 1, + 'type': 'integer', + 'description': ('Which occurrence to replace (1-based). Use -1 to replace all occurrences. Default: 1'), + 'default': 1, }, }, 'required': ['path', 'source', 'target'], @@ -54,9 +50,7 @@ output_schema={ 'type': 'object', 'properties': { - 'result': { - 'type': 'string' - }, + 'result': {'type': 'string'}, }, }, tags=['filesystem', 'edit', 'replace', 'diff', 'concurrent-safe'], @@ -67,10 +61,10 @@ name='replace_file_lines', version='0.1.0', granularity='tool', - summary= - ('Replace, insert, or append content by line range. ' - 'Supports insert-at-beginning (start_line=0) and append-at-end (start_line=-1).' - ), + summary=( + 'Replace, insert, or append content by line range. ' + 'Supports insert-at-beginning (start_line=0) and append-at-end (start_line=-1).' + ), description=( 'Replaces a range of lines in a file with new content. ' 'Special modes: start_line=0 inserts at the beginning, ' @@ -88,17 +82,12 @@ 'description': 'New content to insert or replace with', }, 'start_line': { - 'type': - 'integer', - 'description': - ('Start line (1-based inclusive). ' - '0 = insert at beginning, -1 = append at end.'), + 'type': 'integer', + 'description': ('Start line (1-based inclusive). 0 = insert at beginning, -1 = append at end.'), }, 'end_line': { - 'type': - 'integer', - 'description': - 'End line (1-based inclusive). Required unless start_line is 0 or -1.', + 'type': 'integer', + 'description': 'End line (1-based inclusive). Required unless start_line is 0 or -1.', }, }, 'required': ['path', 'content', 'start_line'], @@ -106,9 +95,7 @@ output_schema={ 'type': 'object', 'properties': { - 'result': { - 'type': 'string' - }, + 'result': {'type': 'string'}, }, }, tags=['filesystem', 'edit', 'replace', 'lines'], @@ -125,10 +112,8 @@ def _resolve_path(path: str, workspace: str | None) -> str: return os.path.abspath(path) -async def _handle_replace_contents(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: - workspace = kwargs.get('workspace') or os.environ.get( - 'MS_AGENT_OUTPUT_DIR', '') +async def _handle_replace_contents(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: + workspace = kwargs.get('workspace') or os.environ.get('MS_AGENT_OUTPUT_DIR', '') path = _resolve_path(args['path'], workspace) source: str = args['source'] target: str = args['target'] @@ -145,9 +130,7 @@ async def _handle_replace_contents(args: dict[str, Any], content = f.read() if source not in content: - return { - 'error': f'Could not find the exact content to replace in {path}' - } + return {'error': f'Could not find the exact content to replace in {path}'} count = content.count(source) @@ -160,8 +143,7 @@ async def _handle_replace_contents(args: dict[str, Any], return {'error': f'occurrence {occurrence} exceeds total ({count})'} else: parts = content.split(source, occurrence) - updated = source.join(parts[:occurrence]) + target + source.join( - parts[occurrence:]) + updated = source.join(parts[:occurrence]) + target + source.join(parts[occurrence:]) msg = f'Replaced occurrence {occurrence} of {count}' with open(path, 'w', encoding='utf-8') as f: @@ -170,10 +152,8 @@ async def _handle_replace_contents(args: dict[str, Any], return {'result': f'{msg} in {path}'} -async def _handle_replace_lines(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: - workspace = kwargs.get('workspace') or os.environ.get( - 'MS_AGENT_OUTPUT_DIR', '') +async def _handle_replace_lines(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: + workspace = kwargs.get('workspace') or os.environ.get('MS_AGENT_OUTPUT_DIR', '') path = _resolve_path(args['path'], workspace) new_content: str = args['content'] start_line: int = args['start_line'] diff --git a/ms_agent/capabilities/wrappers/lsp_code_server.py b/ms_agent/capabilities/wrappers/lsp_code_server.py index 072c394db..69b39bb3b 100644 --- a/ms_agent/capabilities/wrappers/lsp_code_server.py +++ b/ms_agent/capabilities/wrappers/lsp_code_server.py @@ -9,30 +9,25 @@ name='lsp_check_directory', version='0.1.0', granularity='component', - summary=('Run LSP diagnostics on all code files in a directory. ' - 'Supports TypeScript/JavaScript, Python, and Java.'), + summary=('Run LSP diagnostics on all code files in a directory. Supports TypeScript/JavaScript, Python, and Java.'), description=( 'Starts the appropriate Language Server Protocol backend ' '(typescript-language-server, pyright, or jdtls) and runs ' 'diagnostics on every matching file in the given directory. ' 'Returns structured error/warning information. Useful for ' - 'validating generated code or checking a project for issues.'), + 'validating generated code or checking a project for issues.' + ), input_schema={ 'type': 'object', 'properties': { 'directory': { - 'type': - 'string', - 'description': - 'Path to the directory to check (relative to workspace or absolute)', + 'type': 'string', + 'description': 'Path to the directory to check (relative to workspace or absolute)', }, 'language': { - 'type': - 'string', + 'type': 'string', 'enum': ['typescript', 'python', 'java'], - 'description': - ('Programming language to check. ' - 'typescript covers .ts/.tsx/.js/.jsx/.mjs/.cjs files'), + 'description': ('Programming language to check. typescript covers .ts/.tsx/.js/.jsx/.mjs/.cjs files'), }, }, 'required': ['directory', 'language'], @@ -40,16 +35,10 @@ output_schema={ 'type': 'object', 'properties': { - 'result': { - 'type': 'string', - 'description': 'Diagnostic summary' - }, + 'result': {'type': 'string', 'description': 'Diagnostic summary'}, }, }, - tags=[ - 'code', 'lsp', 'diagnostics', 'validation', 'typescript', 'python', - 'java' - ], + tags=['code', 'lsp', 'diagnostics', 'validation', 'typescript', 'python', 'java'], estimated_duration='minutes', parent='lsp_code_server', requires={'bins': []}, @@ -61,19 +50,19 @@ granularity='tool', summary=( 'Incrementally update a file and check for LSP errors. ' - 'More efficient than a full directory check for single-file edits.'), + 'More efficient than a full directory check for single-file edits.' + ), description=( 'Updates a file with new content and runs LSP diagnostics on it. ' 'The LSP server is reused across calls, making repeated checks on ' - 'the same project very efficient.'), + 'the same project very efficient.' + ), input_schema={ 'type': 'object', 'properties': { 'file_path': { - 'type': - 'string', - 'description': - 'Path to the file (relative to workspace or absolute)', + 'type': 'string', + 'description': 'Path to the file (relative to workspace or absolute)', }, 'content': { 'type': 'string', @@ -90,10 +79,7 @@ output_schema={ 'type': 'object', 'properties': { - 'result': { - 'type': 'string', - 'description': 'Diagnostic output' - }, + 'result': {'type': 'string', 'description': 'Diagnostic output'}, }, }, tags=['code', 'lsp', 'diagnostics', 'validation'], @@ -105,18 +91,17 @@ name='lsp_code_server', version='0.1.0', granularity='component', - summary= - ('LSP-based code validation server supporting TypeScript, Python, and Java. ' - 'Provides directory-wide and incremental file-level diagnostics.'), + summary=( + 'LSP-based code validation server supporting TypeScript, Python, and Java. ' + 'Provides directory-wide and incremental file-level diagnostics.' + ), description=( 'A component that wraps Language Server Protocol backends to provide ' 'code diagnostics without requiring an IDE. Sub-capabilities: ' 'lsp_check_directory (full project scan) and lsp_update_and_check ' - '(incremental single-file validation).'), - input_schema={ - 'type': 'object', - 'properties': {} - }, + '(incremental single-file validation).' + ), + input_schema={'type': 'object', 'properties': {}}, tags=['code', 'lsp', 'diagnostics', 'validation'], estimated_duration='minutes', sub_capabilities=['lsp_check_directory', 'lsp_update_and_check'], @@ -156,16 +141,13 @@ def _resolve_workspace(directory: str, fallback: str) -> str: return fallback -async def _handle_check_directory(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: - fallback = kwargs.get('workspace') or os.environ.get( - 'MS_AGENT_OUTPUT_DIR', os.getcwd()) +async def _handle_check_directory(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: + fallback = kwargs.get('workspace') or os.environ.get('MS_AGENT_OUTPUT_DIR', os.getcwd()) directory = args['directory'] workspace = _resolve_workspace(directory, fallback) lsp = _get_lsp_server(workspace) - rel_dir = os.path.relpath( - directory, workspace) if os.path.isabs(directory) else directory + rel_dir = os.path.relpath(directory, workspace) if os.path.isabs(directory) else directory result = await lsp.call_tool( 'lsp_code_server', @@ -178,18 +160,13 @@ async def _handle_check_directory(args: dict[str, Any], return {'result': result} -async def _handle_update_and_check(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: - fallback = kwargs.get('workspace') or os.environ.get( - 'MS_AGENT_OUTPUT_DIR', os.getcwd()) +async def _handle_update_and_check(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: + fallback = kwargs.get('workspace') or os.environ.get('MS_AGENT_OUTPUT_DIR', os.getcwd()) file_path = args['file_path'] - workspace = _resolve_workspace( - os.path.dirname(file_path), - fallback) if os.path.isabs(file_path) else fallback + workspace = _resolve_workspace(os.path.dirname(file_path), fallback) if os.path.isabs(file_path) else fallback lsp = _get_lsp_server(workspace) - rel_path = os.path.relpath( - file_path, workspace) if os.path.isabs(file_path) else file_path + rel_path = os.path.relpath(file_path, workspace) if os.path.isabs(file_path) else file_path result = await lsp.call_tool( 'lsp_code_server', diff --git a/ms_agent/capabilities/wrappers/web_search.py b/ms_agent/capabilities/wrappers/web_search.py index 02a9e2b56..60fba7d29 100644 --- a/ms_agent/capabilities/wrappers/web_search.py +++ b/ms_agent/capabilities/wrappers/web_search.py @@ -16,6 +16,7 @@ def _get_engine(engine_type: str) -> Any: """Return a cached :class:`SearchEngine` instance for *engine_type*.""" if engine_type not in _engines: from ms_agent.tools.search.websearch_tool import get_search_engine + _engines[engine_type] = get_search_engine(engine_type) return _engines[engine_type] @@ -25,6 +26,7 @@ def _get_fetcher() -> Any: global _fetcher if _fetcher is None: from ms_agent.tools.search.websearch_tool import get_content_fetcher + _fetcher = get_content_fetcher('jina_reader') return _fetcher @@ -33,13 +35,13 @@ def _get_fetcher() -> Any: name='web_search', version='0.1.0', granularity='tool', - summary=('Search the web using multiple engines (exa, serpapi, arxiv) ' - 'and optionally fetch full page content.'), + summary=('Search the web using multiple engines (exa, serpapi, arxiv) and optionally fetch full page content.'), description=( 'Performs a web search and returns structured results including ' 'title, URL, and summary for each hit. Supports exa, serpapi, ' 'and arxiv backends. Set fetch_content=true to additionally ' - 'retrieve and return page text (truncated to 10 000 chars).'), + 'retrieve and return page text (truncated to 10 000 chars).' + ), input_schema={ 'type': 'object', 'properties': { @@ -53,22 +55,14 @@ def _get_fetcher() -> Any: 'default': 5, }, 'engine_type': { - 'type': - 'string', - 'description': - ("Search engine to use: 'exa', 'serpapi', or 'arxiv' " - "(default: 'arxiv')"), - 'default': - 'arxiv', + 'type': 'string', + 'description': ("Search engine to use: 'exa', 'serpapi', or 'arxiv' (default: 'arxiv')"), + 'default': 'arxiv', }, 'fetch_content': { - 'type': - 'boolean', - 'description': - ('Whether to fetch full page content for each result ' - '(default: false)'), - 'default': - False, + 'type': 'boolean', + 'description': ('Whether to fetch full page content for each result (default: false)'), + 'default': False, }, }, 'required': ['query'], @@ -76,21 +70,11 @@ def _get_fetcher() -> Any: output_schema={ 'type': 'object', 'properties': { - 'status': { - 'type': 'string' - }, - 'query': { - 'type': 'string' - }, - 'engine': { - 'type': 'string' - }, - 'count': { - 'type': 'integer' - }, - 'results': { - 'type': 'array' - }, + 'status': {'type': 'string'}, + 'query': {'type': 'string'}, + 'engine': {'type': 'string'}, + 'count': {'type': 'integer'}, + 'results': {'type': 'array'}, }, }, tags=['search', 'web', 'research'], @@ -98,8 +82,7 @@ def _get_fetcher() -> Any: ) -async def _handle_web_search(args: dict[str, Any], - **kwargs: Any) -> dict[str, Any]: +async def _handle_web_search(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """Execute a web search and return structured results.""" query = (args.get('query') or '').strip() if not query: @@ -113,10 +96,7 @@ async def _handle_web_search(args: dict[str, Any], try: engine = _get_engine(engine_type) except Exception as exc: - return { - 'error': - f'Failed to initialise search engine {engine_type!r}: {exc}' - } + return {'error': f'Failed to initialise search engine {engine_type!r}: {exc}'} # Build request via the engine's class method engine_cls = type(engine) @@ -143,11 +123,13 @@ async def _handle_web_search(args: dict[str, Any], raw_list = search_result.to_list() if search_result else [] results: list[dict[str, Any]] = [] for item in raw_list[:num_results]: - results.append({ - 'title': item.get('title', ''), - 'url': item.get('url', ''), - 'summary': item.get('summary', ''), - }) + results.append( + { + 'title': item.get('title', ''), + 'url': item.get('url', ''), + 'summary': item.get('summary', ''), + } + ) # Optional content fetching if fetch_content and results: diff --git a/ms_agent/cli/app.py b/ms_agent/cli/app.py index d9a690d7d..70bed1229 100644 --- a/ms_agent/cli/app.py +++ b/ms_agent/cli/app.py @@ -5,8 +5,7 @@ def subparser_func(args): - """ Function which will be called for a specific sub parser. - """ + """Function which will be called for a specific sub parser.""" return AppCMD(args) @@ -30,41 +29,29 @@ def define_args(parsers: argparse.ArgumentParser): '--app_type', type=str, default='doc_research', - help= - 'The app type, supported values: `doc_research`, `fin_research`') + help='The app type, supported values: `doc_research`, `fin_research`', + ) - parser.add_argument( - '--server_name', - type=str, - default='0.0.0.0', - help='The gradio server name to bind to.') + parser.add_argument('--server_name', type=str, default='0.0.0.0', help='The gradio server name to bind to.') - parser.add_argument( - '--server_port', - type=int, - default=7860, - help='The gradio server port to bind to.') + parser.add_argument('--server_port', type=int, default=7860, help='The gradio server port to bind to.') - parser.add_argument( - '--share', - action='store_true', - help='Whether to share the gradio app publicly.') + parser.add_argument('--share', action='store_true', help='Whether to share the gradio app publicly.') parser.set_defaults(func=subparser_func) def execute(self): - if self.args.app_type == 'doc_research': from ms_agent.app.doc_research import launch_server as launch_doc_research + launch_doc_research( - server_name=self.args.server_name, - server_port=self.args.server_port, - share=self.args.share) + server_name=self.args.server_name, server_port=self.args.server_port, share=self.args.share + ) elif self.args.app_type == 'fin_research': from ms_agent.app.fin_research import launch_server as launch_fin_research + launch_fin_research( - server_name=self.args.server_name, - server_port=self.args.server_port, - share=self.args.share) + server_name=self.args.server_name, server_port=self.args.server_port, share=self.args.share + ) else: raise ValueError(f'Unsupported app type: {self.args.app_type}') diff --git a/ms_agent/cli/cli.py b/ms_agent/cli/cli.py index da709e98d..fafc6e849 100644 --- a/ms_agent/cli/cli.py +++ b/ms_agent/cli/cli.py @@ -10,12 +10,9 @@ def run_cmd(): This cmd imports all other sub commands, for example, `run` and `app`. """ - parser = argparse.ArgumentParser( - 'ModelScope-agent Command Line tool', - usage='ms-agent []') + parser = argparse.ArgumentParser('ModelScope-agent Command Line tool', usage='ms-agent []') - subparsers = parser.add_subparsers( - help='ModelScope-agent commands helpers') + subparsers = parser.add_subparsers(help='ModelScope-agent commands helpers') RunCMD.define_args(subparsers) AppCMD.define_args(subparsers) diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index 2accdb40e..855ce4144 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -4,11 +4,12 @@ import os from importlib import resources as importlib_resources +from omegaconf import OmegaConf + from ms_agent.config import Config from ms_agent.config.env import Env from ms_agent.utils import get_logger, strtobool from ms_agent.utils.constants import AGENT_CONFIG_FILE, MS_AGENT_ASCII -from omegaconf import OmegaConf from .base import CLICommand @@ -16,8 +17,7 @@ def subparser_func(args): - """ Function which will be called for a specific sub parser. - """ + """Function which will be called for a specific sub parser.""" return RunCMD(args) @@ -36,9 +36,7 @@ def list_builtin_projects(): def project_help_text(): projects = list_builtin_projects() if projects: - return ( - 'Built-in bundled project name under package ms_agent/projects. ' - f'Available: {", ".join(projects)}') + return f'Built-in bundled project name under package ms_agent/projects. Available: {", ".join(projects)}' return 'Built-in bundled project name under package ms_agent/projects.' @@ -76,22 +74,18 @@ def define_args(parsers: argparse.ArgumentParser): type=str, default=None, metavar='PATH', - help= - 'Path to a .env file. If omitted, loads ./.env from the current ' - 'working directory when present; missing file is ignored.') + help='Path to a .env file. If omitted, loads ./.env from the current ' + 'working directory when present; missing file is ignored.', + ) parser.add_argument( '--query', required=False, type=str, - help= - 'The query or prompt to send to the LLM. If not set, will enter an interactive mode.' + help='The query or prompt to send to the LLM. If not set, will enter an interactive mode.', ) parser.add_argument( - '--config', - required=False, - type=str, - default=None, - help='The directory or the repo id of the config file') + '--config', required=False, type=str, default=None, help='The directory or the repo id of the config file' + ) parser.add_argument( '--project', required=False, @@ -105,78 +99,64 @@ def define_args(parsers: argparse.ArgumentParser): required=False, type=str, default='false', - help='Trust the code belongs to the config file, default False') + help='Trust the code belongs to the config file, default False', + ) parser.add_argument( '--load_cache', required=False, type=str, default='false', - help= - 'Load previous step histories from cache, this is useful when a query fails and retry' + help='Load previous step histories from cache, this is useful when a query fails and retry', ) + parser.add_argument('--mcp_config', required=False, type=str, default=None, help='The extra mcp server config') parser.add_argument( - '--mcp_config', - required=False, - type=str, - default=None, - help='The extra mcp server config') - parser.add_argument( - '--mcp_server_file', - required=False, - type=str, - default=None, - help='An extra mcp server file.') + '--mcp_server_file', required=False, type=str, default=None, help='An extra mcp server file.' + ) parser.add_argument( '--openai_api_key', required=False, type=str, default=None, - help='API key for accessing an OpenAI-compatible service.') + help='API key for accessing an OpenAI-compatible service.', + ) parser.add_argument( '--modelscope_api_key', required=False, type=str, default=None, - help='API key for accessing ModelScope api-inference services.') + help='API key for accessing ModelScope api-inference services.', + ) parser.add_argument( '--animation_mode', required=False, type=str, choices=['auto', 'human'], default=None, - help= - 'Animation mode for video_generate project: auto (default) or human.' + help='Animation mode for video_generate project: auto (default) or human.', ) parser.add_argument( '--knowledge_search_paths', required=False, type=str, default=None, - help= - 'Comma-separated list of paths for knowledge search.' + help='Comma-separated list of paths for knowledge search.', ) parser.set_defaults(func=subparser_func) def execute(self): if getattr(self.args, 'project', None): if self.args.config: - raise ValueError( - 'Please specify only one of --config or --project') + raise ValueError('Please specify only one of --config or --project') project = self.args.project - project_trav = importlib_resources.files('ms_agent').joinpath( - 'projects', project) + project_trav = importlib_resources.files('ms_agent').joinpath('projects', project) if not project_trav.exists(): - projects_root = importlib_resources.files('ms_agent').joinpath( - 'projects') + projects_root = importlib_resources.files('ms_agent').joinpath('projects') available = [] if projects_root.exists(): - available = [ - p.name for p in projects_root.iterdir() if p.is_dir() - ] - raise ValueError( - f'Unknown project: {project}. Available: {available}') + available = [p.name for p in projects_root.iterdir() if p.is_dir()] + raise ValueError(f'Unknown project: {project}. Available: {available}') # as_file ensures we get a real filesystem path even if installed as zip with importlib_resources.as_file(project_trav) as project_dir: @@ -192,16 +172,14 @@ def _execute_with_config(self): self.args.config = os.path.join(current_dir, AGENT_CONFIG_FILE) else: # Use built-in default agent.yaml from package - default_config_path = importlib_resources.files( - 'ms_agent').joinpath('agent', AGENT_CONFIG_FILE) - with importlib_resources.as_file( - default_config_path) as config_file: + default_config_path = importlib_resources.files('ms_agent').joinpath('agent', AGENT_CONFIG_FILE) + with importlib_resources.as_file(default_config_path) as config_file: self.args.config = str(config_file) elif not os.path.exists(self.args.config): from modelscope import snapshot_download + self.args.config = snapshot_download(self.args.config) - self.args.trust_remote_code = strtobool( - self.args.trust_remote_code) # noqa + self.args.trust_remote_code = strtobool(self.args.trust_remote_code) # noqa self.args.load_cache = strtobool(self.args.load_cache) # Propagate animation mode via environment variable for downstream code agents @@ -219,26 +197,19 @@ def _execute_with_config(self): author = f.read() blue_color_prefix = '\033[34m' blue_color_suffix = '\033[0m' - print( - blue_color_prefix + MS_AGENT_ASCII + blue_color_suffix, flush=True) + print(blue_color_prefix + MS_AGENT_ASCII + blue_color_suffix, flush=True) line_start = '═════════════════════════Workflow Contributed By════════════════════════════' line_end = '════════════════════════════════════════════════════════════════════════════' if author: - print( - blue_color_prefix + line_start + blue_color_suffix, flush=True) - print( - blue_color_prefix + author.strip() + blue_color_suffix, - flush=True) + print(blue_color_prefix + line_start + blue_color_suffix, flush=True) + print(blue_color_prefix + author.strip() + blue_color_suffix, flush=True) print(blue_color_prefix + line_end + blue_color_suffix, flush=True) config = Config.from_task(self.args.config) # If knowledge_search_paths is provided, configure tools.localsearch if getattr(self.args, 'knowledge_search_paths', None): - paths = [ - p.strip() for p in self.args.knowledge_search_paths.split(',') - if p.strip() - ] + paths = [p.strip() for p in self.args.knowledge_search_paths.split(',') if p.strip()] if paths: if not hasattr(config, 'tools') or config.tools is None: config['tools'] = OmegaConf.create({}) @@ -249,8 +220,7 @@ def _execute_with_config(self): 'work_path': './.sirchmunk', 'mode': 'FAST', } - config.tools['localsearch'] = OmegaConf.create( - localsearch_config) + config.tools['localsearch'] = OmegaConf.create(localsearch_config) else: existing = OmegaConf.to_container(tl, resolve=True) existing['paths'] = paths @@ -258,18 +228,22 @@ def _execute_with_config(self): if Config.is_workflow(config): from ms_agent.workflow.loader import WorkflowLoader + engine = WorkflowLoader.build( config_dir_or_id=self.args.config, config=config, mcp_server_file=self.args.mcp_server_file, load_cache=self.args.load_cache, - trust_remote_code=self.args.trust_remote_code) + trust_remote_code=self.args.trust_remote_code, + ) else: from ms_agent.agent.loader import AgentLoader + engine = AgentLoader.build( config_dir_or_id=self.args.config, config=config, mcp_server_file=self.args.mcp_server_file, load_cache=self.args.load_cache, - trust_remote_code=self.args.trust_remote_code) + trust_remote_code=self.args.trust_remote_code, + ) asyncio.run(engine.run(self.args.query)) diff --git a/ms_agent/cli/ui.py b/ms_agent/cli/ui.py index b10c9bb5d..4d0c40be2 100644 --- a/ms_agent/cli/ui.py +++ b/ms_agent/cli/ui.py @@ -11,8 +11,7 @@ def subparser_func(args): - """ Function which will be called for a specific sub parser. - """ + """Function which will be called for a specific sub parser.""" return UICMD(args) @@ -28,28 +27,11 @@ def __init__(self, args): def define_args(parsers: argparse.ArgumentParser): """Define args for the ui command.""" parser: argparse.ArgumentParser = parsers.add_parser(UICMD.name) - parser.add_argument( - '--host', - type=str, - default='0.0.0.0', - help='The server host to bind to.') - parser.add_argument( - '--port', - type=int, - default=7860, - help='The server port to bind to.') - parser.add_argument( - '--reload', - action='store_true', - help='Enable auto-reload for development.') - parser.add_argument( - '--production', - action='store_true', - help='Run in production mode (serve built frontend).') - parser.add_argument( - '--no-browser', - action='store_true', - help='Do not automatically open browser.') + parser.add_argument('--host', type=str, default='0.0.0.0', help='The server host to bind to.') + parser.add_argument('--port', type=int, default=7860, help='The server port to bind to.') + parser.add_argument('--reload', action='store_true', help='Enable auto-reload for development.') + parser.add_argument('--production', action='store_true', help='Run in production mode (serve built frontend).') + parser.add_argument('--no-browser', action='store_true', help='Do not automatically open browser.') parser.set_defaults(func=subparser_func) def execute(self): @@ -59,6 +41,7 @@ def execute(self): if not webui_dir.exists(): import ms_agent + ms_agent_path = Path(ms_agent.__file__).parent webui_dir = ms_agent_path / 'webui' @@ -73,13 +56,10 @@ def execute(self): sys.exit(1) frontend_dist = frontend_dir / 'dist' - frontend_built = frontend_dist.exists() and (frontend_dist - / 'index.html').exists() + frontend_built = frontend_dist.exists() and (frontend_dist / 'index.html').exists() if self.args.production and not frontend_built: - print( - 'Error: Frontend not built. Please run "npm run build" in webui/frontend first.' - ) + print('Error: Frontend not built. Please run "npm run build" in webui/frontend first.') sys.exit(1) if not self.args.production and not frontend_built: @@ -115,8 +95,7 @@ def open_browser(): time.sleep(1.5) webbrowser.open(browser_url) - browser_thread = threading.Thread( - target=open_browser, daemon=True) + browser_thread = threading.Thread(target=open_browser, daemon=True) browser_thread.start() main() @@ -126,6 +105,7 @@ def open_browser(): except Exception as e: print(f'Error starting WebUI: {e}') import traceback + traceback.print_exc() sys.exit(1) finally: @@ -136,33 +116,33 @@ def _build_frontend(self, frontend_dir: Path) -> bool: import subprocess try: - subprocess.run(['npm', '--version'], - capture_output=True, - check=True, - timeout=5) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, - FileNotFoundError): + subprocess.run(['npm', '--version'], capture_output=True, check=True, timeout=5) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): return False node_modules = frontend_dir / 'node_modules' if not node_modules.exists(): try: - subprocess.run(['npm', 'install'], - cwd=frontend_dir, - check=True, - timeout=300, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + subprocess.run( + ['npm', 'install'], + cwd=frontend_dir, + check=True, + timeout=300, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) except (subprocess.TimeoutExpired, subprocess.CalledProcessError): return False try: - subprocess.run(['npm', 'run', 'build'], - cwd=frontend_dir, - check=True, - timeout=300, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + subprocess.run( + ['npm', 'run', 'build'], + cwd=frontend_dir, + check=True, + timeout=300, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) return True except (subprocess.TimeoutExpired, subprocess.CalledProcessError): return False diff --git a/ms_agent/config/config.py b/ms_agent/config/config.py index 2f6175524..93cf3a338 100644 --- a/ms_agent/config/config.py +++ b/ms_agent/config/config.py @@ -5,12 +5,13 @@ from copy import deepcopy from typing import Any, Dict, Union -from ms_agent.prompting import apply_prompt_files -from ms_agent.utils import get_logger +from modelscope import snapshot_download from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf.basecontainer import BaseContainer -from modelscope import snapshot_download +from ms_agent.prompting import apply_prompt_files +from ms_agent.utils import get_logger + from ..utils.constants import TOOL_PLUGIN_NAME from .env import Env @@ -18,7 +19,6 @@ class ConfigLifecycleHandler: - def task_begin(self, config: DictConfig, tag: str) -> DictConfig: """Modify config when the task begins. @@ -51,14 +51,10 @@ class Config: """All tasks begin from a config""" tag: str = '' - supported_config_names = [ - 'workflow.yaml', 'workflow.yml', 'agent.yaml', 'agent.yml' - ] + supported_config_names = ['workflow.yaml', 'workflow.yml', 'agent.yaml', 'agent.yml'] @classmethod - def from_task(cls, - config_dir_or_id: str, - env: Dict[str, str] = None) -> Union[DictConfig, ListConfig]: + def from_task(cls, config_dir_or_id: str, env: Dict[str, str] = None) -> Union[DictConfig, ListConfig]: """Read a task config file and return a config object. Args: @@ -88,7 +84,8 @@ def from_task(cls, assert config is not None, ( f'Cannot find any valid config file in {config_dir_or_id}, ' - f'supported configs are: {Config.supported_config_names}') + f'supported configs are: {Config.supported_config_names}' + ) envs = Env.load_env(env) cls._update_config(config, envs) _dict_config = cls.parse_args() @@ -117,10 +114,7 @@ def fill_missing_fields(config: DictConfig) -> DictConfig: @staticmethod def is_workflow(config: DictConfig) -> bool: assert config.name is not None, 'Cannot find a valid name in this config' - return config.name in [ - 'workflow.yaml', 'workflow.yml', 'simple_workflow.yaml', - 'simple_workflow.yml' - ] + return config.name in ['workflow.yaml', 'workflow.yml', 'simple_workflow.yaml', 'simple_workflow.yml'] @staticmethod def parse_args() -> Dict[str, Any]: @@ -131,19 +125,16 @@ def parse_args() -> Dict[str, Any]: for idx in range(1, len(unknown) - 1, 2): key = unknown[idx] value = unknown[idx + 1] - assert key.startswith( - '--'), f'Parameter not correct: {unknown}' + assert key.startswith('--'), f'Parameter not correct: {unknown}' _dict_config[key[2:]] = value return _dict_config @staticmethod - def _update_config(config: Union[DictConfig, ListConfig], - extra: Dict[str, str] = None): + def _update_config(config: Union[DictConfig, ListConfig], extra: Dict[str, str] = None): if not extra: return config - def traverse_config(_config: Union[DictConfig, ListConfig, Any], - path: str = ''): + def traverse_config(_config: Union[DictConfig, ListConfig, Any], path: str = ''): if isinstance(_config, DictConfig): for name, value in _config.items(): current_path = f'{path}.{name}' if path else name @@ -152,48 +143,45 @@ def traverse_config(_config: Union[DictConfig, ListConfig, Any], traverse_config(value, current_path) else: if current_path in extra: - logger.info( - f'Replacing {current_path} with extra value.') + logger.info(f'Replacing {current_path} with extra value.') # Convert temperature to float and max_tokens to int if they're numeric strings value_to_set = extra[current_path] - if name == 'temperature' and isinstance( - value_to_set, str): + if name == 'temperature' and isinstance(value_to_set, str): try: value_to_set = float(value_to_set) except (ValueError, TypeError): pass - elif name == 'max_tokens' and isinstance( - value_to_set, str): + elif name == 'max_tokens' and isinstance(value_to_set, str): try: value_to_set = int(value_to_set) except (ValueError, TypeError): pass setattr(_config, name, value_to_set) # Find the key in extra that matches name (case-insensitive) - elif (key_match := next( - (key - for key in extra if key.lower() == name.lower()), - None)) is not None: + elif ( + key_match := next((key for key in extra if key.lower() == name.lower()), None) + ) is not None: logger.info(f'Replacing {name} with extra value.') # Convert temperature to float and max_tokens to int if they're numeric strings value_to_set = extra[key_match] - if name == 'temperature' and isinstance( - value_to_set, str): + if name == 'temperature' and isinstance(value_to_set, str): try: value_to_set = float(value_to_set) except (ValueError, TypeError): pass - elif name == 'max_tokens' and isinstance( - value_to_set, str): + elif name == 'max_tokens' and isinstance(value_to_set, str): try: value_to_set = int(value_to_set) except (ValueError, TypeError): pass setattr(_config, name, value_to_set) # Handle placeholder replacement like - elif (isinstance(value, str) and value.startswith('<') - and value.endswith('>') - and value[1:-1] in extra): + elif ( + isinstance(value, str) + and value.startswith('<') + and value.endswith('>') + and value[1:-1] in extra + ): logger.info(f'Replacing {value} with extra value.') setattr(_config, name, extra[value[1:-1]]) @@ -203,9 +191,12 @@ def traverse_config(_config: Union[DictConfig, ListConfig, Any], if isinstance(value, BaseContainer): traverse_config(value, path) else: - if (isinstance(value, str) and value.startswith('<') - and value.endswith('>') - and value[1:-1] in extra): + if ( + isinstance(value, str) + and value.startswith('<') + and value.endswith('>') + and value[1:-1] in extra + ): logger.info(f'Replacing {value} with extra value.') _config[idx] = extra[value[1:-1]] @@ -217,24 +208,20 @@ def traverse_config(_config: Union[DictConfig, ListConfig, Any], current = config # Navigate/create nested structure for i, part in enumerate(parts[:-1]): - if not hasattr(current, - part) or getattr(current, part) is None: + if not hasattr(current, part) or getattr(current, part) is None: setattr(current, part, DictConfig({})) current = getattr(current, part) final_key = parts[-1] - if not hasattr(current, final_key) or getattr( - current, final_key) is None: + if not hasattr(current, final_key) or getattr(current, final_key) is None: logger.info(f'Adding new config key: {key}') # Convert temperature to float and max_tokens to int if they're numeric strings value_to_set = value - if final_key == 'temperature' and isinstance( - value_to_set, str): + if final_key == 'temperature' and isinstance(value_to_set, str): try: value_to_set = float(value_to_set) except (ValueError, TypeError): pass - elif final_key == 'max_tokens' and isinstance( - value_to_set, str): + elif final_key == 'max_tokens' and isinstance(value_to_set, str): try: value_to_set = int(value_to_set) except (ValueError, TypeError): @@ -244,9 +231,7 @@ def traverse_config(_config: Union[DictConfig, ListConfig, Any], return None @staticmethod - def convert_mcp_servers_to_json( - config: Union[DictConfig, - ListConfig]) -> Dict[str, Dict[str, Any]]: + def convert_mcp_servers_to_json(config: Union[DictConfig, ListConfig]) -> Dict[str, Dict[str, Any]]: """Convert the mcp servers to json mcp config.""" servers = {'mcpServers': {}} if getattr(config, 'tools', None): diff --git a/ms_agent/config/env.py b/ms_agent/config/env.py index 83553254c..658c58a18 100644 --- a/ms_agent/config/env.py +++ b/ms_agent/config/env.py @@ -7,7 +7,6 @@ class Env: - @staticmethod def load_dotenv_into_environ(dotenv_path: Optional[str] = None) -> None: """Load key=value pairs from a .env file into ``os.environ``. @@ -29,8 +28,7 @@ def load_dotenv_into_environ(dotenv_path: Optional[str] = None) -> None: load_dotenv(default, override=False) @staticmethod - def load_env(envs: Dict[str, str] = None, - dotenv_path: Optional[str] = None) -> Dict[str, str]: + def load_env(envs: Dict[str, str] = None, dotenv_path: Optional[str] = None) -> Dict[str, str]: """Load .env into the process env, then merge with ``envs`` and return.""" Env.load_dotenv_into_environ(dotenv_path) _envs = copy(os.environ) diff --git a/ms_agent/llm/anthropic_llm.py b/ms_agent/llm/anthropic_llm.py index 5b35bfb5d..5df3c9cfa 100644 --- a/ms_agent/llm/anthropic_llm.py +++ b/ms_agent/llm/anthropic_llm.py @@ -1,13 +1,14 @@ import inspect +import json from typing import Any, Dict, Generator, Iterator, List, Optional, Union import httpx -import json +from omegaconf import DictConfig, OmegaConf + from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool, ToolCall from ms_agent.utils import assert_package_exist, retry from ms_agent.utils.constants import get_service_config -from omegaconf import DictConfig, OmegaConf class _SSEEventInjector(httpx.SyncByteStream): @@ -61,10 +62,7 @@ class DashScopeAnthropicTransport(httpx.BaseTransport): rewrites URL, auth headers, and body so the Anthropic SDK works unmodified. """ - def __init__(self, - dashscope_url: str, - api_key: str, - supplier: Optional[str] = None): + def __init__(self, dashscope_url: str, api_key: str, supplier: Optional[str] = None): self.dashscope_url = dashscope_url self.api_key = api_key self.supplier = supplier @@ -83,10 +81,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: 'content-type': 'application/json', 'authorization': f'Bearer {self.api_key}', } - _skip = frozenset({ - 'x-api-key', 'content-type', 'authorization', 'content-length', - 'host', 'transfer-encoding' - }) + _skip = frozenset({'x-api-key', 'content-type', 'authorization', 'content-length', 'host', 'transfer-encoding'}) for key, value in request.headers.items(): k = key.lower() if k not in _skip and not k.startswith('anthropic'): @@ -116,7 +111,6 @@ def close(self): class Anthropic(LLM): - def __init__( self, config: DictConfig, @@ -129,8 +123,7 @@ def __init__( self.model: str = config.llm.model - base_url = base_url or config.llm.get( - 'anthropic_base_url') or get_service_config('anthropic').base_url + base_url = base_url or config.llm.get('anthropic_base_url') or get_service_config('anthropic').base_url api_key = api_key or config.llm.get('anthropic_api_key') if not api_key: @@ -162,30 +155,28 @@ def __init__( base_url=base_url, ) - self.args: Dict = OmegaConf.to_container( - getattr(config, 'generation_config', DictConfig({}))) + self.args: Dict = OmegaConf.to_container(getattr(config, 'generation_config', DictConfig({}))) - def format_tools(self, - tools: Optional[List[Tool]]) -> Optional[List[Dict]]: + def format_tools(self, tools: Optional[List[Tool]]) -> Optional[List[Dict]]: if not tools: return None formatted_tools = [] for tool in tools: - formatted_tools.append({ - 'name': tool['tool_name'], - 'description': tool.get('description', ''), - 'input_schema': { - 'type': 'object', - 'properties': tool.get('parameters', - {}).get('properties', {}), - 'required': tool.get('parameters', {}).get('required', []), + formatted_tools.append( + { + 'name': tool['tool_name'], + 'description': tool.get('description', ''), + 'input_schema': { + 'type': 'object', + 'properties': tool.get('parameters', {}).get('properties', {}), + 'required': tool.get('parameters', {}).get('required', []), + }, } - }) + ) return formatted_tools - def _format_input_message(self, - messages: List[Message]) -> List[Dict[str, Any]]: + def _format_input_message(self, messages: List[Message]) -> List[Dict[str, Any]]: """Converts a list of Message objects into the format expected by the Anthropic API. Args: @@ -203,34 +194,30 @@ def _format_input_message(self, if msg.tool_calls: for tool_call in msg.tool_calls: - content.append({ - 'type': 'tool_use', - 'id': tool_call['id'], - 'name': tool_call['tool_name'], - 'input': tool_call.get('arguments', {}) - }) + content.append( + { + 'type': 'tool_use', + 'id': tool_call['id'], + 'name': tool_call['tool_name'], + 'input': tool_call.get('arguments', {}), + } + ) if msg.role == 'tool': - formatted_messages.append({ - 'role': - 'user', - 'content': [{ - 'type': 'tool_result', - 'tool_use_id': msg.tool_call_id, - 'content': msg.content - }] - }) + formatted_messages.append( + { + 'role': 'user', + 'content': [{'type': 'tool_result', 'tool_use_id': msg.tool_call_id, 'content': msg.content}], + } + ) continue formatted_messages.append({'role': msg.role, 'content': content}) return formatted_messages - def _call_llm(self, - messages: List[Message], - tools: Optional[List[Dict]] = None, - stream: bool = False, - **kwargs) -> Any: - + def _call_llm( + self, messages: List[Message], tools: Optional[List[Dict]] = None, stream: bool = False, **kwargs + ) -> Any: formatted_messages = self._format_input_message(messages) formatted_messages = [m for m in formatted_messages if m['content']] @@ -246,21 +233,14 @@ def _call_llm(self, thinking_type = kwargs.pop('thinking_type', None) raw_extra_body = kwargs.pop('extra_body', {}) or {} - extra_body = dict(raw_extra_body) if isinstance(raw_extra_body, - dict) else {} - enable_thinking = bool( - extra_body.pop('enable_thinking', enable_thinking)) - thinking_budget = extra_body.pop('thinking_budget', - thinking_budget) or max_tokens + extra_body = dict(raw_extra_body) if isinstance(raw_extra_body, dict) else {} + enable_thinking = bool(extra_body.pop('enable_thinking', enable_thinking)) + thinking_budget = extra_body.pop('thinking_budget', thinking_budget) or max_tokens thinking_type = extra_body.pop('thinking_type', thinking_type) for _k in ('show_reasoning', 'reasoning_output'): extra_body.pop(_k, None) - params = { - 'model': self.model, - 'messages': formatted_messages, - 'max_tokens': max_tokens - } + params = {'model': self.model, 'messages': formatted_messages, 'max_tokens': max_tokens} if thinking_type == 'adaptive': params['thinking'] = {'type': 'adaptive'} @@ -284,12 +264,13 @@ def _call_llm(self, return self.client.messages.create(**params) @retry(max_attempts=LLM.retry_count, delay=1.0) - def generate(self, - messages: List[Message], - tools: Optional[List[Tool]] = None, - max_continue_runs: Optional[int] = None, - **kwargs) -> Union[Message, Generator[Message, None, None]]: - + def generate( + self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + max_continue_runs: Optional[int] = None, + **kwargs, + ) -> Union[Message, Generator[Message, None, None]]: formatted_tools = self.format_tools(tools) args = self.args.copy() args.update(kwargs) @@ -298,16 +279,14 @@ def generate(self, sig_params = inspect.signature(self.client.messages.create).parameters filtered_args = {k: v for k, v in args.items() if k in sig_params} - completion = self._call_llm(messages, formatted_tools, stream, - **filtered_args) + completion = self._call_llm(messages, formatted_tools, stream, **filtered_args) if stream: return self._stream_format_output_message(completion) else: return self._format_output_message(completion) - def _stream_format_output_message(self, - stream_manager) -> Iterator[Message]: + def _stream_format_output_message(self, stream_manager) -> Iterator[Message]: current_message = Message( role='assistant', content='', @@ -360,11 +339,11 @@ def _stream_format_output_message(self, current_message.content = full_content current_message.partial = False current_message.completion_tokens = getattr( - final_msg.usage, 'output_tokens', - current_message.completion_tokens) + final_msg.usage, 'output_tokens', current_message.completion_tokens + ) current_message.prompt_tokens = getattr( - final_msg.usage, 'input_tokens', - current_message.prompt_tokens) + final_msg.usage, 'input_tokens', current_message.prompt_tokens + ) yield current_message @@ -392,11 +371,11 @@ def _format_output_message(completion) -> Message: ToolCall( id=block.id, index=len(tool_calls), # index based on appearance - type= - 'function', # or "tool_use" depending on your schema + type='function', # or "tool_use" depending on your schema arguments=block.input, tool_name=block.name, - )) + ) + ) # Anthropic does not have a native "reasoning_content" field reasoning_content = '' @@ -414,34 +393,31 @@ def _format_output_message(completion) -> Message: if __name__ == '__main__': import os + config = { 'llm': { 'model': 'Qwen/Qwen2.5-VL-72B-Instruct', 'anthropic_api_key': os.getenv('MODELSCOPE_API_KEY'), - 'anthropic_base_url': 'https://api-inference.modelscope.cn' + 'anthropic_base_url': 'https://api-inference.modelscope.cn', }, 'generation_config': { 'stream': True, - } + }, } - tools = [{ - 'tool_name': 'get_weather', - 'description': 'Get the current weather in a given location', - 'parameters': { - 'type': 'object', - 'properties': { - 'location': { - 'type': 'string', - 'description': 'City and state' + tools = [ + { + 'tool_name': 'get_weather', + 'description': 'Get the current weather in a given location', + 'parameters': { + 'type': 'object', + 'properties': { + 'location': {'type': 'string', 'description': 'City and state'}, + 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}, }, - 'unit': { - 'type': 'string', - 'enum': ['celsius', 'fahrenheit'] - } + 'required': ['location'], }, - 'required': ['location'] } - }] + ] messages = [Message(role='user', content='描述杭州,300字')] # messages = [Message(role='user', content='去伦敦现在该带什么样的衣服?')] diff --git a/ms_agent/llm/dashscope_llm.py b/ms_agent/llm/dashscope_llm.py index b4a6ddaa8..d50ca7edb 100644 --- a/ms_agent/llm/dashscope_llm.py +++ b/ms_agent/llm/dashscope_llm.py @@ -1,29 +1,24 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import List +from omegaconf import DictConfig + from ms_agent.llm.openai_llm import OpenAI from ms_agent.llm.utils import Message, Tool from ms_agent.utils.constants import get_service_config -from omegaconf import DictConfig class DashScope(OpenAI): - def __init__(self, config: DictConfig): super().__init__( config, - base_url=config.llm.dashscope_base_url - or get_service_config('dashscope').base_url, - api_key=config.llm.dashscope_api_key) + base_url=config.llm.dashscope_base_url or get_service_config('dashscope').base_url, + api_key=config.llm.dashscope_api_key, + ) - def _call_llm_for_continue_gen(self, - messages: List[Message], - new_message, - tools: List[Tool] = None, - **kwargs): + def _call_llm_for_continue_gen(self, messages: List[Message], new_message, tools: List[Tool] = None, **kwargs): # ref: https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=https%3A%2F%2Fhelp.aliyun.com%2Fdocument_detail%2F2862210.html&renderType=iframe # noqa if messages and messages[-1].to_dict().get('partial', False): - messages[-1].reasoning_content += new_message.reasoning_content messages[-1].content += new_message.content if new_message.tool_calls: diff --git a/ms_agent/llm/deepseek_llm.py b/ms_agent/llm/deepseek_llm.py index e565308bc..d379debd3 100644 --- a/ms_agent/llm/deepseek_llm.py +++ b/ms_agent/llm/deepseek_llm.py @@ -1,28 +1,21 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import List +from omegaconf import DictConfig + from ms_agent.llm.openai_llm import OpenAI from ms_agent.llm.utils import Message, Tool -from omegaconf import DictConfig class DeepSeek(OpenAI): input_msg = {'role', 'content', 'tool_calls', 'prefix'} def __init__(self, config: DictConfig): - super().__init__( - config, - base_url=config.llm.deepseek_base_url, - api_key=config.llm.deepseek_api_key) - - def _call_llm_for_continue_gen(self, - messages: List[Message], - new_message, - tools: List[Tool] = None, - **kwargs): + super().__init__(config, base_url=config.llm.deepseek_base_url, api_key=config.llm.deepseek_api_key) + + def _call_llm_for_continue_gen(self, messages: List[Message], new_message, tools: List[Tool] = None, **kwargs): # ref: https://api-docs.deepseek.com/zh-cn/guides/chat_prefix_completion if messages and messages[-1].to_dict().get('prefix', False): - messages[-1].reasoning_content += new_message.reasoning_content messages[-1].content += new_message.content if new_message.tool_calls: @@ -36,28 +29,30 @@ def _call_llm_for_continue_gen(self, messages = self.format_input_message(messages) stop = kwargs.pop('stop', []).append('```') - return self._call_llm( - messages=messages, tools=tools, stop=stop, **kwargs) + return self._call_llm(messages=messages, tools=tools, stop=stop, **kwargs) if __name__ == '__main__': import os + from omegaconf import OmegaConf # 创建一个嵌套的字典结构 - conf: DictConfig = OmegaConf.create({ - 'llm': { - 'model': 'deepseek-reasoner', - 'deepseek_base_url': 'https://api.deepseek.com/beta/v1', - 'deepseek_api_key': os.getenv('DEEPSEEK_API_KEY'), - 'openai_base_url': 'https://api-inference.modelscope.cn/v1', - 'openai_api_key': os.getenv('MODELSCOPE_API_KEY'), - 'generation_config': { - 'stream': True, - 'max_tokens': 500, + conf: DictConfig = OmegaConf.create( + { + 'llm': { + 'model': 'deepseek-reasoner', + 'deepseek_base_url': 'https://api.deepseek.com/beta/v1', + 'deepseek_api_key': os.getenv('DEEPSEEK_API_KEY'), + 'openai_base_url': 'https://api-inference.modelscope.cn/v1', + 'openai_api_key': os.getenv('MODELSCOPE_API_KEY'), + 'generation_config': { + 'stream': True, + 'max_tokens': 500, + }, } } - }) + ) messages = [ Message(role='assistant', content='You are a helpful assistant.'), @@ -86,11 +81,7 @@ def _call_llm_for_continue_gen(self, # print(chunk) # kwargs覆盖conf - message = llm.generate( - messages=messages, - tools=tools, - stream=False, - extra_body={'enable_thinking': False}) + message = llm.generate(messages=messages, tools=tools, stream=False, extra_body={'enable_thinking': False}) print(message) messages.append(message) # messages.append(Message(role='tool', content='北京市朝阳区崔各庄阿里巴巴朝阳科技园')) diff --git a/ms_agent/llm/llm.py b/ms_agent/llm/llm.py index 72af53467..e6998c2de 100644 --- a/ms_agent/llm/llm.py +++ b/ms_agent/llm/llm.py @@ -3,15 +3,15 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional -from ms_agent.config import Config from omegaconf import DictConfig +from ms_agent.config import Config + from ..utils.constants import DEFAULT_RETRY_COUNT from .utils import Message, Tool class LLM: - retry_count = int(os.environ.get('LLM_RETRY_COUNT', DEFAULT_RETRY_COUNT)) def __init__(self, config: DictConfig): @@ -23,11 +23,9 @@ def __init__(self, config: DictConfig): self.config = config @abstractmethod - def generate(self, - messages: List[Message], - model: Optional[str] = None, - tools: Optional[List[Tool]] = None, - **kwargs) -> Any: + def generate( + self, messages: List[Message], model: Optional[str] = None, tools: Optional[List[Tool]] = None, **kwargs + ) -> Any: """Generate response by the given messages. Args: @@ -42,10 +40,7 @@ def generate(self, pass @classmethod - def from_task(cls, - config_dir_or_id: str, - *, - env: Optional[Dict[str, str]] = None) -> Any: + def from_task(cls, config_dir_or_id: str, *, env: Optional[Dict[str, str]] = None) -> Any: """Instantiate an LLM instance. Args: @@ -69,7 +64,8 @@ def from_config(cls, config: DictConfig) -> Any: Returns: The LLM instance. """ - from .model_mapping import all_services_mapping, OpenAI + from .model_mapping import OpenAI, all_services_mapping + if config.llm.get('service') in all_services_mapping: return all_services_mapping[config.llm.service](config) else: diff --git a/ms_agent/llm/modelscope_llm.py b/ms_agent/llm/modelscope_llm.py index 7b761c5c0..e1ba329ab 100644 --- a/ms_agent/llm/modelscope_llm.py +++ b/ms_agent/llm/modelscope_llm.py @@ -1,17 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig + from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils.constants import get_service_config -from omegaconf import DictConfig class ModelScope(OpenAI): - def __init__(self, config: DictConfig): - assert hasattr( - config.llm, 'modelscope_api_key' - ) and config.llm.modelscope_api_key is not None, 'Please provide `modelscope_api_key` in env or cmd.' + assert hasattr(config.llm, 'modelscope_api_key') and config.llm.modelscope_api_key is not None, ( + 'Please provide `modelscope_api_key` in env or cmd.' + ) super().__init__( config, - base_url=config.llm.modelscope_base_url - or get_service_config('modelscope').base_url, - api_key=config.llm.modelscope_api_key) + base_url=config.llm.modelscope_base_url or get_service_config('modelscope').base_url, + api_key=config.llm.modelscope_api_key, + ) diff --git a/ms_agent/llm/openai.py b/ms_agent/llm/openai.py index 390f1d4ab..6b44c61b9 100644 --- a/ms_agent/llm/openai.py +++ b/ms_agent/llm/openai.py @@ -1,11 +1,11 @@ # flake8: noqa +import json import uuid +from openai import OpenAI, Stream +from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall from typing import TYPE_CHECKING, Any, Dict, List, Literal -import json from ms_agent.utils.logger import get_logger -from openai import OpenAI, Stream -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall logger = get_logger() @@ -14,12 +14,7 @@ class OpenAIChat: - - def __init__(self, - api_key: str = None, - base_url: str = None, - model: str = None, - **kwargs): + def __init__(self, api_key: str = None, base_url: str = None, model: str = None, **kwargs): """ Initialize the OpenAIChat client. """ @@ -31,31 +26,25 @@ def __init__(self, self._model = model self._kwargs = kwargs - def chat(self, - messages: List[Dict[str, Any]], - tools: List[Dict[str, Any]] = None, - **kwargs) -> Dict[str, Any]: - + def chat(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, Any]: completion: ChatCompletion = self._client.chat.completions.create( - messages=messages, model=self._model, tools=tools, **kwargs) + messages=messages, model=self._model, tools=tools, **kwargs + ) res_d: Dict[str, Any] = dict( role='assistant', reasoning_content='', content=completion.choices[0].message.content, - tool_calls=completion.choices[0].message.tool_calls if hasattr( - completion.choices[0].message, 'tool_calls') else [], - finish_reason=completion.choices[0]. - finish_reason, # 'stop', 'tool_calls', 'length', None + tool_calls=completion.choices[0].message.tool_calls + if hasattr(completion.choices[0].message, 'tool_calls') + else [], + finish_reason=completion.choices[0].finish_reason, # 'stop', 'tool_calls', 'length', None usage=completion.usage.to_dict(), ) return res_d - def chat_stream(self, - messages: List[Dict[str, Any]], - tools: List[Dict[str, Any]] = None, - **kwargs): + def chat_stream(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]] = None, **kwargs): """ Get chat response from OpenAI API using streaming. @@ -98,9 +87,7 @@ def chat_stream(self, if 'stream' not in kwargs: kwargs['stream'] = True - assert kwargs.get( - 'stream', True - ), "Streaming must be enabled by setting 'stream=True' in kwargs." + assert kwargs.get('stream', True), "Streaming must be enabled by setting 'stream=True' in kwargs." logger.info(f"Temperature: {kwargs.get('temperature', -1)}") @@ -110,7 +97,8 @@ def chat_stream(self, tools=tools, # Note: Gemini2.5-Pro does not support parallel_tool_calls # parallel_tool_calls=True, - **kwargs) + **kwargs, + ) res_d: Dict[str, Any] = dict( role='assistant', @@ -118,11 +106,7 @@ def chat_stream(self, content='', tool_calls=[], finish_reason=None, # 'stop', 'tool_calls', 'length', None - usage={ - 'completion_tokens': 0, - 'prompt_tokens': 0, - 'total_tokens': 0 - }, + usage={'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0}, ) for chunk in completion: @@ -134,24 +118,21 @@ def chat_stream(self, delta = chunk.choices[0].delta res_d['role'] = delta.role - res_d['reasoning_content'] = delta.reasoning_content if hasattr( - delta, 'reasoning_content') else '' + res_d['reasoning_content'] = delta.reasoning_content if hasattr(delta, 'reasoning_content') else '' res_d['content'] = delta.content - res_d['tool_calls'] = delta.tool_calls if hasattr( - delta, 'tool_calls') else [] + res_d['tool_calls'] = delta.tool_calls if hasattr(delta, 'tool_calls') else [] res_d['finish_reason'] = chunk.choices[0].finish_reason if hasattr(chunk, 'usage') and chunk.usage: res_d['usage'] = { 'completion_tokens': chunk.usage.completion_tokens, 'prompt_tokens': chunk.usage.prompt_tokens, - 'total_tokens': chunk.usage.total_tokens + 'total_tokens': chunk.usage.total_tokens, } yield res_d @staticmethod - def aggregate_stream_chunks( - stream_chunks: List[Dict[str, Any]]) -> Dict[str, Any]: + def aggregate_stream_chunks(stream_chunks: List[Dict[str, Any]]) -> Dict[str, Any]: """ Aggregate the streaming chunks into a single response dictionary within current round of chat. @@ -174,40 +155,29 @@ def aggregate_stream_chunks( content='', tool_calls=[], finish_reason=None, # 'stop', 'tool_calls', 'length', None - usage={ - 'completion_tokens': 0, - 'prompt_tokens': 0, - 'total_tokens': 0 - }, + usage={'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0}, ) for chunk_d in stream_chunks: res_d['role'] = chunk_d.get('role') - res_d['reasoning_content'] += chunk_d.get( - 'reasoning_content', - '') if chunk_d.get('reasoning_content') is not None else '' - res_d['content'] += chunk_d.get( - 'content', '') if chunk_d.get('content') is not None else '' + res_d['reasoning_content'] += ( + chunk_d.get('reasoning_content', '') if chunk_d.get('reasoning_content') is not None else '' + ) + res_d['content'] += chunk_d.get('content', '') if chunk_d.get('content') is not None else '' if chunk_d.get('tool_calls') is not None: res_d['tool_calls'].extend(chunk_d.get('tool_calls', [])) - res_d['finish_reason'] = chunk_d.get('finish_reason', - res_d['finish_reason']) + res_d['finish_reason'] = chunk_d.get('finish_reason', res_d['finish_reason']) # Get the last usage information as final usage for current round (consider cache tokens) if chunk_d.get('usage') is not None: - res_d['usage']['completion_tokens'] = chunk_d['usage'].get( - 'completion_tokens', 0) - res_d['usage']['prompt_tokens'] = chunk_d['usage'].get( - 'prompt_tokens', 0) - res_d['usage']['total_tokens'] = chunk_d['usage'].get( - 'total_tokens', 0) + res_d['usage']['completion_tokens'] = chunk_d['usage'].get('completion_tokens', 0) + res_d['usage']['prompt_tokens'] = chunk_d['usage'].get('prompt_tokens', 0) + res_d['usage']['total_tokens'] = chunk_d['usage'].get('total_tokens', 0) return res_d @staticmethod - def convert_message(role: Literal['assistant', 'tool'], - round_message: Dict[str, Any]) -> Dict[str, Any]: - + def convert_message(role: Literal['assistant', 'tool'], round_message: Dict[str, Any]) -> Dict[str, Any]: if role == 'assistant': res_msg: Dict[str, Any] = { 'role': 'assistant', @@ -220,34 +190,30 @@ def convert_message(role: Literal['assistant', 'tool'], if isinstance(tool_call, ChoiceDeltaToolCall): if not tool_call.id: tool_call.id = f'tc_{uuid.uuid4().hex}' - tool_call = tool_call.model_dump( - include=['id', 'index', 'type', 'function']) + tool_call = tool_call.model_dump(include=['id', 'index', 'type', 'function']) else: - raise ValueError( - f'Unsupported tool call type: {type(tool_call)}. Expected ChoiceDeltaToolCall.' - ) + raise ValueError(f'Unsupported tool call type: {type(tool_call)}. Expected ChoiceDeltaToolCall.') tmp_tool_calls.append(tool_call) res_msg['tool_calls'] = tmp_tool_calls elif role == 'tool': # TODO: tbd ... - raise ValueError( - '`tool message` is to be implemented in the future.') + raise ValueError('`tool message` is to be implemented in the future.') else: - raise ValueError( - f"Unsupported role: {role}. Supported roles are 'assistant' and 'tool' for now." - ) + raise ValueError(f"Unsupported role: {role}. Supported roles are 'assistant' and 'tool' for now.") return res_msg - def chat_stream_mt(self, - messages: List[Dict[str, Any]], - available_functions: Dict[str, Any], - tools: List[Dict[str, Any]] = None, - history: List[Dict[str, Any]] = None, - **kwargs): + def chat_stream_mt( + self, + messages: List[Dict[str, Any]], + available_functions: Dict[str, Any], + tools: List[Dict[str, Any]] = None, + history: List[Dict[str, Any]] = None, + **kwargs, + ): """ Get chat response from OpenAI API using streaming for multi-turn chat. """ @@ -258,15 +224,10 @@ def chat_stream_mt(self, # Add a system message if not present roles: List[str] = [msg['role'] for msg in messages] if 'system' not in roles: - system_message: Dict[str, Any] = { - 'role': 'system', - 'content': 'You are a helpful assistant.' - } + system_message: Dict[str, Any] = {'role': 'system', 'content': 'You are a helpful assistant.'} messages.insert(0, system_message) - assert len( - messages - ) >= 2, 'At least two messages are required: user and system' + assert len(messages) >= 2, 'At least two messages are required: user and system' ## User Message history.extend(messages) @@ -282,16 +243,13 @@ def chat_stream_mt(self, for chunk_d in self.chat_stream(messages, tools, **kwargs): streaming_chunks.append(chunk_d) - round_d: Dict[str, Any] = self.aggregate_stream_chunks( - streaming_chunks) + round_d: Dict[str, Any] = self.aggregate_stream_chunks(streaming_chunks) yield round_d # Convert `round_d` to OpenAI's chat messages format ## Assistant Message if round_d['role'] == 'assistant': - - assistant_message = self.convert_message( - role='assistant', round_message=round_d) + assistant_message = self.convert_message(role='assistant', round_message=round_d) history.append(assistant_message) # Execute tool calls and append the tool messages @@ -299,11 +257,9 @@ def chat_stream_mt(self, for tool_call in assistant_message.get('tool_calls', []): if tool_call['type'] == 'function': function_name = tool_call['function']['name'] - function_args = json.loads( - tool_call['function']['arguments']) + function_args = json.loads(tool_call['function']['arguments']) # Call the function and get the result - tool_call_result = available_functions[ - function_name](**function_args) + tool_call_result = available_functions[function_name](**function_args) # Construct a tool message with the result # TODO: Check the `tool_call_id` is empty ? @@ -316,9 +272,7 @@ def chat_stream_mt(self, history.append(tool_message) # If the response is complete, break the loop - if round_d['finish_reason'] in [ - 'stop', 'tool_calls', 'length' - ]: + if round_d['finish_reason'] in ['stop', 'tool_calls', 'length']: break except Exception as e: @@ -327,15 +281,9 @@ def chat_stream_mt(self, # Note: must contain role=assistant(with tool_calls) and role=tool if history[-1]['role'] == 'tool': - messages = history + [{ - 'role': - 'user', - 'content': - 'Please output the tool calling results very briefly.' - }] - round_item: dict = self.aggregate_stream_chunks([ - chunk_item for chunk_item in self.chat_stream( - messages=messages, tools=tools, **kwargs) - ]) + messages = history + [{'role': 'user', 'content': 'Please output the tool calling results very briefly.'}] + round_item: dict = self.aggregate_stream_chunks( + [chunk_item for chunk_item in self.chat_stream(messages=messages, tools=tools, **kwargs)] + ) yield round_item diff --git a/ms_agent/llm/openai_llm.py b/ms_agent/llm/openai_llm.py index fa2df6004..883cafab4 100644 --- a/ms_agent/llm/openai_llm.py +++ b/ms_agent/llm/openai_llm.py @@ -1,18 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import inspect +import json from copy import deepcopy from typing import Any, Dict, Generator, Iterable, List, Optional import httpx -import json +from omegaconf import DictConfig, OmegaConf +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function + from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool, ToolCall -from ms_agent.utils import (MAX_CONTINUE_RUNS, assert_package_exist, - get_logger, retry) +from ms_agent.utils import MAX_CONTINUE_RUNS, assert_package_exist, get_logger, retry from ms_agent.utils.constants import get_service_config -from omegaconf import DictConfig, OmegaConf -from openai.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, Function) logger = get_logger() @@ -28,8 +27,7 @@ class _DashScopeResponsesTransport(httpx.HTTPTransport): def handle_request(self, request): if b'/v1/responses' in request.url.raw_path: - new_path = request.url.raw_path.replace(b'/v1/responses', - b'/v1/chat/completions') + new_path = request.url.raw_path.replace(b'/v1/responses', b'/v1/chat/completions') request.url = request.url.copy_with(raw_path=new_path) return super().handle_request(request) @@ -51,9 +49,8 @@ class OpenAI(LLM): base_url (`Optional[str]`): Custom base URL for the API endpoint. Defaults to None. api_key (`Optional[str]`): Authentication key for the API. Defaults to None. """ - input_msg = { - 'role', 'content', 'tool_calls', 'partial', 'prefix', 'tool_call_id' - } + + input_msg = {'role', 'content', 'tool_calls', 'partial', 'prefix', 'tool_call_id'} # Providers that support cache_control in structured content blocks CACHE_CONTROL_PROVIDERS = ['dashscope', 'anthropic'] @@ -67,12 +64,10 @@ def __init__( super().__init__(config) assert_package_exist('openai') import openai + self.model: str = config.llm.model - self.max_continue_runs = getattr(config.llm, 'max_continue_runs', - None) or MAX_CONTINUE_RUNS - base_url = base_url or getattr( - config.llm, 'openai_base_url', - None) or get_service_config('openai').base_url + self.max_continue_runs = getattr(config.llm, 'max_continue_runs', None) or MAX_CONTINUE_RUNS + base_url = base_url or getattr(config.llm, 'openai_base_url', None) or get_service_config('openai').base_url api_key = api_key or getattr(config.llm, 'openai_api_key', None) self.client = openai.OpenAI( @@ -80,21 +75,17 @@ def __init__( base_url=base_url, ) self.base_url = base_url or '' - self.args: Dict = OmegaConf.to_container( - getattr(config, 'generation_config', DictConfig({}))) + self.args: Dict = OmegaConf.to_container(getattr(config, 'generation_config', DictConfig({}))) # Responses API support - self._use_responses_api = bool( - self.args.get('use_responses_api', False)) + self._use_responses_api = bool(self.args.get('use_responses_api', False)) self._responses_client = None - self._responses_state_mode = str( - self.args.get('responses_state_mode', 'stateless')).lower() + self._responses_state_mode = str(self.args.get('responses_state_mode', 'stateless')).lower() if self._responses_state_mode == 'stateful': self._responses_state_mode = 'previous_response_id' if self._use_responses_api: - self._is_dashscope = bool(base_url - and 'dashscope' in base_url.lower()) + self._is_dashscope = bool(base_url and 'dashscope' in base_url.lower()) if self._is_dashscope: http_client = httpx.Client( transport=_DashScopeResponsesTransport(), @@ -116,8 +107,7 @@ def __init__( # - Special values: 'last_message' (only cache the last message in the list) # Default: ['system'] - system prompt is usually the longest stable prefix self._prefix_cache_enabled = self.args.get('force_prefix_cache', False) - self._prefix_cache_roles = set( - self.args.get('prefix_cache_roles', ['system'])) + self._prefix_cache_roles = set(self.args.get('prefix_cache_roles', ['system'])) self._prefix_cache_provider = self._detect_cache_provider() def _detect_cache_provider(self) -> Optional[str]: @@ -171,8 +161,7 @@ def _to_structured_content( # Add cache_control to text blocks that don't have it new_list = [] for item in content: - if (isinstance(item, dict) and item.get('type') == 'text' - and 'cache_control' not in item): + if isinstance(item, dict) and item.get('type') == 'text' and 'cache_control' not in item: new_item = dict(item) new_item['cache_control'] = {'type': 'ephemeral'} new_list.append(new_item) @@ -183,9 +172,7 @@ def _to_structured_content( # Other types: return as-is return content - def format_tools(self, - tools: Optional[List[Tool]] = None - ) -> List[Dict[str, Any]]: + def format_tools(self, tools: Optional[List[Tool]] = None) -> List[Dict[str, Any]]: """Formats a list of tools into the structure expected by the OpenAI API. If server_name is present in a tool, it will be used as a prefix for the function name. @@ -197,24 +184,29 @@ def format_tools(self, List[Dict[str, Any]]: A list of formatted tool definitions suitable for OpenAI API. """ if tools: - tools = [{ - 'type': 'function', - 'function': { - 'name': tool['tool_name'], - 'description': tool['description'], - 'parameters': tool['parameters'] + tools = [ + { + 'type': 'function', + 'function': { + 'name': tool['tool_name'], + 'description': tool['description'], + 'parameters': tool['parameters'], + }, } - } for tool in tools] + for tool in tools + ] else: tools = None return tools @retry(max_attempts=LLM.retry_count, delay=1.0) - def generate(self, - messages: List[Message], - tools: Optional[List[Tool]] = None, - max_continue_runs: Optional[int] = None, - **kwargs) -> Message | Generator[Message, None, None]: + def generate( + self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + max_continue_runs: Optional[int] = None, + **kwargs, + ) -> Message | Generator[Message, None, None]: """Generates a response based on the given conversation history and optional tools. Args: @@ -236,24 +228,17 @@ def generate(self, else: return self._responses_generate(messages, tools, **args) - parameters = inspect.signature( - self.client.chat.completions.create).parameters + parameters = inspect.signature(self.client.chat.completions.create).parameters args = {key: value for key, value in args.items() if key in parameters} completion = self._call_llm(messages, self.format_tools(tools), **args) max_continue_runs = max_continue_runs or self.max_continue_runs if stream: - return self._stream_continue_generate(messages, completion, tools, - max_continue_runs - 1, - **args) + return self._stream_continue_generate(messages, completion, tools, max_continue_runs - 1, **args) else: - return self._continue_generate(messages, completion, tools, - max_continue_runs - 1, **args) + return self._continue_generate(messages, completion, tools, max_continue_runs - 1, **args) - def _call_llm(self, - messages: List[Message], - tools: Optional[List[Tool]] = None, - **kwargs) -> Any: + def _call_llm(self, messages: List[Message], tools: Optional[List[Tool]] = None, **kwargs) -> Any: """Calls the OpenAI chat completion API with the provided messages and tools. Args: @@ -273,8 +258,7 @@ def _call_llm(self, if is_streaming and stream_options_config.get('include_usage', True): kwargs.setdefault('stream_options', {})['include_usage'] = True - return self.client.chat.completions.create( - model=self.model, messages=messages, tools=tools, **kwargs) + return self.client.chat.completions.create(model=self.model, messages=messages, tools=tools, **kwargs) @staticmethod def _extract_cache_info(usage_obj: Any) -> tuple: @@ -300,12 +284,10 @@ def _extract_cache_info(usage_obj: Any) -> tuple: created = int(details.get('cache_creation_input_tokens', 0) or 0) else: cached = int(getattr(details, 'cached_tokens', 0) or 0) - created = int( - getattr(details, 'cache_creation_input_tokens', 0) or 0) + created = int(getattr(details, 'cache_creation_input_tokens', 0) or 0) return cached, created - def _merge_stream_message(self, pre_message_chunk: Optional[Message], - message_chunk: Message) -> Optional[Message]: + def _merge_stream_message(self, pre_message_chunk: Optional[Message], message_chunk: Message) -> Optional[Message]: """Merges a new chunk of message into the previous chunks during streaming. Used to accumulate partial results into a complete Message object. @@ -330,25 +312,17 @@ def _merge_stream_message(self, pre_message_chunk: Optional[Message], message.content += message_chunk.content if message_chunk.tool_calls: if message.tool_calls: - if message.tool_calls[-1]['index'] == message_chunk.tool_calls[ - 0]['index']: + if message.tool_calls[-1]['index'] == message_chunk.tool_calls[0]['index']: if message_chunk.tool_calls[0]['id']: - message.tool_calls[-1][ - 'id'] = message_chunk.tool_calls[0]['id'] + message.tool_calls[-1]['id'] = message_chunk.tool_calls[0]['id'] if message_chunk.tool_calls[0]['arguments']: if message.tool_calls[-1]['arguments']: - message.tool_calls[-1][ - 'arguments'] += message_chunk.tool_calls[0][ - 'arguments'] + message.tool_calls[-1]['arguments'] += message_chunk.tool_calls[0]['arguments'] else: # message.tool_calls[-1]['arguments'] may be None - message.tool_calls[-1][ - 'arguments'] = message_chunk.tool_calls[0][ - 'arguments'] + message.tool_calls[-1]['arguments'] = message_chunk.tool_calls[0]['arguments'] if message_chunk.tool_calls[0]['tool_name']: - message.tool_calls[-1][ - 'tool_name'] = message_chunk.tool_calls[0][ - 'tool_name'] + message.tool_calls[-1]['tool_name'] = message_chunk.tool_calls[0]['tool_name'] else: message.tool_calls.append( ToolCall( @@ -356,17 +330,21 @@ def _merge_stream_message(self, pre_message_chunk: Optional[Message], arguments=message_chunk.tool_calls[0]['arguments'], type='function', tool_name=message_chunk.tool_calls[0]['tool_name'], - index=message_chunk.tool_calls[0]['index'])) + index=message_chunk.tool_calls[0]['index'], + ) + ) else: message.tool_calls = message_chunk.tool_calls return message - def _stream_continue_generate(self, - messages: List[Message], - completion: Iterable, - tools: Optional[List[Tool]] = None, - max_runs: Optional[int] = None, - **kwargs) -> Generator[Message, None, None]: + def _stream_continue_generate( + self, + messages: List[Message], + completion: Iterable, + tools: Optional[List[Tool]] = None, + max_runs: Optional[int] = None, + **kwargs, + ) -> Generator[Message, None, None]: """Recursively continues generating until the model finishes naturally in streaming mode. Args: @@ -388,8 +366,7 @@ def _stream_continue_generate(self, try: next_chunk = next(completion) message.prompt_tokens += next_chunk.usage.prompt_tokens - cached, created = self._extract_cache_info( - getattr(next_chunk, 'usage', None)) + cached, created = self._extract_cache_info(getattr(next_chunk, 'usage', None)) message.cached_tokens += cached message.cache_creation_input_tokens += created message.completion_tokens += next_chunk.usage.completion_tokens @@ -397,21 +374,14 @@ def _stream_continue_generate(self, # The stream may end without a final usage chunk, which is acceptable. pass first_run = not messages[-1].to_dict().get('partial', False) - if chunk.choices[0].finish_reason in [ - 'length', 'null' - ] and (max_runs is None or max_runs != 0): - logger.info( - f'finish_reason: {chunk.choices[0].finish_reason}, continue generate.' - ) - completion = self._call_llm_for_continue_gen( - messages, message, tools, **kwargs) + if chunk.choices[0].finish_reason in ['length', 'null'] and (max_runs is None or max_runs != 0): + logger.info(f'finish_reason: {chunk.choices[0].finish_reason}, continue generate.') + completion = self._call_llm_for_continue_gen(messages, message, tools, **kwargs) for chunk in self._stream_continue_generate( - messages, completion, tools, - max_runs - 1 if max_runs is not None else None, - **kwargs): + messages, completion, tools, max_runs - 1 if max_runs is not None else None, **kwargs + ): if first_run: - yield self._merge_stream_message( - messages[-1], chunk) + yield self._merge_stream_message(messages[-1], chunk) else: yield chunk elif not first_run: @@ -436,8 +406,7 @@ def _stream_format_output_message(completion_chunk) -> Message: content = '' if completion_chunk.choices and completion_chunk.choices[0].delta: content = completion_chunk.choices[0].delta.content - reasoning_content = getattr(completion_chunk.choices[0].delta, - 'reasoning_content', '') + reasoning_content = getattr(completion_chunk.choices[0].delta, 'reasoning_content', '') if completion_chunk.choices[0].delta.tool_calls: func = completion_chunk.choices[0].delta.tool_calls tool_calls = [ @@ -446,7 +415,8 @@ def _stream_format_output_message(completion_chunk) -> Message: index=tool_call.index, type=tool_call.type, arguments=tool_call.function.arguments, - tool_name=tool_call.function.name) + tool_name=tool_call.function.name, + ) for tool_call in func ] content = content or '' @@ -458,23 +428,22 @@ def _stream_format_output_message(completion_chunk) -> Message: tool_calls=tool_calls, id=completion_chunk.id, prompt_tokens=getattr(completion_chunk.usage, 'prompt_tokens', 0), - completion_tokens=getattr(completion_chunk.usage, - 'completion_tokens', 0)) + completion_tokens=getattr(completion_chunk.usage, 'completion_tokens', 0), + ) @staticmethod def _format_output_message(completion) -> Message: """Formats the full non-streaming response into a Message object. - Args: - completion: The raw response from the OpenAI API. + Args: + completion: The raw response from the OpenAI API. - Returns: - Message: A Message object containing the final response. - """ + Returns: + Message: A Message object containing the final response. + """ content = completion.choices[0].message.content or '' if hasattr(completion.choices[0].message, 'reasoning_content'): - reasoning_content = completion.choices[ - 0].message.reasoning_content or '' + reasoning_content = completion.choices[0].message.reasoning_content or '' else: reasoning_content = '' tool_calls = None @@ -485,11 +454,11 @@ def _format_output_message(completion) -> Message: index=getattr(tool_call, 'index', idx), type=tool_call.type, arguments=tool_call.function.arguments, - tool_name=tool_call.function.name) for idx, tool_call in - enumerate(completion.choices[0].message.tool_calls) + tool_name=tool_call.function.name, + ) + for idx, tool_call in enumerate(completion.choices[0].message.tool_calls) ] - cached, created = OpenAI._extract_cache_info( - getattr(completion, 'usage', None)) + cached, created = OpenAI._extract_cache_info(getattr(completion, 'usage', None)) return Message( role='assistant', content=content, @@ -499,7 +468,8 @@ def _format_output_message(completion) -> Message: prompt_tokens=completion.usage.prompt_tokens, cached_tokens=cached, cache_creation_input_tokens=created, - completion_tokens=completion.usage.completion_tokens) + completion_tokens=completion.usage.completion_tokens, + ) @staticmethod def _merge_partial_message(messages: List[Message], new_message: Message): @@ -513,8 +483,7 @@ def _merge_partial_message(messages: List[Message], new_message: Message): messages[-1].content += new_message.content messages[-1].prompt_tokens += new_message.prompt_tokens messages[-1].cached_tokens += new_message.cached_tokens - messages[ - -1].cache_creation_input_tokens += new_message.cache_creation_input_tokens + messages[-1].cache_creation_input_tokens += new_message.cache_creation_input_tokens messages[-1].completion_tokens += new_message.completion_tokens if new_message.tool_calls: if messages[-1].tool_calls: @@ -522,11 +491,9 @@ def _merge_partial_message(messages: List[Message], new_message: Message): else: messages[-1].tool_calls = new_message.tool_calls - def _call_llm_for_continue_gen(self, - messages: List[Message], - new_message: Message, - tools: List[Tool] = None, - **kwargs) -> Any: + def _call_llm_for_continue_gen( + self, messages: List[Message], new_message: Message, tools: List[Tool] = None, **kwargs + ) -> Any: """Prepares and calls the LLM for continuation when the response is unfinished. If the previous message marked as unfinished, it will be updated with the new content. @@ -555,12 +522,9 @@ def _call_llm_for_continue_gen(self, return self._call_llm(messages, tools, **kwargs) - def _continue_generate(self, - messages: List[Message], - completion, - tools: List[Tool] = None, - max_runs: Optional[int] = None, - **kwargs) -> Message: + def _continue_generate( + self, messages: List[Message], completion, tools: List[Tool] = None, max_runs: Optional[int] = None, **kwargs + ) -> Message: """Recursively continues generating until the model finishes naturally. This method checks whether the generation was stopped due to length limitations, @@ -576,17 +540,12 @@ def _continue_generate(self, Message: A fully formed Message object containing the complete response. """ new_message = self._format_output_message(completion) - if completion.choices[0].finish_reason in [ - 'length', 'null' - ] and (max_runs is None or max_runs != 0): - logger.info( - f'finish_reason: {completion.choices[0].finish_reason}, continue generate.' - ) - completion = self._call_llm_for_continue_gen( - messages, new_message, tools, **kwargs) + if completion.choices[0].finish_reason in ['length', 'null'] and (max_runs is None or max_runs != 0): + logger.info(f'finish_reason: {completion.choices[0].finish_reason}, continue generate.') + completion = self._call_llm_for_continue_gen(messages, new_message, tools, **kwargs) return self._continue_generate( - messages, completion, tools, - max_runs - 1 if max_runs is not None else None, **kwargs) + messages, completion, tools, max_runs - 1 if max_runs is not None else None, **kwargs + ) elif messages[-1].to_dict().get('partial', False): self._merge_partial_message(messages, new_message) messages[-1].partial = False @@ -594,8 +553,7 @@ def _continue_generate(self, else: return new_message - def _build_responses_input( - self, messages: List[Message]) -> List[Dict[str, Any]]: + def _build_responses_input(self, messages: List[Message]) -> List[Dict[str, Any]]: """Convert internal Message list to the ``input`` format expected by the Responses API. @@ -610,60 +568,66 @@ def _build_responses_input( items: List[Dict[str, Any]] = [] for msg in messages: if msg.role == 'system': - items.append({ - 'role': 'developer', - 'content': msg.content, - }) + items.append( + { + 'role': 'developer', + 'content': msg.content, + } + ) elif msg.role == 'assistant': if self._responses_state_mode != 'previous_response_id': # Stateless mode needs explicit passback of opaque reasoning # items returned by the previous response. - for raw_item in getattr(msg, '_responses_output_items', - []): + for raw_item in getattr(msg, '_responses_output_items', []): items.append(raw_item) - if msg.content and not self._is_responses_tool_placeholder( - msg): - items.append({ - 'role': 'assistant', - 'content': msg.content, - }) + if msg.content and not self._is_responses_tool_placeholder(msg): + items.append( + { + 'role': 'assistant', + 'content': msg.content, + } + ) if msg.tool_calls: for tc in msg.tool_calls: arguments = tc.get('arguments', '{}') if not isinstance(arguments, str): - arguments = json.dumps( - arguments, ensure_ascii=False) - items.append({ - 'type': 'function_call', - 'call_id': tc.get('id', ''), - 'name': tc.get('tool_name', ''), - 'arguments': arguments, - }) + arguments = json.dumps(arguments, ensure_ascii=False) + items.append( + { + 'type': 'function_call', + 'call_id': tc.get('id', ''), + 'name': tc.get('tool_name', ''), + 'arguments': arguments, + } + ) elif msg.role == 'tool': content = msg.content if not isinstance(content, str): content = json.dumps(content, ensure_ascii=False) - items.append({ - 'type': 'function_call_output', - 'call_id': msg.tool_call_id or '', - 'output': content, - }) + items.append( + { + 'type': 'function_call_output', + 'call_id': msg.tool_call_id or '', + 'output': content, + } + ) else: - items.append({ - 'role': msg.role, - 'content': msg.content, - }) + items.append( + { + 'role': msg.role, + 'content': msg.content, + } + ) return items @staticmethod def _is_responses_tool_placeholder(message: Message) -> bool: """Return True for framework-generated assistant placeholder text.""" - return bool(message.tool_calls - ) and message.content == 'Let me do a tool calling.' + return bool(message.tool_calls) and message.content == 'Let me do a tool calling.' def _prepare_responses_request( - self, messages: List[Message], - args: Dict[str, Any]) -> tuple[List[Message], Dict[str, Any]]: + self, messages: List[Message], args: Dict[str, Any] + ) -> tuple[List[Message], Dict[str, Any]]: """Prepare message slice and request args for Responses API calls.""" request_args = dict(args) @@ -677,22 +641,23 @@ def _prepare_responses_request( msg = messages[idx] if msg.role == 'assistant' and msg.id: request_args['previous_response_id'] = msg.id - return messages[idx + 1:], request_args + return messages[idx + 1 :], request_args return messages, request_args - def _build_responses_tools( - self, - tools: Optional[List[Tool]]) -> Optional[List[Dict[str, Any]]]: + def _build_responses_tools(self, tools: Optional[List[Tool]]) -> Optional[List[Dict[str, Any]]]: """Convert internal Tool list to Responses API function tool format.""" if not tools: return None - return [{ - 'type': 'function', - 'name': t['tool_name'], - 'description': t.get('description', ''), - 'parameters': t.get('parameters', {}), - } for t in tools] + return [ + { + 'type': 'function', + 'name': t['tool_name'], + 'description': t.get('description', ''), + 'parameters': t.get('parameters', {}), + } + for t in tools + ] def _build_responses_kwargs(self, args: Dict) -> Dict: """Filter and reshape generation args for ``responses.create``.""" @@ -742,8 +707,7 @@ def _extract_reasoning_summaries_from_response(response) -> str: return '\n'.join(parts) @staticmethod - def _extract_tool_calls_from_response( - response) -> Optional[List[ToolCall]]: + def _extract_tool_calls_from_response(response) -> Optional[List[ToolCall]]: """Extract tool calls from a completed Responses API object.""" tool_calls: List[ToolCall] = [] for item in getattr(response, 'output', []) or []: @@ -753,13 +717,13 @@ def _extract_tool_calls_from_response( arguments = json.dumps(arguments, ensure_ascii=False) tool_calls.append( ToolCall( - id=getattr(item, 'call_id', '') - or getattr(item, 'id', ''), + id=getattr(item, 'call_id', '') or getattr(item, 'id', ''), index=len(tool_calls), type='function', tool_name=getattr(item, 'name', ''), arguments=arguments, - )) + ) + ) return tool_calls if tool_calls else None @staticmethod @@ -781,10 +745,7 @@ def _to_jsonable(value: Any) -> Any: if isinstance(value, list): return [OpenAI._to_jsonable(item) for item in value] if isinstance(value, dict): - return { - key: OpenAI._to_jsonable(item) - for key, item in value.items() - } + return {key: OpenAI._to_jsonable(item) for key, item in value.items()} if hasattr(value, 'model_dump'): return OpenAI._to_jsonable(value.model_dump()) if hasattr(value, 'to_dict'): @@ -802,10 +763,8 @@ def _collect_passback_items(self, response) -> List[Dict[str, Any]]: item_type = getattr(item, 'type', None) if item_type == 'reasoning': passback_item: Dict[str, Any] = { - 'type': - 'reasoning', - 'summary': - self._to_jsonable(getattr(item, 'summary', []) or []), + 'type': 'reasoning', + 'summary': self._to_jsonable(getattr(item, 'summary', []) or []), } encrypted_content = getattr(item, 'encrypted_content', None) if encrypted_content: @@ -817,13 +776,9 @@ def _collect_passback_items(self, response) -> List[Dict[str, Any]]: items.append(passback_item) return items - def _responses_generate(self, - messages: List[Message], - tools: Optional[List[Tool]] = None, - **args) -> Message: + def _responses_generate(self, messages: List[Message], tools: Optional[List[Tool]] = None, **args) -> Message: """Non-streaming Responses API call.""" - request_messages, request_args = self._prepare_responses_request( - messages, args) + request_messages, request_args = self._prepare_responses_request(messages, args) input_items = self._build_responses_input(request_messages) resp_tools = self._build_responses_tools(tools) kwargs = self._build_responses_kwargs(request_args) @@ -838,8 +793,7 @@ def _responses_generate(self, text = getattr(response, 'output_text', '') or '' reasoning = self._extract_reasoning_summaries_from_response(response) resp_tool_calls = self._extract_tool_calls_from_response(response) - prompt_tokens, completion_tokens = self._extract_usage_from_response( - response) + prompt_tokens, completion_tokens = self._extract_usage_from_response(response) passback = self._collect_passback_items(response) return Message( @@ -863,10 +817,9 @@ def _extract_reasoning_from_item(item) -> str: parts.append(text) return '\n'.join(parts) - def _responses_stream_generate(self, - messages: List[Message], - tools: Optional[List[Tool]] = None, - **args) -> Generator[Message, None, None]: + def _responses_stream_generate( + self, messages: List[Message], tools: Optional[List[Tool]] = None, **args + ) -> Generator[Message, None, None]: """Streaming Responses API call. Yields incremental ``Message`` objects. Reasoning summaries are @@ -874,8 +827,7 @@ def _responses_stream_generate(self, which arrive *before* the first text delta, so the agent layer can display the thinking header before content begins streaming. """ - request_messages, request_args = self._prepare_responses_request( - messages, args) + request_messages, request_args = self._prepare_responses_request(messages, args) input_items = self._build_responses_input(request_messages) resp_tools = self._build_responses_tools(tools) kwargs = self._build_responses_kwargs(request_args) @@ -908,8 +860,7 @@ def _responses_stream_generate(self, summary_text = self._extract_reasoning_from_item(item) if summary_text: reasoning_parts.append(summary_text) - current_message.reasoning_content = '\n'.join( - reasoning_parts) + current_message.reasoning_content = '\n'.join(reasoning_parts) yield current_message elif event_type == 'response.output_text.delta': @@ -932,35 +883,29 @@ def _responses_stream_generate(self, elif event_type == 'response.failed': failed_response = getattr(event, 'response', None) failed_error = getattr(failed_response, 'error', None) - response_error_msg = getattr(failed_error, 'message', - '') or str(failed_error) + response_error_msg = getattr(failed_error, 'message', '') or str(failed_error) if final_response: if not reasoning_parts: - reasoning = self._extract_reasoning_summaries_from_response( - final_response) + reasoning = self._extract_reasoning_summaries_from_response(final_response) if reasoning: current_message.reasoning_content = reasoning - resp_tool_calls = self._extract_tool_calls_from_response( - final_response) + resp_tool_calls = self._extract_tool_calls_from_response(final_response) if resp_tool_calls: current_message.tool_calls = resp_tool_calls passback = self._collect_passback_items(final_response) if passback: current_message._responses_output_items = passback - prompt_tokens, completion_tokens = self._extract_usage_from_response( - final_response) + prompt_tokens, completion_tokens = self._extract_usage_from_response(final_response) current_message.prompt_tokens = prompt_tokens current_message.completion_tokens = completion_tokens current_message.id = getattr(final_response, 'id', '') yield current_message elif response_error_msg: logger.error(f'Responses API failed: {response_error_msg}') - raise RuntimeError( - f'Responses API call failed: {response_error_msg}') + raise RuntimeError(f'Responses API call failed: {response_error_msg}') - def _format_input_message(self, - messages: List[Message]) -> List[Dict[str, Any]]: + def _format_input_message(self, messages: List[Message]) -> List[Dict[str, Any]]: """Converts a list of Message objects into the format expected by the OpenAI API. Args: @@ -982,8 +927,7 @@ def _format_input_message(self, # Check for role-based caching role_cache = self._prefix_cache_roles - {'last_message'} for idx, msg in enumerate(messages): - msg_role = msg.role if isinstance(msg, Message) else msg.get( - 'role', '') + msg_role = msg.role if isinstance(msg, Message) else msg.get('role', '') if msg_role in role_cache: cache_indices.add(idx) cache_indice = max(cache_indices) if cache_indices else None @@ -1007,9 +951,8 @@ def _format_input_message(self, # Only for string content, multimodal content is already structured if cache_indice is not None and idx == cache_indice: content = self._to_structured_content( - content, - add_cache_control=True, - provider=self._prefix_cache_provider) + content, add_cache_control=True, provider=self._prefix_cache_provider + ) # Build the message dict, handling both string and multimodal content formatted_message = {} diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 4ae5833bf..c2320d932 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Optional, Union -import json from typing_extensions import Literal, Required, TypedDict @@ -86,19 +86,20 @@ def to_dict_clean(self): 'function': { 'name': tool_call['tool_name'], 'arguments': tool_call['arguments'], - } + }, } required = ['content', 'role'] # Never send UI-only fields to model providers. rm = [ - 'completion_tokens', 'prompt_tokens', 'api_calls', 'tool_detail', - 'searching_detail', 'search_result', '_responses_output_items', + 'completion_tokens', + 'prompt_tokens', + 'api_calls', + 'tool_detail', + 'searching_detail', + 'search_result', + '_responses_output_items', ] - return { - key: value - for key, value in raw_dict.items() - if (value or key in required) and key not in rm - } + return {key: value for key, value in raw_dict.items() if (value or key in required) and key not in rm} @dataclass @@ -127,9 +128,6 @@ def from_raw(raw): text=str(model_text), resources=raw.get('resources', []), tool_detail=None if td is None else str(td), - extra={ - k: v - for k, v in raw.items() - if k not in ['text', 'resources', 'result', 'tool_detail'] - }) + extra={k: v for k, v in raw.items() if k not in ['text', 'resources', 'result', 'tool_detail']}, + ) raise TypeError('tool_call_result must be str or dict') diff --git a/ms_agent/memory/base.py b/ms_agent/memory/base.py index fde42fb56..ed55e7fe5 100644 --- a/ms_agent/memory/base.py +++ b/ms_agent/memory/base.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod from typing import List +from omegaconf import DictConfig + from ms_agent.llm.utils import Message from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR -from omegaconf import DictConfig class Memory(ABC): @@ -12,8 +13,7 @@ class Memory(ABC): def __init__(self, config): self.config = config - self.output_dir = getattr(self.config, 'output_dir', - DEFAULT_OUTPUT_DIR) + self.output_dir = getattr(self.config, 'output_dir', DEFAULT_OUTPUT_DIR) self.base_config = None @abstractmethod diff --git a/ms_agent/memory/condenser/code_condenser.py b/ms_agent/memory/condenser/code_condenser.py index 12fe626e5..ff4a7d0a9 100644 --- a/ms_agent/memory/condenser/code_condenser.py +++ b/ms_agent/memory/condenser/code_condenser.py @@ -1,19 +1,17 @@ +import json import os from typing import List -import json from ms_agent.llm import LLM, Message from ms_agent.memory import Memory from ms_agent.utils import get_logger -from ms_agent.utils.constants import (DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, - DEFAULT_OUTPUT_WRAPPER) +from ms_agent.utils.constants import DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, DEFAULT_OUTPUT_WRAPPER from ms_agent.utils.utils import extract_code_blocks, file_lock logger = get_logger() class CodeCondenser(Memory): - system = """你是一个帮我简化代码并返回缩略信息的机器人。你缩略的文件会给与另一个LLM用来编写代码,因此你生成的缩略文件需要具有充足的供其他文件依赖的信息。 需要保留的信息: @@ -97,7 +95,7 @@ class CodeCondenser(Memory): 你的优化目标: 1. 【优先】保留充足的信息供其它代码使用 2. 【其次】保留尽量少的token数量 -""" # noqa +""" # noqa def __init__(self, config): super().__init__(config) @@ -108,8 +106,7 @@ def __init__(self, config): index_dir = getattr(config, 'index_cache_dir', DEFAULT_INDEX_DIR) self.index_dir = os.path.join(self.output_dir, index_dir) self.lock_dir = os.path.join(self.output_dir, DEFAULT_LOCK_DIR) - self.code_wrapper = getattr(mem_config, 'code_wrapper', - DEFAULT_OUTPUT_WRAPPER) + self.code_wrapper = getattr(mem_config, 'code_wrapper', DEFAULT_OUTPUT_WRAPPER) def condense_code(self, message: Message): prefix = 'Your generated code was replaced by a index version:\n' @@ -122,22 +119,17 @@ def condense_code(self, message: Message): arguments = json.loads(arguments) code_file = arguments['path'] content = arguments['content'] - index_content = self.generate_index_file( - code_file, content) + index_content = self.generate_index_file(code_file, content) arguments['content'] = f'{prefix}{index_content}' - tool_call['arguments'] = json.dumps( - arguments, ensure_ascii=False) - elif self.code_wrapper[0] in message.content and self.code_wrapper[ - 1] in message.content: - result, remaining_text = extract_code_blocks( - message.content, file_wrapper=self.code_wrapper) + tool_call['arguments'] = json.dumps(arguments, ensure_ascii=False) + elif self.code_wrapper[0] in message.content and self.code_wrapper[1] in message.content: + result, remaining_text = extract_code_blocks(message.content, file_wrapper=self.code_wrapper) if result: final_content = remaining_text + prefix for code_block in result: code_file = code_block['filename'] content = code_block['code'] - index_content = self.generate_index_file( - code_file, content) + index_content = self.generate_index_file(code_file, content) final_content += index_content + '\n' message.content = final_content @@ -172,8 +164,7 @@ def generate_index_file(self, file: str, content: str = None): error = None for i in range(3): try: - response_message = self.llm.generate( - messages, stream=False) + response_message = self.llm.generate(messages, stream=False) content = response_message.content.split('\n') if '```' in content[0]: content = content[1:] @@ -183,14 +174,11 @@ def generate_index_file(self, file: str, content: str = None): os.makedirs(os.path.dirname(index_file), exist_ok=True) with open(index_file, 'w') as f: f.write(content) - json.loads( - content - ) # try to load once to ensure the json format is ok + json.loads(content) # try to load once to ensure the json format is ok break except Exception as e: error = e - logger.error( - f'Code index file generate failed because of {e}') + logger.error(f'Code index file generate failed because of {e}') if content is None: raise error return content diff --git a/ms_agent/memory/condenser/context_compressor.py b/ms_agent/memory/condenser/context_compressor.py index 9bec9bcf1..035a12a6f 100644 --- a/ms_agent/memory/condenser/context_compressor.py +++ b/ms_agent/memory/condenser/context_compressor.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json from typing import List, Optional -import json from ms_agent.llm import LLM, Message from ms_agent.memory import Memory from ms_agent.utils.logger import logger @@ -44,8 +44,7 @@ def __init__(self, config): self.reserved_buffer = getattr(mem_config, 'reserved_buffer', 20000) # Summary prompt - self.summary_prompt = getattr(mem_config, 'summary_prompt', - SUMMARY_PROMPT) + self.summary_prompt = getattr(mem_config, 'summary_prompt', SUMMARY_PROMPT) # LLM for summarization self.llm: Optional[LLM] = None @@ -67,9 +66,7 @@ def _estimate_message_tokens_from_content(self, msg: Message) -> int: """Heuristic token count from message body (no API usage fields).""" total = 0 if msg.content: - content = msg.content if isinstance( - msg.content, str) else json.dumps( - msg.content, ensure_ascii=False) + content = msg.content if isinstance(msg.content, str) else json.dumps(msg.content, ensure_ascii=False) total += self.estimate_tokens(content) if msg.tool_calls: total += self.estimate_tokens(json.dumps(msg.tool_calls)) @@ -99,11 +96,8 @@ def estimate_total_tokens(self, messages: List[Message]) -> int: break if last_usage_idx >= 0: m = messages[last_usage_idx] - base = int(getattr(m, 'prompt_tokens', 0) or 0) + int( - getattr(m, 'completion_tokens', 0) or 0) - tail = sum( - self._estimate_message_tokens_from_content(x) - for x in messages[last_usage_idx + 1:]) + base = int(getattr(m, 'prompt_tokens', 0) or 0) + int(getattr(m, 'completion_tokens', 0) or 0) + tail = sum(self._estimate_message_tokens_from_content(x) for x in messages[last_usage_idx + 1 :]) return base + tail return sum(self.estimate_message_tokens(m) for m in messages) @@ -128,9 +122,7 @@ def prune_tool_outputs(self, messages: List[Message]) -> List[Message]: msg = messages[idx] if msg.role != 'tool' or not msg.content: continue - content_str = msg.content if isinstance( - msg.content, str) else json.dumps( - msg.content, ensure_ascii=False) + content_str = msg.content if isinstance(msg.content, str) else json.dumps(msg.content, ensure_ascii=False) tokens = self.estimate_tokens(content_str) total_tool_tokens += tokens @@ -152,8 +144,7 @@ def summarize(self, messages: List[Message]) -> Optional[str]: conv_parts = [] for msg in messages: role = msg.role.upper() - content = msg.content if isinstance(msg.content, str) else str( - msg.content) + content = msg.content if isinstance(msg.content, str) else str(msg.content) if content: conv_parts.append(f'{role}: {content[:2000]}') @@ -161,8 +152,7 @@ def summarize(self, messages: List[Message]) -> Optional[str]: query = f'{self.summary_prompt}\n\n---\n{conversation}' try: - response = self.llm.generate([Message(role='user', content=query)], - stream=False) + response = self.llm.generate([Message(role='user', content=query)], stream=False) return response.content except Exception as e: logger.error(f'Summary generation failed: {e}') @@ -199,10 +189,8 @@ def compress(self, messages: List[Message]) -> List[Message]: break result.append( - Message( - role='user', - content=f'[Conversation Summary]\n{summary}\n\n' - 'Please continue based on this summary.')) + Message(role='user', content=f'[Conversation Summary]\n{summary}\n\nPlease continue based on this summary.') + ) # Keep the most recent user message if different if messages and messages[-1].role == 'user': @@ -210,8 +198,7 @@ def compress(self, messages: List[Message]) -> List[Message]: if last_user.content and last_user.content != result[-1].content: result.append(last_user) - logger.info( - f'Compressed {len(messages)} messages to {len(result)} messages') + logger.info(f'Compressed {len(messages)} messages to {len(result)} messages') return result async def run(self, messages: List[Message]) -> List[Message]: diff --git a/ms_agent/memory/condenser/refine_condenser.py b/ms_agent/memory/condenser/refine_condenser.py index 557e2a1ff..d779e4b73 100644 --- a/ms_agent/memory/condenser/refine_condenser.py +++ b/ms_agent/memory/condenser/refine_condenser.py @@ -1,6 +1,6 @@ +import json from typing import List -import json from ms_agent.llm import LLM, Message from ms_agent.memory import Memory @@ -68,8 +68,7 @@ def __init__(self, config): self.threshold = getattr(mem_config, 'threshold', 60000) async def condense_memory(self, messages): - if len(str(messages)) > self.threshold and messages[-1].role in ( - 'user', 'tool'): + if len(str(messages)) > self.threshold and messages[-1].role in ('user', 'tool'): keep_messages = messages[:2] # keep system and user keep_messages_tail = [] i = 0 @@ -80,24 +79,23 @@ async def condense_memory(self, messages): keep_messages_tail = reversed(keep_messages_tail) compress_messages = json.dumps( - [message.to_dict_clean() for message in messages[2:-i - 1]], - ensure_ascii=False, - indent=2) + [message.to_dict_clean() for message in messages[2 : -i - 1]], ensure_ascii=False, indent=2 + ) keep_messages_json = json.dumps( - [message.to_dict_clean() for message in keep_messages], - ensure_ascii=False, - indent=2) + [message.to_dict_clean() for message in keep_messages], ensure_ascii=False, indent=2 + ) keep_messages_tail_json = json.dumps( - [message.to_dict_clean() for message in keep_messages_tail], - ensure_ascii=False, - indent=2) + [message.to_dict_clean() for message in keep_messages_tail], ensure_ascii=False, indent=2 + ) - query = (f'# Messages to be retained\n' - f'## system and user: {keep_messages_json}\n' - f'## Last assistant response: {keep_messages_tail_json}\n' - f'# Messages to be compressed' - f'## These messages are located between system/user ' - f'and the last assistant response: {compress_messages}') + query = ( + f'# Messages to be retained\n' + f'## system and user: {keep_messages_json}\n' + f'## Last assistant response: {keep_messages_tail_json}\n' + f'# Messages to be compressed' + f'## These messages are located between system/user ' + f'and the last assistant response: {compress_messages}' + ) _messages = [ Message(role='system', content=self.system), @@ -108,17 +106,21 @@ async def condense_memory(self, messages): keep_messages.append( Message( role='user', - content= - f'Intermediate messages are compressed, here is the compressed message:\n{content}\n' - )) - messages = keep_messages + list(keep_messages_tail) + [ - Message( - role='user', - content= - 'History messages are compressed due to a long sequence, now ' - 'continue solve your problem according to ' - 'the messages and the tool calling:\n') - ] + content=f'Intermediate messages are compressed, here is the compressed message:\n{content}\n', + ) + ) + messages = ( + keep_messages + + list(keep_messages_tail) + + [ + Message( + role='user', + content='History messages are compressed due to a long sequence, now ' + 'continue solve your problem according to ' + 'the messages and the tool calling:\n', + ) + ] + ) return messages else: return messages diff --git a/ms_agent/memory/default_memory.py b/ms_agent/memory/default_memory.py index b087a2dd9..943334c86 100644 --- a/ms_agent/memory/default_memory.py +++ b/ms_agent/memory/default_memory.py @@ -2,6 +2,7 @@ import asyncio import hashlib import importlib +import json import os import re import traceback @@ -10,15 +11,14 @@ from inspect import signature from typing import Any, Dict, List, Optional, Tuple -import json import json5 +from omegaconf import DictConfig, OmegaConf + from ms_agent.llm.utils import Message from ms_agent.memory import Memory from ms_agent.utils import get_fact_retrieval_prompt -from ms_agent.utils.constants import (DEFAULT_OUTPUT_DIR, DEFAULT_SEARCH_LIMIT, - DEFAULT_USER, get_service_config) +from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_SEARCH_LIMIT, DEFAULT_USER, get_service_config from ms_agent.utils.logger import logger -from omegaconf import DictConfig, OmegaConf class MemoryMapping: @@ -28,8 +28,7 @@ class MemoryMapping: enable_idxs: List[int] = [] disable_idx: int = -1 - def __init__(self, memory_id: str, value: str, enable_idxs: int - or List[int]): + def __init__(self, memory_id: str, value: str, enable_idxs: int or List[int]): self.memory_id = memory_id self.value = value self.valid = True @@ -59,20 +58,15 @@ def to_dict(self) -> Dict: 'memory_id': self.memory_id, 'value': self.value, 'valid': self.valid, - 'enable_idxs': self.enable_idxs.copy( - ), # Return a copy to prevent external modification - 'disable_idx': self.disable_idx + 'enable_idxs': self.enable_idxs.copy(), # Return a copy to prevent external modification + 'disable_idx': self.disable_idx, } @classmethod def from_dict(cls, data: Dict) -> 'MemoryMapping': - instance = cls( - memory_id=data['memory_id'], - value=data['value'], - enable_idxs=data['enable_idxs']) + instance = cls(memory_id=data['memory_id'], value=data['value'], enable_idxs=data['enable_idxs']) instance.valid = data['valid'] - instance.disable_idx = data.get('disable_idx', - -1) # Compatible with old data + instance.disable_idx = data.get('disable_idx', -1) # Compatible with old data return instance @@ -82,22 +76,16 @@ class DefaultMemory(Memory): def __init__(self, config: DictConfig): super().__init__(config) memory_config = config.memory.default_memory - self.user_id: Optional[str] = getattr(memory_config, 'user_id', - DEFAULT_USER) + self.user_id: Optional[str] = getattr(memory_config, 'user_id', DEFAULT_USER) self.agent_id: Optional[str] = getattr(memory_config, 'agent_id', None) self.run_id: Optional[str] = getattr(memory_config, 'run_id', None) self.compress: Optional[bool] = getattr(config, 'compress', True) self.is_retrieve: Optional[bool] = getattr(config, 'is_retrieve', True) - self.path: Optional[str] = getattr( - memory_config, 'path', - os.path.join(DEFAULT_OUTPUT_DIR, '.default_memory')) + self.path: Optional[str] = getattr(memory_config, 'path', os.path.join(DEFAULT_OUTPUT_DIR, '.default_memory')) self.history_mode = getattr(memory_config, 'history_mode', 'add') - self.ignore_roles: List[str] = getattr(memory_config, 'ignore_roles', - ['tool', 'system']) - self.ignore_fields: List[str] = getattr(memory_config, 'ignore_fields', - ['reasoning_content']) - self.search_limit: int = getattr(memory_config, 'search_limit', - DEFAULT_SEARCH_LIMIT) + self.ignore_roles: List[str] = getattr(memory_config, 'ignore_roles', ['tool', 'system']) + self.ignore_fields: List[str] = getattr(memory_config, 'ignore_fields', ['reasoning_content']) + self.search_limit: int = getattr(memory_config, 'search_limit', DEFAULT_SEARCH_LIMIT) # Add lock for thread safety in shared usage self._lock = asyncio.Lock() self.memory = self._init_memory_obj() @@ -123,7 +111,7 @@ def save_cache(self): str(k): ([msg.to_dict() for msg in msg_list], _hash) for k, (msg_list, _hash) in self.cache_messages.items() }, - 'memory_snapshot': [mm.to_dict() for mm in self.memory_snapshot] + 'memory_snapshot': [mm.to_dict() for mm in self.memory_snapshot], } with open(cache_file, 'w', encoding='utf-8') as f: @@ -157,10 +145,7 @@ def load_cache(self): self.cache_messages = cache_messages # Parse memory_snapshot - self.memory_snapshot = [ - MemoryMapping.from_dict(d) - for d in data.get('memory_snapshot', []) - ] + self.memory_snapshot = [MemoryMapping.from_dict(d) for d in data.get('memory_snapshot', [])] except (json.JSONDecodeError, KeyError, Exception) as e: logger.warning(f'Failed to load cache: {e}') @@ -179,7 +164,6 @@ def _delete_single(self, msg_id: int): idx = 0 while idx < len(self.memory_snapshot): - enable_ids = self.memory_snapshot[idx].enable_idxs disable_id = self.memory_snapshot[idx].disable_idx if msg_id == disable_id: @@ -191,9 +175,8 @@ def _delete_single(self, msg_id: int): metadata['run_id'] = self.run_id try: self.memory._create_memory( - data=self.memory_snapshot[idx].value, - existing_embeddings={}, - metadata=metadata) + data=self.memory_snapshot[idx].value, existing_embeddings={}, metadata=metadata + ) except Exception as e: logger.warning(f'Failed to recover memory: {e}') if msg_id in enable_ids: @@ -206,13 +189,15 @@ def _delete_single(self, msg_id: int): idx += 1 - async def add_single(self, - messages: List[Message], - user_id: Optional[int] = None, - agent_id: Optional[int] = None, - run_id: Optional[int] = None, - memory_type: Optional[str] = None, - msg_id: Optional[int] = None) -> None: + async def add_single( + self, + messages: List[Message], + user_id: Optional[int] = None, + agent_id: Optional[int] = None, + run_id: Optional[int] = None, + memory_type: Optional[str] = None, + msg_id: Optional[int] = None, + ) -> None: messages_dict = [] for message in messages: if isinstance(message, Message): @@ -233,16 +218,16 @@ async def add_single(self, user_id=user_id or self.user_id, agent_id=agent_id or self.agent_id, run_id=run_id or self.run_id, - memory_type=memory_type) + memory_type=memory_type, + ) logger.info('Add memory success.') except Exception as e: logger.warning(f'Failed to add memory: {e}') if self.history_mode == 'overwrite': res = self.memory.get_all( - user_id=user_id or self.user_id, - agent_id=agent_id or self.agent_id, - run_id=run_id or self.run_id) # sorted + user_id=user_id or self.user_id, agent_id=agent_id or self.agent_id, run_id=run_id or self.run_id + ) # sorted res = [(item['id'], item['memory']) for item in res['results']] if len(res): logger.info('All memory info:') @@ -266,14 +251,11 @@ async def add_single(self, for item in self.memory_snapshot: if item.memory_id not in valids: item.disable(msg_id) - for (id, memory) in unmatched: - m = MemoryMapping( - memory_id=id, value=memory, enable_idxs=msg_id) + for id, memory in unmatched: + m = MemoryMapping(memory_id=id, value=memory, enable_idxs=msg_id) self.memory_snapshot.append(m) - def search(self, - query: str, - meta_infos: List[Dict[str, Any]] = None) -> List[str]: + def search(self, query: str, meta_infos: List[Dict[str, Any]] = None) -> List[str]: """ Search for relevant memories based on a query string and optional metadata filters. @@ -302,12 +284,14 @@ def search(self, (self.user_id, self.agent_id, etc.) is used as fallback. """ if meta_infos is None: - meta_infos = [{ - 'user_id': self.user_id, - 'agent_id': self.agent_id, - 'run_id': self.run_id, - 'limit': self.search_limit, - }] + meta_infos = [ + { + 'user_id': self.user_id, + 'agent_id': self.agent_id, + 'run_id': self.run_id, + 'limit': self.search_limit, + } + ] memories = [] for meta_info in meta_infos: user_id = meta_info.get('user_id', None) @@ -319,13 +303,12 @@ def search(self, user_id=user_id or self.user_id, agent_id=agent_id or self.agent_id, run_id=run_id or self.run_id, - limit=limit) - memories.extend( - [entry['memory'] for entry in relevant_memories['results']]) + limit=limit, + ) + memories.extend([entry['memory'] for entry in relevant_memories['results']]) return memories - def _split_into_blocks(self, - messages: List[Message]) -> List[List[Message]]: + def _split_into_blocks(self, messages: List[Message]) -> List[List[Message]]: """ Split messages into blocks where each block starts with a 'user' message and includes all following non-user messages until the next 'user' (exclusive). @@ -362,25 +345,20 @@ def _hash_block(self, block: List[Message]) -> str: """Compute sha256 hash of a message block for comparison""" data = [message.to_dict_clean() for message in block] allow_role = ['user', 'system', 'assistant', 'tool'] - allow_role = [ - role for role in allow_role if role not in self.ignore_roles - ] + allow_role = [role for role in allow_role if role not in self.ignore_roles] allow_fields = ['reasoning_content', 'content', 'tool_calls', 'role'] - allow_fields = [ - field for field in allow_fields if field not in self.ignore_fields - ] + allow_fields = [field for field in allow_fields if field not in self.ignore_fields] - data = [{ - field: value - for field, value in msg.items() if field in allow_fields - } for msg in data if msg['role'] in allow_role] + data = [ + {field: value for field, value in msg.items() if field in allow_fields} + for msg in data + if msg['role'] in allow_role + ] block_data = json5.dumps(data) return hashlib.sha256(block_data.encode('utf-8')).hexdigest() - def _analyze_messages( - self, - messages: List[Message]) -> Tuple[List[List[Message]], List[int]]: + def _analyze_messages(self, messages: List[Message]) -> Tuple[List[List[Message]], List[int]]: """ Analyze incoming messages against cache. @@ -390,8 +368,7 @@ def _analyze_messages( """ new_blocks = self._split_into_blocks(messages) self.cache_messages = dict(sorted(self.cache_messages.items())) - cache_messages = [(key, value) - for key, value in self.cache_messages.items()] + cache_messages = [(key, value) for key, value in self.cache_messages.items()] first_unmatched_idx = -1 @@ -399,8 +376,7 @@ def _analyze_messages( block_hash = self._hash_block(new_blocks[idx]) # Must allow comparison up to the last cache entry - if idx < len(cache_messages) and str(block_hash) == str( - cache_messages[idx][1][1]): + if idx < len(cache_messages) and str(block_hash) == str(cache_messages[idx][1][1]): continue # mismatch @@ -410,16 +386,12 @@ def _analyze_messages( # If all new_blocks match but the cache has extra entries → delete the extra cache entries if first_unmatched_idx == -1: should_add_messages = [] - should_delete = [ - item[0] for item in cache_messages[len(new_blocks):] - ] + should_delete = [item[0] for item in cache_messages[len(new_blocks) :]] return should_add_messages, should_delete # On mismatch: add all new blocks and delete all cache entries starting from the mismatch index should_add_messages = new_blocks[first_unmatched_idx:] - should_delete = [ - item[0] for item in cache_messages[first_unmatched_idx:] - ] + should_delete = [item[0] for item in cache_messages[first_unmatched_idx:]] return should_add_messages, should_delete @@ -445,9 +417,8 @@ async def add( for msg_id in should_delete: self._delete_single(msg_id=msg_id) res = self.memory.get_all( - user_id=user_id or self.user_id, - agent_id=agent_id or self.agent_id, - run_id=run_id or self.run_id) # sorted + user_id=user_id or self.user_id, agent_id=agent_id or self.agent_id, run_id=run_id or self.run_id + ) # sorted res = [(item['id'], item['memory']) for item in res['results']] logger.info('Roll back success. All memory info:') for item in res: @@ -456,11 +427,8 @@ async def add( for messages in should_add_messages: messages = self.parse_messages(messages) await self.add_single( - messages, - user_id=user_id, - agent_id=agent_id, - run_id=run_id, - memory_type=memory_type) + messages, user_id=user_id, agent_id=agent_id, run_id=run_id, memory_type=memory_type + ) self.save_cache() def parse_messages(self, messages: List[Message]) -> List[Message]: @@ -480,16 +448,17 @@ def parse_messages(self, messages: List[Message]) -> List[Message]: return new_messages - def delete(self, - user_id: Optional[str] = None, - agent_id: Optional[str] = None, - run_id: Optional[str] = None, - memory_ids: Optional[List[str]] = None) -> Tuple[bool, str]: + def delete( + self, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + memory_ids: Optional[List[str]] = None, + ) -> Tuple[bool, str]: failed = {} if memory_ids is None: try: - self.memory.delete_all( - user_id=user_id, agent_id=agent_id, run_id=run_id) + self.memory.delete_all(user_id=user_id, agent_id=agent_id, run_id=run_id) return True, '' except Exception as e: return False, str(e) + '\n' + traceback.format_exc() @@ -497,9 +466,7 @@ def delete(self, try: self.memory.delete(memory_id=memory_id) except IndexError: - failed[ - memory_id] = 'This memory_id does not exist in the database.\n' + traceback.format_exc( - ) # noqa + failed[memory_id] = 'This memory_id does not exist in the database.\n' + traceback.format_exc() # noqa except Exception as e: failed[memory_id] = str(e) + '\n' + traceback.format_exc() if failed: @@ -507,54 +474,42 @@ def delete(self, else: return True, '' - def get_all(self, - user_id: Optional[str] = None, - agent_id: Optional[str] = None, - run_id: Optional[str] = None): + def get_all(self, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None): try: - res = self.memory.get_all( - user_id=user_id or self.user_id, - agent_id=agent_id, - run_id=run_id) + res = self.memory.get_all(user_id=user_id or self.user_id, agent_id=agent_id, run_id=run_id) return res['results'] except Exception: return [] - def _get_latest_user_message(self, - messages: List[Message]) -> Optional[str]: + def _get_latest_user_message(self, messages: List[Message]) -> Optional[str]: """Get the latest user message content.""" for message in reversed(messages): if message.role == 'user' and hasattr(message, 'content'): return message.content return None - def _inject_memories_into_messages(self, messages: List[Message], - memories: List[str], - keep_details) -> List[Message]: + def _inject_memories_into_messages( + self, messages: List[Message], memories: List[str], keep_details + ) -> List[Message]: """Inject relevant memories into the system message.""" # Format memories for injection - memories_str = 'User Memories:\n' + '\n'.join(f'- {memory}' - for memory in memories) + memories_str = 'User Memories:\n' + '\n'.join(f'- {memory}' for memory in memories) # Remove the messages section corresponding to memory, and add the related memory_str information if getattr(messages[0], 'role') == 'system': - system_prompt = getattr( - messages[0], 'content') + f'\nUser Memories: {memories_str}' + system_prompt = getattr(messages[0], 'content') + f'\nUser Memories: {memories_str}' remain_idx = 1 else: - system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\n' \ - f'User Memories: {memories_str}' + system_prompt = ( + f'\nYou are a helpful assistant. Answer the question based on query and memories.\n' + f'User Memories: {memories_str}' + ) remain_idx = 0 if not keep_details: - should_add_messages, should_delete = self._analyze_messages( - messages) - remain_idx = max( - remain_idx, - len(messages) - - sum([len(block) for block in should_add_messages])) - - new_messages = [Message(role='system', content=system_prompt) - ] + messages[remain_idx:] + should_add_messages, should_delete = self._analyze_messages(messages) + remain_idx = max(remain_idx, len(messages) - sum([len(block) for block in should_add_messages])) + + new_messages = [Message(role='system', content=system_prompt)] + messages[remain_idx:] return new_messages async def run( @@ -576,51 +531,46 @@ async def run( logger.warning(f'Failed to search memories: {search_error}') memories = [] if memories: - messages = self._inject_memories_into_messages( - messages, memories, keep_details) + messages = self._inject_memories_into_messages(messages, memories, keep_details) return messages def _init_memory_obj(self): try: import mem0 except ImportError as e: - logger.error( - f'Failed to import mem0: {e}. Please install mem0ai package via `pip install mem0ai`.' - ) + logger.error(f'Failed to import mem0: {e}. Please install mem0ai package via `pip install mem0ai`.') raise capture_event_origin = mem0.memory.main.capture_event @wraps(capture_event_origin) - def patched_capture_event(event_name, - memory_instance, - additional_data=None): + def patched_capture_event(event_name, memory_instance, additional_data=None): pass - mem0.memory.main.capture_event = partial(patched_capture_event, ) + mem0.memory.main.capture_event = partial( + patched_capture_event, + ) # emb config embedder = None - embedder_config = getattr(self.config.memory.default_memory, - 'embedder', OmegaConf.create({})) + embedder_config = getattr(self.config.memory.default_memory, 'embedder', OmegaConf.create({})) service = getattr(embedder_config, 'service', 'modelscope') api_key = getattr(embedder_config, 'api_key', None) - emb_model = getattr(embedder_config, 'model', - 'Qwen/Qwen3-Embedding-8B') - embedding_dims = getattr(embedder_config, 'embedding_dims', - None) # for vector store config + emb_model = getattr(embedder_config, 'model', 'Qwen/Qwen3-Embedding-8B') + embedding_dims = getattr(embedder_config, 'embedding_dims', None) # for vector store config if self.is_retrieve: - embedder = OmegaConf.create({ - 'provider': 'openai', - 'config': { - 'api_key': api_key - or os.getenv(f'{service.upper()}_API_KEY'), - 'openai_base_url': get_service_config(service).base_url, - 'model': emb_model, - 'embedding_dims': embedding_dims + embedder = OmegaConf.create( + { + 'provider': 'openai', + 'config': { + 'api_key': api_key or os.getenv(f'{service.upper()}_API_KEY'), + 'openai_base_url': get_service_config(service).base_url, + 'model': emb_model, + 'embedding_dims': embedding_dims, + }, } - }) + ) # llm config llm = None @@ -628,32 +578,25 @@ def patched_capture_event(event_name, llm_config = getattr(self.config, 'llm', None) if llm_config is not None: service = getattr(llm_config, 'service', 'modelscope') - llm_model = getattr(llm_config, 'model', - 'Qwen/Qwen3-Coder-30B-A3B-Instruct') + llm_model = getattr(llm_config, 'model', 'Qwen/Qwen3-Coder-30B-A3B-Instruct') api_key = getattr(llm_config, f'{service}_api_key', None) - openai_base_url = getattr(llm_config, f'{service}_base_url', - None) + openai_base_url = getattr(llm_config, f'{service}_base_url', None) gen_cfg = getattr(self.config, 'generation_config', None) max_tokens = getattr(gen_cfg, 'max_tokens', None) llm = { 'provider': 'openai', 'config': { - 'model': - llm_model, - 'api_key': - api_key or os.getenv(f'{service.upper()}_API_KEY'), - 'openai_base_url': - openai_base_url - or get_service_config(service).base_url, - } + 'model': llm_model, + 'api_key': api_key or os.getenv(f'{service.upper()}_API_KEY'), + 'openai_base_url': openai_base_url or get_service_config(service).base_url, + }, } if max_tokens is not None: llm['config']['max_tokens'] = max_tokens # vector_store config - def sanitize_database_name(ori_name: str, - default_name: str = 'default') -> str: + def sanitize_database_name(ori_name: str, default_name: str = 'default') -> str: if not ori_name or not isinstance(ori_name, str): return default_name sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', ori_name) @@ -665,10 +608,8 @@ def sanitize_database_name(ori_name: str, sanitized = f'col_{sanitized}' return sanitized - vector_store_config = getattr(self.config.memory.default_memory, - 'vector_store', OmegaConf.create({})) - vector_store_provider = getattr(vector_store_config, 'service', - 'qdrant') + vector_store_config = getattr(self.config.memory.default_memory, 'vector_store', OmegaConf.create({})) + vector_store_provider = getattr(vector_store_config, 'service', 'qdrant') on_disk = getattr(vector_store_config, 'on_disk', True) path = getattr(vector_store_config, 'path', self.path) db_name = getattr(vector_store_config, 'db_name', None) @@ -677,13 +618,12 @@ def sanitize_database_name(ori_name: str, collection_name = getattr(vector_store_config, 'collection_name', path) db_name = sanitize_database_name(db_name) if db_name else None - collection_name = sanitize_database_name( - collection_name) if collection_name else None + collection_name = sanitize_database_name(collection_name) if collection_name else None # check value from mem0.memory.main import VectorStoreFactory - class_type = VectorStoreFactory.provider_to_class.get( - vector_store_provider) + + class_type = VectorStoreFactory.provider_to_class.get(vector_store_provider) if class_type: module_path, class_name = class_type.rsplit('.', 1) module = importlib.import_module(module_path) @@ -697,17 +637,10 @@ def sanitize_database_name(ori_name: str, 'url': url, 'token': token, 'db_name': db_name, - 'embedding_model_dims': embedding_dims - } - config_format = { - key: value - for key, value in config_raw.items() - if value and key in parameters - } - vector_store = { - 'provider': vector_store_provider, - 'config': config_format + 'embedding_model_dims': embedding_dims, } + config_format = {key: value for key, value in config_raw.items() if value and key in parameters} + vector_store = {'provider': vector_store_provider, 'config': config_format} else: vector_store = {} @@ -719,13 +652,14 @@ def sanitize_database_name(ori_name: str, logger.info(f'Memory config: {mem0_config}') # Prompt content is too long, default logging reduces readability custom_fact_extraction_prompt = getattr( - self.config.memory.default_memory, 'fact_retrieval_prompt', - getattr(self.config.memory.default_memory, - 'custom_fact_extraction_prompt', None)) + self.config.memory.default_memory, + 'fact_retrieval_prompt', + getattr(self.config.memory.default_memory, 'custom_fact_extraction_prompt', None), + ) if custom_fact_extraction_prompt is not None: mem0_config['custom_fact_extraction_prompt'] = ( - custom_fact_extraction_prompt - + f'Today\'s date is {datetime.now().strftime("%Y-%m-%d")}.') + custom_fact_extraction_prompt + f'Today\'s date is {datetime.now().strftime("%Y-%m-%d")}.' + ) try: memory = mem0.Memory.from_config(mem0_config) memory._telemetry_vector_store = None diff --git a/ms_agent/memory/diversity.py b/ms_agent/memory/diversity.py index 775e80da1..75cb885ff 100644 --- a/ms_agent/memory/diversity.py +++ b/ms_agent/memory/diversity.py @@ -3,9 +3,10 @@ from copy import deepcopy from typing import List -from ms_agent.utils import get_logger from omegaconf import DictConfig +from ms_agent.utils import get_logger + from ..llm import LLM, Message from .base import Memory @@ -13,7 +14,6 @@ class Diversity(Memory): - div_system1 = """You are an inspiration bot. You will be given an original requirement, and you need to provide keywords that you associate with it. The keywords must meet the following conditions: 1. The keywords you provide should be terms, such as "security", "independent module", "aesthetics", "style", "examples", etc. @@ -27,7 +27,7 @@ class Diversity(Memory): 6. Your keywords must be in the same language as the original requirement Here is the original query: -""" # noqa +""" # noqa div_system2 = """You are an inspiration bot. You will be given a series of keywords, and you need to provide related words that you associate with based on these keywords. The words must meet the following conditions: @@ -37,7 +37,7 @@ class Diversity(Memory): 4. Your keywords must be in the same language as the input keywords Here are the keywords: -""" # noqa +""" # noqa div_system3 = """You are an inspiration bot. You will be given a series of keywords and an original requirement. You need to carefully analyze the relationship between the original requirement and the keywords, and provide your suggestions for completing the original requirement based on the keywords: @@ -53,7 +53,7 @@ class Diversity(Memory): 7. Wrap your final suggestions with only one wrapper Here are the original query and the keywords: -""" # noqa +""" # noqa def __init__(self, config): super().__init__(config) @@ -72,6 +72,7 @@ def __init__(self, config): async def _run_tasks_sequential(self, tasks: list) -> str: """Run a list of {system, query} tasks sequentially using LLMAgent.""" from ms_agent.agent import LLMAgent + res = [] for i, task in enumerate(tasks): system = task.get('system', '') @@ -130,10 +131,7 @@ async def run(self, messages: List[Message]): pattern = r'(.*?)' all_keywords = [] for keywords in re.findall(pattern, results, re.DOTALL): - all_keywords.extend([ - keyword.strip() for keyword in keywords.split(',') - if keyword.strip() - ]) + all_keywords.extend([keyword.strip() for keyword in keywords.split(',') if keyword.strip()]) arguments = [] _query = ','.join(set(all_keywords)) @@ -149,15 +147,11 @@ async def run(self, messages: List[Message]): pattern = r'(.*?)' all_keywords = [] for keywords in re.findall(pattern, results, re.DOTALL): - all_keywords.extend([ - keyword.strip() for keyword in keywords.split(',') - if keyword.strip() - ]) + all_keywords.extend([keyword.strip() for keyword in keywords.split(',') if keyword.strip()]) _query = ','.join(set(all_keywords)) logger.info(f'Diversity second round keywords: {_query}') - _query = (f'Original query: {query}\n' - f'Keywords generated by LLMs: {all_keywords}') + _query = f'Original query: {query}\nKeywords generated by LLMs: {all_keywords}' _messages = [ Message(role='system', content=self.div_system3), Message(role='user', content=_query), @@ -174,7 +168,8 @@ async def run(self, messages: List[Message]): suggestions = ( '\nNow Additional suggestions and findings are given to you, ' 'you need to consider these suggestions and carefully process the query:\n' - f'{suggestions}') + f'{suggestions}' + ) if system != query: system = system + suggestions messages[0].content = system diff --git a/ms_agent/memory/memory_manager.py b/ms_agent/memory/memory_manager.py index 5a203505d..1c2622aee 100644 --- a/ms_agent/memory/memory_manager.py +++ b/ms_agent/memory/memory_manager.py @@ -1,21 +1,22 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import Dict +from omegaconf import DictConfig + from ms_agent.memory import Memory, memory_mapping from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_USER -from omegaconf import DictConfig logger = get_logger() class SharedMemoryManager: """Manager for shared memory instances across different agents.""" + _instances: Dict[str, Memory] = {} @classmethod - async def get_shared_memory(cls, config: DictConfig, - mem_instance_type: str) -> Memory: + async def get_shared_memory(cls, config: DictConfig, mem_instance_type: str) -> Memory: """Get or create a shared memory instance based on configuration.""" user_id: str = getattr(config, 'user_id', DEFAULT_USER) path: str = getattr(config, 'path', DEFAULT_OUTPUT_DIR) @@ -26,8 +27,7 @@ async def get_shared_memory(cls, config: DictConfig, logger.info(f'Creating new shared memory instance for key: {key}') cls._instances[key] = memory_mapping[mem_instance_type](config) else: - logger.info( - f'Reusing existing shared memory instance for key: {key}') + logger.info(f'Reusing existing shared memory instance for key: {key}') return cls._instances[key] @@ -46,5 +46,4 @@ def clear_shared_memory(cls, config: DictConfig, mem_instance_type: str): del cls._instances[key] logger.info(f'Cleared shared memory instance for key: {key}') else: - logger.warning( - f'No shared memory instance found for key: {key}') + logger.warning(f'No shared memory instance found for key: {key}') diff --git a/ms_agent/memory/utils.py b/ms_agent/memory/utils.py index b7e20ad30..b778327c6 100644 --- a/ms_agent/memory/utils.py +++ b/ms_agent/memory/utils.py @@ -16,9 +16,7 @@ } -def get_memory_meta_safe(config: DictConfig, - key: str, - default_user_id: str | None = None): +def get_memory_meta_safe(config: DictConfig, key: str, default_user_id: str | None = None): if not hasattr(config, key): return None, None, None, None trigger_config = getattr(config, key, OmegaConf.create({})) diff --git a/ms_agent/prompting/file_resolver.py b/ms_agent/prompting/file_resolver.py index c04b8bdde..4c998b1b9 100644 --- a/ms_agent/prompting/file_resolver.py +++ b/ms_agent/prompting/file_resolver.py @@ -24,15 +24,14 @@ def candidate_paths(self) -> List[str]: paths = [] if family: - paths.extend([ - os.path.join(root, agent, lang, f'{family}.txt'), - os.path.join(root, agent, lang, f'{family}.md'), - ]) + paths.extend( + [ + os.path.join(root, agent, lang, f'{family}.txt'), + os.path.join(root, agent, lang, f'{family}.md'), + ] + ) # base fallback - paths.extend([ - os.path.join(root, agent, lang, 'base.txt'), - os.path.join(root, agent, lang, 'base.md') - ]) + paths.extend([os.path.join(root, agent, lang, 'base.txt'), os.path.join(root, agent, lang, 'base.md')]) return paths @@ -135,18 +134,15 @@ def _get_prompt_lang_and_family(config: DictConfig) -> Tuple[str, str]: prompt_cfg = getattr(config, 'prompt', None) # lang - env_lang = os.environ.get('MS_AGENT_PROMPT_LANG') or os.environ.get( - 'MS_AGENT_LANG') - cfg_lang = getattr(prompt_cfg, 'lang', None) if isinstance( - prompt_cfg, DictConfig) else None + env_lang = os.environ.get('MS_AGENT_PROMPT_LANG') or os.environ.get('MS_AGENT_LANG') + cfg_lang = getattr(prompt_cfg, 'lang', None) if isinstance(prompt_cfg, DictConfig) else None lang = _norm_lang(cfg_lang or env_lang or 'zh') # family env_family = os.environ.get('MS_AGENT_PROMPT_FAMILY') - cfg_family = getattr(prompt_cfg, 'family', None) if isinstance( - prompt_cfg, DictConfig) else None + cfg_family = getattr(prompt_cfg, 'family', None) if isinstance(prompt_cfg, DictConfig) else None - family = (cfg_family or env_family or 'auto') + family = cfg_family or env_family or 'auto' family = str(family).strip() if not family: family = 'auto' @@ -226,7 +222,6 @@ def apply_prompt_files(config: DictConfig) -> DictConfig: if not hasattr(config, 'prompt') or config.prompt is None: config.prompt = DictConfig({}) - if getattr(config.prompt, 'system', None) is None or not str( - getattr(config.prompt, 'system', '')).strip(): + if getattr(config.prompt, 'system', None) is None or not str(getattr(config.prompt, 'system', '')).strip(): config.prompt.system = prompt_text return config diff --git a/ms_agent/rag/base.py b/ms_agent/rag/base.py index 2d983a2c6..318b836cb 100644 --- a/ms_agent/rag/base.py +++ b/ms_agent/rag/base.py @@ -33,11 +33,7 @@ async def query(self, query: str) -> str: pass @abstractmethod - async def retrieve(self, - query: str, - limit: int = 5, - score_threshold: float = 0.7, - **filters) -> List[Any]: + async def retrieve(self, query: str, limit: int = 5, score_threshold: float = 0.7, **filters) -> List[Any]: """Retrieve documents Args: diff --git a/ms_agent/rag/extraction.py b/ms_agent/rag/extraction.py index c62b10067..0215d8b7b 100644 --- a/ms_agent/rag/extraction.py +++ b/ms_agent/rag/extraction.py @@ -1,11 +1,11 @@ # flake8: noqa # yapf: disable from abc import ABC, abstractmethod -from typing import Any, Dict, List - from docling_core.transforms.chunker import BaseChunk from docling_core.types import DoclingDocument from docling_core.types.doc import DocItem, DocItemLabel +from typing import Any, Dict, List + from ms_agent.rag.schema import KeyInformation from ms_agent.tools.docling.chunker import HybridDocumentChunker from ms_agent.tools.docling.doc_loader import DocLoader diff --git a/ms_agent/rag/extraction_manager.py b/ms_agent/rag/extraction_manager.py index dfee086cd..909949205 100644 --- a/ms_agent/rag/extraction_manager.py +++ b/ms_agent/rag/extraction_manager.py @@ -9,14 +9,16 @@ try: import ray # type: ignore + _RAY_AVAILABLE = True except Exception: # pragma: no cover - optional dependency ray = None # type: ignore _RAY_AVAILABLE = False logger.warning( 'Ray is not available. Install it for faster information extraction:\n' - ' pip install \"ray[default]\"\n' - 'Program will run without acceleration.') + ' pip install "ray[default]"\n' + 'Program will run without acceleration.' + ) class InformationExtractionManager: @@ -24,19 +26,19 @@ class InformationExtractionManager: Optimized key information extraction with optional Ray acceleration. """ - def __init__(self, - verbose: bool = False, - use_ray: bool = False, - ray_num_workers: Optional[int] = None, - ray_cpus_per_task: float = 1.0): + def __init__( + self, + verbose: bool = False, + use_ray: bool = False, + ray_num_workers: Optional[int] = None, + ray_cpus_per_task: float = 1.0, + ): self._verbose = verbose self._use_ray = use_ray and _RAY_AVAILABLE self._ray_num_workers = ray_num_workers self._ray_cpus_per_task = ray_cpus_per_task - def extract( - self, urls_or_files: List[str] - ) -> Tuple[List[KeyInformation], Dict[str, str]]: + def extract(self, urls_or_files: List[str]) -> Tuple[List[KeyInformation], Dict[str, str]]: """ Extract key information from URLs or files. @@ -51,38 +53,27 @@ def extract( try: return self._extract_with_ray(urls_or_files) except Exception as e: - logger.warning( - f'Ray extraction failed, falling back to sequential: {e}') + logger.warning(f'Ray extraction failed, falling back to sequential: {e}') # Use sequential extraction if Ray is disabled or failed if not _RAY_AVAILABLE: - logger.warning( - 'Ray is not available, falling back to sequential extraction.') + logger.warning('Ray is not available, falling back to sequential extraction.') return self._extract_sequential(urls_or_files) - def _extract_sequential( - self, urls_or_files: List[str] - ) -> Tuple[List[KeyInformation], Dict[str, str]]: + def _extract_sequential(self, urls_or_files: List[str]) -> Tuple[List[KeyInformation], Dict[str, str]]: """Sequential extraction using the original implementation.""" - extractor = HierarchicalKeyInformationExtraction( - urls_or_files=urls_or_files, verbose=self._verbose) + extractor = HierarchicalKeyInformationExtraction(urls_or_files=urls_or_files, verbose=self._verbose) key_info_list = extractor.extract() return key_info_list, extractor.all_ref_items - def _extract_with_ray( - self, urls_or_files: List[str] - ) -> Tuple[List[KeyInformation], Dict[str, str]]: + def _extract_with_ray(self, urls_or_files: List[str]) -> Tuple[List[KeyInformation], Dict[str, str]]: """Ray-accelerated extraction.""" if not ray.is_initialized(): - ray.init( - ignore_reinit_error=True, - include_dashboard=False, - log_to_driver=False) + ray.init(ignore_reinit_error=True, include_dashboard=False, log_to_driver=False) # Determine optimal worker count - max_workers = self._ray_num_workers or min( - len(urls_or_files), (os.cpu_count() or 4)) + max_workers = self._ray_num_workers or min(len(urls_or_files), (os.cpu_count() or 4)) max_workers = max(1, max_workers) # Partition URLs/files among workers: should be balanced @@ -93,7 +84,8 @@ def _extract_with_ray( # Create actors and dispatch tasks actors = [ _ExtractionWorker.options(num_cpus=self._ray_cpus_per_task).remote( - urls_or_files=partitions[i], verbose=self._verbose) + urls_or_files=partitions[i], verbose=self._verbose + ) for i in range(max_workers) ] @@ -101,8 +93,7 @@ def _extract_with_ray( for exraction_actor in actors: futures.append(exraction_actor.process_partition.remote()) - results: List[Tuple[List[KeyInformation], - Dict[str, str]]] = ray.get(futures) + results: List[Tuple[List[KeyInformation], Dict[str, str]]] = ray.get(futures) # Merge results merged_infos: List[KeyInformation] = [] @@ -124,11 +115,9 @@ class _ExtractionWorker: def __init__(self, urls_or_files: List[str], verbose: bool = False): self._verbose = verbose self._urls_or_files = urls_or_files - self.extractor = HierarchicalKeyInformationExtraction( - urls_or_files=self._urls_or_files, verbose=verbose) + self.extractor = HierarchicalKeyInformationExtraction(urls_or_files=self._urls_or_files, verbose=verbose) - def process_partition( - self) -> Tuple[List[KeyInformation], Dict[str, str]]: + def process_partition(self) -> Tuple[List[KeyInformation], Dict[str, str]]: """Process a partition of URLs/files and return extracted information.""" try: key_info_list_partition = self.extractor.extract() @@ -145,7 +134,7 @@ def extract_key_information( use_ray: bool = False, verbose: bool = False, ray_num_workers: Optional[int] = None, - ray_cpus_per_task: float = 1.0 + ray_cpus_per_task: float = 1.0, ) -> Tuple[List[KeyInformation], Dict[str, str]]: """ High-level function to extract key information with optional Ray acceleration. @@ -161,9 +150,7 @@ def extract_key_information( Tuple of (key_info_list, resource_map) """ extractor = InformationExtractionManager( - verbose=verbose, - use_ray=use_ray, - ray_num_workers=ray_num_workers, - ray_cpus_per_task=ray_cpus_per_task) + verbose=verbose, use_ray=use_ray, ray_num_workers=ray_num_workers, ray_cpus_per_task=ray_cpus_per_task + ) return extractor.extract(urls_or_files) diff --git a/ms_agent/rag/llama_index_rag.py b/ms_agent/rag/llama_index_rag.py index e4535d38d..8156b284b 100644 --- a/ms_agent/rag/llama_index_rag.py +++ b/ms_agent/rag/llama_index_rag.py @@ -2,10 +2,11 @@ import shutil from typing import Any, List, Optional -from ms_agent.utils import assert_package_exist +from modelscope import snapshot_download from omegaconf import DictConfig -from modelscope import snapshot_download +from ms_agent.utils import assert_package_exist + from ..llm import LLM, Message from .base import RAG @@ -28,8 +29,7 @@ def __init__(self, config: DictConfig): super().__init__(config) self._validate_config(config) - self.embedding_model = getattr(config.rag, 'embedding', - 'Qwen/Qwen3-Embedding-0.6B') + self.embedding_model = getattr(config.rag, 'embedding', 'Qwen/Qwen3-Embedding-0.6B') self.llm_model = getattr(config.rag, 'llm', None) self.chunk_size = getattr(config.rag, 'chunk_size', 512) self.chunk_overlap = getattr(config.rag, 'chunk_overlap', 50) @@ -41,22 +41,21 @@ def __init__(self, config: DictConfig): from llama_index.core import Settings from llama_index.core.node_parser import SentenceSplitter + # Set node parser - Settings.node_parser = SentenceSplitter( - chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) + Settings.node_parser = SentenceSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) # If retrieve only, don't set LLM if self.retrieve_only: Settings.llm = None else: + from llama_index.core.base.llms.types import CompletionResponse, LLMMetadata from llama_index.core.llms import CustomLLM - from llama_index.core.base.llms.types import LLMMetadata from llama_index.core.llms.callbacks import llm_completion_callback - from llama_index.core.base.llms.types import CompletionResponse + self._llm_instance = LLM.from_config(self.config) class MSCustomLLM(CustomLLM): - @property def metadata(_self) -> LLMMetadata: return LLMMetadata( @@ -66,23 +65,17 @@ def metadata(_self) -> LLMMetadata: ) @llm_completion_callback() - def complete(_self, prompt: str, - **kwargs) -> CompletionResponse: + def complete(_self, prompt: str, **kwargs) -> CompletionResponse: message: Message = self._llm_instance.generate( - messages=[Message(role='user', content=prompt)], - stream=False, - **kwargs) + messages=[Message(role='user', content=prompt)], stream=False, **kwargs + ) return CompletionResponse(text=message.content) @llm_completion_callback() - def stream_complete(_self, - prompt: str, - formatted: bool = False, - **kwargs: Any): + def stream_complete(_self, prompt: str, formatted: bool = False, **kwargs: Any): for message in self._llm_instance.generate( - messages=[Message(role='user', content=prompt)], - stream=True, - **kwargs): + messages=[Message(role='user', content=prompt)], stream=True, **kwargs + ): yield CompletionResponse(text=message.content) Settings.llm = MSCustomLLM() @@ -95,28 +88,28 @@ def _validate_requirements(self): 'llama_index', 'Please install llama_index to support llama-index-rag:\n' '> pip install -U llama-index-core llama-index-embeddings-huggingface ' - 'llama-index-llms-openai llama-index-llms-replicate\n') + 'llama-index-llms-openai llama-index-llms-replicate\n', + ) def _validate_config(self, config: DictConfig): """Validate configuration parameters""" if not hasattr(config, 'rag') or not hasattr(config.rag, 'embedding'): - raise ValueError( - 'Missing rag.embedding parameter in configuration') + raise ValueError('Missing rag.embedding parameter in configuration') chunk_size = getattr(config.rag, 'chunk_size', 512) if chunk_size <= 0: raise ValueError('chunk_size must be greater than 0') def _setup_embedding_model(self, config: DictConfig): - from llama_index.core import (Settings) + from llama_index.core import Settings from llama_index.embeddings.huggingface import HuggingFaceEmbedding + try: use_hf = getattr(config, 'use_huggingface', False) if not use_hf: self.embedding_model = snapshot_download(self.embedding_model) - Settings.embed_model = HuggingFaceEmbedding( - model_name=self.embedding_model, device='cpu') + Settings.embed_model = HuggingFaceEmbedding(model_name=self.embedding_model, device='cpu') except Exception as e: raise RuntimeError(f'Failed to load embedding model: {e}') @@ -124,7 +117,8 @@ def _setup_embedding_model(self, config: DictConfig): async def add_documents(self, documents: List[str]): if not documents: raise ValueError('Document list cannot be empty') - from llama_index.core import (Document, VectorStoreIndex) + from llama_index.core import Document, VectorStoreIndex + docs = [Document(text=doc) for doc in documents] self.index = VectorStoreIndex.from_documents(docs) if not self.retrieve_only: @@ -136,6 +130,7 @@ async def add_documents_from_files(self, file_paths: List[str]): from llama_index.core import VectorStoreIndex from llama_index.core.readers import SimpleDirectoryReader + documents = [] for file_path in file_paths: if not os.path.exists(file_path): @@ -159,18 +154,14 @@ async def _setup_query_engine(self): return from llama_index.core import Settings + # Check if LLM is set if Settings.llm is None and not self.retrieve_only: return - self.query_engine = self.index.as_query_engine( - similarity_top_k=5, response_mode='compact') + self.query_engine = self.index.as_query_engine(similarity_top_k=5, response_mode='compact') - async def _retrieve(self, - query: str, - limit: int = 5, - score_threshold: float = 0.0, - **filters) -> List[dict]: + async def _retrieve(self, query: str, limit: int = 5, score_threshold: float = 0.0, **filters) -> List[dict]: if self.index is None: return [] @@ -178,58 +169,55 @@ async def _retrieve(self, return [] from llama_index.core.retrievers import VectorIndexRetriever - retriever = VectorIndexRetriever( - index=self.index, similarity_top_k=limit) + + retriever = VectorIndexRetriever(index=self.index, similarity_top_k=limit) nodes = retriever.retrieve(query) results = [] for node in nodes: if node.score >= score_threshold: - results.append({ - 'text': node.node.text, - 'score': float(node.score), - 'metadata': node.node.metadata, - 'node_id': node.node.node_id - }) + results.append( + { + 'text': node.node.text, + 'score': float(node.score), + 'metadata': node.node.metadata, + 'node_id': node.node.node_id, + } + ) return results - async def retrieve(self, - query: str, - limit: int = 5, - score_threshold: float = 0.0, - **filters) -> List[dict]: + async def retrieve(self, query: str, limit: int = 5, score_threshold: float = 0.0, **filters) -> List[dict]: if self.retrieve_only: - return await self._retrieve(query, limit, score_threshold, - **filters) + return await self._retrieve(query, limit, score_threshold, **filters) from llama_index.core import Settings from llama_index.core.postprocessor import SimilarityPostprocessor from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core.retrievers import VectorIndexRetriever + if self.index is None or Settings.llm is None: return [] - retriever = VectorIndexRetriever( - index=self.index, similarity_top_k=limit) + retriever = VectorIndexRetriever(index=self.index, similarity_top_k=limit) - postprocessor = SimilarityPostprocessor( - similarity_cutoff=score_threshold) + postprocessor = SimilarityPostprocessor(similarity_cutoff=score_threshold) - query_engine = RetrieverQueryEngine( - retriever=retriever, node_postprocessors=[postprocessor]) + query_engine = RetrieverQueryEngine(retriever=retriever, node_postprocessors=[postprocessor]) response = query_engine.query(query) results = [] for node in response.source_nodes: - results.append({ - 'text': node.node.text, - 'score': float(node.score), - 'metadata': node.node.metadata, - 'node_id': node.node.node_id - }) + results.append( + { + 'text': node.node.text, + 'score': float(node.score), + 'metadata': node.node.metadata, + 'node_id': node.node.node_id, + } + ) return results @@ -239,17 +227,18 @@ async def hybrid_search(self, query: str, top_k: int = 5) -> List[dict]: return [] from llama_index.core.retrievers import VectorIndexRetriever + # Try to import BM25 related modules try: - from llama_index.retrievers.bm25 import BM25Retriever from llama_index.core.retrievers import QueryFusionRetriever + from llama_index.retrievers.bm25 import BM25Retriever + bm25_available = True except ImportError: bm25_available = False # Vector retriever - vector_retriever = VectorIndexRetriever( - index=self.index, similarity_top_k=top_k) + vector_retriever = VectorIndexRetriever(index=self.index, similarity_top_k=top_k) if not bm25_available: # Use vector retrieval only @@ -257,13 +246,11 @@ async def hybrid_search(self, query: str, top_k: int = 5) -> List[dict]: else: # Use hybrid retrieval try: - bm25_retriever = BM25Retriever.from_defaults( - docstore=self.index.docstore, similarity_top_k=top_k) + bm25_retriever = BM25Retriever.from_defaults(docstore=self.index.docstore, similarity_top_k=top_k) fusion_retriever = QueryFusionRetriever( - retrievers=[vector_retriever, bm25_retriever], - similarity_top_k=top_k, - num_queries=1) + retrievers=[vector_retriever, bm25_retriever], similarity_top_k=top_k, num_queries=1 + ) nodes = fusion_retriever.retrieve(query) @@ -272,25 +259,23 @@ async def hybrid_search(self, query: str, top_k: int = 5) -> List[dict]: results = [] for node in nodes: - results.append({ - 'text': node.node.text, - 'score': float(node.score), - 'metadata': node.node.metadata, - 'node_id': node.node.node_id - }) + results.append( + { + 'text': node.node.text, + 'score': float(node.score), + 'metadata': node.node.metadata, + 'node_id': node.node.node_id, + } + ) return results async def query(self, query: str) -> str: if self.query_engine is None: if self.retrieve_only: - raise ValueError( - 'Current mode is retrieve only, question answering not supported' - ) + raise ValueError('Current mode is retrieve only, question answering not supported') else: - raise ValueError( - 'Query engine not initialized, please add documents and set LLM first' - ) + raise ValueError('Query engine not initialized, please add documents and set LLM first') try: response = self.query_engine.query(query) @@ -313,10 +298,10 @@ async def load_index(self, persist_dir: Optional[str] = None): load_dir = persist_dir or self.storage_dir if not os.path.exists(load_dir): - raise FileNotFoundError( - f'Index directory does not exist: {load_dir}') + raise FileNotFoundError(f'Index directory does not exist: {load_dir}') + + from llama_index.core import StorageContext, load_index_from_storage - from llama_index.core import (StorageContext, load_index_from_storage) storage_context = StorageContext.from_defaults(persist_dir=load_dir) self.index = load_index_from_storage(storage_context) @@ -336,7 +321,7 @@ def get_index_info(self) -> dict: 'retrieve_only': self.retrieve_only, 'chunk_size': self.chunk_size, 'chunk_overlap': self.chunk_overlap, - 'embedding_model': self.embedding_model + 'embedding_model': self.embedding_model, } async def remove_all_documents(self): diff --git a/ms_agent/rag/schema.py b/ms_agent/rag/schema.py index 9142d4945..f6f997153 100644 --- a/ms_agent/rag/schema.py +++ b/ms_agent/rag/schema.py @@ -18,6 +18,7 @@ class KeyInformation: including images, tables, or other relevant data. [{'id': 'doc_file_name@binary_hash@self_ref', 'content': PILImage.Image}, ...] """ + text: str resources: List[Dict[str, Any]] diff --git a/ms_agent/retriever/hybrid_retriever.py b/ms_agent/retriever/hybrid_retriever.py index e84bc8398..283366652 100644 --- a/ms_agent/retriever/hybrid_retriever.py +++ b/ms_agent/retriever/hybrid_retriever.py @@ -6,6 +6,7 @@ import faiss import numpy as np + from ms_agent.utils.tokenizer_util import TokenizerUtil os.environ['OMP_NUM_THREADS'] = '1' @@ -18,10 +19,7 @@ class BM25Retriever: Sparse retriever based on BM25 algorithm. """ - def __init__(self, - tokenized_corpus: List[List[str]], - k1: float = 1.5, - b: float = 0.75): + def __init__(self, tokenized_corpus: List[List[str]], k1: float = 1.5, b: float = 0.75): self.k1 = k1 self.b = b self.corpus_size = len(tokenized_corpus) @@ -50,8 +48,7 @@ def _initialize(self, tokenized_corpus: List[List[str]]): self.avgdl = total_length / doc_count if doc_count > 0 else 0 for word, freq in self.idf.items(): - self.idf[word] = math.log((self.corpus_size - freq + 0.5) - / (freq + 0.5) + 1) + self.idf[word] = math.log((self.corpus_size - freq + 0.5) / (freq + 0.5) + 1) for doc_tokens in tokenized_corpus: freqs = {} @@ -68,11 +65,11 @@ def get_scores(self, tokenized_query: List[str]) -> List[float]: idf_score = self.idf[token] for index, doc_freqs in enumerate(self.doc_term_freqs): freq = doc_freqs.get(token, 0) - if freq == 0: continue # noqa: E701 + if freq == 0: + continue # noqa: E701 doc_len = self.doc_len[index] numerator = freq * (self.k1 + 1) - denominator = freq + self.k1 * ( - 1 - self.b + self.b * (doc_len / self.avgdl)) # noqa: W504 + denominator = freq + self.k1 * (1 - self.b + self.b * (doc_len / self.avgdl)) # noqa: W504 scores[index] += idf_score * (numerator / denominator) return scores @@ -83,13 +80,13 @@ class HybridRetriever: """ def __init__( - self, - corpus: List[str] = None, - embed_model: - str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', # noqa - tokenizer_model_id: str = 'Qwen/Qwen3-8B', - bm25_k1: float = 1.5, - bm25_b: float = 0.75): + self, + corpus: List[str] = None, + embed_model: str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', # noqa + tokenizer_model_id: str = 'Qwen/Qwen3-8B', + bm25_k1: float = 1.5, + bm25_b: float = 0.75, + ): """ Initialize Hybrid Retriever with both Dense and Sparse indices. @@ -149,10 +146,8 @@ def __init__( # Initialize Dense Retriever (FAISS) embed_model_path: str = self._load_model( model_id=embed_model, - ignore_patterns=[ - 'openvino/*', 'onnx/*', 'pytorch_model.bin', 'rust_model.ot', - 'tf_model.h5' - ]) + ignore_patterns=['openvino/*', 'onnx/*', 'pytorch_model.bin', 'rust_model.ot', 'tf_model.h5'], + ) from sentence_transformers import SentenceTransformer @@ -171,15 +166,11 @@ def _load_model(model_id: str, ignore_patterns: List[str] = None) -> str: from modelscope import snapshot_download try: - return snapshot_download( - model_id=model_id, ignore_patterns=ignore_patterns) + return snapshot_download(model_id=model_id, ignore_patterns=ignore_patterns) except Exception as e: raise RuntimeError(f'Failed to load model {model_id}: {e}') from e - def _init_corpus(self, - corpus: List[str], - bm25_k1: float = 1.5, - bm25_b: float = 0.75): + def _init_corpus(self, corpus: List[str], bm25_k1: float = 1.5, bm25_b: float = 0.75): """ Initialize corpus and build both Dense and Sparse indices. @@ -204,9 +195,7 @@ def _init_corpus(self, # Initialize Sparse Retriever (BM25) print('Building BM25 index...') - self.tokenized_corpus = [ - self.tokenizer_util.segment(doc) for doc in self.corpus - ] + self.tokenized_corpus = [self.tokenizer_util.segment(doc) for doc in self.corpus] self.bm25 = BM25Retriever( tokenized_corpus=self.tokenized_corpus, k1=bm25_k1, @@ -223,17 +212,17 @@ def _build_dense_index(self, texts: List[str]): faiss.normalize_L2(embeddings) self.index = faiss.IndexFlatIP(embeddings.shape[1]) self.index.add(embeddings) - print( - f'Successfully indexed {len(texts)} documents for Dense Retrieval.' - ) + print(f'Successfully indexed {len(texts)} documents for Dense Retrieval.') @staticmethod def _z_score_normalization(scores: List[float]) -> List[float]: """Apply Z-score normalization: z = (x - mean) / std.""" - if not scores: return [] # noqa: E701 + if not scores: + return [] # noqa: E701 arr = np.array(scores) std = np.std(arr) - if std == 0: return [0.0] * len(scores) # noqa: E701 + if std == 0: + return [0.0] * len(scores) # noqa: E701 mean = np.mean(arr) return ((arr - mean) / std).tolist() @@ -257,9 +246,7 @@ def _validate_corpus(self, corpus: List[str] = None): if corpus is not None and corpus != self.corpus: self._init_corpus(corpus=corpus) elif self.corpus is None: - raise ValueError( - 'Corpus is empty. Please provide a valid corpus for searching.' - ) + raise ValueError('Corpus is empty. Please provide a valid corpus for searching.') if self.index is None: raise ValueError('Index not built.') @@ -278,11 +265,7 @@ def _compute_dense_scores(self, query: str) -> List[float]: search_k: int = min(len(self.corpus), 500) dense_dists, dense_indices = self.index.search(x=query_vec, k=search_k) - dense_scores_map = { - idx: float(score) - for idx, score in zip(dense_indices[0], dense_dists[0]) - if idx != -1 - } + dense_scores_map = {idx: float(score) for idx, score in zip(dense_indices[0], dense_dists[0]) if idx != -1} return [dense_scores_map.get(i, 0.0) for i in range(len(self.corpus))] def _compute_sparse_scores(self, query: str) -> List[float]: @@ -394,8 +377,7 @@ def search( raw_bm25_scores = self._compute_sparse_scores(query) # Fuse and normalize scores - candidates = self._fuse_and_normalize_scores(raw_dense_scores, - raw_bm25_scores, alpha) + candidates = self._fuse_and_normalize_scores(raw_dense_scores, raw_bm25_scores, alpha) # Filter and rank results return self._filter_and_rank(candidates, top_k, min_score) @@ -432,12 +414,10 @@ async def async_search( dense_task = asyncio.to_thread(self._compute_dense_scores, query) sparse_task = asyncio.to_thread(self._compute_sparse_scores, query) - raw_dense_scores, raw_bm25_scores = await asyncio.gather( - dense_task, sparse_task) + raw_dense_scores, raw_bm25_scores = await asyncio.gather(dense_task, sparse_task) # Fuse and normalize scores - candidates = self._fuse_and_normalize_scores(raw_dense_scores, - raw_bm25_scores, alpha) + candidates = self._fuse_and_normalize_scores(raw_dense_scores, raw_bm25_scores, alpha) # Filter and rank results return self._filter_and_rank(candidates, top_k, min_score) diff --git a/ms_agent/sandbox/sandbox.py b/ms_agent/sandbox/sandbox.py index 8a5753761..5974781ae 100644 --- a/ms_agent/sandbox/sandbox.py +++ b/ms_agent/sandbox/sandbox.py @@ -9,8 +9,7 @@ class Sandbox: Base class for sandbox environments. """ - def __init__(self): - ... + def __init__(self): ... async def async_execute(self, *args, **kwargs): """ @@ -46,7 +45,7 @@ def __init__(self, **kwargs): super().__init__() self._init() - from ms_enclave.sandbox import SandboxConfig, DockerSandboxConfig + from ms_enclave.sandbox import DockerSandboxConfig, SandboxConfig # Mount host directories into the sandbox container if provided _volumes = kwargs.pop('volumes', None) or [] @@ -55,19 +54,12 @@ def __init__(self, **kwargs): for host_path, container_path, mode in _volumes: host_path = str(host_path) container_path = str(container_path) - self.volume_dict[host_path] = { - 'bind': container_path, - 'mode': mode - } + self.volume_dict[host_path] = {'bind': container_path, 'mode': mode} self.sandbox_config: SandboxConfig = DockerSandboxConfig( image=kwargs.pop('image', None) or 'python:3.11-slim', memory_limit=kwargs.pop('memory_limit', None) or '512m', - tools_config={ - 'python_executor': {}, - 'file_operation': {}, - 'shell_executor': {} - }, + tools_config={'python_executor': {}, 'file_operation': {}, 'shell_executor': {}}, volumes=self.volume_dict, ) @@ -81,17 +73,16 @@ def _init(): """ logger.info('Installing ms-enclave package...') try: - install_package( - package_name='ms-enclave', - import_name='ms_enclave', - extend_module='docker') + install_package(package_name='ms-enclave', import_name='ms_enclave', extend_module='docker') except Exception as e: raise e - async def async_execute(self, - python_code: Union[str, List[str]] = None, - shell_command: Union[str, List[str]] = None, - requirements: List[str] = None) -> Dict[str, Any]: + async def async_execute( + self, + python_code: Union[str, List[str]] = None, + shell_command: Union[str, List[str]] = None, + requirements: List[str] = None, + ) -> Dict[str, Any]: """ Asynchronously execute Python code and shell commands within the sandbox. @@ -132,20 +123,15 @@ async def async_execute(self, 'shell_executor': [], } - async with SandboxFactory.create_sandbox( - SandboxType.DOCKER, self.sandbox_config) as sandbox: - + async with SandboxFactory.create_sandbox(SandboxType.DOCKER, self.sandbox_config) as sandbox: if requirements: requirements_file = f'/{str(uuid.uuid4())}/requirements.txt' await sandbox.execute_tool( - 'file_operation', { - 'operation': 'write', - 'file_path': f'{requirements_file}', - 'content': '\n'.join(requirements) - }) - - result_requirements = await sandbox.execute_command( - f'pip install -r {requirements_file}') + 'file_operation', + {'operation': 'write', 'file_path': f'{requirements_file}', 'content': '\n'.join(requirements)}, + ) + + result_requirements = await sandbox.execute_command(f'pip install -r {requirements_file}') logger.info(result_requirements.stdout) if python_code: @@ -153,17 +139,11 @@ async def async_execute(self, python_code = [python_code] for py_item in python_code: - py_result = await sandbox.execute_tool( - 'python_executor', {'code': py_item}) - - results['python_executor'].append({ - 'output': - py_result.output, - 'error': - py_result.error, - 'status': - py_result.status - }) + py_result = await sandbox.execute_tool('python_executor', {'code': py_item}) + + results['python_executor'].append( + {'output': py_result.output, 'error': py_result.error, 'status': py_result.status} + ) if shell_command: if isinstance(shell_command, str): @@ -172,21 +152,18 @@ async def async_execute(self, for shell_item in shell_command: shell_result = await sandbox.execute_command(shell_item) - results['shell_executor'].append({ - 'output': - shell_result.stdout, - 'error': - shell_result.stderr, - 'status': - shell_result.status - }) + results['shell_executor'].append( + {'output': shell_result.stdout, 'error': shell_result.stderr, 'status': shell_result.status} + ) return results - def execute(self, - python_code: Union[str, List[str]] = None, - shell_command: Union[str, List[str]] = None, - requirements: List[str] = None) -> Dict[str, Any]: + def execute( + self, + python_code: Union[str, List[str]] = None, + shell_command: Union[str, List[str]] = None, + requirements: List[str] = None, + ) -> Dict[str, Any]: """ Synchronously execute Python code and shell commands within the sandbox. @@ -203,7 +180,5 @@ def execute(self, import asyncio return asyncio.run( - self.async_execute( - python_code=python_code, - shell_command=shell_command, - requirements=requirements)) + self.async_execute(python_code=python_code, shell_command=shell_command, requirements=requirements) + ) diff --git a/ms_agent/skill/auto_skills.py b/ms_agent/skill/auto_skills.py index 170c49c91..36b09eb9d 100644 --- a/ms_agent/skill/auto_skills.py +++ b/ms_agent/skill/auto_skills.py @@ -625,7 +625,7 @@ async def _execute_with_progressive_analysis( skill_id=skill_id, success=False, error= - f'Skill cannot handle query: {context.plan.reasoning if context.plan else "No plan"}' + f'Skill cannot handle query: {context.plan.reasoning if context.plan else 'No plan'}' ) if not commands: @@ -872,7 +872,7 @@ async def _execute_command_with_retry( additional_reqs = analysis.get('additional_requirements', []) logger.info( - f'[{skill_id}] Error analysis: type={error_info.get("error_type")}, ' + f'[{skill_id}] Error analysis: type={error_info.get('error_type')}, ' f'fixable={is_fixable}') # Apply fix if available @@ -1402,13 +1402,13 @@ def _filter_skills( else: logger.info( f'Removing skill [{sid}]: cannot execute - ' - f'{analysis.get("reason", "")[:200]}' + f'{analysis.get('reason', '')[:200]}' ) filtered_ids = final_ids logger.info( f'Filter ({mode}): {len(skill_ids)} -> {len(filtered_ids)} skills. ' - f'Reason: {parsed.get("reasoning", "")[:1000]}' + f'Reason: {parsed.get('reasoning', '')[:1000]}' ) return set(filtered_ids) diff --git a/ms_agent/skill/container.py b/ms_agent/skill/container.py index 51d96f6f3..d9806097e 100644 --- a/ms_agent/skill/container.py +++ b/ms_agent/skill/container.py @@ -10,6 +10,7 @@ - use_sandbox=True: Execute in Docker sandbox (default, recommended for untrusted code) - use_sandbox=False: Execute locally with security checks (for trusted code or no Docker) """ + import asyncio import os import platform @@ -54,6 +55,7 @@ class ExecutorType(Enum): """Supported executor types for skill execution.""" + PYTHON_SCRIPT = 'python_script' PYTHON_CODE = 'python_code' PYTHON_FUNCTION = 'python_function' @@ -63,6 +65,7 @@ class ExecutorType(Enum): class ExecutionStatus(Enum): """Execution status codes.""" + PENDING = 'pending' RUNNING = 'running' SUCCESS = 'success' @@ -86,6 +89,7 @@ class ExecutionInput: working_dir: Working directory for execution. requirements: Python packages to install before execution. """ + args: List[Any] = field(default_factory=list) kwargs: Dict[str, Any] = field(default_factory=dict) env_vars: Dict[str, str] = field(default_factory=dict) @@ -99,8 +103,7 @@ def to_dict(self) -> Dict[str, Any]: 'args': self.args, 'kwargs': self.kwargs, 'env_vars': self.env_vars, - 'input_files': {k: str(v) - for k, v in self.input_files.items()}, + 'input_files': {k: str(v) for k, v in self.input_files.items()}, 'stdin': self.stdin, 'working_dir': str(self.working_dir) if self.working_dir else None, 'requirements': self.requirements, @@ -121,6 +124,7 @@ class ExecutionOutput: artifacts: Any generated artifacts (data, objects, etc.). duration_ms: Execution duration in milliseconds. """ + return_value: Any = None stdout: str = '' stderr: str = '' @@ -131,13 +135,11 @@ class ExecutionOutput: def to_dict(self) -> Dict[str, Any]: return { - 'return_value': - str(self.return_value) if self.return_value else None, + 'return_value': str(self.return_value) if self.return_value else None, 'stdout': self.stdout, 'stderr': self.stderr, 'exit_code': self.exit_code, - 'output_files': {k: str(v) - for k, v in self.output_files.items()}, + 'output_files': {k: str(v) for k, v in self.output_files.items()}, 'artifacts': list(self.artifacts.keys()), 'duration_ms': self.duration_ms, } @@ -162,6 +164,7 @@ class ExecutionRecord: error_message: Error message if failed. sandbox_used: Whether sandbox was used for execution. """ + execution_id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) skill_id: str = '' executor_type: ExecutorType = ExecutorType.PYTHON_SCRIPT @@ -209,8 +212,7 @@ def to_markdown(self) -> str: for name, path in self.input_spec.input_files.items(): lines.append(f' - `{name}`: `{path}`') if self.input_spec.requirements: - lines.append( - f'- **Requirements**: `{self.input_spec.requirements}`') + lines.append(f'- **Requirements**: `{self.input_spec.requirements}`') # Output section lines.extend(['', '#### Output', '']) @@ -228,8 +230,7 @@ def to_markdown(self) -> str: lines.append(f' - `{name}`: `{path}`') if self.error_message: - lines.extend( - ['', '#### Error', '', f'```\n{self.error_message}\n```']) + lines.extend(['', '#### Error', '', f'```\n{self.error_message}\n```']) lines.append('') return '\n'.join(lines) @@ -248,6 +249,7 @@ class ExecutionSpec: created_at: Creation timestamp. upstream_outputs: Outputs from upstream skills available as inputs. """ + spec_id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) title: str = 'Skill Execution Spec' description: str = '' @@ -285,26 +287,25 @@ def to_markdown(self) -> str: # Summary total = len(self.records) - success = sum(1 for r in self.records - if r.status == ExecutionStatus.SUCCESS) - failed = sum(1 for r in self.records - if r.status == ExecutionStatus.FAILED) - blocked = sum(1 for r in self.records - if r.status == ExecutionStatus.SECURITY_BLOCKED) - - lines.extend([ - '## Summary', - '', - f'- **Total Executions**: {total}', - f'- **Successful**: {success}', - f'- **Failed**: {failed}', - f'- **Security Blocked**: {blocked}', - '', - '---', - '', - '## Execution Records', - '', - ]) + success = sum(1 for r in self.records if r.status == ExecutionStatus.SUCCESS) + failed = sum(1 for r in self.records if r.status == ExecutionStatus.FAILED) + blocked = sum(1 for r in self.records if r.status == ExecutionStatus.SECURITY_BLOCKED) + + lines.extend( + [ + '## Summary', + '', + f'- **Total Executions**: {total}', + f'- **Successful**: {success}', + f'- **Failed**: {failed}', + f'- **Security Blocked**: {blocked}', + '', + '---', + '', + '## Execution Records', + '', + ] + ) for record in self.records: lines.append(record.to_markdown()) @@ -342,14 +343,16 @@ class SkillContainer: SANDBOX_OUTPUT_DIR = '/sandbox/outputs' SANDBOX_WORK_DIR = '/sandbox/scripts' - def __init__(self, - workspace_dir: Optional[Union[str, Path]] = None, - timeout: int = 300, - image: str = 'python:3.11-slim', - memory_limit: str = '512m', - enable_security_check: bool = True, - network_enabled: bool = False, - use_sandbox: bool = True): + def __init__( + self, + workspace_dir: Optional[Union[str, Path]] = None, + timeout: int = 300, + image: str = 'python:3.11-slim', + memory_limit: str = '512m', + enable_security_check: bool = True, + network_enabled: bool = False, + use_sandbox: bool = True, + ): """ Initialize the skill container. @@ -366,8 +369,7 @@ def __init__(self, if workspace_dir: self.workspace_dir = Path(workspace_dir).resolve() else: - self.workspace_dir = Path( - tempfile.mkdtemp(prefix='skill_container_')).resolve() + self.workspace_dir = Path(tempfile.mkdtemp(prefix='skill_container_')).resolve() self.workspace_dir.mkdir(parents=True, exist_ok=True) self.timeout = timeout @@ -397,10 +399,12 @@ def __init__(self, logger.warning( 'SkillContainer running in LOCAL mode (use_sandbox=False). ' 'Scripts will execute directly on this machine. ' - 'Ensure you trust the code being executed!') + 'Ensure you trust the code being executed!' + ) - logger.info(f'SkillContainer initialized at: {self.workspace_dir} ' - f'[mode: {"sandbox" if self.use_sandbox else "local"}]') + logger.info( + f'SkillContainer initialized at: {self.workspace_dir} [mode: {"sandbox" if self.use_sandbox else "local"}]' + ) def _get_sandbox(self): """ @@ -423,8 +427,7 @@ def _get_sandbox(self): for skill_id, skill_dir in self._skill_dirs.items(): safe_id = skill_id.replace('@', '_').replace('/', '_') sandbox_path = f'{self.SANDBOX_ROOT}/skills/{safe_id}' - volumes.append( - (str(Path(skill_dir).resolve()), sandbox_path, 'ro')) + volumes.append((str(Path(skill_dir).resolve()), sandbox_path, 'ro')) self._sandbox = EnclaveSandbox( image=self.image, @@ -433,8 +436,7 @@ def _get_sandbox(self): ) return self._sandbox - def mount_skill_directory(self, skill_id: str, skill_dir: Union[str, - Path]): + def mount_skill_directory(self, skill_id: str, skill_dir: Union[str, Path]): """ Mount a skill directory for sandbox access. @@ -459,9 +461,7 @@ def get_skill_sandbox_path(self, skill_id: str) -> str: safe_id = skill_id.replace('@', '_').replace('/', '_') return f'{self.SANDBOX_ROOT}/skills/{safe_id}' - def _security_check(self, - code: str, - is_local: bool = False) -> tuple[bool, str]: + def _security_check(self, code: str, is_local: bool = False) -> tuple[bool, str]: """ Check code for potentially dangerous patterns. @@ -523,13 +523,15 @@ def _collect_output_files(self) -> Dict[str, Path]: outputs[f.name] = f return outputs - def _create_record(self, - skill_id: str, - executor_type: ExecutorType, - input_spec: ExecutionInput, - script_path: str = None, - function_name: str = None, - sandbox_used: bool = None) -> ExecutionRecord: + def _create_record( + self, + skill_id: str, + executor_type: ExecutorType, + input_spec: ExecutionInput, + script_path: str = None, + function_name: str = None, + sandbox_used: bool = None, + ) -> ExecutionRecord: """Create a new execution record.""" return ExecutionRecord( skill_id=skill_id, @@ -538,18 +540,16 @@ def _create_record(self, function_name=function_name, input_spec=input_spec, status=ExecutionStatus.PENDING, - sandbox_used=sandbox_used - if sandbox_used is not None else self.use_sandbox) + sandbox_used=sandbox_used if sandbox_used is not None else self.use_sandbox, + ) # ------------------------------------------------------------------------- # Local Execution Helpers (for use_sandbox=False mode) # ------------------------------------------------------------------------- - def _local_run_subprocess(self, - cmd: List[str], - env: Dict[str, str] = None, - cwd: Path = None, - stdin_input: str = None) -> tuple[str, str, int]: + def _local_run_subprocess( + self, cmd: List[str], env: Dict[str, str] = None, cwd: Path = None, stdin_input: str = None + ) -> tuple[str, str, int]: """ Run subprocess locally with security restrictions. @@ -607,8 +607,7 @@ def _get_node_executable(self) -> str: return 'node.exe' return 'node' - async def _local_install_requirements( - self, requirements: List[str]) -> tuple[bool, str]: + async def _local_install_requirements(self, requirements: List[str]) -> tuple[bool, str]: """ Install Python requirements locally using pip. @@ -623,8 +622,12 @@ async def _local_install_requirements( try: cmd = [ - self._get_python_executable(), '-m', 'pip', 'install', - '--quiet', '--disable-pip-version-check' + self._get_python_executable(), + '-m', + 'pip', + 'install', + '--quiet', + '--disable-pip-version-check', ] + requirements stdout, stderr, exit_code = self._local_run_subprocess(cmd) @@ -639,9 +642,7 @@ async def _local_install_requirements( logger.error(f'Error installing requirements: {e}') return False, str(e) - async def _local_execute_python_code( - self, code: str, - input_spec: ExecutionInput) -> tuple[str, str, int]: + async def _local_execute_python_code(self, code: str, input_spec: ExecutionInput) -> tuple[str, str, int]: """ Execute Python code locally. @@ -654,8 +655,7 @@ async def _local_execute_python_code( """ # Install requirements first if any if input_spec.requirements: - success, error = await self._local_install_requirements( - input_spec.requirements) + success, error = await self._local_install_requirements(input_spec.requirements) if not success: return '', f'Failed to install requirements: {error}', -1 @@ -677,10 +677,8 @@ async def _local_execute_python_code( cwd = input_spec.working_dir if input_spec.working_dir else None stdout, stderr, exit_code = self._local_run_subprocess( - cmd, - env=input_spec.env_vars, - cwd=cwd, - stdin_input=input_spec.stdin) + cmd, env=input_spec.env_vars, cwd=cwd, stdin_input=input_spec.stdin + ) # Keep script in scripts folder for logging/debugging return stdout, stderr, exit_code @@ -688,9 +686,7 @@ async def _local_execute_python_code( logger.error(f'Local Python execution failed: {e}') raise - async def _local_execute_shell( - self, command: str, - input_spec: ExecutionInput) -> tuple[str, str, int]: + async def _local_execute_shell(self, command: str, input_spec: ExecutionInput) -> tuple[str, str, int]: """ Execute shell command locally. @@ -707,30 +703,20 @@ async def _local_execute_shell( if platform.system() == 'Windows': # Windows: use set for environment env_cmds = [f'set {k}={v}' for k, v in input_spec.env_vars.items()] - full_cmd = ' && '.join(env_cmds - + [command]) if env_cmds else command + full_cmd = ' && '.join(env_cmds + [command]) if env_cmds else command cmd = shell_exec + [full_cmd] else: # Unix: use export - env_cmds = [ - f"export {k}='{v}'" for k, v in input_spec.env_vars.items() - ] - full_cmd = ' && '.join(env_cmds - + [command]) if env_cmds else command + env_cmds = [f"export {k}='{v}'" for k, v in input_spec.env_vars.items()] + full_cmd = ' && '.join(env_cmds + [command]) if env_cmds else command cmd = shell_exec + [full_cmd] # Use working_dir from input_spec for proper resource access cwd = input_spec.working_dir if input_spec.working_dir else None - return self._local_run_subprocess( - cmd, - env=input_spec.env_vars, - cwd=cwd, - stdin_input=input_spec.stdin) + return self._local_run_subprocess(cmd, env=input_spec.env_vars, cwd=cwd, stdin_input=input_spec.stdin) - async def _local_execute_javascript( - self, js_code: str, - input_spec: ExecutionInput) -> tuple[str, str, int]: + async def _local_execute_javascript(self, js_code: str, input_spec: ExecutionInput) -> tuple[str, str, int]: """ Execute JavaScript code locally via Node.js. @@ -759,11 +745,7 @@ async def _local_execute_javascript( cwd = input_spec.working_dir if input_spec.working_dir else None # Keep script in scripts folder for logging/debugging - return self._local_run_subprocess( - cmd, - env=input_spec.env_vars, - cwd=cwd, - stdin_input=input_spec.stdin) + return self._local_run_subprocess(cmd, env=input_spec.env_vars, cwd=cwd, stdin_input=input_spec.stdin) except Exception as e: logger.error(f'Local JavaScript execution failed: {e}') raise @@ -790,16 +772,18 @@ def _generate_local_env_setup(self, input_spec: ExecutionInput) -> str: # Add working directory to sys.path for imports and change to it if input_spec.working_dir: work_dir = str(input_spec.working_dir) - lines.extend([ - '', - '# Setup working directory for resource access (READ-ONLY for resources)', - f'_skill_dir = {repr(work_dir)}', - "os.environ['SKILL_DIR'] = _skill_dir", - 'SKILL_DIR = _skill_dir', - 'if _skill_dir not in sys.path:', - ' sys.path.insert(0, _skill_dir)', - 'os.chdir(_skill_dir)', - ]) + lines.extend( + [ + '', + '# Setup working directory for resource access (READ-ONLY for resources)', + f'_skill_dir = {repr(work_dir)}', + "os.environ['SKILL_DIR'] = _skill_dir", + 'SKILL_DIR = _skill_dir', + 'if _skill_dir not in sys.path:', + ' sys.path.insert(0, _skill_dir)', + 'os.chdir(_skill_dir)', + ] + ) # Add custom env vars for key, value in input_spec.env_vars.items(): @@ -830,8 +814,7 @@ def _generate_local_js_env_setup(self, input_spec: ExecutionInput) -> str: lines.append('') return '\n'.join(lines) - def _parse_sandbox_result(self, - results: Dict[str, Any]) -> tuple[str, str, int]: + def _parse_sandbox_result(self, results: Dict[str, Any]) -> tuple[str, str, int]: """Parse sandbox execution results into stdout, stderr, exit_code.""" stdout_parts = [] stderr_parts = [] @@ -850,22 +833,20 @@ def _parse_sandbox_result(self, return '\n'.join(stdout_parts), '\n'.join(stderr_parts), exit_code async def _execute_in_sandbox( - self, - python_code: Union[str, List[str]] = None, - shell_command: Union[str, List[str]] = None, - requirements: List[str] = None) -> Dict[str, Any]: + self, + python_code: Union[str, List[str]] = None, + shell_command: Union[str, List[str]] = None, + requirements: List[str] = None, + ) -> Dict[str, Any]: """Execute code in EnclaveSandbox.""" sandbox = self._get_sandbox() return await sandbox.async_execute( - python_code=python_code, - shell_command=shell_command, - requirements=requirements) + python_code=python_code, shell_command=shell_command, requirements=requirements + ) async def execute_python_script( - self, - script_path: Union[str, Path], - skill_id: str = 'unknown', - input_spec: ExecutionInput = None) -> ExecutionOutput: + self, script_path: Union[str, Path], skill_id: str = 'unknown', input_spec: ExecutionInput = None + ) -> ExecutionOutput: """ Execute a Python script file. @@ -886,7 +867,8 @@ async def execute_python_script( skill_id=skill_id, executor_type=ExecutorType.PYTHON_SCRIPT, input_spec=input_spec, - script_path=str(script_path)) + script_path=str(script_path), + ) record.start_time = datetime.now() record.status = ExecutionStatus.RUNNING @@ -897,13 +879,11 @@ async def execute_python_script( code = f.read() # Security check (stricter for local mode) - is_safe, reason = self._security_check( - code, is_local=not self.use_sandbox) + is_safe, reason = self._security_check(code, is_local=not self.use_sandbox) if not is_safe: record.status = ExecutionStatus.SECURITY_BLOCKED record.error_message = reason - output = ExecutionOutput( - stderr=f'Security check failed: {reason}', exit_code=-1) + output = ExecutionOutput(stderr=f'Security check failed: {reason}', exit_code=-1) record.end_time = datetime.now() record.output_spec = output self.spec.add_record(record) @@ -916,14 +896,11 @@ async def execute_python_script( env_setup = self._generate_env_setup(input_spec, {}) full_code = env_setup + '\n' + code - results = await self._execute_in_sandbox( - python_code=full_code, - requirements=input_spec.requirements) + results = await self._execute_in_sandbox(python_code=full_code, requirements=input_spec.requirements) stdout, stderr, exit_code = self._parse_sandbox_result(results) else: # Local mode: execute directly - stdout, stderr, exit_code = await self._local_execute_python_code( - code, input_spec) + stdout, stderr, exit_code = await self._local_execute_python_code(code, input_spec) end_time = datetime.now() @@ -932,11 +909,10 @@ async def execute_python_script( stderr=stderr, exit_code=exit_code, output_files=self._collect_output_files(), - duration_ms=(end_time - start_time).total_seconds() * 1000) + duration_ms=(end_time - start_time).total_seconds() * 1000, + ) - record.status = ( - ExecutionStatus.SUCCESS - if exit_code == 0 else ExecutionStatus.FAILED) + record.status = ExecutionStatus.SUCCESS if exit_code == 0 else ExecutionStatus.FAILED except Exception as e: output = ExecutionOutput(stderr=str(e), exit_code=-1) @@ -950,10 +926,8 @@ async def execute_python_script( return output async def execute_python_code( - self, - code: str, - skill_id: str = 'unknown', - input_spec: ExecutionInput = None) -> ExecutionOutput: + self, code: str, skill_id: str = 'unknown', input_spec: ExecutionInput = None + ) -> ExecutionOutput: """ Execute Python code string. @@ -970,23 +944,19 @@ async def execute_python_code( input_spec = input_spec or ExecutionInput() record = self._create_record( - skill_id=skill_id, - executor_type=ExecutorType.PYTHON_CODE, - input_spec=input_spec, - script_path='') + skill_id=skill_id, executor_type=ExecutorType.PYTHON_CODE, input_spec=input_spec, script_path='' + ) record.start_time = datetime.now() record.status = ExecutionStatus.RUNNING try: # Security check (stricter for local mode) - is_safe, reason = self._security_check( - code, is_local=not self.use_sandbox) + is_safe, reason = self._security_check(code, is_local=not self.use_sandbox) if not is_safe: record.status = ExecutionStatus.SECURITY_BLOCKED record.error_message = reason - output = ExecutionOutput( - stderr=f'Security check failed: {reason}', exit_code=-1) + output = ExecutionOutput(stderr=f'Security check failed: {reason}', exit_code=-1) record.end_time = datetime.now() record.output_spec = output self.spec.add_record(record) @@ -999,14 +969,11 @@ async def execute_python_code( env_setup = self._generate_env_setup(input_spec, {}) full_code = env_setup + '\n' + code - results = await self._execute_in_sandbox( - python_code=full_code, - requirements=input_spec.requirements) + results = await self._execute_in_sandbox(python_code=full_code, requirements=input_spec.requirements) stdout, stderr, exit_code = self._parse_sandbox_result(results) else: # Local mode - stdout, stderr, exit_code = await self._local_execute_python_code( - code, input_spec) + stdout, stderr, exit_code = await self._local_execute_python_code(code, input_spec) end_time = datetime.now() @@ -1015,11 +982,10 @@ async def execute_python_code( stderr=stderr, exit_code=exit_code, output_files=self._collect_output_files(), - duration_ms=(end_time - start_time).total_seconds() * 1000) + duration_ms=(end_time - start_time).total_seconds() * 1000, + ) - record.status = ( - ExecutionStatus.SUCCESS - if exit_code == 0 else ExecutionStatus.FAILED) + record.status = ExecutionStatus.SUCCESS if exit_code == 0 else ExecutionStatus.FAILED except Exception as e: output = ExecutionOutput(stderr=str(e), exit_code=-1) @@ -1032,8 +998,7 @@ async def execute_python_code( self.spec.add_record(record) return output - def _generate_env_setup(self, input_spec: ExecutionInput, - sandbox_files: Dict[str, str]) -> str: + def _generate_env_setup(self, input_spec: ExecutionInput, sandbox_files: Dict[str, str]) -> str: """Generate Python code to setup environment variables and paths.""" sandbox_logs_dir = f'{self.SANDBOX_ROOT}/logs' lines = [ @@ -1071,10 +1036,8 @@ def _generate_env_setup(self, input_spec: ExecutionInput, return '\n'.join(lines) def execute_python_function( - self, - func: Callable, - skill_id: str = 'unknown', - input_spec: ExecutionInput = None) -> ExecutionOutput: + self, func: Callable, skill_id: str = 'unknown', input_spec: ExecutionInput = None + ) -> ExecutionOutput: """ Execute a Python function directly (local execution, not sandboxed). @@ -1095,7 +1058,8 @@ def execute_python_function( skill_id=skill_id, executor_type=ExecutorType.PYTHON_FUNCTION, input_spec=input_spec, - function_name=func.__name__) + function_name=func.__name__, + ) record.sandbox_used = False # Local execution record.start_time = datetime.now() @@ -1114,7 +1078,8 @@ def execute_python_function( return_value=return_value, exit_code=0, output_files=self._collect_output_files(), - duration_ms=(end_time - start_time).total_seconds() * 1000) + duration_ms=(end_time - start_time).total_seconds() * 1000, + ) record.status = ExecutionStatus.SUCCESS @@ -1130,10 +1095,8 @@ def execute_python_function( return output async def execute_shell( - self, - command: Union[str, List[str]], - skill_id: str = 'unknown', - input_spec: ExecutionInput = None) -> ExecutionOutput: + self, command: Union[str, List[str]], skill_id: str = 'unknown', input_spec: ExecutionInput = None + ) -> ExecutionOutput: """ Execute a shell command. @@ -1152,23 +1115,19 @@ async def execute_shell( cmd_str = command if isinstance(command, str) else ' && '.join(command) record = self._create_record( - skill_id=skill_id, - executor_type=ExecutorType.SHELL, - input_spec=input_spec, - script_path=cmd_str[:200]) + skill_id=skill_id, executor_type=ExecutorType.SHELL, input_spec=input_spec, script_path=cmd_str[:200] + ) record.start_time = datetime.now() record.status = ExecutionStatus.RUNNING try: # Security check (stricter for local mode) - is_safe, reason = self._security_check( - cmd_str, is_local=not self.use_sandbox) + is_safe, reason = self._security_check(cmd_str, is_local=not self.use_sandbox) if not is_safe: record.status = ExecutionStatus.SECURITY_BLOCKED record.error_message = reason - output = ExecutionOutput( - stderr=f'Security check failed: {reason}', exit_code=-1) + output = ExecutionOutput(stderr=f'Security check failed: {reason}', exit_code=-1) record.end_time = datetime.now() record.output_spec = output self.spec.add_record(record) @@ -1187,13 +1146,11 @@ async def execute_shell( full_cmd = ' && '.join(env_exports + [cmd_str]) - results = await self._execute_in_sandbox(shell_command=full_cmd - ) + results = await self._execute_in_sandbox(shell_command=full_cmd) stdout, stderr, exit_code = self._parse_sandbox_result(results) else: # Local mode - stdout, stderr, exit_code = await self._local_execute_shell( - cmd_str, input_spec) + stdout, stderr, exit_code = await self._local_execute_shell(cmd_str, input_spec) end_time = datetime.now() @@ -1202,11 +1159,10 @@ async def execute_shell( stderr=stderr, exit_code=exit_code, output_files=self._collect_output_files(), - duration_ms=(end_time - start_time).total_seconds() * 1000) + duration_ms=(end_time - start_time).total_seconds() * 1000, + ) - record.status = ( - ExecutionStatus.SUCCESS - if exit_code == 0 else ExecutionStatus.FAILED) + record.status = ExecutionStatus.SUCCESS if exit_code == 0 else ExecutionStatus.FAILED except Exception as e: output = ExecutionOutput(stderr=str(e), exit_code=-1) @@ -1219,12 +1175,14 @@ async def execute_shell( self.spec.add_record(record) return output - async def execute_javascript(self, - script_path: Union[str, Path] = None, - code: str = None, - skill_id: str = 'unknown', - input_spec: ExecutionInput = None, - runtime: str = 'node') -> ExecutionOutput: + async def execute_javascript( + self, + script_path: Union[str, Path] = None, + code: str = None, + skill_id: str = 'unknown', + input_spec: ExecutionInput = None, + runtime: str = 'node', + ) -> ExecutionOutput: """ Execute JavaScript code via Node.js. @@ -1246,7 +1204,8 @@ async def execute_javascript(self, skill_id=skill_id, executor_type=ExecutorType.JAVASCRIPT, input_spec=input_spec, - script_path=str(script_path) if script_path else '') + script_path=str(script_path) if script_path else '', + ) record.start_time = datetime.now() record.status = ExecutionStatus.RUNNING @@ -1262,13 +1221,11 @@ async def execute_javascript(self, raise ValueError('Either script_path or code must be provided') # Security check (stricter for local mode) - is_safe, reason = self._security_check( - js_code, is_local=not self.use_sandbox) + is_safe, reason = self._security_check(js_code, is_local=not self.use_sandbox) if not is_safe: record.status = ExecutionStatus.SECURITY_BLOCKED record.error_message = reason - output = ExecutionOutput( - stderr=f'Security check failed: {reason}', exit_code=-1) + output = ExecutionOutput(stderr=f'Security check failed: {reason}', exit_code=-1) record.end_time = datetime.now() record.output_spec = output self.spec.add_record(record) @@ -1293,13 +1250,11 @@ async def execute_javascript(self, args_str = ' '.join(f'"{arg}"' for arg in input_spec.args) shell_cmd = f'{runtime} {sandbox_js_path} {args_str}' - results = await self._execute_in_sandbox( - shell_command=shell_cmd) + results = await self._execute_in_sandbox(shell_command=shell_cmd) stdout, stderr, exit_code = self._parse_sandbox_result(results) else: # Local mode - stdout, stderr, exit_code = await self._local_execute_javascript( - js_code, input_spec) + stdout, stderr, exit_code = await self._local_execute_javascript(js_code, input_spec) end_time = datetime.now() @@ -1308,11 +1263,10 @@ async def execute_javascript(self, stderr=stderr, exit_code=exit_code, output_files=self._collect_output_files(), - duration_ms=(end_time - start_time).total_seconds() * 1000) + duration_ms=(end_time - start_time).total_seconds() * 1000, + ) - record.status = ( - ExecutionStatus.SUCCESS - if exit_code == 0 else ExecutionStatus.FAILED) + record.status = ExecutionStatus.SUCCESS if exit_code == 0 else ExecutionStatus.FAILED except Exception as e: output = ExecutionOutput(stderr=str(e), exit_code=-1) @@ -1325,8 +1279,7 @@ async def execute_javascript(self, self.spec.add_record(record) return output - def _generate_js_env_setup(self, input_spec: ExecutionInput, - sandbox_files: Dict[str, str]) -> str: + def _generate_js_env_setup(self, input_spec: ExecutionInput, sandbox_files: Dict[str, str]) -> str: """Generate JavaScript code to setup environment.""" lines = [ '// Environment setup', @@ -1340,15 +1293,17 @@ def _generate_js_env_setup(self, input_spec: ExecutionInput, lines.append('') return '\n'.join(lines) - async def execute(self, - executor_type: ExecutorType, - skill_id: str = 'unknown', - script_path: Union[str, Path] = None, - func: Callable = None, - command: Union[str, List[str]] = None, - code: str = None, - input_spec: ExecutionInput = None, - **kwargs) -> ExecutionOutput: + async def execute( + self, + executor_type: ExecutorType, + skill_id: str = 'unknown', + script_path: Union[str, Path] = None, + func: Callable = None, + command: Union[str, List[str]] = None, + code: str = None, + input_spec: ExecutionInput = None, + **kwargs, + ) -> ExecutionOutput: """ Unified async execution interface. @@ -1366,40 +1321,25 @@ async def execute(self, ExecutionOutput with results. """ if executor_type == ExecutorType.PYTHON_SCRIPT: - return await self.execute_python_script( - script_path=script_path, - skill_id=skill_id, - input_spec=input_spec) + return await self.execute_python_script(script_path=script_path, skill_id=skill_id, input_spec=input_spec) elif executor_type == ExecutorType.PYTHON_CODE: - return await self.execute_python_code( - code=code, skill_id=skill_id, input_spec=input_spec) + return await self.execute_python_code(code=code, skill_id=skill_id, input_spec=input_spec) elif executor_type == ExecutorType.PYTHON_FUNCTION: - return self.execute_python_function( - func=func, skill_id=skill_id, input_spec=input_spec) + return self.execute_python_function(func=func, skill_id=skill_id, input_spec=input_spec) elif executor_type == ExecutorType.SHELL: - return await self.execute_shell( - command=command, skill_id=skill_id, input_spec=input_spec) + return await self.execute_shell(command=command, skill_id=skill_id, input_spec=input_spec) elif executor_type == ExecutorType.JAVASCRIPT: return await self.execute_javascript( - script_path=script_path, - code=code, - skill_id=skill_id, - input_spec=input_spec, - **kwargs) + script_path=script_path, code=code, skill_id=skill_id, input_spec=input_spec, **kwargs + ) else: raise ValueError(f'Unsupported executor type: {executor_type}') - def execute_sync(self, - executor_type: ExecutorType, - skill_id: str = 'unknown', - **kwargs) -> ExecutionOutput: + def execute_sync(self, executor_type: ExecutorType, skill_id: str = 'unknown', **kwargs) -> ExecutionOutput: """Synchronous wrapper for execute().""" return asyncio.run(self.execute(executor_type, skill_id, **kwargs)) - def link_skills(self, - upstream_skill_id: str, - downstream_input_key: str, - output_key: str = None) -> Optional[Any]: + def link_skills(self, upstream_skill_id: str, downstream_input_key: str, output_key: str = None) -> Optional[Any]: """ Link output from upstream skill to downstream skill input. diff --git a/ms_agent/skill/loader.py b/ms_agent/skill/loader.py index 1f5dca2a7..0aae6cd1a 100644 --- a/ms_agent/skill/loader.py +++ b/ms_agent/skill/loader.py @@ -21,9 +21,7 @@ def __init__(self): self.loaded_skills: Dict[str, SkillSchema] = {} self.parser = SkillSchemaParser() - def load_skills( - self, skills: Union[str, List[str], List[SkillSchema]] - ) -> Dict[str, SkillSchema]: + def load_skills(self, skills: Union[str, List[str], List[SkillSchema]]) -> Dict[str, SkillSchema]: """ Load agent skills from various sources. @@ -42,30 +40,27 @@ def load_skills( return all_skills def is_skill_id(s: str) -> bool: - return '/' in s and len(s.split('/')) == 2 and all( - s.split('/')) and not os.path.exists(s) + return '/' in s and len(s.split('/')) == 2 and all(s.split('/')) and not os.path.exists(s) if isinstance(skills, str): # Could be a single skill path, root path of skills, or skill ID on ModelScope hub skill_list = [skills] - elif all(isinstance(s, str) for s in skills) or all( - isinstance(s, SkillSchema) for s in skills): + elif all(isinstance(s, str) for s in skills) or all(isinstance(s, SkillSchema) for s in skills): skill_list = skills else: raise ValueError('Invalid skills input type.') for skill in skill_list: - if is_skill_id(skill): from modelscope import snapshot_download + skill_path: str = snapshot_download(repo_id=skill) skill = skill_path if isinstance(skill, SkillSchema): skill_key = self._get_skill_key(skill=skill) all_skills[skill_key] = skill - logger.info( - f'Loaded skill from SkillSchema object: {skill_key}') + logger.info(f'Loaded skill from SkillSchema object: {skill_key}') continue skill_dir: Path = Path(skill) @@ -81,8 +76,7 @@ def is_skill_id(s: str) -> bool: all_skills[skill_key] = skill_schema # logger.info(f'Successfully loaded skill: {skill_key}') else: - skill_schema_dict: Dict[ - str, SkillSchema] = self._scan_and_load_skills(skill_dir) + skill_schema_dict: Dict[str, SkillSchema] = self._scan_and_load_skills(skill_dir) all_skills.update(skill_schema_dict) self.loaded_skills.update(all_skills) @@ -227,9 +221,7 @@ def reload_skill(self, skill_path: str) -> Optional[SkillSchema]: return skill -def load_skills( - skills: Union[str, List[str], - List[SkillSchema]]) -> Dict[str, SkillSchema]: +def load_skills(skills: Union[str, List[str], List[SkillSchema]]) -> Dict[str, SkillSchema]: """ Convenience function to load skills without creating a SkillLoader instance. diff --git a/ms_agent/skill/schema.py b/ms_agent/skill/schema.py index 722e0acc4..312556885 100644 --- a/ms_agent/skill/schema.py +++ b/ms_agent/skill/schema.py @@ -5,19 +5,20 @@ Defines the data structure and validation logic for Agent Skills. Each Skill is represented as a self-contained directory with metadata. """ + import re from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Union import yaml + from ms_agent.utils.logger import logger from .spec import Spec SUPPORTED_SCRIPT_EXT = ('.py', '.sh', '.js') -SUPPORTED_READ_EXT = ('.md', '.txt', '.py', '.json', '.yaml', '.yml', '.sh', - '.js', '.html', '.xml') +SUPPORTED_READ_EXT = ('.md', '.txt', '.py', '.json', '.yaml', '.yml', '.sh', '.js', '.html', '.xml') @dataclass @@ -31,6 +32,7 @@ class SkillFile: path: Relative path within Skill directory required: Whether this file is required """ + name: str type: str path: Path @@ -55,12 +57,7 @@ def to_dict(self): Returns: Dictionary containing file information """ - return { - 'name': self.name, - 'type': self.type, - 'path': str(self.path), - 'required': self.required - } + return {'name': self.name, 'type': self.type, 'path': str(self.path), 'required': self.required} @dataclass @@ -81,6 +78,7 @@ class SkillSchema: scripts: List of script files (optional) references: List of reference documents (optional) """ + skill_id: str name: str description: str @@ -141,8 +139,7 @@ def validate(self) -> bool: return True except Exception as e: - logger.error( - f'Skill validation failed with an unexpected error: {e}') + logger.error(f'Skill validation failed with an unexpected error: {e}') return False def get_file_by_name(self, name: str) -> Optional[SkillFile]: @@ -168,32 +165,17 @@ def to_dict(self) -> Dict[str, Any]: Dictionary containing all schema information """ return { - 'skill_id': - self.skill_id, - 'name': - self.name, - 'description': - self.description, - 'version': - self.version, - 'author': - self.author, - 'tags': - self.tags, - 'skill_path': - str(self.skill_path), - 'files': [{ - 'name': f.name, - 'type': f.type, - 'path': f.path, - 'required': f.required - } for f in self.files], - 'scripts': - self.scripts, - 'references': - self.references, - 'resources': - self.resources, + 'skill_id': self.skill_id, + 'name': self.name, + 'description': self.description, + 'version': self.version, + 'author': self.author, + 'tags': self.tags, + 'skill_path': str(self.skill_path), + 'files': [{'name': f.name, 'type': f.type, 'path': f.path, 'required': f.required} for f in self.files], + 'scripts': self.scripts, + 'references': self.references, + 'resources': self.resources, } @@ -235,10 +217,7 @@ def is_ignored_path(p: Path) -> bool: Returns: True if path should be ignored, False otherwise """ - ignored_names = { - '.DS_Store', '__pycache__', '.git', '.gitignore', '.pytest_cache', - '.mypy_cache' - } + ignored_names = {'.DS_Store', '__pycache__', '.git', '.gitignore', '.pytest_cache', '.mypy_cache'} ignored_suffixes = {'.pyc', '.pyo'} return (p.name in ignored_names) or (p.suffix in ignored_suffixes) @@ -287,17 +266,14 @@ def parse_skill_directory(directory_path: Path) -> Optional[SkillSchema]: file_type = file_path.suffix if file_path.suffix else '.unknown' skill_file = SkillFile( - name=file_path.name, - type=file_type, - path=file_path, - required=(file_path.name == 'SKILL.md')) + name=file_path.name, type=file_type, path=file_path, required=(file_path.name == 'SKILL.md') + ) files.append(skill_file) # Get scripts, references and resources if skill_file.type in SUPPORTED_SCRIPT_EXT: scripts.append(skill_file) - elif skill_file.type in ['.md' - ] and skill_file.name != 'SKILL.md': + elif skill_file.type in ['.md'] and skill_file.name != 'SKILL.md': references.append(skill_file) else: resources.append(skill_file) @@ -370,6 +346,7 @@ class SkillExecutionPlan: parameters: Parameters extracted from user query. reasoning: Explanation of the plan. """ + can_handle: bool = False plan_summary: str = '' steps: List[Dict[str, Any]] = field(default_factory=list) @@ -396,8 +373,7 @@ class SkillContext: query: str = '' # The working directory (absolute path to skills folder's parent directory) - root_path: Path = field( - default_factory=lambda: Path.cwd().parent.resolve()) + root_path: Path = field(default_factory=lambda: Path.cwd().parent.resolve()) # Execution plan from progressive analysis plan: Optional[SkillExecutionPlan] = None @@ -464,10 +440,7 @@ def get_references_list(self) -> List[str]: def get_resources_list(self) -> List[str]: """Get list of available resource names without loading content.""" - return [ - r.name for r in self.skill.resources - if r.name not in ['SKILL.md', 'LICENSE.txt'] - ] + return [r.name for r in self.skill.resources if r.name not in ['SKILL.md', 'LICENSE.txt']] def _get_resource_path(self, file_path: Path) -> str: """ @@ -505,13 +478,15 @@ def load_scripts(self, names: List[str] = None) -> List[Dict[str, Any]]: loaded = [] for script in target_scripts: abs_path = script.path.resolve() - loaded.append({ - 'name': script.name, - 'file': script.to_dict(), - 'path': self._get_resource_path(script.path), - 'abs_path': str(abs_path), - 'content': self._read_file_content(abs_path), - }) + loaded.append( + { + 'name': script.name, + 'file': script.to_dict(), + 'path': self._get_resource_path(script.path), + 'abs_path': str(abs_path), + 'content': self._read_file_content(abs_path), + } + ) self.scripts.extend(loaded) return loaded @@ -532,13 +507,15 @@ def load_references(self, names: List[str] = None) -> List[Dict[str, Any]]: loaded = [] for ref in target_refs: abs_path = ref.path.resolve() - loaded.append({ - 'name': ref.name, - 'file': ref.to_dict(), - 'path': self._get_resource_path(ref.path), - 'abs_path': str(abs_path), - 'content': self._read_file_content(abs_path), - }) + loaded.append( + { + 'name': ref.name, + 'file': ref.to_dict(), + 'path': self._get_resource_path(ref.path), + 'abs_path': str(abs_path), + 'content': self._read_file_content(abs_path), + } + ) self.references.extend(loaded) return loaded @@ -552,23 +529,22 @@ def load_resources(self, names: List[str] = None) -> List[Dict[str, Any]]: Returns: List of loaded resource dictionaries with content. """ - target_res = [ - r for r in self.skill.resources - if r.name not in ['SKILL.md', 'LICENSE.txt'] - ] + target_res = [r for r in self.skill.resources if r.name not in ['SKILL.md', 'LICENSE.txt']] if names: target_res = [r for r in target_res if r.name in names] loaded = [] for res in target_res: abs_path = res.path.resolve() - loaded.append({ - 'name': res.name, - 'file': res.to_dict(), - 'path': self._get_resource_path(res.path), - 'abs_path': str(abs_path), - 'content': self._read_file_content(abs_path), - }) + loaded.append( + { + 'name': res.name, + 'file': res.to_dict(), + 'path': self._get_resource_path(res.path), + 'abs_path': str(abs_path), + 'content': self._read_file_content(abs_path), + } + ) self.resources.extend(loaded) return loaded diff --git a/ms_agent/skill/spec.py b/ms_agent/skill/spec.py index 3c666b8d4..676ed219c 100644 --- a/ms_agent/skill/spec.py +++ b/ms_agent/skill/spec.py @@ -18,7 +18,6 @@ class Spec: implementation: str = '' def __post_init__(self): - if not self.plan: self.plan = DEFAULT_PLAN @@ -41,20 +40,13 @@ def dump(self, output_dir: str) -> str: output_path: str = os.path.join(output_dir, '.spec') os.makedirs(output_path, exist_ok=True) - with open( - os.path.join(output_path, 'plan.md'), 'w', - encoding='utf-8') as f: + with open(os.path.join(output_path, 'plan.md'), 'w', encoding='utf-8') as f: f.write(self.plan) - with open( - os.path.join(output_path, 'tasks.md'), 'w', - encoding='utf-8') as f: + with open(os.path.join(output_path, 'tasks.md'), 'w', encoding='utf-8') as f: f.write(self.tasks) - with open( - os.path.join(output_path, 'implementation.md'), - 'w', - encoding='utf-8') as f: + with open(os.path.join(output_path, 'implementation.md'), 'w', encoding='utf-8') as f: f.write(self.implementation) return output_path diff --git a/ms_agent/tools/agent_tool.py b/ms_agent/tools/agent_tool.py index 02a57b8d5..1b4e2c38a 100644 --- a/ms_agent/tools/agent_tool.py +++ b/ms_agent/tools/agent_tool.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio +import json import multiprocessing as mp import os import threading @@ -11,16 +12,14 @@ from queue import Full as QueueFull from typing import Any, Callable, Dict, List, Optional, Union -import json +from omegaconf import DictConfig, ListConfig, OmegaConf + from ms_agent.agent.loader import AgentLoader from ms_agent.llm.utils import Message, Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger -from ms_agent.utils.stats import (append_stats, build_timing_record, - get_stats_path, monotonic, now_iso, - summarize_usage) +from ms_agent.utils.stats import append_stats, build_timing_record, get_stats_path, monotonic, now_iso, summarize_usage from ms_agent.utils.stream_writer import SubAgentStreamWriter -from omegaconf import DictConfig, ListConfig, OmegaConf logger = get_logger() @@ -76,8 +75,7 @@ def _message_from_data(data: Any) -> Message: def _build_sub_agent(spec: _AgentToolSpec, default_trust_remote_code: bool): if spec.inline_config is not None: container = _to_container(spec.inline_config) - base_override = OmegaConf.create(container) if isinstance( - container, dict) else OmegaConf.create({}) + base_override = OmegaConf.create(container) if isinstance(container, dict) else OmegaConf.create({}) else: base_override = OmegaConf.create({}) # Sub-agents default snapshots off in LLMAgent unless enable_snapshots is set @@ -142,18 +140,11 @@ async def _runner(): history = chunk if stream_events and event_queue is not None: serialized_chunk = { - 'kind': - 'messages', - 'messages': [ - _message_from_data(msg).to_dict() - for msg in (history or []) - ], + 'kind': 'messages', + 'messages': [_message_from_data(msg).to_dict() for msg in (history or [])], } try: - event_queue.put_nowait({ - 'type': 'chunk', - 'history': serialized_chunk - }) + event_queue.put_nowait({'type': 'chunk', 'history': serialized_chunk}) except QueueFull: # Avoid blocking sub-agent progress if UI/event consumer # is temporarily slower than chunk production. @@ -162,16 +153,11 @@ async def _runner(): result = history if isinstance(result, list): return { - 'kind': - 'messages', - 'messages': - [_message_from_data(msg).to_dict() for msg in result], - 'streamed_chunks': - chunk_count, - 'agent_tag': - getattr(sub_agent, 'tag', None), - 'agent_type': - getattr(sub_agent, 'AGENT_NAME', None), + 'kind': 'messages', + 'messages': [_message_from_data(msg).to_dict() for msg in result], + 'streamed_chunks': chunk_count, + 'agent_tag': getattr(sub_agent, 'tag', None), + 'agent_type': getattr(sub_agent, 'AGENT_NAME', None), } return { 'kind': 'raw', @@ -183,13 +169,15 @@ async def _runner(): result_queue.put({'ok': True, 'result': asyncio.run(_runner())}) except BaseException as exc: # pragma: no cover - result_queue.put({ - 'ok': False, - 'error': str(exc), - 'traceback': traceback.format_exc(), - 'agent_tag': getattr(sub_agent, 'tag', None), - 'agent_type': getattr(sub_agent, 'AGENT_NAME', None), - }) + result_queue.put( + { + 'ok': False, + 'error': str(exc), + 'traceback': traceback.format_exc(), + 'agent_tag': getattr(sub_agent, 'tag', None), + 'agent_type': getattr(sub_agent, 'AGENT_NAME', None), + } + ) class AgentTool(ToolBase): @@ -226,7 +214,8 @@ def enabled(self) -> bool: 'split a website generation task into sub tasks, ' 'you plan the framework, include code files and classes and functions, and give the detail ' 'information to the system and query field of the subtask, then ' - 'let each subtask to write a single file') + 'let each subtask to write a single file' + ) _SPLIT_TASK_PARAMETERS = { 'type': 'object', @@ -235,7 +224,8 @@ def enabled(self) -> bool: 'type': 'array', 'description': ( 'MANDATORY: Each element is a dict, which must contains two fields: ' - '`system`(str) and `query`(str) to start one sub task.'), + '`system`(str) and `query`(str) to start one sub task.' + ), }, 'execution_mode': { 'type': 'string', @@ -284,13 +274,10 @@ def _load_specs(self): self._build_server_index() return - if isinstance(agent_tools_cfg, DictConfig) and hasattr( - agent_tools_cfg, 'definitions'): + if isinstance(agent_tools_cfg, DictConfig) and hasattr(agent_tools_cfg, 'definitions'): definitions = agent_tools_cfg.definitions - server_name = getattr(agent_tools_cfg, 'server_name', - self.DEFAULT_SERVER) - self._enable_stats = bool( - getattr(agent_tools_cfg, 'enable_stats', False)) + server_name = getattr(agent_tools_cfg, 'server_name', self.DEFAULT_SERVER) + self._enable_stats = bool(getattr(agent_tools_cfg, 'enable_stats', False)) else: definitions = agent_tools_cfg server_name = self.DEFAULT_SERVER @@ -314,25 +301,22 @@ def _load_specs(self): continue if spec.tool_name in self._specs: logger.warning( - 'Duplicate agent tool name detected: %s, overriding previous definition.', - spec.tool_name) + 'Duplicate agent tool name detected: %s, overriding previous definition.', spec.tool_name + ) self._specs[spec.tool_name] = spec self._build_server_index() - def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], - default_server, idx: int) -> Optional[_AgentToolSpec]: + def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], default_server, idx: int) -> Optional[_AgentToolSpec]: cfg = cfg or {} cfg = cfg if isinstance(cfg, DictConfig) else DictConfig(cfg) - tool_name = getattr(cfg, 'tool_name', None) or getattr( - cfg, 'name', None) + tool_name = getattr(cfg, 'tool_name', None) or getattr(cfg, 'name', None) if not tool_name: - logger.warning( - 'agent_tools[%s] missing tool_name/name field, skip.', idx) + logger.warning('agent_tools[%s] missing tool_name/name field, skip.', idx) return None mode = getattr(cfg, 'mode', None) - is_dynamic = (mode == 'dynamic') + is_dynamic = mode == 'dynamic' agent_cfg = getattr(cfg, 'agent', None) config_path = getattr(cfg, 'config_path', None) @@ -340,17 +324,13 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], if agent_cfg is not None: config_path = getattr(agent_cfg, 'config_path', config_path) inline_cfg = getattr(agent_cfg, 'config', inline_cfg) - inline_cfg = _to_container( - inline_cfg) if inline_cfg is not None else None + inline_cfg = _to_container(inline_cfg) if inline_cfg is not None else None if not is_dynamic and not config_path and inline_cfg is None: - logger.warning( - 'agent_tools[%s] (%s) missing config_path/config definition.', - idx, tool_name) + logger.warning('agent_tools[%s] (%s) missing config_path/config definition.', idx, tool_name) return None - description = getattr(cfg, 'description', - f'Invoke agent "{tool_name}" as a tool.') + description = getattr(cfg, 'description', f'Invoke agent "{tool_name}" as a tool.') parameters = getattr(cfg, 'parameters', None) if parameters is None: if is_dynamic: @@ -360,10 +340,8 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], 'type': 'object', 'properties': { 'request': { - 'type': - 'string', - 'description': - f'Task description forwarded to the sub-agent {tool_name}.' + 'type': 'string', + 'description': f'Task description forwarded to the sub-agent {tool_name}.', }, }, 'required': ['request'], @@ -372,9 +350,7 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], else: parameters = _to_container(parameters) - tag_prefix = getattr( - cfg, 'tag_prefix', - f'{getattr(self.config, "tag", "agent")}-{tool_name}-') + tag_prefix = getattr(cfg, 'tag_prefix', f'{getattr(self.config, "tag", "agent")}-{tool_name}-') request_field = getattr(cfg, 'request_field', 'request') input_template = getattr(cfg, 'input_template', None) @@ -400,8 +376,7 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], if config_path and not os.path.isabs(config_path): base_dir = getattr(self.config, 'local_dir', None) if base_dir: - config_path = os.path.normpath( - os.path.join(base_dir, config_path)) + config_path = os.path.normpath(os.path.join(base_dir, config_path)) return _AgentToolSpec( tool_name=tool_name, @@ -436,7 +411,8 @@ def _build_server_index(self): server_name=spec.server_name, description=spec.description, parameters=spec.parameters, - )) + ) + ) self._server_tools = server_map async def connect(self): @@ -481,8 +457,7 @@ def _stream_file_enabled(self) -> bool: 2. ``config.agent_stream_file`` Defaults to ``False``. """ - agent_tools_cfg = getattr( - getattr(self.config, 'tools', None), 'agent_tools', None) + agent_tools_cfg = getattr(getattr(self.config, 'tools', None), 'agent_tools', None) if agent_tools_cfg is not None: val = getattr(agent_tools_cfg, 'enable_stream_file', None) if val is not None: @@ -495,8 +470,7 @@ def _stream_file_dir(self) -> str: Checks ``config.tools.agent_tools.stream_file_dir`` first, then falls back to ``config.output_dir``. """ - agent_tools_cfg = getattr( - getattr(self.config, 'tools', None), 'agent_tools', None) + agent_tools_cfg = getattr(getattr(self.config, 'tools', None), 'agent_tools', None) if agent_tools_cfg is not None: override = getattr(agent_tools_cfg, 'stream_file_dir', None) if override: @@ -509,23 +483,19 @@ def _stream_include_in_result(self) -> bool: Controlled by ``config.tools.agent_tools.stream_include_in_result`` (defaults to ``True`` when stream files are enabled). """ - agent_tools_cfg = getattr( - getattr(self.config, 'tools', None), 'agent_tools', None) + agent_tools_cfg = getattr(getattr(self.config, 'tools', None), 'agent_tools', None) if agent_tools_cfg is not None: val = getattr(agent_tools_cfg, 'stream_include_in_result', None) if val is not None: return bool(val) return True - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: if tool_name not in self._specs: raise ValueError(f'Agent tool "{tool_name}" not registered.') spec = self._specs[tool_name] if spec.server_name != server_name: - raise ValueError( - f'Agent tool "{tool_name}" is not part of server "{server_name}".' - ) + raise ValueError(f'Agent tool "{tool_name}" is not part of server "{server_name}".') call_id = None if isinstance(tool_args, dict) and '__call_id' in tool_args: @@ -545,8 +515,7 @@ async def call_tool(self, server_name: str, *, tool_name: str, use_subprocess = spec.run_in_thread and spec.run_in_process if use_subprocess: - messages = await self._run_agent( - None, payload, spec, call_id=effective_call_id) + messages = await self._run_agent(None, payload, spec, call_id=effective_call_id) result_str = self._format_output(messages, spec) return self._maybe_append_stream_path(result_str, effective_call_id) @@ -582,8 +551,7 @@ def _maybe_append_stream_path(self, result_str: str, effective_call_id: str) -> def _build_agent(self, spec: _AgentToolSpec): return _build_sub_agent(spec, self._trust_remote_code) - async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, - call_id: Optional[str]) -> Any: + async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, call_id: Optional[str]) -> Any: """Run sub-agent inline (pure async/await). If spec.sync_timeout_s is set, the call auto-escapes to background after @@ -596,17 +564,14 @@ async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, escape_event = asyncio.Event() effective_call_id = call_id or uuid.uuid4().hex[:12] - run_task = asyncio.create_task( - self._run_agent(None, payload, spec, call_id=effective_call_id)) + run_task = asyncio.create_task(self._run_agent(None, payload, spec, call_id=effective_call_id)) - self._active_sync_tasks[effective_call_id] = (run_task, spec, payload, - escape_event) + self._active_sync_tasks[effective_call_id] = (run_task, spec, payload, escape_event) try: if spec.sync_timeout_s and spec.sync_timeout_s > 0: escape_wait_task = asyncio.create_task(escape_event.wait()) - sleep_task = asyncio.create_task( - asyncio.sleep(spec.sync_timeout_s)) + sleep_task = asyncio.create_task(asyncio.sleep(spec.sync_timeout_s)) _, pending = await asyncio.wait( [run_task, escape_wait_task, sleep_task], return_when=asyncio.FIRST_COMPLETED, @@ -614,8 +579,7 @@ async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, for t in pending: t.cancel() if not run_task.done(): - return await self._escape_running_task( - effective_call_id, run_task, spec, payload) + return await self._escape_running_task(effective_call_id, run_task, spec, payload) else: # No timeout: wait for completion or explicit escape signal. escape_task = asyncio.create_task(escape_event.wait()) @@ -626,22 +590,20 @@ async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, for t in pending: t.cancel() if not run_task.done(): - return await self._escape_running_task( - effective_call_id, run_task, spec, payload) + return await self._escape_running_task(effective_call_id, run_task, spec, payload) return run_task.result() finally: self._active_sync_tasks.pop(effective_call_id, None) - async def _escape_running_task(self, call_id: str, - run_task: 'asyncio.Task[Any]', - spec: _AgentToolSpec, - payload: Any) -> str: + async def _escape_running_task( + self, call_id: str, run_task: 'asyncio.Task[Any]', spec: _AgentToolSpec, payload: Any + ) -> str: """Cancel the in-progress sync task and re-launch it as a background subprocess.""" if self._task_manager is None: raise RuntimeError( - f'AgentTool "{spec.tool_name}" tried to escape to background but ' - 'no TaskManager is attached.') + f'AgentTool "{spec.tool_name}" tried to escape to background but no TaskManager is attached.' + ) run_task.cancel() try: @@ -672,22 +634,21 @@ def escape_to_background(self, call_id: str) -> bool: escape_event.set() return True - async def _launch_background(self, payload: Any, spec: _AgentToolSpec, - call_id: Optional[str]) -> str: + async def _launch_background(self, payload: Any, spec: _AgentToolSpec, call_id: Optional[str]) -> str: """Fire-and-forget: start subprocess, register with TaskManager, return immediately.""" if self._task_manager is None: raise RuntimeError( f'AgentTool "{spec.tool_name}" has run_in_background=true but ' 'no TaskManager is attached. Ensure LLMAgent wires task_manager ' - 'into AgentTool via set_task_manager().') + 'into AgentTool via set_task_manager().' + ) ctx = mp.get_context('spawn') result_queue = ctx.Queue(maxsize=1) process_payload = self._serialize_payload_for_process(payload) proc = ctx.Process( target=_run_agent_in_subprocess, - args=(spec, self._trust_remote_code, process_payload, False, None, - result_queue), + args=(spec, self._trust_remote_code, process_payload, False, None, result_queue), name=f'agent_tool_bg_{spec.tool_name}', ) proc.start() @@ -728,11 +689,14 @@ async def _watcher(): self._watcher_tasks.add(t) t.add_done_callback(self._watcher_tasks.discard) - return json.dumps({ - 'status': 'async_launched', - 'task_id': task_id, - 'tool_name': spec.tool_name, - }, ensure_ascii=False) + return json.dumps( + { + 'status': 'async_launched', + 'task_id': task_id, + 'tool_name': spec.tool_name, + }, + ensure_ascii=False, + ) async def _call_dynamic(self, tool_args: dict, spec: '_AgentToolSpec') -> str: tasks = tool_args.get('tasks', []) @@ -799,7 +763,7 @@ async def _run_one(i: int, task: dict) -> str: formatted = '' for i, content in enumerate(res_list): if len(content) > spec.max_subtask_output_chars: - content = content[:spec.max_subtask_output_chars] + content = content[: spec.max_subtask_output_chars] formatted += f'SubTask{i}:{content}\n' return formatted @@ -853,11 +817,9 @@ def _terminate_all_active_processes(self, *, reason: str) -> None: for _, proc in active: self._terminate_process(proc, reason=reason) - async def _wait_process_result(self, - proc: mp.Process, - result_queue: Any, - on_poll: Optional[Callable[[], - None]] = None): + async def _wait_process_result( + self, proc: mp.Process, result_queue: Any, on_poll: Optional[Callable[[], None]] = None + ): exited_at = None while True: if on_poll is not None: @@ -872,16 +834,13 @@ async def _wait_process_result(self, if not proc.is_alive(): if exited_at is None: exited_at = monotonic() - elif (monotonic() - - exited_at) >= self._PROCESS_EXIT_RESULT_GRACE_S: + elif (monotonic() - exited_at) >= self._PROCESS_EXIT_RESULT_GRACE_S: return None await asyncio.sleep(self._PROCESS_POLL_INTERVAL_S) @staticmethod - def _drain_process_event_queue( - event_queue: Any, on_event: Callable[[Dict[str, Any]], - None]) -> None: + def _drain_process_event_queue(event_queue: Any, on_event: Callable[[Dict[str, Any]], None]) -> None: if event_queue is None: return while True: @@ -905,11 +864,7 @@ def _restore_process_result(result_payload: Dict[str, Any]) -> Any: return [_message_from_data(msg) for msg in messages] return result_payload.get('raw', '') - async def _run_agent(self, - agent, - payload, - spec: _AgentToolSpec, - call_id: Optional[str] = None): + async def _run_agent(self, agent, payload, spec: _AgentToolSpec, call_id: Optional[str] = None): runtime_agent = agent runtime_agent_tag = getattr(runtime_agent, 'tag', None) runtime_agent_type = getattr(runtime_agent, 'AGENT_NAME', None) @@ -925,7 +880,9 @@ async def _run_agent(self, ) logger.info( '[stream] %s (call_id=%s) streaming to %s', - spec.tool_name, _effective_call_id, _writer.stream_path, + spec.tool_name, + _effective_call_id, + _writer.stream_path, ) # ─────────────────────────────────────────────────────────────────── @@ -941,53 +898,67 @@ async def _run_and_collect(): result = await runtime_agent.run(payload) if hasattr(result, '__aiter__'): history = None - self._emit_chunk_event('start', { - 'call_id': call_id, - 'tool_name': spec.tool_name, - }) + self._emit_chunk_event( + 'start', + { + 'call_id': call_id, + 'tool_name': spec.tool_name, + }, + ) if _writer is not None: _writer.on_start(runtime_agent_tag) async for chunk in result: history = chunk self._emit_chunk_event( - 'chunk', { + 'chunk', + { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': chunk, - }) + }, + ) if _writer is not None: _writer.on_chunk(chunk) if history is not None: self._emit_chunk_event( - 'end', { + 'end', + { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': history, - }) + }, + ) if _writer is not None: _writer.on_end(history) result = history else: - self._emit_chunk_event('start', { - 'call_id': call_id, - 'tool_name': spec.tool_name, - }) + self._emit_chunk_event( + 'start', + { + 'call_id': call_id, + 'tool_name': spec.tool_name, + }, + ) if _writer is not None: _writer.on_start(runtime_agent_tag) self._emit_chunk_event( - 'chunk', { + 'chunk', + { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': result, - }) + }, + ) if _writer is not None: _writer.on_chunk(result) self._emit_chunk_event( - 'end', { + 'end', + { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': result, - }) + }, + ) if _writer is not None: _writer.on_end(result) return result @@ -1010,48 +981,45 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: history = self._restore_process_result(history_payload) if self._chunk_cb: self._emit_chunk_event( - 'chunk', { + 'chunk', + { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': history, - }) + }, + ) if _writer is not None: _writer.on_chunk(history) try: if self._chunk_cb: - self._emit_chunk_event('start', { - 'call_id': call_id, - 'tool_name': spec.tool_name, - }) + self._emit_chunk_event( + 'start', + { + 'call_id': call_id, + 'tool_name': spec.tool_name, + }, + ) if _writer is not None: # agent_tag unknown until subprocess completes; pass None _writer.on_start(None) process_payload = self._serialize_payload_for_process(payload) proc = ctx.Process( target=_run_agent_in_subprocess, - args=(spec, self._trust_remote_code, process_payload, - need_events, event_queue, result_queue), + args=(spec, self._trust_remote_code, process_payload, need_events, event_queue, result_queue), name=f'agent_tool_{spec.tool_name}', ) proc.start() self._register_process(run_id, proc) result = await self._wait_process_result( - proc, - result_queue, - on_poll=lambda: self._drain_process_event_queue( - event_queue, _emit_stream_event)) + proc, result_queue, on_poll=lambda: self._drain_process_event_queue(event_queue, _emit_stream_event) + ) if result is None: - raise RuntimeError( - f'AgentTool subprocess exited without result: {spec.tool_name}' - ) - self._drain_process_event_queue(event_queue, - _emit_stream_event) + raise RuntimeError(f'AgentTool subprocess exited without result: {spec.tool_name}') + self._drain_process_event_queue(event_queue, _emit_stream_event) if not result.get('ok'): - runtime_agent_tag = result.get( - 'agent_tag') or runtime_agent_tag - runtime_agent_type = result.get( - 'agent_type') or runtime_agent_type + runtime_agent_tag = result.get('agent_tag') or runtime_agent_tag + runtime_agent_type = result.get('agent_type') or runtime_agent_type tb = result.get('traceback', '') if tb: logger.warning(tb) @@ -1060,27 +1028,28 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: _writer.on_error(err_msg) raise RuntimeError(err_msg) result_payload = result.get('result', {}) or {} - runtime_agent_tag = result_payload.get( - 'agent_tag') or runtime_agent_tag - runtime_agent_type = result_payload.get( - 'agent_type') or runtime_agent_type + runtime_agent_tag = result_payload.get('agent_tag') or runtime_agent_tag + runtime_agent_type = result_payload.get('agent_type') or runtime_agent_type restored = self._restore_process_result(result_payload) - streamed_chunks = int( - result_payload.get('streamed_chunks', 0) or 0) + streamed_chunks = int(result_payload.get('streamed_chunks', 0) or 0) if self._chunk_cb: if streamed_chunks <= 0: self._emit_chunk_event( - 'chunk', { + 'chunk', + { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': restored, - }) + }, + ) self._emit_chunk_event( - 'end', { + 'end', + { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': restored, - }) + }, + ) # Always finalise the writer regardless of _chunk_cb. if _writer is not None: _writer.on_end(restored) @@ -1103,8 +1072,7 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: except Exception: pass if proc.is_alive(): - self._terminate_process( - proc, reason='did not exit after result handling') + self._terminate_process(proc, reason='did not exit after result handling') try: result_queue.close() result_queue.join_thread() @@ -1145,8 +1113,7 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: self._stream_paths[store_key] = _writer.stream_path return result except BaseException as exc: - status = 'cancelled' if isinstance( - exc, asyncio.CancelledError) else 'error' + status = 'cancelled' if isinstance(exc, asyncio.CancelledError) else 'error' raise finally: end_ts = now_iso() @@ -1170,9 +1137,7 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: try: await append_stats(get_stats_path(self.config), record) except Exception as exc: - logger.warning( - f'Failed to write agent tool stats for {spec.tool_name}: {exc}' - ) + logger.warning(f'Failed to write agent tool stats for {spec.tool_name}: {exc}') def _save_transcript(self, messages: Any, agent_tag: Optional[str]) -> None: if not isinstance(messages, list) or not agent_tag: @@ -1194,9 +1159,7 @@ def _build_payload(self, tool_args: dict, spec: _AgentToolSpec): field = spec.request_field or 'messages' raw_messages = tool_args.get(field) if not isinstance(raw_messages, list): - raise ValueError( - f'Agent tool "{spec.tool_name}" expects "{field}" to be a list of messages.' - ) + raise ValueError(f'Agent tool "{spec.tool_name}" expects "{field}" to be a list of messages.') return [ Message( role=msg.get('role', 'user'), @@ -1205,7 +1168,8 @@ def _build_payload(self, tool_args: dict, spec: _AgentToolSpec): tool_call_id=msg.get('tool_call_id'), name=msg.get('name'), reasoning_content=msg.get('reasoning_content', ''), - ) for msg in raw_messages # TODO: Change role to user or not + ) + for msg in raw_messages # TODO: Change role to user or not ] if spec.input_template: @@ -1215,7 +1179,9 @@ def _build_payload(self, tool_args: dict, spec: _AgentToolSpec): except Exception as exc: logger.warning( 'Failed to render input template for tool %s: %s. Falling back to JSON payload.', - spec.tool_name, exc) + spec.tool_name, + exc, + ) field = spec.request_field or 'request' if field in tool_args and isinstance(tool_args[field], str): @@ -1229,31 +1195,25 @@ def _format_output(self, messages: Any, spec: _AgentToolSpec) -> str: if spec.output_mode == 'history': serialized = [self._serialize_message(msg) for msg in messages] - return self._truncate( - json.dumps(serialized, ensure_ascii=False, indent=2), - spec.max_output_chars) + return self._truncate(json.dumps(serialized, ensure_ascii=False, indent=2), spec.max_output_chars) if spec.output_mode == 'raw_json': serialized = [msg.to_dict() for msg in messages] # type: ignore - return self._truncate( - json.dumps(serialized, ensure_ascii=False), - spec.max_output_chars) + return self._truncate(json.dumps(serialized, ensure_ascii=False), spec.max_output_chars) # Default: return final assistant message text for msg in reversed(messages): if getattr(msg, 'role', '') == 'assistant': return self._truncate(msg.content or '', spec.max_output_chars) - return self._truncate(messages[-1].content or '', - spec.max_output_chars) + return self._truncate(messages[-1].content or '', spec.max_output_chars) def _serialize_message(self, message: Message) -> Dict[str, Any]: data = message.to_dict() if data.get('tool_calls'): for call in data['tool_calls']: if isinstance(call.get('arguments'), dict): - call['arguments'] = json.dumps( - call['arguments'], ensure_ascii=False) + call['arguments'] = json.dumps(call['arguments'], ensure_ascii=False) return data @staticmethod diff --git a/ms_agent/tools/audio_generator/audio_gen.py b/ms_agent/tools/audio_generator/audio_gen.py index 2b533c08b..e9897d0e7 100644 --- a/ms_agent/tools/audio_generator/audio_gen.py +++ b/ms_agent/tools/audio_generator/audio_gen.py @@ -6,15 +6,14 @@ class AudioGenerator(ToolBase): - def __init__(self, config): super().__init__(config) - self.temp_dir = os.path.join(self.output_dir, '.temp', - 'audio_generator') + self.temp_dir = os.path.join(self.output_dir, '.temp', 'audio_generator') os.makedirs(self.temp_dir, exist_ok=True) audio_generator = self.config.audio_generator if audio_generator.type == 'edge_tts': from .edge_tts import EdgeTTSGenerator + self.generator = EdgeTTSGenerator(self.config, self.temp_dir) else: raise NotImplementedError() @@ -28,25 +27,21 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='generate_audio', server_name='audio_generator', - description= - 'Generate audio with a prompt, and return the audio file path.', + description='Generate audio with a prompt, and return the audio file path.', parameters={ 'type': 'object', 'properties': { - 'text': { - 'type': 'string', - 'description': 'The text to generate speech' - }, + 'text': {'type': 'string', 'description': 'The text to generate speech'}, }, 'required': ['text'], - 'additionalProperties': False - }) + 'additionalProperties': False, + }, + ) ] } async def generate_audio(self, text, **kwargs): return await self.generator.generate_audio(text, **kwargs) - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await self.generate_audio(**tool_args) diff --git a/ms_agent/tools/audio_generator/edge_tts.py b/ms_agent/tools/audio_generator/edge_tts.py index 7efa5ff77..3e05a7e40 100644 --- a/ms_agent/tools/audio_generator/edge_tts.py +++ b/ms_agent/tools/audio_generator/edge_tts.py @@ -3,7 +3,6 @@ class EdgeTTSGenerator: - def __init__(self, config, temp_dir): self.config = config self.temp_dir = temp_dir @@ -15,25 +14,19 @@ async def generate_audio(self, text, **kwargs): return output_file @staticmethod - async def edge_tts_generate(text, - output_file, - speaker='zh-CN-YunjianNeural', - rate='+0%', - pitch='+0Hz'): + async def edge_tts_generate(text, output_file, speaker='zh-CN-YunjianNeural', rate='+0%', pitch='+0Hz'): import edge_tts + output_dir = os.path.dirname(output_file) or '.' os.makedirs(output_dir, exist_ok=True) text = text.replace('[', '').replace(']', '') - communicate = edge_tts.Communicate( - text=text, voice=speaker, rate=rate, pitch=pitch) + communicate = edge_tts.Communicate(text=text, voice=speaker, rate=rate, pitch=pitch) audio_data = b'' async for chunk in communicate.stream(): if chunk['type'] == 'audio': audio_data += chunk['data'] - assert len( - audio_data - ) > 0, 'Audio generation failed: no data received from edge_tts.' + assert len(audio_data) > 0, 'Audio generation failed: no data received from edge_tts.' with open(output_file, 'wb') as f: f.write(audio_data) diff --git a/ms_agent/tools/base.py b/ms_agent/tools/base.py index 12ece9948..ad867522e 100644 --- a/ms_agent/tools/base.py +++ b/ms_agent/tools/base.py @@ -2,9 +2,10 @@ from abc import abstractmethod from typing import Any, Dict -from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR from omegaconf import DictConfig +from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR + class ToolBase: """The base class for all tools. @@ -16,17 +17,16 @@ def __init__(self, config): self.config = config self.exclude_functions = [] self.include_functions = [] - self.output_dir = getattr(self.config, 'output_dir', - DEFAULT_OUTPUT_DIR) + self.output_dir = getattr(self.config, 'output_dir', DEFAULT_OUTPUT_DIR) def exclude_func(self, tool_config: DictConfig): if tool_config is not None: self.exclude_functions = getattr(tool_config, 'exclude', []) self.include_functions = getattr(tool_config, 'include', []) - assert (not self.exclude_functions) or ( - not self.include_functions - ), 'Set either `include` or `exclude` in tools config.' + assert (not self.exclude_functions) or (not self.include_functions), ( + 'Set either `include` or `exclude` in tools config.' + ) @abstractmethod async def connect(self) -> None: @@ -76,8 +76,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: pass @abstractmethod - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: """Call a tool. Args: diff --git a/ms_agent/tools/code/code_executor.py b/ms_agent/tools/code/code_executor.py index 1df89e39c..6cce39b3a 100644 --- a/ms_agent/tools/code/code_executor.py +++ b/ms_agent/tools/code/code_executor.py @@ -1,17 +1,18 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio +import json import socket from pathlib import Path from typing import Any, Dict, Optional, Union -import json +from omegaconf import DictConfig + from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.tools.code.sandbox_manager import SandboxManagerFactory from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR from ms_agent.utils.utils import install_package -from omegaconf import DictConfig logger = get_logger() @@ -40,9 +41,7 @@ def check_port_available(port: int, host: str = '127.0.0.1') -> bool: if e.errno == 98 or e.errno == 48: # Address already in use return False # Port is occupied except Exception as e: - logger.warning( - f'Bind test failed for port {port}, falling back to connection test: {e}' - ) + logger.warning(f'Bind test failed for port {port}, falling back to connection test: {e}') # Second try: connection test (fallback method) try: @@ -55,9 +54,7 @@ def check_port_available(port: int, host: str = '127.0.0.1') -> bool: return False # Be conservative: assume occupied if we can't check reliably -def find_available_port(start_port: int = 8888, - max_attempts: int = 100, - host: str = '127.0.0.1') -> Optional[int]: +def find_available_port(start_port: int = 8888, max_attempts: int = 100, host: str = '127.0.0.1') -> Optional[int]: """ Find an available port starting from start_port. @@ -74,9 +71,7 @@ def find_available_port(start_port: int = 8888, logger.info(f'Found available port: {port}') return port - logger.error( - f'Could not find available port in range {start_port}-{start_port + max_attempts - 1}' - ) + logger.error(f'Could not find available port in range {start_port}-{start_port + max_attempts - 1}') return None @@ -94,8 +89,7 @@ class CodeExecutionTool(ToolBase): def __init__(self, config): logger.info('Installing ms-enclave package...') try: - install_package( - package_name='ms-enclave', import_name='ms_enclave') + install_package(package_name='ms-enclave', import_name='ms_enclave') except Exception as e: raise e @@ -113,19 +107,15 @@ def __init__(self, config): logger.info('CodeExecutionTool initialized (ms-enclave based)') - def _build_sandbox_config( - self, - config) -> Union['DockerNotebookConfig', 'DockerSandboxConfig']: + def _build_sandbox_config(self, config) -> Union['DockerNotebookConfig', 'DockerSandboxConfig']: """Build sandbox configuration from agent config""" from ms_enclave.sandbox.model import DockerNotebookConfig, DockerSandboxConfig, SandboxType # Get sandbox-specific config or use defaults - if isinstance(config, DictConfig) and hasattr( - config, 'tools') and hasattr(config.tools, 'code_executor'): + if isinstance(config, DictConfig) and hasattr(config, 'tools') and hasattr(config.tools, 'code_executor'): sandbox_cfg = getattr(config.tools.code_executor, 'sandbox', {}) else: - sandbox_cfg = getattr(config, 'sandbox', {}) or getattr( - config, 'tools', {}).get('sandbox', {}) + sandbox_cfg = getattr(config, 'sandbox', {}) or getattr(config, 'tools', {}).get('sandbox', {}) # Get output directory for data mounting output_dir = Path(getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR)) @@ -145,51 +135,39 @@ def _build_sandbox_config( 'env_vars': env_vars, } if hasattr(sandbox_cfg, '__getitem__'): - self.sandbox_type = sandbox_cfg.get('type', - SandboxType.DOCKER_NOTEBOOK) - - config_dict.update({ - 'image': - sandbox_cfg.get('image', 'jupyter-kernel-gateway'), - 'command': - sandbox_cfg.get('command', None), - 'ports': - sandbox_cfg.get('ports', {}), - 'network': - sandbox_cfg.get('network', 'bridge'), - 'memory_limit': - sandbox_cfg.get('memory_limit', '2g'), - 'cpu_limit': - sandbox_cfg.get('cpu_limit', 2.0), - 'network_enabled': - sandbox_cfg.get('network_enabled', True), - 'privileged': - sandbox_cfg.get('privileged', False), - 'remove_on_exit': - sandbox_cfg.get('remove_on_exit', True), - 'timeout': - sandbox_cfg.get('timeout', 30), - 'tools_config': - sandbox_cfg.get('tools_config', {}), - 'working_dir': - sandbox_cfg.get('working_dir', '/workspace'), - 'resource_limits': - sandbox_cfg.get('resource_limits', {}), - }) + self.sandbox_type = sandbox_cfg.get('type', SandboxType.DOCKER_NOTEBOOK) + + config_dict.update( + { + 'image': sandbox_cfg.get('image', 'jupyter-kernel-gateway'), + 'command': sandbox_cfg.get('command', None), + 'ports': sandbox_cfg.get('ports', {}), + 'network': sandbox_cfg.get('network', 'bridge'), + 'memory_limit': sandbox_cfg.get('memory_limit', '2g'), + 'cpu_limit': sandbox_cfg.get('cpu_limit', 2.0), + 'network_enabled': sandbox_cfg.get('network_enabled', True), + 'privileged': sandbox_cfg.get('privileged', False), + 'remove_on_exit': sandbox_cfg.get('remove_on_exit', True), + 'timeout': sandbox_cfg.get('timeout', 30), + 'tools_config': sandbox_cfg.get('tools_config', {}), + 'working_dir': sandbox_cfg.get('working_dir', '/workspace'), + 'resource_limits': sandbox_cfg.get('resource_limits', {}), + } + ) if self.sandbox_type == SandboxType.DOCKER_NOTEBOOK: - config_dict.update({ - 'host': sandbox_cfg.get('host', '127.0.0.1'), - 'port': sandbox_cfg.get('port', 8888), - 'token': sandbox_cfg.get('token', None), - }) + config_dict.update( + { + 'host': sandbox_cfg.get('host', '127.0.0.1'), + 'port': sandbox_cfg.get('port', 8888), + 'token': sandbox_cfg.get('token', None), + } + ) # Store original port for retry logic self._original_port = config_dict['port'] - self._port_retry_enabled = sandbox_cfg.get( - 'port_retry_enabled', True) - self._max_port_retries = sandbox_cfg.get( - 'max_port_retries', 10) + self._port_retry_enabled = sandbox_cfg.get('port_retry_enabled', True) + self._max_port_retries = sandbox_cfg.get('max_port_retries', 10) logger.info(f'Sandbox config: type={self.sandbox_type}') @@ -212,8 +190,7 @@ async def connect(self) -> None: logger.info('Initializing sandbox manager...') # Create manager using factory - self.manager = await SandboxManagerFactory.create_manager( - self.config) + self.manager = await SandboxManagerFactory.create_manager(self.config) await self.manager.start() logger.info('Creating sandbox instance...') @@ -226,8 +203,8 @@ async def connect(self) -> None: while retry_count < max_retries: try: self.sandbox_id = await self.manager.create_sandbox( - sandbox_type=self.sandbox_type, - config=self.sandbox_config) + sandbox_type=self.sandbox_type, config=self.sandbox_config + ) logger.info(f'Sandbox created: {self.sandbox_id}') @@ -243,68 +220,55 @@ async def connect(self) -> None: last_error = e # Check if it's a port conflict error - is_port_conflict = any(keyword in error_msg - for keyword in [ - 'address already in use', - 'port is already allocated', - 'bind: address already in use', - 'port already in use' - ]) - - if is_port_conflict and self._port_retry_enabled and retry_count < ( - max_retries - 1): + is_port_conflict = any( + keyword in error_msg + for keyword in [ + 'address already in use', + 'port is already allocated', + 'bind: address already in use', + 'port already in use', + ] + ) + + if is_port_conflict and self._port_retry_enabled and retry_count < (max_retries - 1): retry_count += 1 - logger.warning( - f'Port conflict detected (attempt {retry_count}/{max_retries}): {e}' - ) + logger.warning(f'Port conflict detected (attempt {retry_count}/{max_retries}): {e}') # Try to find a new available port if self.sandbox_type == SandboxType.DOCKER_NOTEBOOK: new_port = find_available_port( - start_port=self.sandbox_config.port + 1, - max_attempts=100, - host=self.sandbox_config.host) + start_port=self.sandbox_config.port + 1, max_attempts=100, host=self.sandbox_config.host + ) if new_port: - logger.info( - f'Retrying with new port: {new_port} (was {self.sandbox_config.port})' - ) + logger.info(f'Retrying with new port: {new_port} (was {self.sandbox_config.port})') # Update the config with new port self.sandbox_config.port = new_port # Clean up failed sandbox if it was created if self.sandbox_id: try: - await self.manager.delete_sandbox( - self.sandbox_id) + await self.manager.delete_sandbox(self.sandbox_id) self.sandbox_id = None except Exception as cleanup_error: - logger.warning( - f'Failed to cleanup sandbox: {cleanup_error}' - ) + logger.warning(f'Failed to cleanup sandbox: {cleanup_error}') # Wait a bit before retry await asyncio.sleep(1) continue else: - logger.error( - 'Could not find available port for retry') - raise RuntimeError( - f'Port conflict and no available ports found: {e}' - ) from e + logger.error('Could not find available port for retry') + raise RuntimeError(f'Port conflict and no available ports found: {e}') from e else: # For non-notebook sandbox, just retry - logger.info( - f'Retrying sandbox creation (attempt {retry_count}/{max_retries})...' - ) + logger.info(f'Retrying sandbox creation (attempt {retry_count}/{max_retries})...') await asyncio.sleep(1) continue else: # Not a port conflict or retries exhausted raise - logger.error( - f'Failed to create sandbox after {max_retries} attempts') + logger.error(f'Failed to create sandbox after {max_retries} attempts') raise RuntimeError( f'Sandbox initialization failed after {max_retries} attempts: {last_error}' ) from last_error @@ -338,166 +302,130 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='notebook_executor', server_name='code_executor', - description= - ('Execute Python code in an isolated Docker sandbox with state ' - 'persistence in a Jupyter kernel environment. Variables, imports, and ' - 'data are preserved across multiple calls within the same session. ' - 'Supports pandas, numpy, matplotlib, seaborn for data analysis. ' - 'Data files in the output directory are accessible at /data/ path. ' - 'Use print() to output results.'), + description=( + 'Execute Python code in an isolated Docker sandbox with state ' + 'persistence in a Jupyter kernel environment. Variables, imports, and ' + 'data are preserved across multiple calls within the same session. ' + 'Supports pandas, numpy, matplotlib, seaborn for data analysis. ' + 'Data files in the output directory are accessible at /data/ path. ' + 'Use print() to output results.' + ), parameters={ 'type': 'object', 'properties': { 'code': { - 'type': - 'string', - 'description': - ('Python code to execute. Can access previously defined variables. ' - 'Data files are at /data/ (e.g., pd.read_csv(\'/data/file.csv\')). ' - 'Use print() for output.') - }, - 'description': { - 'type': - 'string', - 'description': - 'Brief description of what the code does' + 'type': 'string', + 'description': ( + 'Python code to execute. Can access previously defined variables. ' + 'Data files are at /data/ (e.g., pd.read_csv(\'/data/file.csv\')). ' + 'Use print() for output.' + ), }, + 'description': {'type': 'string', 'description': 'Brief description of what the code does'}, 'timeout': { 'type': 'integer', 'minimum': 1, 'maximum': 600, 'description': 'Execution timeout in seconds', - 'default': 60 - } + 'default': 60, + }, }, 'required': ['code'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='python_executor', server_name='code_executor', - description= - ('Execute Python code in an isolated environment. ' - 'Supports pandas, numpy, matplotlib, seaborn and other libraries you need for data analysis. ' - 'Data files in the output directory are accessible at /data/ path. ' - 'Use print() to output results.'), + description=( + 'Execute Python code in an isolated environment. ' + 'Supports pandas, numpy, matplotlib, seaborn and other libraries you need for data analysis. ' + 'Data files in the output directory are accessible at /data/ path. ' + 'Use print() to output results.' + ), parameters={ 'type': 'object', 'properties': { - 'code': { - 'type': 'string', - 'description': 'Python code to execute' - }, - 'description': { - 'type': - 'string', - 'description': - 'Brief description of what the code does' - }, + 'code': {'type': 'string', 'description': 'Python code to execute'}, + 'description': {'type': 'string', 'description': 'Brief description of what the code does'}, 'timeout': { 'type': 'integer', 'description': 'Execution timeout in seconds', - 'default': 30 - } + 'default': 30, + }, }, 'required': ['code'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='shell_executor', server_name='code_executor', - description= - ('Execute one shell command in an isolated environment. ' - 'Commands will be executed directly without shell parsing. ' - 'For shell syntax (cd, &&, ||, pipes, redirection), use explicit wrapper like sh -lc "...". ' - 'Supports basic operations like ls, mkdir, rm, mv, npm, pip, etc. ' - 'Data files in the output directory are accessible at /data/ path. ' - ), + description=( + 'Execute one shell command in an isolated environment. ' + 'Commands will be executed directly without shell parsing. ' + 'For shell syntax (cd, &&, ||, pipes, redirection), use explicit wrapper like sh -lc "...". ' + 'Supports basic operations like ls, mkdir, rm, mv, npm, pip, etc. ' + 'Data files in the output directory are accessible at /data/ path. ' + ), parameters={ 'type': 'object', 'properties': { - 'command': { - 'type': 'string', - 'description': 'Shell command to execute' - }, + 'command': {'type': 'string', 'description': 'Shell command to execute'}, 'timeout': { 'type': 'integer', 'description': 'Execution timeout in seconds', - 'default': 900 - } + 'default': 900, + }, }, 'required': ['command'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='file_operation', server_name='code_executor', - description= - 'Perform file operations like read, write, delete, and list files', + description='Perform file operations like read, write, delete, and list files', parameters={ 'type': 'object', 'properties': { 'operation': { - 'type': - 'string', - 'description': - 'Type of file operation to perform', - 'enum': [ - 'create', 'read', 'write', 'delete', - 'list', 'exists' - ] - }, - 'file_path': { 'type': 'string', - 'description': 'Path to the file or directory' + 'description': 'Type of file operation to perform', + 'enum': ['create', 'read', 'write', 'delete', 'list', 'exists'], }, + 'file_path': {'type': 'string', 'description': 'Path to the file or directory'}, 'content': { - 'type': - 'string', - 'description': - 'Content to write to file (only for write operation)' - }, - 'encoding': { 'type': 'string', - 'description': 'File encoding', - 'default': 'utf-8' - } + 'description': 'Content to write to file (only for write operation)', + }, + 'encoding': {'type': 'string', 'description': 'File encoding', 'default': 'utf-8'}, }, 'required': ['operation', 'file_path'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='reset_executor', server_name='code_executor', - description= - ('Reset the sandbox state by restarting the kernel. ' - 'All variables, imports, and session state will be cleared.' - ), - parameters={ - 'type': 'object', - 'properties': {}, - 'required': [], - 'additionalProperties': False - }, + description=( + 'Reset the sandbox state by restarting the kernel. ' + 'All variables, imports, and session state will be cleared.' + ), + parameters={'type': 'object', 'properties': {}, 'required': [], 'additionalProperties': False}, ), Tool( tool_name='get_executor_info', server_name='code_executor', description='Get current sandbox status and information', - parameters={ - 'type': 'object', - 'properties': {}, - 'required': [], - 'additionalProperties': False - }, - ) + parameters={'type': 'object', 'properties': {}, 'required': [], 'additionalProperties': False}, + ), ] } return tools - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: """Route tool calls to appropriate methods""" if not self._initialized: await self.connect() @@ -506,26 +434,12 @@ async def call_tool(self, server_name: str, *, tool_name: str, method = getattr(self, tool_name) return await method(**tool_args) except AttributeError: - return json.dumps( - { - 'success': False, - 'error': f'Unknown tool: {tool_name}' - }, - indent=2) + return json.dumps({'success': False, 'error': f'Unknown tool: {tool_name}'}, indent=2) except Exception as e: - logger.error( - f'Tool execution error ({tool_name}): {e}', exc_info=True) - return json.dumps( - { - 'success': False, - 'error': f'Tool execution error: {str(e)}' - }, - indent=2) + logger.error(f'Tool execution error ({tool_name}): {e}', exc_info=True) + return json.dumps({'success': False, 'error': f'Tool execution error: {str(e)}'}, indent=2) - async def notebook_executor(self, - code: str, - description: str = '', - timeout: Optional[int] = None) -> str: + async def notebook_executor(self, code: str, description: str = '', timeout: Optional[int] = None) -> str: """ Execute Python code in the sandbox using notebook_executor. @@ -546,10 +460,8 @@ async def notebook_executor(self, result = await self.manager.execute_tool( sandbox_id=self.sandbox_id, tool_name='notebook_executor', - parameters={ - 'code': code, - 'timeout': timeout or 60 - }) + parameters={'code': code, 'timeout': timeout or 60}, + ) success = result.status == ExecutionStatus.SUCCESS @@ -563,24 +475,16 @@ async def notebook_executor(self, 'success': success, 'description': description, 'output': result.output or '', - 'error': result.error if result.error else None + 'error': result.error if result.error else None, }, - indent=2) + indent=2, + ) except Exception as e: logger.error(f'Execute python failed: {e}', exc_info=True) - return json.dumps( - { - 'success': False, - 'description': description, - 'error': str(e) - }, - indent=2) + return json.dumps({'success': False, 'description': description, 'error': str(e)}, indent=2) - async def python_executor(self, - code: str, - description: str = '', - timeout: Optional[int] = None) -> str: + async def python_executor(self, code: str, description: str = '', timeout: Optional[int] = None) -> str: """ Execute Python code in the sandbox. @@ -601,10 +505,8 @@ async def python_executor(self, result = await self.manager.execute_tool( sandbox_id=self.sandbox_id, tool_name='python_executor', - parameters={ - 'code': code, - 'timeout': timeout or 60 - }) + parameters={'code': code, 'timeout': timeout or 60}, + ) success = result.status == ExecutionStatus.SUCCESS @@ -618,23 +520,16 @@ async def python_executor(self, 'success': success, 'description': description, 'output': result.output or '', - 'error': result.error if result.error else None + 'error': result.error if result.error else None, }, - indent=2) + indent=2, + ) except Exception as e: logger.error(f'Execute python failed: {e}', exc_info=True) - return json.dumps( - { - 'success': False, - 'description': description, - 'error': str(e) - }, - indent=2) + return json.dumps({'success': False, 'description': description, 'error': str(e)}, indent=2) - async def shell_executor(self, - command: str, - timeout: Optional[int] = None) -> str: + async def shell_executor(self, command: str, timeout: Optional[int] = None) -> str: """ Execute shell commands in the sandbox. @@ -650,23 +545,19 @@ async def shell_executor(self, try: logger.info(f'Executing command: {command[:50]}...') - shell_meta = ('&&', '||', '|', ';', '>', '<', '`', '$(', 'cd ', - 'export ') - already_wrapped = command.lstrip().startswith( - ('sh ', 'bash ', '/bin/sh ', '/bin/bash ')) - if not already_wrapped and any(meta in command - for meta in shell_meta): + shell_meta = ('&&', '||', '|', ';', '>', '<', '`', '$(', 'cd ', 'export ') + already_wrapped = command.lstrip().startswith(('sh ', 'bash ', '/bin/sh ', '/bin/bash ')) + if not already_wrapped and any(meta in command for meta in shell_meta): import shlex + command = f'sh -lc {shlex.quote(command)}' # Execute via shell_executor result = await self.manager.execute_tool( sandbox_id=self.sandbox_id, tool_name='shell_executor', - parameters={ - 'command': command, - 'timeout': timeout or 900 - }) + parameters={'command': command, 'timeout': timeout or 900}, + ) success = result.status == ExecutionStatus.SUCCESS if success: @@ -675,22 +566,17 @@ async def shell_executor(self, logger.warning(f'Command execution failed: {result.error}') return json.dumps( - { - 'success': success, - 'output': result.output or '', - 'error': result.error if result.error else None - }, - indent=2) + {'success': success, 'output': result.output or '', 'error': result.error if result.error else None}, + indent=2, + ) except Exception as e: logger.error(f'Execute shell failed: {e}', exc_info=True) return json.dumps({'success': False, 'error': str(e)}, indent=2) - async def file_operation(self, - operation: str, - file_path: str, - content: Optional[str] = None, - encoding: Optional[str] = 'utf-8') -> str: + async def file_operation( + self, operation: str, file_path: str, content: Optional[str] = None, encoding: Optional[str] = 'utf-8' + ) -> str: """ Perform file operations like read, write, delete, and list files in the sandbox. @@ -709,41 +595,29 @@ async def file_operation(self, result = await self.manager.execute_tool( sandbox_id=self.sandbox_id, tool_name='file_operation', - parameters={ - 'operation': operation, - 'file_path': file_path, - 'content': content, - 'encoding': encoding - }) + parameters={'operation': operation, 'file_path': file_path, 'content': content, 'encoding': encoding}, + ) success = result.status == ExecutionStatus.SUCCESS if success: - logger.info( - f'File operation {operation} successful for {file_path}') + logger.info(f'File operation {operation} successful for {file_path}') else: - logger.warning( - f'File operation {operation} failed for {file_path}: {result.error}' - ) + logger.warning(f'File operation {operation} failed for {file_path}: {result.error}') return json.dumps( { 'success': success, 'file_path': file_path, 'output': result.output if success else '', - 'error': result.error if result.error else None + 'error': result.error if result.error else None, }, - indent=2) + indent=2, + ) except Exception as e: logger.error(f'Read file failed: {e}', exc_info=True) - return json.dumps( - { - 'success': False, - 'file_path': file_path, - 'error': str(e) - }, - indent=2) + return json.dumps({'success': False, 'file_path': file_path, 'error': str(e)}, indent=2) async def reset_executor(self) -> str: """ @@ -763,8 +637,8 @@ async def reset_executor(self) -> str: # Create new sandbox self.sandbox_id = await self.manager.create_sandbox( - sandbox_type=SandboxType.DOCKER_NOTEBOOK, - config=self.sandbox_config) + sandbox_type=SandboxType.DOCKER_NOTEBOOK, config=self.sandbox_config + ) # Wait for it to be ready await self._wait_for_sandbox_ready() @@ -774,11 +648,11 @@ async def reset_executor(self) -> str: return json.dumps( { 'success': True, - 'message': - 'Sandbox reset successfully. All variables and state cleared.', - 'new_sandbox_id': self.sandbox_id + 'message': 'Sandbox reset successfully. All variables and state cleared.', + 'new_sandbox_id': self.sandbox_id, }, - indent=2) + indent=2, + ) except Exception as e: logger.error(f'Reset sandbox failed: {e}', exc_info=True) @@ -807,18 +681,14 @@ async def get_executor_info(self) -> str: 'config': { 'memory_limit': self.sandbox_config.memory_limit, 'cpu_limit': self.sandbox_config.cpu_limit, - 'timeout': self.sandbox_config.timeout - } + 'timeout': self.sandbox_config.timeout, + }, }, indent=2, - default=str) + default=str, + ) else: - return json.dumps( - { - 'success': False, - 'error': 'Sandbox info not available' - }, - indent=2) + return json.dumps({'success': False, 'error': 'Sandbox info not available'}, indent=2) except Exception as e: logger.error(f'Get sandbox info failed: {e}', exc_info=True) @@ -846,16 +716,12 @@ async def _wait_for_sandbox_ready(self, max_wait: int = 60) -> None: logger.info('Sandbox is running and ready') return elif info.status == SandboxStatus.ERROR: - error_msg = info.metadata.get( - 'error') or f'Unknown error: {info.metadata}' + error_msg = info.metadata.get('error') or f'Unknown error: {info.metadata}' raise RuntimeError(f'Sandbox failed to start: {error_msg}') if i % 5 == 0: - logger.debug( - f'Waiting for sandbox... ({i}/{max_wait}s, status={info.status.value})' - ) + logger.debug(f'Waiting for sandbox... ({i}/{max_wait}s, status={info.status.value})') await asyncio.sleep(1) - raise TimeoutError( - f'Sandbox failed to become ready within {max_wait} seconds') + raise TimeoutError(f'Sandbox failed to become ready within {max_wait} seconds') diff --git a/ms_agent/tools/code/local_code_executor.py b/ms_agent/tools/code/local_code_executor.py index d1e2104c5..3ceb6d1a4 100644 --- a/ms_agent/tools/code/local_code_executor.py +++ b/ms_agent/tools/code/local_code_executor.py @@ -2,6 +2,7 @@ import asyncio.subprocess as ai_subprocess import inspect import io +import json import os import shlex import shutil @@ -10,7 +11,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger @@ -41,11 +41,13 @@ def _coerce_str(value: Optional[bytes]) -> str: class LocalKernelSession: """Manage a local ipykernel instance for stateful notebook execution.""" - def __init__(self, - working_dir: Path, - env: Optional[Dict[str, str]] = None, - kernel_name: str = 'python3', - extra_arguments: Optional[List[str]] = None): + def __init__( + self, + working_dir: Path, + env: Optional[Dict[str, str]] = None, + kernel_name: str = 'python3', + extra_arguments: Optional[List[str]] = None, + ): self.working_dir = working_dir self.env = env or {} self.kernel_name = kernel_name @@ -67,9 +69,8 @@ async def start(self) -> None: logger.info('Starting local ipykernel session...') self._km = AsyncKernelManager( - kernel_name=self.kernel_name, - env=self.env, - cwd=str(self.working_dir)) # cwd may be ignored here + kernel_name=self.kernel_name, env=self.env, cwd=str(self.working_dir) + ) # cwd may be ignored here start_kernel_result = self._km.start_kernel( extra_arguments=self.extra_arguments, @@ -147,10 +148,8 @@ async def execute(self, code: str, timeout: int) -> Dict[str, Any]: if not self._client: raise RuntimeError('Kernel client not initialized') - execute_call = self._client.execute( - code=code, allow_stdin=False, stop_on_error=False) - msg_id = await execute_call if inspect.isawaitable( - execute_call) else execute_call + execute_call = self._client.execute(code=code, allow_stdin=False, stop_on_error=False) + msg_id = await execute_call if inspect.isawaitable(execute_call) else execute_call stdout_parts: List[str] = [] stderr_parts: List[str] = [] @@ -174,8 +173,7 @@ async def _drain() -> None: msg_type = msg['msg_type'] content = msg.get('content', {}) - if msg_type == 'status' and content.get( - 'execution_state') == 'idle': + if msg_type == 'status' and content.get('execution_state') == 'idle': break if msg_type == 'stream': name = content.get('name', 'stdout') @@ -191,8 +189,7 @@ async def _drain() -> None: elif 'text/html' in data: display_parts.append(data['text/html']) elif data: - display_parts.append( - json.dumps(data, ensure_ascii=False)) + display_parts.append(json.dumps(data, ensure_ascii=False)) elif msg_type == 'error': error_payload = { 'ename': content.get('ename'), @@ -209,23 +206,15 @@ async def _drain() -> None: except asyncio.TimeoutError as exc: logger.warning('Notebook execution timed out, interrupting kernel') await self.interrupt() - raise TimeoutError( - f'Notebook execution timed out after {timeout} seconds' - ) from exc + raise TimeoutError(f'Notebook execution timed out after {timeout} seconds') from exc self.execution_count += 1 stdout = ''.join(stdout_parts).strip('\n') stderr = ''.join(stderr_parts).strip('\n') displays = '\n'.join(display_parts).strip('\n') - output_segments = [ - segment for segment in [stdout, displays] if segment - ] - - return { - 'output': '\n'.join(output_segments), - 'stderr': stderr, - 'error': error_payload - } + output_segments = [segment for segment in [stdout, displays] if segment] + + return {'output': '\n'.join(output_segments), 'stderr': stderr, 'error': error_payload} class LocalCodeExecutionTool(ToolBase): @@ -233,23 +222,17 @@ class LocalCodeExecutionTool(ToolBase): def __init__(self, config): super().__init__(config) - self.output_dir = Path( - getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR)).expanduser().resolve() + self.output_dir = Path(getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR)).expanduser().resolve() self.output_dir.mkdir(parents=True, exist_ok=True) - self.tool_config = getattr( - getattr(config, 'tools', None), 'code_executor', None) - self._notebook_timeout = getattr(self.tool_config, 'notebook_timeout', - 60) if self.tool_config else 60 - self._python_timeout = getattr(self.tool_config, 'python_timeout', - 30) if self.tool_config else 30 - self._shell_timeout = getattr(self.tool_config, 'shell_timeout', - 60) if self.tool_config else 60 + self.tool_config = getattr(getattr(config, 'tools', None), 'code_executor', None) + self._notebook_timeout = getattr(self.tool_config, 'notebook_timeout', 60) if self.tool_config else 60 + self._python_timeout = getattr(self.tool_config, 'python_timeout', 30) if self.tool_config else 30 + self._shell_timeout = getattr(self.tool_config, 'shell_timeout', 60) if self.tool_config else 60 kernel_env = self._build_env('kernel_env', inherit=False) shell_env = self._build_env('shell_env', inherit=False) - self.kernel_session = LocalKernelSession( - working_dir=self.output_dir, env=kernel_env) + self.kernel_session = LocalKernelSession(working_dir=self.output_dir, env=kernel_env) self.shell_env = shell_env self._kernel_lock = asyncio.Lock() self._initialized = False @@ -265,12 +248,9 @@ def __init__(self, config): if dg: deny_globs = list(dg) shell_cfg = getattr(self.tool_config, 'shell', None) if self.tool_config else None - shell_mode = getattr(shell_cfg, 'default_mode', - 'workspace_write') if shell_cfg else 'workspace_write' - net = bool(getattr(shell_cfg, 'network_enabled', False) - ) if shell_cfg else False - max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192) - ) if shell_cfg else 8192 + shell_mode = getattr(shell_cfg, 'default_mode', 'workspace_write') if shell_cfg else 'workspace_write' + net = bool(getattr(shell_cfg, 'network_enabled', False)) if shell_cfg else False + max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192)) if shell_cfg else 8192 self._policy = WorkspacePolicyKernel( self.output_dir, extra_allow_roots=extra_allow, @@ -282,19 +262,14 @@ def __init__(self, config): max_kb = 256 if shell_cfg and getattr(shell_cfg, 'max_output_kb', None): max_kb = int(shell_cfg.max_output_kb) - self._artifacts = ArtifactManager( - self.output_dir, max_combined_bytes=max_kb * 1024) + self._artifacts = ArtifactManager(self.output_dir, max_combined_bytes=max_kb * 1024) - self.exclude_func( - getattr(getattr(config, 'tools', None), 'code_executor', None)) + self.exclude_func(getattr(getattr(config, 'tools', None), 'code_executor', None)) if 'file_operation' not in self.exclude_functions: - logger.warning( - 'file_operation is not suggested to be included in local code execution tool.' - ) + logger.warning('file_operation is not suggested to be included in local code execution tool.') results = self._check_dependencies() - logger.info(f'Dependency check results: {results}\n' - f'Make sure to install the missing dependencies.') + logger.info(f'Dependency check results: {results}\nMake sure to install the missing dependencies.') logger.info('LocalCodeExecutionTool initialized (ipykernel based)') @@ -328,13 +303,11 @@ def _check_dependencies(self) -> None: install_package(pip_name, import_name) module = importlib.import_module(import_name) except Exception as e: - logger.error( - f'Failed to install or import {pip_name}: {e}') + logger.error(f'Failed to install or import {pip_name}: {e}') results[pip_name] = None continue except Exception as e: - logger.error( - f'Unexpected error when importing {pip_name}: {e}') + logger.error(f'Unexpected error when importing {pip_name}: {e}') results[pip_name] = None continue @@ -345,8 +318,7 @@ def _check_dependencies(self) -> None: def _build_env(self, field: str, inherit: bool = False) -> Dict[str, str]: if inherit: env: Dict[str, str] = dict(os.environ) - logger.warning( - "It's not safe to inherit from the parent environment.") + logger.warning("It's not safe to inherit from the parent environment.") else: env: Dict[str, str] = { 'INHERITED_FROM_LOCAL': 'False', @@ -394,72 +366,63 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='notebook_executor', server_name='code_executor', - description= - ('Execute Python code locally with state ' - 'persistence in a Jupyter kernel environment. Variables, imports, and ' - 'data are preserved across multiple calls within the same session. ' - 'Supports pandas, numpy, matplotlib, seaborn for data analysis. ' - 'Use print() to output results.'), + description=( + 'Execute Python code locally with state ' + 'persistence in a Jupyter kernel environment. Variables, imports, and ' + 'data are preserved across multiple calls within the same session. ' + 'Supports pandas, numpy, matplotlib, seaborn for data analysis. ' + 'Use print() to output results.' + ), parameters={ 'type': 'object', 'properties': { 'code': { - 'type': - 'string', - 'description': - ('Python code to execute in the notebook session. ' - 'Can access previously defined variables. ' - 'Use print() for output.') - }, - 'description': { - 'type': - 'string', - 'description': - 'Brief description of what the code does' + 'type': 'string', + 'description': ( + 'Python code to execute in the notebook session. ' + 'Can access previously defined variables. ' + 'Use print() for output.' + ), }, + 'description': {'type': 'string', 'description': 'Brief description of what the code does'}, 'timeout': { 'type': 'integer', 'minimum': 1, 'maximum': 600, 'description': 'Execution timeout in seconds', - 'default': self._notebook_timeout - } + 'default': self._notebook_timeout, + }, }, 'required': ['code'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='python_executor', server_name='code_executor', - description= - ('Execute stateless Python code locally. ' - 'Each call runs in an isolated environment without ' - 'persisting context between invocations. ' - 'Supports pandas, numpy, matplotlib, seaborn, and other ' - 'libraries you need for data analysis. ' - 'Use print() to output results.'), + description=( + 'Execute stateless Python code locally. ' + 'Each call runs in an isolated environment without ' + 'persisting context between invocations. ' + 'Supports pandas, numpy, matplotlib, seaborn, and other ' + 'libraries you need for data analysis. ' + 'Use print() to output results.' + ), parameters={ 'type': 'object', 'properties': { - 'code': { - 'type': 'string', - 'description': 'Python code to execute' - }, - 'description': { - 'type': - 'string', - 'description': - 'Brief description of what the code does' - }, + 'code': {'type': 'string', 'description': 'Python code to execute'}, + 'description': {'type': 'string', 'description': 'Brief description of what the code does'}, 'timeout': { 'type': 'integer', 'description': 'Execution timeout in seconds', - 'default': self._python_timeout - } + 'default': self._python_timeout, + }, }, 'required': ['code'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='shell_executor', server_name='code_executor', @@ -472,19 +435,15 @@ async def _get_tools_inner(self) -> Dict[str, Any]: parameters={ 'type': 'object', 'properties': { - 'command': { - 'type': 'string', - 'description': 'Shell command to execute' - }, + 'command': {'type': 'string', 'description': 'Shell command to execute'}, 'timeout': { 'type': 'integer', 'description': 'Execution timeout in seconds', - 'default': self._shell_timeout + 'default': self._shell_timeout, }, 'run_in_background': { 'type': 'boolean', - 'description': - 'If true, start the command asynchronously and return task_id (requires TaskManager).', + 'description': 'If true, start the command asynchronously and return task_id (requires TaskManager).', 'default': False, }, '__call_id': { @@ -493,76 +452,50 @@ async def _get_tools_inner(self) -> Dict[str, Any]: }, }, 'required': ['command'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='file_operation', server_name='code_executor', - description= - 'Perform file operations inside the local output directory', + description='Perform file operations inside the local output directory', parameters={ 'type': 'object', 'properties': { 'operation': { - 'type': - 'string', - 'description': - 'Type of file operation to perform', - 'enum': [ - 'create', 'read', 'write', 'delete', - 'list', 'exists' - ] - }, - 'file_path': { 'type': 'string', - 'description': 'Path to the file or directory' + 'description': 'Type of file operation to perform', + 'enum': ['create', 'read', 'write', 'delete', 'list', 'exists'], }, - 'content': { - 'type': - 'string', - 'description': - 'Content for write/create operations' - }, - 'encoding': { - 'type': 'string', - 'description': 'File encoding to use', - 'default': 'utf-8' - } + 'file_path': {'type': 'string', 'description': 'Path to the file or directory'}, + 'content': {'type': 'string', 'description': 'Content for write/create operations'}, + 'encoding': {'type': 'string', 'description': 'File encoding to use', 'default': 'utf-8'}, }, 'required': ['operation', 'file_path'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='reset_executor', server_name='code_executor', - description= - ('Restart the local ipykernel session to clear state. ' - 'All variables, imports, and session state will be cleared.' - ), - parameters={ - 'type': 'object', - 'properties': {}, - 'required': [], - 'additionalProperties': False - }), + description=( + 'Restart the local ipykernel session to clear state. ' + 'All variables, imports, and session state will be cleared.' + ), + parameters={'type': 'object', 'properties': {}, 'required': [], 'additionalProperties': False}, + ), Tool( tool_name='get_executor_info', server_name='code_executor', - description= - 'Get information about the local execution environment.', - parameters={ - 'type': 'object', - 'properties': {}, - 'required': [], - 'additionalProperties': False - }), + description='Get information about the local execution environment.', + parameters={'type': 'object', 'properties': {}, 'required': [], 'additionalProperties': False}, + ), ] } return tools - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: if not self._initialized: await self.connect() @@ -570,28 +503,12 @@ async def call_tool(self, server_name: str, *, tool_name: str, method = getattr(self, tool_name) return await method(**tool_args) except AttributeError: - return json.dumps( - { - 'success': False, - 'error': f'Unknown tool: {tool_name}' - }, - ensure_ascii=False, - indent=2) + return json.dumps({'success': False, 'error': f'Unknown tool: {tool_name}'}, ensure_ascii=False, indent=2) except Exception as exc: - logger.error( - f'Tool execution error ({tool_name}): {exc}', exc_info=True) - return json.dumps( - { - 'success': False, - 'error': f'Tool execution error: {exc}' - }, - ensure_ascii=False, - indent=2) + logger.error(f'Tool execution error ({tool_name}): {exc}', exc_info=True) + return json.dumps({'success': False, 'error': f'Tool execution error: {exc}'}, ensure_ascii=False, indent=2) - async def notebook_executor(self, - code: str, - description: str = '', - timeout: Optional[int] = None) -> str: + async def notebook_executor(self, code: str, description: str = '', timeout: Optional[int] = None) -> str: exec_timeout = timeout or self._notebook_timeout try: @@ -599,13 +516,8 @@ async def notebook_executor(self, result = await self.kernel_session.execute(code, exec_timeout) except Exception as exc: return json.dumps( - { - 'success': False, - 'description': description, - 'error': str(exc) - }, - ensure_ascii=False, - indent=2) + {'success': False, 'description': description, 'error': str(exc)}, ensure_ascii=False, indent=2 + ) error_payload = result.get('error') stderr = result.get('stderr') or '' @@ -622,15 +534,13 @@ async def notebook_executor(self, 'success': error_payload is None, 'description': description, 'output': result.get('output', ''), - 'error': stderr or None + 'error': stderr or None, }, ensure_ascii=False, - indent=2) + indent=2, + ) - async def python_executor(self, - code: str, - description: str = '', - timeout: Optional[int] = None) -> str: + async def python_executor(self, code: str, description: str = '', timeout: Optional[int] = None) -> str: exec_timeout = timeout or self._python_timeout def _exec_code(): @@ -638,35 +548,26 @@ def _exec_code(): stderr_buffer = io.StringIO() globals_dict: Dict[str, Any] = {'__builtins__': __builtins__} locals_dict: Dict[str, Any] = {} - with redirect_stdout(stdout_buffer), redirect_stderr( - stderr_buffer): + with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer): exec(code, globals_dict, locals_dict) return stdout_buffer.getvalue(), stderr_buffer.getvalue() try: - stdout, stderr = await asyncio.wait_for( - asyncio.to_thread(_exec_code), timeout=exec_timeout) + stdout, stderr = await asyncio.wait_for(asyncio.to_thread(_exec_code), timeout=exec_timeout) except asyncio.TimeoutError: - return json.dumps( - { - 'success': - False, - 'description': - description, - 'error': - f'Python execution timed out after {exec_timeout} seconds' - }, - ensure_ascii=False, - indent=2) - except Exception as exc: return json.dumps( { 'success': False, 'description': description, - 'error': str(exc) + 'error': f'Python execution timed out after {exec_timeout} seconds', }, ensure_ascii=False, - indent=2) + indent=2, + ) + except Exception as exc: + return json.dumps( + {'success': False, 'description': description, 'error': str(exc)}, ensure_ascii=False, indent=2 + ) if not stderr: logger.info('Python code executed successfully') @@ -678,16 +579,19 @@ def _exec_code(): 'success': not stderr, 'description': description, 'output': stdout.strip('\n'), - 'error': stderr.strip('\n') or None + 'error': stderr.strip('\n') or None, }, ensure_ascii=False, - indent=2) + indent=2, + ) - async def shell_executor(self, - command: str, - timeout: Optional[int] = None, - run_in_background: bool = False, - __call_id: Optional[str] = None) -> str: + async def shell_executor( + self, + command: str, + timeout: Optional[int] = None, + run_in_background: bool = False, + __call_id: Optional[str] = None, + ) -> str: exec_timeout = timeout or self._shell_timeout call_id = __call_id or f'shell-{os.urandom(4).hex()}' @@ -695,10 +599,7 @@ async def shell_executor(self, self._policy.assert_shell_command_allowed(command) except WorkspacePolicyError as e: return json.dumps( - { - 'success': False, - 'error': str(e) - }, + {'success': False, 'error': str(e)}, ensure_ascii=False, indent=2, ) @@ -710,8 +611,7 @@ async def shell_executor(self, return json.dumps( { 'success': False, - 'error': - 'run_in_background requires TaskManager (host must wire LLMAgent.task_manager).', + 'error': 'run_in_background requires TaskManager (host must wire LLMAgent.task_manager).', }, ensure_ascii=False, indent=2, @@ -726,10 +626,7 @@ async def shell_executor(self, ) except FileNotFoundError as exc: return json.dumps( - { - 'success': False, - 'error': f'Shell not available: {exc}' - }, + {'success': False, 'error': f'Shell not available: {exc}'}, ensure_ascii=False, indent=2, ) @@ -743,8 +640,7 @@ async def shell_executor(self, async def _watcher() -> None: try: - stdout, stderr = await asyncio.wait_for( - process.communicate(), timeout=exec_timeout) + stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=exec_timeout) stdout_text = _coerce_str(stdout).strip('\n') stderr_text = _coerce_str(stderr).strip('\n') success = process.returncode == 0 @@ -798,17 +694,13 @@ async def _watcher() -> None: ) except FileNotFoundError as exc: return json.dumps( - { - 'success': False, - 'error': f'Shell not available: {exc}' - }, + {'success': False, 'error': f'Shell not available: {exc}'}, ensure_ascii=False, indent=2, ) try: - stdout, stderr = await asyncio.wait_for( - process.communicate(), timeout=exec_timeout) + stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=exec_timeout) except asyncio.TimeoutError: process.kill() try: @@ -816,12 +708,7 @@ async def _watcher() -> None: except Exception: # noqa: B902 pass return json.dumps( - { - 'success': - False, - 'error': - f'Shell command timed out after {exec_timeout} seconds' - }, + {'success': False, 'error': f'Shell command timed out after {exec_timeout} seconds'}, ensure_ascii=False, indent=2, ) @@ -841,22 +728,15 @@ async def _watcher() -> None: payload=payload, ) - async def file_operation(self, - operation: str, - file_path: str, - content: Optional[str] = None, - encoding: Optional[str] = 'utf-8') -> str: + async def file_operation( + self, operation: str, file_path: str, content: Optional[str] = None, encoding: Optional[str] = 'utf-8' + ) -> str: try: target = self._resolve_path(file_path) except ValueError as exc: return json.dumps( - { - 'success': False, - 'error': str(exc), - 'file_path': file_path - }, - ensure_ascii=False, - indent=2) + {'success': False, 'error': str(exc), 'file_path': file_path}, ensure_ascii=False, indent=2 + ) op = operation.lower() @@ -864,69 +744,40 @@ async def file_operation(self, if op == 'create': target.parent.mkdir(parents=True, exist_ok=True) target.touch(exist_ok=True) - result = { - 'success': True, - 'file_path': str(target), - 'message': 'File created' - } + result = {'success': True, 'file_path': str(target), 'message': 'File created'} elif op == 'read': data = target.read_text(encoding=encoding or 'utf-8') - result = { - 'success': True, - 'file_path': str(target), - 'output': data - } + result = {'success': True, 'file_path': str(target), 'output': data} elif op == 'write': if content is None: raise ValueError('Content is required for write operation') target.parent.mkdir(parents=True, exist_ok=True) target.write_text(content, encoding=encoding or 'utf-8') - result = { - 'success': True, - 'file_path': str(target), - 'message': 'File written' - } + result = {'success': True, 'file_path': str(target), 'message': 'File written'} elif op == 'delete': if target.is_dir(): shutil.rmtree(target) else: target.unlink(missing_ok=True) - result = { - 'success': True, - 'file_path': str(target), - 'message': 'Deleted successfully' - } + result = {'success': True, 'file_path': str(target), 'message': 'Deleted successfully'} elif op == 'list': if not target.is_dir(): - raise ValueError( - 'List operation requires a directory path') - entries = [{ - 'name': - child.name, - 'is_dir': - child.is_dir(), - 'size': - child.stat().st_size if child.is_file() else None - } for child in sorted(target.iterdir())] - result = { - 'success': True, - 'file_path': str(target), - 'entries': entries - } + raise ValueError('List operation requires a directory path') + entries = [ + { + 'name': child.name, + 'is_dir': child.is_dir(), + 'size': child.stat().st_size if child.is_file() else None, + } + for child in sorted(target.iterdir()) + ] + result = {'success': True, 'file_path': str(target), 'entries': entries} elif op == 'exists': - result = { - 'success': True, - 'file_path': str(target), - 'exists': target.exists() - } + result = {'success': True, 'file_path': str(target), 'exists': target.exists()} else: raise ValueError(f'Unsupported file operation: {operation}') except Exception as exc: - result = { - 'success': False, - 'file_path': str(target), - 'error': str(exc) - } + result = {'success': False, 'file_path': str(target), 'error': str(exc)} return json.dumps(result, ensure_ascii=False, indent=2, default=str) @@ -935,14 +786,10 @@ async def reset_executor(self) -> str: async with self._kernel_lock: await self.kernel_session.restart() return json.dumps( - { - 'success': - True, - 'message': - 'Local kernel session restarted. State has been cleared.' - }, + {'success': True, 'message': 'Local kernel session restarted. State has been cleared.'}, ensure_ascii=False, - indent=2) + indent=2, + ) except Exception as exc: return json.dumps({'success': False, 'error': str(exc)}, ensure_ascii=False, indent=2) # yapf: disable @@ -964,6 +811,5 @@ def _resolve_path(self, file_path: str) -> Path: else: raw_path = raw_path.resolve() if not _is_relative_to(raw_path, self.output_dir): - raise ValueError( - 'Access outside the output directory is not permitted') + raise ValueError('Access outside the output directory is not permitted') return raw_path diff --git a/ms_agent/tools/code/sandbox_manager.py b/ms_agent/tools/code/sandbox_manager.py index 9744ca7f8..8a3880a90 100644 --- a/ms_agent/tools/code/sandbox_manager.py +++ b/ms_agent/tools/code/sandbox_manager.py @@ -1,9 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import Union -from ms_agent.utils import get_logger from omegaconf import DictConfig +from ms_agent.utils import get_logger + logger = get_logger() @@ -30,22 +31,18 @@ def ensure_local_image_exists(image: str) -> bool: try: client = docker.from_env() - image_exists = any(image in img.tags - for img in client.images.list() if img.tags) + image_exists = any(image in img.tags for img in client.images.list() if img.tags) if image_exists: logger.info(f'Image exists in local Docker registry: {image}') else: - logger.info( - f'Image does not exist in local Docker registry: {image}') + logger.info(f'Image does not exist in local Docker registry: {image}') return image_exists except Exception as e: logger.error(f'Error checking if image exists: {e}') raise RuntimeError(f'Failed to check image existence: {e}') from e @staticmethod - async def create_manager( - config: Union[DictConfig, dict] - ) -> Union['LocalSandboxManager', 'HttpSandboxManager']: + async def create_manager(config: Union[DictConfig, dict]) -> Union['LocalSandboxManager', 'HttpSandboxManager']: """ Create and initialize a sandbox manager based on configuration. @@ -61,13 +58,12 @@ async def create_manager( from ms_enclave.sandbox.manager import HttpSandboxManager, LocalSandboxManager # Extract sandbox configuration - if isinstance(config, DictConfig) and hasattr( - config, 'tools') and hasattr(config.tools, 'code_executor'): + if isinstance(config, DictConfig) and hasattr(config, 'tools') and hasattr(config.tools, 'code_executor'): sandbox_config = getattr(config.tools.code_executor, 'sandbox', {}) elif isinstance(config, (DictConfig, dict)): - sandbox_config = config.get('tools', {}).get( - 'code_executor', {}).get('sandbox', {}) or config.get( - 'sandbox', {}) + sandbox_config = config.get('tools', {}).get('code_executor', {}).get('sandbox', {}) or config.get( + 'sandbox', {} + ) else: raise ValueError(f'Unknown config type: {type(config)}') @@ -78,24 +74,16 @@ async def create_manager( if mode == 'local': cleanup_interval = sandbox_config.get('cleanup_interval', 300) manager = LocalSandboxManager(cleanup_interval=cleanup_interval) - logger.info( - f'Created LocalSandboxManager with cleanup_interval={cleanup_interval}s' - ) + logger.info(f'Created LocalSandboxManager with cleanup_interval={cleanup_interval}s') if image: try: - if not SandboxManagerFactory.ensure_local_image_exists( - image): - raise ValueError( - f'Image "{image}" does not exist in local Docker registry' - ) + if not SandboxManagerFactory.ensure_local_image_exists(image): + raise ValueError(f'Image "{image}" does not exist in local Docker registry') except RuntimeError as e: - raise ValueError( - f'Error checking if image exists: {e}') from e + raise ValueError(f'Error checking if image exists: {e}') from e else: - logger.warning( - 'No image specified for LocalSandboxManager, using default' - ) + logger.warning('No image specified for LocalSandboxManager, using default') elif mode == 'http': base_url = sandbox_config.get('http_url', 'http://localhost:8000') @@ -103,7 +91,6 @@ async def create_manager( logger.info(f'Created HttpSandboxManager with base_url={base_url}') else: - raise ValueError( - f"Unknown sandbox mode: {mode}. Must be 'local' or 'http'") + raise ValueError(f"Unknown sandbox mode: {mode}. Must be 'local' or 'http'") return manager diff --git a/ms_agent/tools/code_server/lsp_code_server.py b/ms_agent/tools/code_server/lsp_code_server.py index df1e043e5..64431c84a 100644 --- a/ms_agent/tools/code_server/lsp_code_server.py +++ b/ms_agent/tools/code_server/lsp_code_server.py @@ -1,15 +1,14 @@ import asyncio +import json import os import shutil import sys from pathlib import Path from typing import Any, Dict, List, Optional -import json from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger -from ms_agent.utils.constants import (DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, - DEFAULT_OUTPUT_DIR) +from ms_agent.utils.constants import DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, DEFAULT_OUTPUT_DIR logger = get_logger() @@ -24,8 +23,7 @@ def __init__(self, config): self.stdout = None self.message_id = 0 self.initialized = False - self.output_dir = getattr(self.config, 'output_dir', - DEFAULT_OUTPUT_DIR) + self.output_dir = getattr(self.config, 'output_dir', DEFAULT_OUTPUT_DIR) self.workspace_dir = Path(self.output_dir).resolve() self.index_dir = os.path.join(self.output_dir, DEFAULT_INDEX_DIR) self.lock_dir = os.path.join(self.output_dir, DEFAULT_LOCK_DIR) @@ -59,12 +57,7 @@ async def send_request(self, method: str, params: dict = None) -> dict: self.message_id += 1 request_id = self.message_id - request = { - 'jsonrpc': '2.0', - 'id': request_id, - 'method': method, - 'params': params or {} - } + request = {'jsonrpc': '2.0', 'id': request_id, 'method': method, 'params': params or {}} content = json.dumps(request) message = f'Content-Length: {len(content)}\r\n\r\n{content}' @@ -84,14 +77,10 @@ async def send_request(self, method: str, params: dict = None) -> dict: # It's a notification (no id) or response for different request # Log and continue reading if 'method' in msg: - logger.debug( - f"Received notification during request: {msg.get('method')}" - ) + logger.debug(f"Received notification during request: {msg.get('method')}") continue - logger.warning( - f'No response received for request {request_id} after {max_retries} attempts' - ) + logger.warning(f'No response received for request {request_id} after {max_retries} attempts') return {'error': 'No response received'} except Exception as e: @@ -103,11 +92,7 @@ async def send_notification(self, method: str, params: dict = None): if not self.process or not self.stdin: raise RuntimeError('LSP server not started') - notification = { - 'jsonrpc': '2.0', - 'method': method, - 'params': params or {} - } + notification = {'jsonrpc': '2.0', 'method': method, 'params': params or {}} content = json.dumps(notification) message = f'Content-Length: {len(content)}\r\n\r\n{content}' @@ -147,55 +132,45 @@ async def _read_message(self) -> dict: async def initialize(self): """Initialize the LSP server and wait for it to be ready""" response = await self.send_request( - 'initialize', { - 'processId': - os.getpid(), - 'rootUri': - self.workspace_dir.as_uri(), - 'rootPath': - str(self.workspace_dir), - 'workspaceFolders': [{ - 'uri': self.workspace_dir.as_uri(), - 'name': self.workspace_dir.name - }], + 'initialize', + { + 'processId': os.getpid(), + 'rootUri': self.workspace_dir.as_uri(), + 'rootPath': str(self.workspace_dir), + 'workspaceFolders': [{'uri': self.workspace_dir.as_uri(), 'name': self.workspace_dir.name}], 'capabilities': { 'textDocument': { 'publishDiagnostics': {}, - 'synchronization': { - 'didOpen': True, - 'didChange': True, - 'didClose': True - } + 'synchronization': {'didOpen': True, 'didChange': True, 'didClose': True}, } - } - }) + }, + }, + ) if 'result' in response: await self.send_notification('initialized', {}) # CRITICAL: Wait for server to be fully ready # Read and discard any startup messages - await asyncio.sleep( - 1.0) # Give server time to complete initialization + await asyncio.sleep(1.0) # Give server time to complete initialization await self.send_notification( - 'workspace/didChangeConfiguration', { + 'workspace/didChangeConfiguration', + { 'settings': { 'python': { 'pythonPath': sys.executable, }, - 'pyright': { - 'extraPaths': [str(self.workspace_dir)] - }, + 'pyright': {'extraPaths': [str(self.workspace_dir)]}, } - }) + }, + ) # Consume any pending messages (like "starting" notifications) try: for _ in range(10): try: - await asyncio.wait_for( - self._read_message(), timeout=2.0) + await asyncio.wait_for(self._read_message(), timeout=2.0) except asyncio.TimeoutError: break except Exception as e: @@ -208,13 +183,11 @@ async def initialize(self): logger.error(f'LSP initialization failed: {response}') return False - async def open_document(self, file_path: str, content: str, - language_id: str): + async def open_document(self, file_path: str, content: str, language_id: str): """Open a document in the LSP server""" file_uri = Path(file_path).resolve().as_uri() changes = [{'uri': file_uri, 'type': 1}] - await self.send_notification('workspace/didChangeWatchedFiles', - {'changes': changes}) + await self.send_notification('workspace/didChangeWatchedFiles', {'changes': changes}) if file_path.endswith('.tsx'): language_id = 'typescriptreact' @@ -226,45 +199,25 @@ async def open_document(self, file_path: str, content: str, language_id = 'javascript' await self.send_notification( - 'textDocument/didOpen', { - 'textDocument': { - 'uri': file_uri, - 'languageId': language_id, - 'version': 1, - 'text': content - } - }) + 'textDocument/didOpen', + {'textDocument': {'uri': file_uri, 'languageId': language_id, 'version': 1, 'text': content}}, + ) await asyncio.sleep(2.0) async def close_document(self, file_path: str): """Close a document to clean up old index""" file_uri = Path(file_path).resolve().as_uri() - await self.send_notification('textDocument/didClose', - {'textDocument': { - 'uri': file_uri - }}) - - async def update_document(self, - file_path: str, - content: str, - version: int = 2): + await self.send_notification('textDocument/didClose', {'textDocument': {'uri': file_uri}}) + + async def update_document(self, file_path: str, content: str, version: int = 2): """Update a document in the LSP server""" file_uri = Path(file_path).resolve().as_uri() await self.send_notification( - 'textDocument/didChange', { - 'textDocument': { - 'uri': file_uri, - 'version': version - }, - 'contentChanges': [{ - 'text': content - }] - }) - - async def get_diagnostics(self, - file_path: str, - wait_time: float = 2.0, - use_cache: bool = True) -> List[dict]: + 'textDocument/didChange', + {'textDocument': {'uri': file_uri, 'version': version}, 'contentChanges': [{'text': content}]}, + ) + + async def get_diagnostics(self, file_path: str, wait_time: float = 2.0, use_cache: bool = True) -> List[dict]: await asyncio.sleep(wait_time) file_uri = Path(file_path).resolve().as_uri() @@ -281,8 +234,7 @@ async def get_diagnostics(self, if msg.get('method') == 'textDocument/publishDiagnostics': current_uri = msg.get('params', {}).get('uri') - current_diags = msg.get('params', - {}).get('diagnostics', []) + current_diags = msg.get('params', {}).get('diagnostics', []) self.diagnostics_cache[current_uri] = current_diags logger.debug(f'Cached diagnostics for {current_uri}') @@ -290,15 +242,12 @@ async def get_diagnostics(self, if current_uri == file_uri: diagnostics = current_diags found_target = True - logger.debug( - f'Found target diagnostics for {file_uri}') + logger.debug(f'Found target diagnostics for {file_uri}') except asyncio.TimeoutError: consecutive_timeouts += 1 if consecutive_timeouts >= 3: - logger.debug( - f'Stopped after {consecutive_timeouts} consecutive timeouts' - ) + logger.debug(f'Stopped after {consecutive_timeouts} consecutive timeouts') break else: continue @@ -329,17 +278,12 @@ async def start(self) -> bool: try: # Check if typescript is installed check_process = await asyncio.create_subprocess_exec( - 'npx', - 'tsc', - '--version', - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE) + 'npx', 'tsc', '--version', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) await check_process.communicate() if check_process.returncode != 0: - logger.error( - 'TypeScript not found. Install with: npm install -g typescript' - ) + logger.error('TypeScript not found. Install with: npm install -g typescript') return False # Start typescript-language-server @@ -350,7 +294,8 @@ async def start(self) -> bool: stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - cwd=str(self.workspace_dir)) + cwd=str(self.workspace_dir), + ) self.stdin = self.process.stdin self.stdout = self.process.stdout @@ -394,8 +339,7 @@ def _clean_env_for_node() -> dict[str, str]: env = dict(os.environ) removed = env.pop('PYTHONPATH', None) if removed: - logger.debug('Removed PYTHONPATH=%r from pyright subprocess env', - removed) + logger.debug('Removed PYTHONPATH=%r from pyright subprocess env', removed) return env async def start(self) -> bool: @@ -405,15 +349,12 @@ async def start(self) -> bool: # Check if pyright is installed check_process = await asyncio.create_subprocess_exec( - 'pyright', - '--version', - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE) + 'pyright', '--version', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) await check_process.communicate() if check_process.returncode != 0: - logger.warning( - 'Pyright not found. Install with: pip install pyright') + logger.warning('Pyright not found. Install with: pip install pyright') return False # Start pyright langserver @@ -424,7 +365,8 @@ async def start(self) -> bool: stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=str(self.workspace_dir), - env=clean_env) + env=clean_env, + ) self.stdin = self.process.stdin self.stdout = self.process.stdout @@ -434,8 +376,7 @@ async def _read_server_stderr(process): line = await process.stderr.readline() if not line: break - logger.error( - f"LSP: {line.decode(errors='ignore').rstrip()}") + logger.error(f"LSP: {line.decode(errors='ignore').rstrip()}") asyncio.create_task(_read_server_stderr(self.process)) @@ -443,9 +384,7 @@ async def _read_server_stderr(process): return await self.initialize() except FileNotFoundError: - logger.error( - 'pyright-langserver not found. Install with: pip install pyright' - ) + logger.error('pyright-langserver not found. Install with: pip install pyright') return False except Exception as e: logger.error(f'Failed to start Python LSP server: {e}') @@ -475,10 +414,8 @@ async def start(self) -> bool: if not jdtls_cmd: # Try to find in PATH check_process = await asyncio.create_subprocess_exec( - 'which', - 'jdtls', - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE) + 'which', 'jdtls', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) stdout, _ = await check_process.communicate() if check_process.returncode == 0: jdtls_cmd = stdout.decode('utf-8').strip() @@ -503,7 +440,8 @@ async def start(self) -> bool: stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - cwd=str(self.workspace_dir)) + cwd=str(self.workspace_dir), + ) self.stdin = self.process.stdin self.stdout = self.process.stdout @@ -524,12 +462,20 @@ async def start(self) -> bool: class LSPCodeServer(ToolBase): - skip_files = [ - 'vite.config.ts', 'vite.config.js', 'webpack.config.js', - 'webpack.config.ts', 'rollup.config.js', 'rollup.config.ts', - 'next.config.js', 'next.config.ts', 'tsconfig.json', 'jsconfig.json', - 'package.json', 'pom.xml', 'build.gradle' + 'vite.config.ts', + 'vite.config.js', + 'webpack.config.js', + 'webpack.config.ts', + 'rollup.config.js', + 'rollup.config.ts', + 'next.config.js', + 'next.config.ts', + 'tsconfig.json', + 'jsconfig.json', + 'package.json', + 'pom.xml', + 'build.gradle', ] language_mapping = { @@ -544,10 +490,8 @@ def __init__(self, config): super().__init__(config) self.servers: Dict[str, LSPServer] = {} self.file_versions: Dict[str, int] = {} - self.opened_documents: Dict[str, str] = { - } # Track opened documents: file_path -> language - self.output_dir = getattr(self.config, 'output_dir', - DEFAULT_OUTPUT_DIR) + self.opened_documents: Dict[str, str] = {} # Track opened documents: file_path -> language + self.output_dir = getattr(self.config, 'output_dir', DEFAULT_OUTPUT_DIR) self.workspace_dir = self.output_dir self.index_dir = os.path.join(self.output_dir, DEFAULT_INDEX_DIR) self.lock_dir = os.path.join(self.output_dir, DEFAULT_LOCK_DIR) @@ -560,10 +504,8 @@ async def connect(self) -> None: def cleanup_lsp_index_dirs(self): cleanup_dirs = [ os.path.join(self.output_dir, '.jdtls_workspace'), # Java LSP - os.path.join(self.output_dir, - '.pyright'), # Python LSP (if exists) - os.path.join(self.output_dir, 'node_modules', - '.cache'), # TypeScript LSP cache + os.path.join(self.output_dir, '.pyright'), # Python LSP (if exists) + os.path.join(self.output_dir, 'node_modules', '.cache'), # TypeScript LSP cache ] for dir_path in cleanup_dirs: @@ -571,9 +513,7 @@ def cleanup_lsp_index_dirs(self): try: shutil.rmtree(dir_path, ignore_errors=True) except Exception as e: # noqa - logger.warning( - f'Failed to cleanup LSP index directory {dir_path}: {e}' - ) + logger.warning(f'Failed to cleanup LSP index directory {dir_path}: {e}') async def cleanup(self) -> None: """Stop all LSP servers and clear indexes""" @@ -599,81 +539,65 @@ async def cleanup(self) -> None: async def _get_tools_inner(self) -> Dict[str, Any]: """Get available tools""" return { - 'lsp_code_server': [{ - 'tool_name': - 'check_directory', - 'description': - ('Check all code files in a directory for errors and issues. ' - 'Supports TypeScript/JavaScript, Python, Java files. ' - 'Returns a summary of all diagnostics found.'), - 'parameters': { - 'type': 'object', - 'properties': { - 'directory': { - 'type': - 'string', - 'description': - 'Path to the directory to check (relative to workspace)' + 'lsp_code_server': [ + { + 'tool_name': 'check_directory', + 'description': ( + 'Check all code files in a directory for errors and issues. ' + 'Supports TypeScript/JavaScript, Python, Java files. ' + 'Returns a summary of all diagnostics found.' + ), + 'parameters': { + 'type': 'object', + 'properties': { + 'directory': { + 'type': 'string', + 'description': 'Path to the directory to check (relative to workspace)', + }, + 'language': { + 'type': 'string', + 'enum': ['typescript', 'python', 'java'], + 'description': 'Programming language to check (typescript for JS/TS, python for Python, java for Java)', + }, }, - 'language': { - 'type': - 'string', - 'enum': ['typescript', 'python', 'java'], - 'description': - 'Programming language to check (typescript for JS/TS, python for Python, java for Java)' - } + 'required': ['directory', 'language'], }, - 'required': ['directory', 'language'] - } - }, { - 'tool_name': - 'update_and_check', - 'description': - ("Incrementally update a file's content and check for errors. " - 'Used during code generation to validate each N lines. ' - 'More efficient than checking from scratch each time.'), - 'parameters': { - 'type': 'object', - 'properties': { - 'file_path': { - 'type': - 'string', - 'description': - 'Path to the file (relative to workspace)' - }, - 'content': { - 'type': 'string', - 'description': 'Updated file content' + }, + { + 'tool_name': 'update_and_check', + 'description': ( + "Incrementally update a file's content and check for errors. " + 'Used during code generation to validate each N lines. ' + 'More efficient than checking from scratch each time.' + ), + 'parameters': { + 'type': 'object', + 'properties': { + 'file_path': {'type': 'string', 'description': 'Path to the file (relative to workspace)'}, + 'content': {'type': 'string', 'description': 'Updated file content'}, + 'language': { + 'type': 'string', + 'enum': ['typescript', 'python', 'java'], + 'description': 'Programming language to check (typescript for JS/TS, python for Python, java for Java)', + }, }, - 'language': { - 'type': - 'string', - 'enum': ['typescript', 'python', 'java'], - 'description': - 'Programming language to check (typescript for JS/TS, python for Python, java for Java)' - } + 'required': ['file_path', 'content', 'language'], }, - 'required': ['file_path', 'content', 'language'] - } - }] + }, + ] } - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: """Call a tool""" if tool_name == 'check_directory': - return await self._check_directory(tool_args['directory'], - tool_args['language']) + return await self._check_directory(tool_args['directory'], tool_args['language']) elif tool_name == 'update_and_check': - return await self._update_and_check(tool_args['file_path'], - tool_args['content'], - tool_args['language']) + return await self._update_and_check(tool_args['file_path'], tool_args['content'], tool_args['language']) else: return json.dumps({'error': f'Unknown tool: {tool_name}'}) - async def _get_or_create_server(self, - language: str) -> Optional[LSPServer]: + async def _get_or_create_server(self, language: str) -> Optional[LSPServer]: """Get or create an LSP server for the given language""" if language in self.servers: return self.servers[language] @@ -699,19 +623,16 @@ async def _check_directory(self, directory: str, language: str) -> str: language = language.lower() server = await self._get_or_create_server(language) if not server: - return json.dumps( - {'error': f'Failed to start LSP server for {language}'}) + return json.dumps({'error': f'Failed to start LSP server for {language}'}) dir_path = Path(self.workspace_dir) / directory if not dir_path.exists() or not dir_path.is_dir(): - return json.dumps( - {'error': f'Directory not found: {directory}'}) + return json.dumps({'error': f'Directory not found: {directory}'}) extensions = self.language_mapping.get(language) if not extensions: - return json.dumps( - {'error': f'No extensions found for language: {language}'}) + return json.dumps({'error': f'No extensions found for language: {language}'}) all_files = [] for ext in extensions: @@ -726,26 +647,18 @@ async def _check_directory(self, directory: str, language: str) -> str: configs = ('xml', 'json', 'yaml', 'yml', 'txt', 'md', 'gradle') if filename.endswith(configs): continue - if any([ - filename.startswith(prefix) - for prefix in self.skip_prefixes - ]): + if any([filename.startswith(prefix) for prefix in self.skip_prefixes]): continue rel_path = file.relative_to(dir_path) - if any([ - part.startswith(prefix) for part in rel_path.parts - for prefix in self.skip_prefixes - ]): + if any([part.startswith(prefix) for part in rel_path.parts for prefix in self.skip_prefixes]): continue cleaned_files.append(file) all_files = cleaned_files if not all_files: - return json.dumps({ - 'message': f'No {language} files found in {directory}', - 'file_count': 0, - 'diagnostics': [] - }) + return json.dumps( + {'message': f'No {language} files found in {directory}', 'file_count': 0, 'diagnostics': []} + ) all_diagnostics = [] for file_path in all_files: @@ -753,8 +666,7 @@ async def _check_directory(self, directory: str, language: str) -> str: content = file_path.read_text(encoding='utf-8') rel_path = file_path.relative_to(Path(self.workspace_dir)) self.file_versions[str(rel_path)] = 1 - await server.open_document( - str(file_path), content, language) + await server.open_document(str(file_path), content, language) self.opened_documents[str(file_path)] = language # Skip diagnostics for index-only mode (trust existing files) @@ -775,9 +687,10 @@ async def _check_directory(self, directory: str, language: str) -> str: 'file_count': len(all_files), 'diagnostics': all_diagnostics, 'files_indexed': len(all_files) - len(all_diagnostics), - 'status': 'indexed' + 'status': 'indexed', }, - indent=2) + indent=2, + ) except Exception as e: logger.error(f'Error checking directory: {e}') @@ -785,7 +698,6 @@ async def _check_directory(self, directory: str, language: str) -> str: @staticmethod def _format_diag_results(diagnostics_result): - ignored_errors = [ # 'cannot be assigned to', 'is not assignable to', 'cannot assign to', '"none"', @@ -793,17 +705,17 @@ def _format_diag_results(diagnostics_result): 'unused', 'never used', 'never read', - 'implicitly has' + 'implicitly has', ] if diagnostics_result.get('has_errors'): issues = diagnostics_result.get('diagnostics', []) # Filter critical errors only critical_errors = [ - d for d in issues if d.get('severity') == 'Error' and not any([ - ignore in d.get('message', '').lower() - for ignore in ignored_errors - ]) + d + for d in issues + if d.get('severity') == 'Error' + and not any([ignore in d.get('message', '').lower() for ignore in ignored_errors]) ] if critical_errors: @@ -818,14 +730,12 @@ def _format_diag_results(diagnostics_result): else: return '' - async def _update_and_check(self, file_path: str, content: str, - language: str) -> str: + async def _update_and_check(self, file_path: str, content: str, language: str) -> str: """Update file content and check for errors""" try: server = await self._get_or_create_server(language) if not server: - return json.dumps( - {'error': f'Failed to start LSP server for {language}'}) + return json.dumps({'error': f'Failed to start LSP server for {language}'}) full_path = Path(self.workspace_dir) / file_path full_path_str = str(full_path) @@ -836,10 +746,7 @@ async def _update_and_check(self, file_path: str, content: str, self.opened_documents[full_path_str] = language else: self.file_versions[file_path] += 1 - await server.update_document( - full_path_str, - content, - version=self.file_versions[file_path]) + await server.update_document(full_path_str, content, version=self.file_versions[file_path]) diagnostics = await server.get_diagnostics(str(full_path)) @@ -849,7 +756,7 @@ async def _update_and_check(self, file_path: str, content: str, 'version': self.file_versions[file_path], 'has_errors': len(diagnostics) > 0, 'diagnostic_count': len(diagnostics), - 'diagnostics': self._format_diagnostics(diagnostics) + 'diagnostics': self._format_diagnostics(diagnostics), } return self._format_diag_results(diagnostics_result) @@ -863,26 +770,17 @@ def _format_diagnostics(diagnostics: List[dict]) -> List[dict]: """Format diagnostics for better readability""" formatted = [] for diag in diagnostics: - severity_map = { - 1: 'Error', - 2: 'Warning', - 3: 'Information', - 4: 'Hint' - } + severity_map = {1: 'Error', 2: 'Warning', 3: 'Information', 4: 'Hint'} - formatted.append({ - 'severity': - severity_map.get(diag.get('severity', 1), 'Error'), - 'message': - diag.get('message', ''), - 'line': - diag.get('range', {}).get('start', {}).get('line', 0) + 1, - 'column': - diag.get('range', {}).get('start', {}).get('character', 0) + 1, - 'source': - diag.get('source', ''), - 'code': - diag.get('code', '') - }) + formatted.append( + { + 'severity': severity_map.get(diag.get('severity', 1), 'Error'), + 'message': diag.get('message', ''), + 'line': diag.get('range', {}).get('start', {}).get('line', 0) + 1, + 'column': diag.get('range', {}).get('start', {}).get('character', 0) + 1, + 'source': diag.get('source', ''), + 'code': diag.get('code', ''), + } + ) return formatted diff --git a/ms_agent/tools/docling/chunker.py b/ms_agent/tools/docling/chunker.py index 3781cb76b..9b30d27b4 100644 --- a/ms_agent/tools/docling/chunker.py +++ b/ms_agent/tools/docling/chunker.py @@ -1,20 +1,18 @@ from typing import Iterable, Iterator, List, Union from docling_core.transforms.chunker import BaseChunk, DocChunk -from docling_core.transforms.chunker.hierarchical_chunker import ( - ChunkingDocSerializer, ChunkingSerializerProvider) +from docling_core.transforms.chunker.hierarchical_chunker import ChunkingDocSerializer, ChunkingSerializerProvider from docling_core.transforms.chunker.hybrid_chunker import HybridChunker from docling_core.transforms.chunker.tokenizer.base import BaseTokenizer -from docling_core.transforms.chunker.tokenizer.huggingface import \ - HuggingFaceTokenizer +from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer from docling_core.transforms.serializer.markdown import MarkdownParams from docling_core.types import DoclingDocument from docling_core.types.doc import DocItemLabel -from ms_agent.utils.logger import get_logger +from modelscope import AutoTokenizer from rich.console import Console from rich.panel import Panel -from modelscope import AutoTokenizer +from ms_agent.utils.logger import get_logger logger = get_logger() @@ -24,11 +22,12 @@ class ImgPlaceholderSerializerProvider(ChunkingSerializerProvider): - def get_serializer(self, doc): return ChunkingDocSerializer( doc=doc, - params=MarkdownParams(image_placeholder='', ), + params=MarkdownParams( + image_placeholder='', + ), ) @@ -36,13 +35,10 @@ def get_serializer(self, doc): class HybridDocumentChunker: - EMBED_MODEL_ID = 'sentence-transformers/all-MiniLM-L6-v2' MAX_TOKENS = 1024 - def __init__(self, - embed_model_id: str = EMBED_MODEL_ID, - max_tokens: int = MAX_TOKENS): + def __init__(self, embed_model_id: str = EMBED_MODEL_ID, max_tokens: int = MAX_TOKENS): """ Hybrid chunker that splits interleaved picture, table, and text into chunks. @@ -66,9 +62,7 @@ def __init__(self, ) @staticmethod - def find_n_th_chunk_with_label( - chunks: List[BaseChunk], n: int, - label: DocItemLabel) -> tuple[int, BaseChunk]: + def find_n_th_chunk_with_label(chunks: List[BaseChunk], n: int, label: DocItemLabel) -> tuple[int, BaseChunk]: """ Find the n-th chunk with the specified label in an iterable of chunks. @@ -88,8 +82,7 @@ def find_n_th_chunk_with_label( return None, None @staticmethod - def find_all_chunks_with_label(chunks: List[BaseChunk], - label: DocItemLabel) -> List[BaseChunk]: + def find_all_chunks_with_label(chunks: List[BaseChunk], label: DocItemLabel) -> List[BaseChunk]: """ Find all chunks with the specified label in an iterable of chunks. @@ -101,15 +94,11 @@ def find_all_chunks_with_label(chunks: List[BaseChunk], List[BaseChunk]: A list of BaseChunk objects that match the label. """ return [ - chunk for chunk in chunks - if any(it.label == label - for it in DocChunk.model_validate(chunk).meta.doc_items) + chunk for chunk in chunks if any(it.label == label for it in DocChunk.model_validate(chunk).meta.doc_items) ] @staticmethod - def find_all_chunks_with_labels( - chunks: List[BaseChunk], - labels: List[DocItemLabel]) -> List[BaseChunk]: + def find_all_chunks_with_labels(chunks: List[BaseChunk], labels: List[DocItemLabel]) -> List[BaseChunk]: """ Find all chunks with any of the specified labels in an iterable of chunks. @@ -121,9 +110,7 @@ def find_all_chunks_with_labels( List[BaseChunk]: A list of BaseChunk objects that match any of the labels. """ return [ - chunk for chunk in chunks if any( - it.label in labels - for it in DocChunk.model_validate(chunk).meta.doc_items) + chunk for chunk in chunks if any(it.label in labels for it in DocChunk.model_validate(chunk).meta.doc_items) ] def print_chunk(self, chunks: List[BaseChunk], chunk_pos: int) -> None: @@ -157,6 +144,7 @@ def chunk(self, docs: Iterable[DoclingDocument]) -> Iterator[BaseChunk]: if __name__ == '__main__': from ms_agent.tools.docling.doc_loader import DocLoader + urls = [ 'https://arxiv.org/pdf/2408.09869', 'https://arxiv.org/pdf/2502.15214', diff --git a/ms_agent/tools/docling/doc_loader.py b/ms_agent/tools/docling/doc_loader.py index 5daf1652b..18e30e635 100644 --- a/ms_agent/tools/docling/doc_loader.py +++ b/ms_agent/tools/docling/doc_loader.py @@ -2,9 +2,6 @@ # yapf: disable import ast import os -from typing import Dict, Iterator, List, Optional, Tuple, Union -from unittest.mock import patch as mock_patch - from docling.backend.html_backend import HTMLDocumentBackend from docling.datamodel.accelerator_options import AcceleratorOptions from docling.datamodel.base_models import InputFormat @@ -12,19 +9,17 @@ from docling.datamodel.pipeline_options import PdfPipelineOptions from docling.datamodel.settings import DEFAULT_PAGE_RANGE, PageRange from docling.document_converter import DocumentConverter, PdfFormatOption -from docling.models.document_picture_classifier import \ - DocumentPictureClassifier +from docling.models.document_picture_classifier import DocumentPictureClassifier from docling.models.layout_model import LayoutModel from docling.models.table_structure_model import TableStructureModel from docling_core.types import DoclingDocument from docling_core.types.doc import DocItem +from typing import Dict, Iterator, List, Optional, Tuple, Union +from unittest.mock import patch as mock_patch + from ms_agent.tools.docling.doc_postprocess import PostProcess -from ms_agent.tools.docling.patches import (download_models_ms, - download_models_pic_classifier_ms, - html_handle_figure, - html_handle_image, - patch_easyocr_models, - requests_get_with_timeout) +from ms_agent.tools.docling.patches import (download_models_ms, download_models_pic_classifier_ms, html_handle_figure, + html_handle_image, patch_easyocr_models, requests_get_with_timeout,) from ms_agent.utils.logger import get_logger from ms_agent.utils.patcher import patch from ms_agent.utils.utils import normalize_url_or_file, txt_to_html diff --git a/ms_agent/tools/docling/doc_postprocess.py b/ms_agent/tools/docling/doc_postprocess.py index 2e21051d9..b1dbf091b 100644 --- a/ms_agent/tools/docling/doc_postprocess.py +++ b/ms_agent/tools/docling/doc_postprocess.py @@ -4,11 +4,9 @@ class PostProcess: - MIN_PICTURE_SIZE = 200.0 * 200.0 # Minimum size for pictures in pixels - def __init__(self): - ... + def __init__(self): ... @staticmethod def filter(doc: DoclingDocument) -> Union[DoclingDocument, None]: diff --git a/ms_agent/tools/docling/patches.py b/ms_agent/tools/docling/patches.py index b2ec77afd..a3a873da7 100644 --- a/ms_agent/tools/docling/patches.py +++ b/ms_agent/tools/docling/patches.py @@ -1,13 +1,12 @@ # flake8: noqa import sys -from pathlib import Path - from bs4 import Tag from docling_core.types import DoclingDocument from docling_core.types.doc import DocItemLabel, ImageRef +from pathlib import Path + from ms_agent.utils.logger import get_logger -from ms_agent.utils.utils import (load_image_from_uri_to_pil, - load_image_from_url_to_pil, validate_url) +from ms_agent.utils.utils import load_image_from_uri_to_pil, load_image_from_url_to_pil, validate_url logger = get_logger() @@ -16,9 +15,7 @@ def html_handle_figure(self, element: Tag, doc: DoclingDocument) -> None: """ Patch the `docling.backend.html_backend.HTMLDocumentBackend.handle_figure` method. """ - logger.debug( - f'Patching HTMLDocumentBackend.handle_figure for {doc.origin.filename}' - ) + logger.debug(f'Patching HTMLDocumentBackend.handle_figure for {doc.origin.filename}') img_element: Tag = element.find('img') if isinstance(img_element, Tag): @@ -32,8 +29,7 @@ def html_handle_figure(self, element: Tag, doc: DoclingDocument) -> None: else: if not img_url.startswith('http'): img_url = validate_url(img_url=img_url, backend=self) - img_pil = load_image_from_url_to_pil( - img_url) if img_url.startswith('http') else None + img_pil = load_image_from_url_to_pil(img_url) if img_url.startswith('http') else None else: img_pil = None @@ -77,8 +73,7 @@ def html_handle_image(self, element: Tag, doc: DoclingDocument) -> None: """ Patch the `docling.backend.html_backend.HTMLDocumentBackend.handle_image` method to use the custom. """ - logger.debug( - f'Patching HTMLDocumentBackend.handle_image for {doc.origin.filename}') + logger.debug(f'Patching HTMLDocumentBackend.handle_image for {doc.origin.filename}') # Get the image from element img_url: str = element.attrs.get('src', None) @@ -147,30 +142,41 @@ def patch_easyocr_models(): logger.info('Patching EasyOCR models URLs for ModelScope...') # Patch detection models - detection_models['craft'][ - 'url'] = 'https://modelscope.cn/models/ms-agent/craft_mlt_25k/resolve/master/craft_mlt_25k.zip' - detection_models['dbnet18'][ - 'url'] = 'https://modelscope.cn/models/ms-agent/pretrained_ic15_res18/resolve/master/pretrained_ic15_res18.zip' - detection_models['dbnet50'][ - 'url'] = 'https://modelscope.cn/models/ms-agent/pretrained_ic15_res50/resolve/master/pretrained_ic15_res50.zip' + detection_models['craft']['url'] = ( + 'https://modelscope.cn/models/ms-agent/craft_mlt_25k/resolve/master/craft_mlt_25k.zip' + ) + detection_models['dbnet18']['url'] = ( + 'https://modelscope.cn/models/ms-agent/pretrained_ic15_res18/resolve/master/pretrained_ic15_res18.zip' + ) + detection_models['dbnet50']['url'] = ( + 'https://modelscope.cn/models/ms-agent/pretrained_ic15_res50/resolve/master/pretrained_ic15_res50.zip' + ) # Patch recognition models - recognition_models['gen2']['english_g2'][ - 'url'] = 'https://modelscope.cn/models/ms-agent/english_g2/resolve/master/english_g2.zip' - recognition_models['gen2']['latin_g2'][ - 'url'] = 'https://modelscope.cn/models/ms-agent/latin_g2/resolve/master/latin_g2.zip' - recognition_models['gen2']['zh_sim_g2'][ - 'url'] = 'https://modelscope.cn/models/ms-agent/zh_sim_g2/resolve/master/zh_sim_g2.zip' - recognition_models['gen2']['japanese_g2'][ - 'url'] = 'https://modelscope.cn/models/ms-agent/japanese_g2/resolve/master/japanese_g2.zip' - recognition_models['gen2']['korean_g2'][ - 'url'] = 'https://modelscope.cn/models/ms-agent/korean_g2/resolve/master/korean_g2.zip' - recognition_models['gen2']['telugu_g2'][ - 'url'] = 'https://modelscope.cn/models/ms-agent/telugu_g2/resolve/master/telugu_g2.zip' - recognition_models['gen2']['kannada_g2'][ - 'url'] = 'https://modelscope.cn/models/ms-agent/kannada_g2/resolve/master/kannada_g2.zip' - recognition_models['gen2']['cyrillic_g2'][ - 'url'] = 'https://modelscope.cn/models/ms-agent/cyrillic_g2/resolve/master/cyrillic_g2.zip' + recognition_models['gen2']['english_g2']['url'] = ( + 'https://modelscope.cn/models/ms-agent/english_g2/resolve/master/english_g2.zip' + ) + recognition_models['gen2']['latin_g2']['url'] = ( + 'https://modelscope.cn/models/ms-agent/latin_g2/resolve/master/latin_g2.zip' + ) + recognition_models['gen2']['zh_sim_g2']['url'] = ( + 'https://modelscope.cn/models/ms-agent/zh_sim_g2/resolve/master/zh_sim_g2.zip' + ) + recognition_models['gen2']['japanese_g2']['url'] = ( + 'https://modelscope.cn/models/ms-agent/japanese_g2/resolve/master/japanese_g2.zip' + ) + recognition_models['gen2']['korean_g2']['url'] = ( + 'https://modelscope.cn/models/ms-agent/korean_g2/resolve/master/korean_g2.zip' + ) + recognition_models['gen2']['telugu_g2']['url'] = ( + 'https://modelscope.cn/models/ms-agent/telugu_g2/resolve/master/telugu_g2.zip' + ) + recognition_models['gen2']['kannada_g2']['url'] = ( + 'https://modelscope.cn/models/ms-agent/kannada_g2/resolve/master/kannada_g2.zip' + ) + recognition_models['gen2']['cyrillic_g2']['url'] = ( + 'https://modelscope.cn/models/ms-agent/cyrillic_g2/resolve/master/cyrillic_g2.zip' + ) def requests_get_with_timeout( diff --git a/ms_agent/tools/fetch_playwright_fallback.py b/ms_agent/tools/fetch_playwright_fallback.py index 0ff89a91d..62b14b231 100644 --- a/ms_agent/tools/fetch_playwright_fallback.py +++ b/ms_agent/tools/fetch_playwright_fallback.py @@ -10,6 +10,7 @@ We keep **one browser per thread** (e.g. each ``ThreadPoolExecutor`` worker) and reuse it across URLs instead of launching Chromium for every fetch. """ + from __future__ import annotations import atexit @@ -36,9 +37,9 @@ def _chromium_launch_args() -> List[str]: '--blink-settings=imagesEnabled=false', ] if os.getenv('MS_AGENT_PLAYWRIGHT_NO_SANDBOX', '').lower() in ( - '1', - 'true', - 'yes', + '1', + 'true', + 'yes', ): args.extend(('--no-sandbox', '--disable-setuid-sandbox')) return args @@ -100,7 +101,8 @@ def _thread_browser() -> object: except ImportError: logger.debug( 'playwright is not installed; skip headless fetch. ' - 'Install with: pip install playwright && playwright install chromium') + 'Install with: pip install playwright && playwright install chromium' + ) raise RuntimeError('playwright not installed') from None pw = sync_playwright().start() @@ -135,7 +137,8 @@ def try_playwright_inner_text( except ImportError: logger.debug( 'playwright is not installed; skip headless fetch. ' - 'Install with: pip install playwright && playwright install chromium') + 'Install with: pip install playwright && playwright install chromium' + ) return '' text = '' @@ -147,13 +150,11 @@ def try_playwright_inner_text( page.goto(url, wait_until='domcontentloaded', timeout=timeout_ms) if settle_ms: page.wait_for_timeout(settle_ms) - raw = page.evaluate( - """() => { + raw = page.evaluate("""() => { const b = document.body; if (!b) return ''; return b.innerText || ''; - }""" - ) + }""") if isinstance(raw, str): text = raw[:_MAX_INNER_TEXT_CHARS] finally: @@ -179,10 +180,7 @@ def looks_like_spa_shell_html(raw_html: str) -> bool: if not raw_html or len(raw_html) < 80: return False low = raw_html.lower() - if any( - x in low - for x in ('enable javascript', 'javascript is required', - 'you need to enable javascript')): + if any(x in low for x in ('enable javascript', 'javascript is required', 'you need to enable javascript')): return True if re.search(r']+\bid=["\']root["\'][^>]*>\s*
', low): return True diff --git a/ms_agent/tools/filesystem_tool.py b/ms_agent/tools/filesystem_tool.py index caa58c58b..987973adc 100644 --- a/ms_agent/tools/filesystem_tool.py +++ b/ms_agent/tools/filesystem_tool.py @@ -23,11 +23,46 @@ _FS_TOOL_ALIASES = {'read': 'read_file', 'edit': 'edit_file', 'write': 'write_file'} _TEXT_SUFFIXES = { - '.py', '.md', '.txt', '.yaml', '.yml', '.json', '.toml', '.cfg', '.ini', - '.sh', '.bash', '.js', '.ts', '.tsx', '.jsx', '.css', '.html', '.xml', - '.rs', '.go', '.java', '.c', '.h', '.cpp', '.hpp', '.cs', '.rb', '.php', - '.sql', '.vue', '.svelte', '.m', '.swift', '.kt', '.gradle', '.properties', - '.env', '.gitignore', '.dockerignore', 'Dockerfile', + '.py', + '.md', + '.txt', + '.yaml', + '.yml', + '.json', + '.toml', + '.cfg', + '.ini', + '.sh', + '.bash', + '.js', + '.ts', + '.tsx', + '.jsx', + '.css', + '.html', + '.xml', + '.rs', + '.go', + '.java', + '.c', + '.h', + '.cpp', + '.hpp', + '.cs', + '.rb', + '.php', + '.sql', + '.vue', + '.svelte', + '.m', + '.swift', + '.kt', + '.gradle', + '.properties', + '.env', + '.gitignore', + '.dockerignore', + 'Dockerfile', } @@ -38,8 +73,10 @@ class FileSystemTool(ToolBase): IMAGE_EXTENSIONS = frozenset({'png', 'jpg', 'jpeg', 'gif', 'webp'}) # Curly quote → straight quote mapping for fuzzy matching CURLY_QUOTE_MAP = { - '\u2018': "'", '\u2019': "'", # ' ' - '\u201c': '"', '\u201d': '"', # " " + '\u2018': "'", + '\u2019': "'", # ' ' + '\u201c': '"', + '\u201d': '"', # " " } SYSTEM_FOR_ABBREVIATIONS = """你是一个帮我简化文件信息并返回缩略的机器人,你需要根据输入文件内容来生成压缩过的文件内容。 @@ -60,18 +97,12 @@ def __init__(self, config, **kwargs): super().__init__(config) self.exclude_func(getattr(config.tools, 'file_system', None)) if self.include_functions: - self.include_functions = [ - _FS_TOOL_ALIASES.get(n, n) for n in self.include_functions - ] + self.include_functions = [_FS_TOOL_ALIASES.get(n, n) for n in self.include_functions] if self.exclude_functions: - self.exclude_functions = [ - _FS_TOOL_ALIASES.get(n, n) for n in self.exclude_functions - ] + self.exclude_functions = [_FS_TOOL_ALIASES.get(n, n) for n in self.exclude_functions] self.output_dir = getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR) self.trust_remote_code = kwargs.get('trust_remote_code', False) - self.allow_read_all_files = getattr( - getattr(config.tools, 'file_system', {}), 'allow_read_all_files', - False) + self.allow_read_all_files = getattr(getattr(config.tools, 'file_system', {}), 'allow_read_all_files', False) if not self.trust_remote_code: self.allow_read_all_files = False if hasattr(self.config, 'llm'): @@ -86,22 +117,17 @@ def __init__(self, config, **kwargs): fs_cfg = getattr(config.tools, 'file_system', None) self._grep_timeout = int(getattr(fs_cfg, 'grep_timeout_s', 120) or 120) - self._default_grep_head = int( - getattr(fs_cfg, 'grep_head_limit', 250) or 250) + self._default_grep_head = int(getattr(fs_cfg, 'grep_head_limit', 250) or 250) self._glob_max_files = int(getattr(fs_cfg, 'glob_max_files', 100) or 100) wp = getattr(getattr(config, 'tools', None), 'workspace_policy', None) extra = list(getattr(wp, 'allow_roots', []) or []) if wp else [] deny = list(getattr(wp, 'deny_globs', []) or []) if wp else [] - shell_cfg = getattr( - getattr(config.tools, 'code_executor', None), 'shell', None) - shell_mode = getattr(shell_cfg, 'default_mode', - 'workspace_write') if shell_cfg else 'workspace_write' - net = bool(getattr(shell_cfg, 'network_enabled', False) - ) if shell_cfg else False - max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192) - ) if shell_cfg else 8192 + shell_cfg = getattr(getattr(config.tools, 'code_executor', None), 'shell', None) + shell_mode = getattr(shell_cfg, 'default_mode', 'workspace_write') if shell_cfg else 'workspace_write' + net = bool(getattr(shell_cfg, 'network_enabled', False)) if shell_cfg else False + max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192)) if shell_cfg else 8192 _out_p = Path(self.output_dir).expanduser().resolve() try: @@ -119,13 +145,13 @@ def __init__(self, config, **kwargs): max_kb = 256 if shell_cfg and getattr(shell_cfg, 'max_output_kb', None): max_kb = int(shell_cfg.max_output_kb) - self._fs_artifacts = ArtifactManager( - _out_p, max_combined_bytes=max_kb * 1024) + self._fs_artifacts = ArtifactManager(_out_p, max_combined_bytes=max_kb * 1024) async def connect(self): logger.warning_once( '[IMPORTANT]FileSystemTool is not implemented with sandbox, please consider other similar ' - 'tools if you want to run dangerous code.') + 'tools if you want to run dangerous code.' + ) async def _get_tools_inner(self): tools = { @@ -155,8 +181,9 @@ async def _get_tools_inner(self): }, }, 'required': ['path', 'content'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='read_file', server_name='file_system', @@ -179,37 +206,33 @@ async def _get_tools_inner(self): 'paths': { 'type': 'array', 'items': {'type': 'string'}, - 'description': - 'List of relative file path(s) to read. ' + 'description': 'List of relative file path(s) to read. ' 'Use this OR `path` (single file).', }, 'path': { 'type': 'string', - 'description': - 'Single relative file path to read (alias for `paths` of length 1).', + 'description': 'Single relative file path to read (alias for `paths` of length 1).', }, 'offset': { 'type': 'integer', - 'description': - 'Line number to start reading from (1-based). ' + 'description': 'Line number to start reading from (1-based). ' 'Only provide if the file is too large to read at once.', }, 'limit': { 'type': 'integer', - 'description': - 'Number of lines to read. ' + 'description': 'Number of lines to read. ' 'Only provide if the file is too large to read at once.', }, 'abbreviate': { 'type': 'boolean', - 'description': - 'If true, return an LLM-generated summary instead of raw content. ' + 'description': 'If true, return an LLM-generated summary instead of raw content. ' 'Useful for large files or quick structural overview.', }, }, 'required': [], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='edit_file', server_name='file_system', @@ -245,13 +268,13 @@ async def _get_tools_inner(self): }, 'replace_all': { 'type': 'boolean', - 'description': - 'If true, replace all occurrences. Default is false (replace only the first).', + 'description': 'If true, replace all occurrences. Default is false (replace only the first).', }, }, 'required': ['path', 'old_string', 'new_string'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='grep', server_name='file_system', @@ -269,8 +292,7 @@ async def _get_tools_inner(self): }, 'path': { 'type': 'string', - 'description': - 'Directory or file to search (relative to output_dir if not absolute). Default ".".', + 'description': 'Directory or file to search (relative to output_dir if not absolute). Default ".".', }, 'glob': { 'type': 'string', @@ -279,8 +301,7 @@ async def _get_tools_inner(self): 'output_mode': { 'type': 'string', 'enum': ['content', 'files_with_matches', 'count'], - 'description': - 'content: matching lines; files_with_matches: paths only; count: per-file counts', + 'description': 'content: matching lines; files_with_matches: paths only; count: per-file counts', }, 'head_limit': { 'type': 'integer', @@ -315,21 +336,18 @@ async def _get_tools_inner(self): }, 'path': { 'type': 'string', - 'description': - 'Base directory (relative to output_dir if not absolute).', + 'description': 'Base directory (relative to output_dir if not absolute).', }, }, 'required': ['pattern'], 'additionalProperties': False, }, ), - ] } return tools - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await getattr(self, tool_name)(**tool_args) async def grep( @@ -343,8 +361,7 @@ async def grep( case_insensitive: bool = False, ) -> str: call_id = f'grep-{pattern[:40]}' - head_limit = (head_limit if head_limit is not None else - self._default_grep_head) + head_limit = head_limit if head_limit is not None else self._default_grep_head offset = offset or 0 path = path or '.' try: @@ -378,13 +395,13 @@ async def grep( try: rg = shutil.which('rg') if rg and root.is_file(): - lines = await self._grep_rg_file(rg, pattern, root, - case_insensitive, output_mode, - head_limit, offset, glob) + lines = await self._grep_rg_file( + rg, pattern, root, case_insensitive, output_mode, head_limit, offset, glob + ) elif rg and root.is_dir(): - lines = await self._grep_rg_dir(rg, pattern, root, - case_insensitive, output_mode, - head_limit, offset, glob) + lines = await self._grep_rg_dir( + rg, pattern, root, case_insensitive, output_mode, head_limit, offset, glob + ) else: lines = self._grep_python( pattern, @@ -398,12 +415,7 @@ async def grep( except Exception as e: err = str(e) # Expected user/tooling failures (bad regex, rg rules) — log without traceback noise. - _quiet = ( - 'rg:' in err - or 'exited' in err.lower() - or 'regex' in err.lower() - or 'pattern' in err.lower() - ) + _quiet = 'rg:' in err or 'exited' in err.lower() or 'regex' in err.lower() or 'pattern' in err.lower() logger.warning('grep failed: %s', e, exc_info=not _quiet) return json.dumps({'success': False, 'error': str(e)}, indent=2) @@ -449,8 +461,7 @@ async def _grep_rg_file( stderr=asyncio.subprocess.PIPE, cwd=str(self._fs_policy.workspace_root), ) - out_b, err_b = await asyncio.wait_for(proc.communicate(), - timeout=self._grep_timeout) + out_b, err_b = await asyncio.wait_for(proc.communicate(), timeout=self._grep_timeout) out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') if proc.returncode not in (0, 1): @@ -486,8 +497,7 @@ async def _grep_rg_dir( stderr=asyncio.subprocess.PIPE, cwd=str(self._fs_policy.workspace_root), ) - out_b, err_b = await asyncio.wait_for(proc.communicate(), - timeout=self._grep_timeout) + out_b, err_b = await asyncio.wait_for(proc.communicate(), timeout=self._grep_timeout) out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') if proc.returncode not in (0, 1): @@ -516,8 +526,7 @@ def _grep_python( def consider_file(fp: Path) -> bool: if glob_pat: rel = str(fp.relative_to(root)) if root.is_dir() else fp.name - if not fnmatch.fnmatch(fp.name, glob_pat) and not fnmatch.fnmatch( - rel, glob_pat): + if not fnmatch.fnmatch(fp.name, glob_pat) and not fnmatch.fnmatch(rel, glob_pat): return False suf = fp.suffix.lower() if suf not in _TEXT_SUFFIXES and fp.suffix == '': @@ -529,8 +538,7 @@ def consider_file(fp: Path) -> bool: if root.is_file(): files = [root] else: - for fp in _walk_files_limited(root, self._fs_policy.deny_globs, - 50_000): + for fp in _walk_files_limited(root, self._fs_policy.deny_globs, 50_000): if consider_file(fp): files.append(fp) @@ -539,9 +547,11 @@ def consider_file(fp: Path) -> bool: text = fp.read_text(encoding='utf-8', errors='replace') except OSError: continue - rel = str(fp.relative_to(self._fs_policy.workspace_root) - ) if _is_relative(fp, self._fs_policy.workspace_root) else str( - fp) + rel = ( + str(fp.relative_to(self._fs_policy.workspace_root)) + if _is_relative(fp, self._fs_policy.workspace_root) + else str(fp) + ) if output_mode == 'files_with_matches': if rx.search(text): lines_out.append(rel) @@ -589,9 +599,11 @@ async def glob(self, pattern: str, path: str = '') -> str: continue if _is_denied_path(rp, base, deny): continue - rel = str(p.relative_to(self._fs_policy.workspace_root) - ) if _is_relative(p, self._fs_policy.workspace_root - ) else str(p) + rel = ( + str(p.relative_to(self._fs_policy.workspace_root)) + if _is_relative(p, self._fs_policy.workspace_root) + else str(p) + ) matches.append(rel) if len(matches) >= self._glob_max_files: truncated = True @@ -697,21 +709,17 @@ def get_real_path(self, path): # Check if path is absolute or already starts with output_dir if os.path.isabs(path): target_path = path - elif path.startswith(self.output_dir + os.sep) or path.startswith( - self.output_dir): + elif path.startswith(self.output_dir + os.sep) or path.startswith(self.output_dir): # Path already includes output_dir as prefix target_path = path else: target_path = os.path.join(self.output_dir, path) target_path_real = os.path.realpath(target_path) output_dir_real = os.path.realpath(self.output_dir) - is_in_output_dir = target_path_real.startswith( - output_dir_real + os.sep) or target_path_real == output_dir_real + is_in_output_dir = target_path_real.startswith(output_dir_real + os.sep) or target_path_real == output_dir_real if not is_in_output_dir and not self.allow_read_all_files: - logger.warning( - f'Attempt to read file outside output directory blocked: {path} -> {target_path_real}' - ) + logger.warning(f'Attempt to read file outside output directory blocked: {path} -> {target_path_real}') return None else: return target_path_real @@ -723,20 +731,19 @@ def _normalize_read_paths(self, paths, path) -> List[str]: if isinstance(paths, str) and paths.strip(): out = [paths.strip()] elif isinstance(paths, list): - out = [ - p.strip() for p in paths - if isinstance(p, str) and p.strip() - ] + out = [p.strip() for p in paths if isinstance(p, str) and p.strip()] if not out and path is not None and isinstance(path, str) and path.strip(): out = [path.strip()] return out - async def read_file(self, - paths: Optional[List[str]] = None, - path: Optional[str] = None, - offset: int = None, - limit: int = None, - abbreviate: bool = False): + async def read_file( + self, + paths: Optional[List[str]] = None, + path: Optional[str] = None, + offset: int = None, + limit: int = None, + abbreviate: bool = False, + ): """Read the content of file(s). Args: @@ -765,8 +772,7 @@ async def read_file(self, return await self._read_files_abbreviated(paths) results = {} - use_line_range = len(paths) == 1 and (offset is not None - or limit is not None) + use_line_range = len(paths) == 1 and (offset is not None or limit is not None) for path in paths: try: @@ -774,7 +780,8 @@ async def read_file(self, if target_path_real is None: results[path] = ( f'Access denied: Reading file <{path}> outside output directory is not allowed. ' - f'Set allow_read_all_files=true in config to enable.') + f'Set allow_read_all_files=true in config to enable.' + ) continue ext = os.path.splitext(path)[1].lstrip('.').lower() @@ -796,16 +803,14 @@ async def read_file(self, if file_size > self.MAX_READ_BYTES and not use_line_range: results[path] = ( f'Error: File <{path}> is too large ({file_size} bytes). ' - f'Use offset and limit to read specific portions.') + f'Use offset and limit to read specific portions.' + ) continue # Dedup: return stub if file unchanged since last read mtime = os.path.getmtime(target_path_real) cached = self._read_cache.get(target_path_real) - if (cached - and cached['mtime'] == mtime - and cached['offset'] == offset - and cached['limit'] == limit): + if cached and cached['mtime'] == mtime and cached['offset'] == offset and cached['limit'] == limit: results[path] = { 'type': 'file_unchanged', 'message': 'File has not changed since last read.', @@ -819,8 +824,8 @@ async def read_file(self, content = raw_bytes.decode('utf-8') except UnicodeDecodeError: results[path] = ( - f'Error: File <{path}> appears to be binary. ' - f'Only text and image files are supported.') + f'Error: File <{path}> appears to be binary. Only text and image files are supported.' + ) continue # Normalize line endings @@ -835,16 +840,13 @@ async def read_file(self, if actual_start > total_lines: results[path] = f'Error: offset {offset} exceeds file length ({total_lines} lines)' continue - selected = lines[actual_start - 1:actual_end] + selected = lines[actual_start - 1 : actual_end] start_lineno = actual_start else: selected = lines start_lineno = 1 - results[path] = ''.join( - f'{start_lineno + i}\t{line}' - for i, line in enumerate(selected) - ) + results[path] = ''.join(f'{start_lineno + i}\t{line}' for i, line in enumerate(selected)) # Update dedup cache self._read_cache[target_path_real] = { @@ -901,11 +903,9 @@ def process_file(path): return json.dumps(results, indent=2, ensure_ascii=False) - async def edit_file(self, - path: str = None, - old_string: str = None, - new_string: str = None, - replace_all: bool = False): + async def edit_file( + self, path: str = None, old_string: str = None, new_string: str = None, replace_all: bool = False + ): """Edit a file by replacing an exact string with new content. Args: @@ -973,7 +973,7 @@ async def edit_file(self, norm_content = self._normalize_quotes(content) idx = norm_content.find(norm_old) if idx != -1: - actual_old = content[idx:idx + len(old_string)] + actual_old = content[idx : idx + len(old_string)] if actual_old is None: return ( @@ -1014,8 +1014,7 @@ async def edit_file(self, return f'Edit file <{path}> failed, error: ' + str(e) -def _apply_offset_limit(lines: List[str], offset: int, - head_limit: int) -> List[str]: +def _apply_offset_limit(lines: List[str], offset: int, head_limit: int) -> List[str]: if offset: lines = lines[offset:] if head_limit and head_limit > 0: @@ -1044,11 +1043,9 @@ def _is_denied_path(path: Path, root: Path, deny: tuple[str, ...]) -> bool: return False -def _walk_files_limited(root: Path, deny: tuple[str, ...], - max_files: int) -> List[Path]: +def _walk_files_limited(root: Path, deny: tuple[str, ...], max_files: int) -> List[Path]: out: List[Path] = [] - for dirpath, dirnames, filenames in os.walk( - root, topdown=True, followlinks=False): + for dirpath, dirnames, filenames in os.walk(root, topdown=True, followlinks=False): dp = Path(dirpath) pruned = [] for d in list(dirnames): diff --git a/ms_agent/tools/findata/__init__.py b/ms_agent/tools/findata/__init__.py index 9ae760a1c..e7c91afd9 100644 --- a/ms_agent/tools/findata/__init__.py +++ b/ms_agent/tools/findata/__init__.py @@ -1,7 +1,6 @@ from .akshare_source import AKShareDataSource from .baostock_source import BaoStockDataSource -from .data_source_base import (DataSourceError, FinancialDataSource, - NoDataFoundError) +from .data_source_base import DataSourceError, FinancialDataSource, NoDataFoundError from .findata_fetcher import FinancialDataFetcher from .hybrid_source import HybridDataSource diff --git a/ms_agent/tools/findata/akshare_source.py b/ms_agent/tools/findata/akshare_source.py index 6e68aaf41..01f86bb4c 100644 --- a/ms_agent/tools/findata/akshare_source.py +++ b/ms_agent/tools/findata/akshare_source.py @@ -3,9 +3,8 @@ from typing import Any, Dict, List, Optional import pandas as pd -from ms_agent.tools.findata.data_source_base import (DataSourceError, - FinancialDataSource, - NoDataFoundError) + +from ms_agent.tools.findata.data_source_base import DataSourceError, FinancialDataSource, NoDataFoundError from ms_agent.utils import get_logger from ms_agent.utils.utils import install_package @@ -67,8 +66,7 @@ def _convert_date(self, date: str) -> str: """Convert date to AKShare format""" return date.replace('-', '') - def _standardize_columns(self, df: pd.DataFrame, - code: str) -> pd.DataFrame: + def _standardize_columns(self, df: pd.DataFrame, code: str) -> pd.DataFrame: """Standardize column names for compatibility with BaoStock format""" if df.empty: return df @@ -147,31 +145,21 @@ def get_historical_k_data( ed_date = self._convert_date(end_date) # Route by market heuristics - if code.startswith('sh.') or code.startswith( - 'sz.') or code.startswith('bj'): + if code.startswith('sh.') or code.startswith('sz.') or code.startswith('bj'): clean_code = self._convert_code(code, market='A') df = akshare.stock_zh_a_hist( - symbol=clean_code, - period=period, - start_date=st_date, - end_date=ed_date, - adjust=adjust) + symbol=clean_code, period=period, start_date=st_date, end_date=ed_date, adjust=adjust + ) elif code.startswith('hk'): clean_code = self._convert_code(code, market='HK') df = akshare.stock_hk_hist( - symbol=clean_code, - period=period, - start_date=st_date, - end_date=ed_date, - adjust=adjust) + symbol=clean_code, period=period, start_date=st_date, end_date=ed_date, adjust=adjust + ) else: clean_code = self._convert_code(code, market='US') df = akshare.stock_us_hist( - symbol=clean_code, - period=period, - start_date=st_date, - end_date=ed_date, - adjust=adjust) + symbol=clean_code, period=period, start_date=st_date, end_date=ed_date, adjust=adjust + ) if df.empty: raise NoDataFoundError(f'No K-data found for {code}') @@ -216,14 +204,16 @@ def _get_hk_basic_info(self, code: str) -> pd.DataFrame: df_base_info = akshare.stock_hk_spot_em() stock_info = df_base_info[df_base_info['代码'] == clean_code] if not stock_info.empty: - df_stock_info = pd.DataFrame({ - 'code': [code], - 'code_name': [stock_info['名称'].iloc[0]], - 'listingDate': [''], # listing date might not be available - 'outDate': [''], - 'type': ['2'], # type of stock - 'status': ['1'] - }) + df_stock_info = pd.DataFrame( + { + 'code': [code], + 'code_name': [stock_info['名称'].iloc[0]], + 'listingDate': [''], # listing date might not be available + 'outDate': [''], + 'type': ['2'], # type of stock + 'status': ['1'], + } + ) except Exception: logger.warning(f'Failed to fetch HK stock base info for {code}') @@ -249,11 +239,11 @@ def _get_hk_basic_info(self, code: str) -> pd.DataFrame: '联系电话': 'contact number', '核数师': 'auditor', '传真': 'fax', - '公司介绍': 'company description' - }) + '公司介绍': 'company description', + } + ) except Exception: - logger.warning( - f'Failed to fetch HK stock business info for {code}') + logger.warning(f'Failed to fetch HK stock business info for {code}') if df_stock_info.empty and df_business_info.empty: raise NoDataFoundError(f'No basic info found for {code}') @@ -269,23 +259,23 @@ def _get_us_basic_info(self, code: str) -> pd.DataFrame: stock_info = df[df['代码'] == symbol] if stock_info.empty: - raise NoDataFoundError( - f'No US stock basic info found for {code}') - - result_df = pd.DataFrame({ - 'code': [code], - 'code_name': [stock_info['名称'].iloc[0]], - 'listingDate': [''], - 'outDate': [''], - 'type': ['3'], - 'status': ['1'] - }) + raise NoDataFoundError(f'No US stock basic info found for {code}') + + result_df = pd.DataFrame( + { + 'code': [code], + 'code_name': [stock_info['名称'].iloc[0]], + 'listingDate': [''], + 'outDate': [''], + 'type': ['3'], + 'status': ['1'], + } + ) return result_df except Exception as e: - raise DataSourceError( - f'Error fetching US stock basic info for {code}: {e}') + raise DataSourceError(f'Error fetching US stock basic info for {code}: {e}') def _get_a_share_basic_info(self, code: str) -> pd.DataFrame: """Get A-share stock basic information""" @@ -297,24 +287,24 @@ def _get_a_share_basic_info(self, code: str) -> pd.DataFrame: if df_base_info.empty: raise NoDataFoundError(f'No basic info found for {code}') - result_df = pd.DataFrame({ - 'code': [code], - 'code_name': [ - df_base_info.loc[df_base_info['item'] == '股票简称', - 'value'].iloc[0] - if not df_base_info.loc[df_base_info['item'] == '股票简称', - 'value'].empty else '' - ], - 'listingDate': [ - df_base_info.loc[df_base_info['item'] == '上市时间', - 'value'].iloc[0] - if not df_base_info.loc[df_base_info['item'] == '上市时间', - 'value'].empty else '' - ], - 'outDate': [''], - 'type': ['1'], - 'status': ['1'] - }) + result_df = pd.DataFrame( + { + 'code': [code], + 'code_name': [ + df_base_info.loc[df_base_info['item'] == '股票简称', 'value'].iloc[0] + if not df_base_info.loc[df_base_info['item'] == '股票简称', 'value'].empty + else '' + ], + 'listingDate': [ + df_base_info.loc[df_base_info['item'] == '上市时间', 'value'].iloc[0] + if not df_base_info.loc[df_base_info['item'] == '上市时间', 'value'].empty + else '' + ], + 'outDate': [''], + 'type': ['1'], + 'status': ['1'], + } + ) df_business_info = akshare.stock_zyjs_ths(symbol=clean_code) if df_business_info.empty: @@ -326,46 +316,35 @@ def _get_a_share_basic_info(self, code: str) -> pd.DataFrame: '主营业务': 'main business', '产品类型': 'product type', '产品名称': 'product name', - '经营范围': 'business scope' - }) + '经营范围': 'business scope', + } + ) return pd.concat([result_df, df_business_info], axis=1) except Exception as e: - raise DataSourceError( - f'Error fetching A-share basic info for {code}: {e}') + raise DataSourceError(f'Error fetching A-share basic info for {code}: {e}') - def get_dividend_data(self, - code: str, - year: Optional[str] = None, - year_type: str = 'report') -> pd.DataFrame: + def get_dividend_data(self, code: str, year: Optional[str] = None, year_type: str = 'report') -> pd.DataFrame: """Dividend info is not provided via a unified endpoint across markets in AKShare.""" - raise DataSourceError( - 'get_dividend_data is not supported by AKShareDataSource; use BaoStock or Hybrid' - ) + raise DataSourceError('get_dividend_data is not supported by AKShareDataSource; use BaoStock or Hybrid') - def get_adjust_factor_data(self, code: str, start_date: str, - end_date: str) -> pd.DataFrame: + def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> pd.DataFrame: """Adjust factor via AKShare varies by function; not standardized here.""" - raise DataSourceError( - 'get_adjust_factor_data is not supported by AKShareDataSource; use BaoStock or Hybrid' - ) + raise DataSourceError('get_adjust_factor_data is not supported by AKShareDataSource; use BaoStock or Hybrid') - def get_financial_data(self, code: str, year: str, quarter: int, - data_types: List[str]) -> dict: + def get_financial_data(self, code: str, year: str, quarter: int, data_types: List[str]) -> dict: """ Get financial data for multiple categories in one call. """ - logger.info( - f'Fetching financial data for {code} ({year}Q{quarter}) {data_types}' - ) + logger.info(f'Fetching financial data for {code} ({year}Q{quarter}) {data_types}') if code.startswith(('hk.', 'us.')): logger.warning( 'For U.S. and Hong Kong stocks, only a single complete financial indicators table is ' - 'currently supported, covering all data types.') - clean_code = self._convert_code( - code, market='HK' if code.startswith('hk.') else 'US') + 'currently supported, covering all data types.' + ) + clean_code = self._convert_code(code, market='HK' if code.startswith('hk.') else 'US') elif code.startswith(('sh.', 'sz.', 'bj.')): clean_code = self._convert_code(code, market='A') else: @@ -394,8 +373,7 @@ def _select_row_by_report(df: pd.DataFrame) -> pd.DataFrame: # convert REPORT_DATE to date if 'REPORT_DATE' in d.columns: - d['_dt'] = pd.to_datetime( - d['REPORT_DATE']).dt.date.astype('str') + d['_dt'] = pd.to_datetime(d['REPORT_DATE']).dt.date.astype('str') hit = d[d['_dt'] == target_date] if not hit.empty: return hit.drop(columns=['_dt']) @@ -403,29 +381,22 @@ def _select_row_by_report(df: pd.DataFrame) -> pd.DataFrame: # match report period name (中报/一季报/三季报/年报) for col in ('REPORT_DATE_NAME', 'REPORT_TYPE'): if col in d.columns: - hit = d[d[col].astype(str).str.contains(str(year)) - & d[col].astype(str).str.contains(target_qname)] + hit = d[d[col].astype(str).str.contains(str(year)) & d[col].astype(str).str.contains(target_qname)] if not hit.empty: return hit # Fallback: Select the row closest to the target_date if 'REPORT_DATE' in d.columns: - d['_dt'] = pd.to_datetime( - d['REPORT_DATE']).dt.date.astype('str') - d['_diff'] = (pd.to_datetime(d['_dt']) - - pd.to_datetime(target_date)).abs() + d['_dt'] = pd.to_datetime(d['REPORT_DATE']).dt.date.astype('str') + d['_diff'] = (pd.to_datetime(d['_dt']) - pd.to_datetime(target_date)).abs() d = d.sort_values('_diff') return d.drop(columns=['_dt', '_diff']).head(1) return d.head(1) - META_KEEP = [ - 'REPORT_DATE', 'REPORT_TYPE', 'REPORT_DATE_NAME', 'NOTICE_DATE', - 'UPDATE_DATE' - ] + META_KEEP = ['REPORT_DATE', 'REPORT_TYPE', 'REPORT_DATE_NAME', 'NOTICE_DATE', 'UPDATE_DATE'] - def _filter_columns(row_df: pd.DataFrame, - category: str) -> pd.DataFrame: + def _filter_columns(row_df: pd.DataFrame, category: str) -> pd.DataFrame: if row_df.empty: return row_df cols = list(row_df.columns) @@ -460,13 +431,10 @@ def _filter_columns(row_df: pd.DataFrame, ], } - keep = set( - c for c in cols if any( - re.match(p, c) for p in PATTERNS[category])) + keep = set(c for c in cols if any(re.match(p, c) for p in PATTERNS[category])) keep |= set(c for c in META_KEEP if c in cols) - out = row_df.loc[:, - [c for c in row_df.columns if c in keep]].copy() + out = row_df.loc[:, [c for c in row_df.columns if c in keep]].copy() # Dupont net profit margin fallback calculation: PARENTNETPROFIT / TOTALOPERATEREVE if category == 'dupont' and 'XSJLL' not in out.columns: @@ -477,8 +445,7 @@ def _filter_columns(row_df: pd.DataFrame, den = float(row_df.iloc[0]['TOTALOPERATEREVE']) out['XSJLL_calc'] = (val / den) if den else pd.NA except Exception as e: - logger.warning( - f'Failed to calculate XSJLL_calc for {code}: {e}') + logger.warning(f'Failed to calculate XSJLL_calc for {code}: {e}') out['XSJLL_calc'] = pd.NA out.insert(0, 'code', code) @@ -488,9 +455,11 @@ def _filter_columns(row_df: pd.DataFrame, ind_df = pd.DataFrame() if code.startswith(('hk.', 'us.')): try: - ind_df = akshare.stock_financial_hk_analysis_indicator_em( - symbol=clean_code) if code.startswith('hk.') else \ - akshare.stock_financial_us_analysis_indicator_em(symbol=clean_code) + ind_df = ( + akshare.stock_financial_hk_analysis_indicator_em(symbol=clean_code) + if code.startswith('hk.') + else akshare.stock_financial_us_analysis_indicator_em(symbol=clean_code) + ) ind_df = _select_row_by_report(ind_df) except Exception as e: logger.warning( @@ -499,43 +468,34 @@ def _filter_columns(row_df: pd.DataFrame, result['financial_indicators'] = ind_df elif code.startswith(('sh.', 'sz.', 'bj.')): - needs_indicator = any( - dt in ('profit', 'operation', 'growth', 'dupont') - for dt in data_types) + needs_indicator = any(dt in ('profit', 'operation', 'growth', 'dupont') for dt in data_types) if needs_indicator: try: - ind_df = akshare.stock_financial_analysis_indicator( - symbol=clean_code) + ind_df = akshare.stock_financial_analysis_indicator(symbol=clean_code) ind_df = _select_row_by_report(ind_df) except Exception as e: - logger.warning( - f'Failed to fetch financial_analysis_indicator: {e}') + logger.warning(f'Failed to fetch financial_analysis_indicator: {e}') ind_df = pd.DataFrame() for data_type in data_types: try: result[data_type] = pd.DataFrame() - if data_type in ('profit', 'operation', 'growth', - 'dupont'): + if data_type in ('profit', 'operation', 'growth', 'dupont'): if ind_df.empty: - logger.warning( - f'No indicator row for {code} {year}Q{quarter}' - ) + logger.warning(f'No indicator row for {code} {year}Q{quarter}') continue result[data_type] = _filter_columns(ind_df, data_type) continue elif data_type == 'balance': - df = akshare.stock_balance_sheet_by_report_em( - symbol=code.replace('.', '').upper()) + df = akshare.stock_balance_sheet_by_report_em(symbol=code.replace('.', '').upper()) row = _select_row_by_report(df) if not row.empty: result[data_type] = row elif data_type == 'cash_flow': - df = akshare.stock_cash_flow_sheet_by_report_em( - symbol=code.replace('.', '').upper()) + df = akshare.stock_cash_flow_sheet_by_report_em(symbol=code.replace('.', '').upper()) row = _select_row_by_report(df) if not row.empty: result[data_type] = row @@ -549,34 +509,23 @@ def _filter_columns(row_df: pd.DataFrame, continue if not result or all(df.empty for df in result.values()): - raise NoDataFoundError( - f'No financial data found for {code} ({year}Q{quarter})') + raise NoDataFoundError(f'No financial data found for {code} ({year}Q{quarter})') return result - def get_report(self, - code: str, - start_date: str, - end_date: str, - report_type: str = 'performance_express') -> pd.DataFrame: + def get_report( + self, code: str, start_date: str, end_date: str, report_type: str = 'performance_express' + ) -> pd.DataFrame: """Report data is not supported by AKShare.""" - raise DataSourceError( - 'get_report is not supported by AKShareDataSource; use BaoStock or Hybrid' - ) + raise DataSourceError('get_report is not supported by AKShareDataSource; use BaoStock or Hybrid') def get_stock_industry(self, code: str, date: str) -> pd.DataFrame: """Industry classification is not supported by AKShare.""" - raise DataSourceError( - 'get_stock_industry is not supported by AKShareDataSource; use BaoStock or Hybrid' - ) + raise DataSourceError('get_stock_industry is not supported by AKShareDataSource; use BaoStock or Hybrid') - def get_stock_list(self, - date: str, - data_type: str = 'all_a_share') -> pd.DataFrame: + def get_stock_list(self, date: str, data_type: str = 'all_a_share') -> pd.DataFrame: """Get stock list (A-shares only, index constituents not supported).""" - logger.info( - f'Fetching stock list for {data_type}, only support a_share and latest data' - ) + logger.info(f'Fetching stock list for {data_type}, only support a_share and latest data') try: if data_type == 'sse50': @@ -596,9 +545,7 @@ def get_stock_list(self, except Exception as e: raise DataSourceError(f'Failed to fetch stock list: {e}') - def get_trade_dates(self, - start_date: Optional[str] = None, - end_date: Optional[str] = None) -> pd.DataFrame: + def get_trade_dates(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: """Get trading calendar""" logger.info(f'Fetching trade dates ({start_date} to {end_date})') @@ -624,17 +571,15 @@ def get_macro_data( start_date: str, end_date: str, data_types: Optional[List[str]] = None, - extra_kwargs: Optional[Dict[str, - Any]] = None) -> Dict[str, pd.DataFrame]: + extra_kwargs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, pd.DataFrame]: """Macroeconomic data.""" if data_types is None: data_types = [] if extra_kwargs is None: extra_kwargs = {} - logger.info( - f'Fetching macroeconomic data ({start_date} to {end_date}) {data_types}' - ) + logger.info(f'Fetching macroeconomic data ({start_date} to {end_date}) {data_types}') if not data_types: raise ValueError('data_types cannot be empty') @@ -645,14 +590,11 @@ def get_macro_data( if data_type in ('deposit_rate', 'loan_rate'): result[data_type] = akshare.rate_interbank() elif data_type in ('required_reserve_ratio'): - raise DataSourceError( - 'Required reserve ratio is not supported by AKShare') + raise DataSourceError('Required reserve ratio is not supported by AKShare') elif data_type == 'money_supply_year': - result[data_type] = self._get_money_supply_data_year( - start_date, end_date) + result[data_type] = self._get_money_supply_data_year(start_date, end_date) elif data_type == 'money_supply_month': - result[data_type] = self._get_money_supply_data_month( - start_date, end_date) + result[data_type] = self._get_money_supply_data_month(start_date, end_date) else: raise ValueError(f'Invalid data type: {data_type}') @@ -662,41 +604,32 @@ def get_macro_data( continue if not result: - raise NoDataFoundError( - 'No macro data found for the specified criteria') + raise NoDataFoundError('No macro data found for the specified criteria') return result def _get_money_supply_data_month( - self, - start_date: Optional[str] = None, - end_date: Optional[str] = None) -> pd.DataFrame: + self, start_date: Optional[str] = None, end_date: Optional[str] = None + ) -> pd.DataFrame: try: df = akshare.macro_china_money_supply() # from 2008-01 to now - df['月份'] = pd.to_datetime(df['月份'].str.replace('月份', - '').str.replace( - '年', '-')) + df['月份'] = pd.to_datetime(df['月份'].str.replace('月份', '').str.replace('年', '-')) df['月份'] = df['月份'].dt.to_period('M') if start_date: - df = df[ - df['月份'] >= pd.to_datetime(start_date).strftime('%Y-%m')] + df = df[df['月份'] >= pd.to_datetime(start_date).strftime('%Y-%m')] if end_date: df = df[df['月份'] <= pd.to_datetime(end_date).strftime('%Y-%m')] return df.sort_values('月份').reset_index(drop=True) except Exception as e: - raise DataSourceError( - f'Error fetching monthly money supply data: {e}') + raise DataSourceError(f'Error fetching monthly money supply data: {e}') def _get_money_supply_data_year( - self, - start_date: Optional[str] = None, - end_date: Optional[str] = None) -> pd.DataFrame: + self, start_date: Optional[str] = None, end_date: Optional[str] = None + ) -> pd.DataFrame: month_df = self._get_money_supply_data_month() # Take the last issue of each year (usually December; if missing, take the last available entry of that year). month_df['年'] = month_df['月份'].dt.year - last_in_year = ( - month_df.sort_values('月份').groupby( - '年', as_index=False).tail(1).reset_index(drop=True)) + last_in_year = month_df.sort_values('月份').groupby('年', as_index=False).tail(1).reset_index(drop=True) cols = [ '货币和准货币(M2)-数量(亿元)', '货币和准货币(M2)-同比增长', @@ -705,8 +638,7 @@ def _get_money_supply_data_year( '流通中的现金(M0)-数量(亿元)', '流通中的现金(M0)-同比增长', ] - year_df = last_in_year[ - ['年'] + [c for c in cols if c in last_in_year.columns]] + year_df = last_in_year[['年'] + [c for c in cols if c in last_in_year.columns]] if start_date: year_df = year_df[year_df['年'] >= pd.to_datetime(start_date).year] diff --git a/ms_agent/tools/findata/baostock_source.py b/ms_agent/tools/findata/baostock_source.py index dd05b7a1f..f14437604 100644 --- a/ms_agent/tools/findata/baostock_source.py +++ b/ms_agent/tools/findata/baostock_source.py @@ -5,9 +5,8 @@ from typing import Any, Dict, List, Optional import pandas as pd -from ms_agent.tools.findata.data_source_base import (DataSourceError, - FinancialDataSource, - NoDataFoundError) + +from ms_agent.tools.findata.data_source_base import DataSourceError, FinancialDataSource, NoDataFoundError from ms_agent.utils import get_logger from ms_agent.utils.utils import install_package @@ -16,6 +15,7 @@ class BaoStockSessionManager: """Thread-safe BaoStock session manager with connection reuse""" + _instance = None _lock = threading.Lock() _session_lock = threading.Lock() @@ -57,8 +57,7 @@ def ensure_login(self): if not self._is_logged_in: lg = baostock.login() if lg.error_code != '0': - raise DataSourceError( - f'BaoStock login failed: {lg.error_msg}') + raise DataSourceError(f'BaoStock login failed: {lg.error_msg}') self._is_logged_in = True self._login_count = 1 logger.debug('BaoStock session established') @@ -66,8 +65,7 @@ def ensure_login(self): self._login_count += 1 # Someone reused the session within idle timeout; cancel scheduled logout self._cancel_logout() - logger.debug( - f'BaoStock session reused (count: {self._login_count})') + logger.debug(f'BaoStock session reused (count: {self._login_count})') def release(self): """Release session (logout only when no active users)""" @@ -101,9 +99,24 @@ class BaoStockDataSource(FinancialDataSource): """ DEFAULT_K_FIELDS = [ - 'date', 'code', 'open', 'high', 'low', 'close', 'preclose', 'volume', - 'amount', 'adjustflag', 'turn', 'tradestatus', 'pctChg', 'peTTM', - 'pbMRQ', 'psTTM', 'pcfNcfTTM', 'isST' + 'date', + 'code', + 'open', + 'high', + 'low', + 'close', + 'preclose', + 'volume', + 'amount', + 'adjustflag', + 'turn', + 'tradestatus', + 'pctChg', + 'peTTM', + 'pbMRQ', + 'psTTM', + 'pcfNcfTTM', + 'isST', ] def __init__(self): @@ -124,11 +137,9 @@ def __init__(self): def _query_to_dataframe(self, rs, data_type: str = 'data') -> pd.DataFrame: """Convert BaoStock query result to DataFrame""" if rs.error_code != '0': - if 'no record found' in rs.error_msg.lower( - ) or rs.error_code == '10002': + if 'no record found' in rs.error_msg.lower() or rs.error_code == '10002': raise NoDataFoundError(f'No {data_type} found: {rs.error_msg}') - raise DataSourceError( - f'BaoStock API error: {rs.error_msg} (code: {rs.error_code})') + raise DataSourceError(f'BaoStock API error: {rs.error_msg} (code: {rs.error_code})') data_list = [] while rs.next(): @@ -160,7 +171,8 @@ def get_historical_k_data( start_date=start_date, end_date=end_date, frequency=frequency, - adjustflag=adjust_flag) + adjustflag=adjust_flag, + ) return self._query_to_dataframe(rs, f'K-data for {code}') def get_stock_basic_info(self, code: str) -> pd.DataFrame: @@ -171,40 +183,25 @@ def get_stock_basic_info(self, code: str) -> pd.DataFrame: rs = baostock.query_stock_basic(code=code) return self._query_to_dataframe(rs, f'basic info for {code}') - def get_dividend_data(self, - code: str, - year: Optional[str] = None, - year_type: str = 'report') -> pd.DataFrame: + def get_dividend_data(self, code: str, year: Optional[str] = None, year_type: str = 'report') -> pd.DataFrame: """Get dividend data""" logger.info(f'Fetching dividend data for {code} ({year} {year_type})') with baostock_session(): - rs = baostock.query_dividend_data( - code=code, year=year, yearType=year_type) - return self._query_to_dataframe( - rs, f'dividend data for {code} ({year} {year_type})') + rs = baostock.query_dividend_data(code=code, year=year, yearType=year_type) + return self._query_to_dataframe(rs, f'dividend data for {code} ({year} {year_type})') - def get_adjust_factor_data(self, code: str, start_date: str, - end_date: str) -> pd.DataFrame: + def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> pd.DataFrame: """Get adjustment factor data""" - logger.info( - f'Fetching adjustment factor data for {code} ({start_date} to {end_date})' - ) + logger.info(f'Fetching adjustment factor data for {code} ({start_date} to {end_date})') with baostock_session(): - rs = baostock.query_adjust_factor( - code=code, start_date=start_date, end_date=end_date) - return self._query_to_dataframe( - rs, - f'adjustment factor data for {code} ({start_date} to {end_date})' - ) + rs = baostock.query_adjust_factor(code=code, start_date=start_date, end_date=end_date) + return self._query_to_dataframe(rs, f'adjustment factor data for {code} ({start_date} to {end_date})') - def get_financial_data(self, code: str, year: str, quarter: int, - data_types: List[str]) -> Dict[str, pd.DataFrame]: + def get_financial_data(self, code: str, year: str, quarter: int, data_types: List[str]) -> Dict[str, pd.DataFrame]: """Get financial data""" - logger.info( - f'Fetching financial data for {code} ({year}Q{quarter}) {data_types}' - ) + logger.info(f'Fetching financial data for {code} ({year}Q{quarter}) {data_types}') if not data_types: raise ValueError('data_types cannot be empty') @@ -227,72 +224,49 @@ def get_financial_data(self, code: str, year: str, quarter: int, else: raise ValueError(f'Invalid data type: {data_type}') - df = self._query_financial_data(query_func, data_type, code, - year, quarter) + df = self._query_financial_data(query_func, data_type, code, year, quarter) result[data_type] = df if not result: - raise NoDataFoundError( - f'No financial data found for {code} ({year}Q{quarter})') + raise NoDataFoundError(f'No financial data found for {code} ({year}Q{quarter})') return result - def _query_financial_data(self, query_func, data_type: str, code: str, - year: str, quarter: int) -> pd.DataFrame: + def _query_financial_data(self, query_func, data_type: str, code: str, year: str, quarter: int) -> pd.DataFrame: """Query financial data using provided function (assumes session is already active)""" logger.info(f'Fetching {data_type} for {code} ({year}Q{quarter})') rs = query_func(code=code, year=year, quarter=quarter) return self._query_to_dataframe(rs, f'{data_type} for {code}') - def get_report(self, - code: str, - start_date: str, - end_date: str, - report_type: str = '') -> pd.DataFrame: + def get_report(self, code: str, start_date: str, end_date: str, report_type: str = '') -> pd.DataFrame: """Get report data""" - logger.info( - f'Fetching report data for {code} ({start_date} to {end_date}) {report_type}' - ) + logger.info(f'Fetching report data for {code} ({start_date} to {end_date}) {report_type}') if not report_type: raise ValueError('report_type cannot be empty') with baostock_session(): if report_type == 'performance_express': - rs = baostock.query_performance_express_report( - code=code, start_date=start_date, end_date=end_date) + rs = baostock.query_performance_express_report(code=code, start_date=start_date, end_date=end_date) elif report_type == 'performance_forecast': - rs = baostock.query_forecast_report( - code=code, start_date=start_date, end_date=end_date) + rs = baostock.query_forecast_report(code=code, start_date=start_date, end_date=end_date) else: raise ValueError(f'Invalid report type: {report_type}') - return self._query_to_dataframe( - rs, - f'report data for {code} ({start_date} to {end_date}) {report_type}' - ) + return self._query_to_dataframe(rs, f'report data for {code} ({start_date} to {end_date}) {report_type}') - def get_stock_industry(self, - code: Optional[str] = None, - date: Optional[str] = None) -> pd.DataFrame: + def get_stock_industry(self, code: Optional[str] = None, date: Optional[str] = None) -> pd.DataFrame: """Get stock industry""" - logger.info( - f"Fetching stock industry for code={code or 'all'}, date={date or 'latest'}" - ) + logger.info(f"Fetching stock industry for code={code or 'all'}, date={date or 'latest'}") with baostock_session(): rs = baostock.query_stock_industry(code=code, date=date) - return self._query_to_dataframe( - rs, f'stock industry for {code or "all"} ({date or "latest"})') + return self._query_to_dataframe(rs, f'stock industry for {code or "all"} ({date or "latest"})') - def get_stock_list(self, - date: str, - data_type: str = 'all_a_share') -> pd.DataFrame: + def get_stock_list(self, date: str, data_type: str = 'all_a_share') -> pd.DataFrame: """Get stock list or index constituents""" - logger.info( - f'Fetching stock list for {date} {data_type}, only support a_share' - ) + logger.info(f'Fetching stock list for {date} {data_type}, only support a_share') with baostock_session(): if data_type == 'sse50': @@ -306,23 +280,17 @@ def get_stock_list(self, else: raise ValueError(f'Invalid data type: {data_type}') - df = self._query_to_dataframe( - rs, f'stock list for {date} {data_type}') + df = self._query_to_dataframe(rs, f'stock list for {date} {data_type}') logger.info(f'Stock list for {date} {data_type}: {df.head()}') return df - def get_trade_dates(self, - start_date: Optional[str] = None, - end_date: Optional[str] = None) -> pd.DataFrame: + def get_trade_dates(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: """Get trading calendar""" - logger.info( - f"Fetching trade dates ({start_date or 'default'} to {end_date or 'default'})" - ) + logger.info(f"Fetching trade dates ({start_date or 'default'} to {end_date or 'default'})") with baostock_session(): - rs = baostock.query_trade_dates( - start_date=start_date, end_date=end_date) + rs = baostock.query_trade_dates(start_date=start_date, end_date=end_date) return self._query_to_dataframe(rs, 'trade dates') def get_macro_data( @@ -330,16 +298,15 @@ def get_macro_data( start_date: str, end_date: str, data_types: Optional[List[str]] = None, - extra_kwargs: Optional[Dict[str, - Any]] = None) -> Dict[str, pd.DataFrame]: + extra_kwargs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, pd.DataFrame]: """Fetch macroeconomic data""" if data_types is None: data_types = [] if extra_kwargs is None: extra_kwargs = {} - logger.info( - f'Fetching macro data ({start_date} to {end_date}) {data_types}') + logger.info(f'Fetching macro data ({start_date} to {end_date}) {data_types}') result = {} with baostock_session(): @@ -364,25 +331,20 @@ def get_macro_data( elif data_type == 'money_supply_month': query_func = baostock.query_money_supply_data_month - parsed_start_date = pd.to_datetime( - start_date).strftime('%Y-%m') - parsed_end_date = pd.to_datetime(end_date).strftime( - '%Y-%m') + parsed_start_date = pd.to_datetime(start_date).strftime('%Y-%m') + parsed_end_date = pd.to_datetime(end_date).strftime('%Y-%m') elif data_type == 'money_supply_year': query_func = baostock.query_money_supply_data_year - parsed_start_date = pd.to_datetime( - start_date).strftime('%Y') - parsed_end_date = pd.to_datetime(end_date).strftime( - '%Y') + parsed_start_date = pd.to_datetime(start_date).strftime('%Y') + parsed_end_date = pd.to_datetime(end_date).strftime('%Y') else: raise ValueError(f'Invalid data type: {data_type}') - df = self._query_macro_data(query_func, data_type, - parsed_start_date, - parsed_end_date, - **parsed_extra_kwargs) + df = self._query_macro_data( + query_func, data_type, parsed_start_date, parsed_end_date, **parsed_extra_kwargs + ) result[data_type] = df except Exception as e: @@ -391,19 +353,16 @@ def get_macro_data( continue if not result: - raise NoDataFoundError( - 'No macro data found for the specified criteria') + raise NoDataFoundError('No macro data found for the specified criteria') return result - def _query_macro_data(self, query_func, data_type: str, start_date: str, - end_date: str, **kwargs) -> pd.DataFrame: + def _query_macro_data(self, query_func, data_type: str, start_date: str, end_date: str, **kwargs) -> pd.DataFrame: """Query macro data using provided function (assumes session is already active)""" logger.info(f'Fetching {data_type} for {start_date} to {end_date}') try: rs = query_func(start_date=start_date, end_date=end_date, **kwargs) - return self._query_to_dataframe( - rs, f'{data_type} for {start_date} to {end_date}') + return self._query_to_dataframe(rs, f'{data_type} for {start_date} to {end_date}') except Exception as e: logger.warning(f'Failed to fetch {data_type} data: {e}') diff --git a/ms_agent/tools/findata/data_source_base.py b/ms_agent/tools/findata/data_source_base.py index d287aebdb..ecd1cc64f 100644 --- a/ms_agent/tools/findata/data_source_base.py +++ b/ms_agent/tools/findata/data_source_base.py @@ -7,11 +7,13 @@ class DataSourceError(Exception): """Base data source error class""" + pass class NoDataFoundError(DataSourceError): """Data not found exception""" + pass @@ -63,22 +65,17 @@ def get_stock_basic_info(self, code: str) -> pd.DataFrame: pass @abstractmethod - def get_dividend_data(self, - code: str, - year: str, - year_type: str = 'report') -> pd.DataFrame: + def get_dividend_data(self, code: str, year: str, year_type: str = 'report') -> pd.DataFrame: """Get dividend information""" pass @abstractmethod - def get_adjust_factor_data(self, code: str, start_date: str, - end_date: str) -> pd.DataFrame: + def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> pd.DataFrame: """Get adjustment factor data""" pass @abstractmethod - def get_financial_data(self, code: str, year: str, quarter: int, - data_types: List[str]) -> Dict[str, pd.DataFrame]: + def get_financial_data(self, code: str, year: str, quarter: int, data_types: List[str]) -> Dict[str, pd.DataFrame]: """Get financial data for multiple categories in one call Returns: @@ -87,11 +84,9 @@ def get_financial_data(self, code: str, year: str, quarter: int, pass @abstractmethod - def get_report(self, - code: str, - start_date: str, - end_date: str, - report_type: str = 'performance_express') -> pd.DataFrame: + def get_report( + self, code: str, start_date: str, end_date: str, report_type: str = 'performance_express' + ) -> pd.DataFrame: """Get report data (performance express/forecast)""" pass @@ -116,8 +111,8 @@ def get_macro_data( start_date: str, end_date: str, data_types: Optional[List[str]] = None, - extra_kwargs: Optional[Dict[str, - Any]] = None) -> Dict[str, pd.DataFrame]: + extra_kwargs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, pd.DataFrame]: """Get macroeconomic data for multiple categories in one call""" pass diff --git a/ms_agent/tools/findata/findata_fetcher.py b/ms_agent/tools/findata/findata_fetcher.py index 996f28781..d18c4cd23 100644 --- a/ms_agent/tools/findata/findata_fetcher.py +++ b/ms_agent/tools/findata/findata_fetcher.py @@ -1,25 +1,24 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio +import json from concurrent.futures import ThreadPoolExecutor from datetime import date, datetime from functools import partial from pathlib import Path from typing import Any, Dict, Optional, Union -import json import numpy as np import pandas as pd +from omegaconf import DictConfig + from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.tools.findata.akshare_source import AKShareDataSource from ms_agent.tools.findata.baostock_source import BaoStockDataSource -from ms_agent.tools.findata.data_source_base import (DataSourceError, - FinancialDataSource, - NoDataFoundError) +from ms_agent.tools.findata.data_source_base import DataSourceError, FinancialDataSource, NoDataFoundError from ms_agent.tools.findata.hybrid_source import HybridDataSource from ms_agent.utils import get_logger from ms_agent.utils.rate_limiter import AdaptiveRateLimiter, RateLimiter -from omegaconf import DictConfig logger = get_logger() @@ -55,8 +54,7 @@ class FinancialDataFetcher(ToolBase): def __init__(self, config: Optional[DictConfig] = None): super().__init__(config) - tools_cfg = getattr(config, 'tools', - None) if config is not None else None + tools_cfg = getattr(config, 'tools', None) if config is not None else None self.exclude_func(getattr(tools_cfg, 'financial_data_fetcher', None)) self.save_dir = getattr(config, 'output_dir', './output') @@ -78,25 +76,21 @@ def __init__(self, config: Optional[DictConfig] = None): thread_name_prefix='financial_data_fetcher_', ) - logger.info( - f'Initializing FinancialDataFetcher with source: {self.source_type}' - ) - logger.info( - f'Financial data will be saved to: {self.financial_data_dir}') + logger.info(f'Initializing FinancialDataFetcher with source: {self.source_type}') + logger.info(f'Financial data will be saved to: {self.financial_data_dir}') def _get_source_type(self, config: Optional[DictConfig]) -> str: """Get data source type from config""" - if isinstance(config, - DictConfig) and hasattr(config, 'tools') and hasattr( - config.tools, 'financial_data_fetcher'): - return getattr(config.tools.financial_data_fetcher, 'source_type', - 'hybrid') + if ( + isinstance(config, DictConfig) + and hasattr(config, 'tools') + and hasattr(config.tools, 'financial_data_fetcher') + ): + return getattr(config.tools.financial_data_fetcher, 'source_type', 'hybrid') return 'hybrid' - def _create_rate_limiter( - self, config: Optional[DictConfig] - ) -> Optional[Union[RateLimiter, AdaptiveRateLimiter]]: + def _create_rate_limiter(self, config: Optional[DictConfig]) -> Optional[Union[RateLimiter, AdaptiveRateLimiter]]: """ Create rate limiter from config. @@ -122,16 +116,17 @@ def _create_rate_limiter( ``` """ # Check if rate limiter is configured - if not (isinstance(config, DictConfig) and hasattr(config, 'tools') - and hasattr(config.tools, 'financial_data_fetcher')): - logger.info( - 'No rate limiter configured, running without rate limiting') + if not ( + isinstance(config, DictConfig) + and hasattr(config, 'tools') + and hasattr(config.tools, 'financial_data_fetcher') + ): + logger.info('No rate limiter configured, running without rate limiting') return None fetcher_config = config.tools.financial_data_fetcher if not hasattr(fetcher_config, 'rate_limiter'): - logger.info( - 'No rate limiter configured, running without rate limiting') + logger.info('No rate limiter configured, running without rate limiting') return None rl_config = fetcher_config.rate_limiter @@ -146,24 +141,15 @@ def _create_rate_limiter( if limiter_type == 'adaptive': # Create AdaptiveRateLimiter params = { - 'initial_requests_per_second': - getattr(rl_config, 'initial_requests_per_second', 2), - 'min_requests_per_second': - getattr(rl_config, 'min_requests_per_second', 1), - 'max_requests_per_second': - getattr(rl_config, 'max_requests_per_second', 5), - 'min_request_interval': - getattr(rl_config, 'min_request_interval', 0.5), - 'max_concurrent': - getattr(rl_config, 'max_concurrent', 3), - 'backoff_factor': - getattr(rl_config, 'backoff_factor', 0.5), - 'recovery_factor': - getattr(rl_config, 'recovery_factor', 1.2), - 'error_threshold': - getattr(rl_config, 'error_threshold', 3), - 'success_threshold': - getattr(rl_config, 'success_threshold', 10), + 'initial_requests_per_second': getattr(rl_config, 'initial_requests_per_second', 2), + 'min_requests_per_second': getattr(rl_config, 'min_requests_per_second', 1), + 'max_requests_per_second': getattr(rl_config, 'max_requests_per_second', 5), + 'min_request_interval': getattr(rl_config, 'min_request_interval', 0.5), + 'max_concurrent': getattr(rl_config, 'max_concurrent', 3), + 'backoff_factor': getattr(rl_config, 'backoff_factor', 0.5), + 'recovery_factor': getattr(rl_config, 'recovery_factor', 1.2), + 'error_threshold': getattr(rl_config, 'error_threshold', 3), + 'success_threshold': getattr(rl_config, 'success_threshold', 10), } logger.info(f'Creating AdaptiveRateLimiter with params: {params}') return AdaptiveRateLimiter(**params) @@ -171,20 +157,15 @@ def _create_rate_limiter( elif limiter_type == 'basic': # Create basic RateLimiter params = { - 'max_requests_per_second': - getattr(rl_config, 'max_requests_per_second', 2), - 'min_request_interval': - getattr(rl_config, 'min_request_interval', 0.5), - 'max_concurrent': - getattr(rl_config, 'max_concurrent', 3), + 'max_requests_per_second': getattr(rl_config, 'max_requests_per_second', 2), + 'min_request_interval': getattr(rl_config, 'min_request_interval', 0.5), + 'max_concurrent': getattr(rl_config, 'max_concurrent', 3), } logger.info(f'Creating RateLimiter with params: {params}') return RateLimiter(**params) else: - logger.warning( - f'Unknown rate limiter type: {limiter_type}, running without rate limiting' - ) + logger.warning(f'Unknown rate limiter type: {limiter_type}, running without rate limiting') return None def _create_data_source(self) -> FinancialDataSource: @@ -197,8 +178,7 @@ def _create_data_source(self) -> FinancialDataSource: source_class = source_map.get(self.source_type.lower()) if not source_class: - logger.warning( - f'Unknown source type: {self.source_type}, using hybrid') + logger.warning(f'Unknown source type: {self.source_type}, using hybrid') source_class = HybridDataSource return source_class() @@ -234,8 +214,7 @@ async def _execute_with_rate_limit(self, func, *args, **kwargs): func_with_args = partial(func, *args, **kwargs) async with self.rate_limiter: - result = await loop.run_in_executor(self.thread_pool, - func_with_args) + result = await loop.run_in_executor(self.thread_pool, func_with_args) # Record success if using adaptive rate limiter if isinstance(self.rate_limiter, AdaptiveRateLimiter): @@ -246,9 +225,9 @@ async def _execute_with_rate_limit(self, func, *args, **kwargs): except Exception as e: if isinstance(self.rate_limiter, AdaptiveRateLimiter): error_msg = str(e).lower() - is_rate_limit_error = any(keyword in error_msg for keyword in [ - 'rate limit', 'too many requests', 'quota exceeded', '429' - ]) + is_rate_limit_error = any( + keyword in error_msg for keyword in ['rate limit', 'too many requests', 'quota exceeded', '429'] + ) self.rate_limiter.record_error(is_rate_limit_error) raise @@ -270,14 +249,10 @@ def _save_dataframe(self, df, filename: str) -> str: logger.info(f'Data saved to: {filepath}') return str(filepath) except Exception as e: - logger.error( - f'Failed to save data to {filename}: {e}', exc_info=True) + logger.error(f'Failed to save data to {filename}: {e}', exc_info=True) return '' - def _create_success_response(self, - df, - saved_path: str, - metadata: Optional[Dict] = None) -> str: + def _create_success_response(self, df, saved_path: str, metadata: Optional[Dict] = None) -> str: """ Create success response with sample data. @@ -302,8 +277,8 @@ def _create_success_response(self, response['example_data'] = sample_df.to_dict(orient='records') if len(df) > self.sample_rows: response['note'] = ( - f'Showing {self.sample_rows} sample rows out of {len(df)} ' - f'total rows. Full data saved to file.') + f'Showing {self.sample_rows} sample rows out of {len(df)} total rows. Full data saved to file.' + ) else: response['example_data'] = [] response['note'] = 'No data returned' @@ -312,11 +287,9 @@ def _create_success_response(self, if metadata: response.update(metadata) - return json.dumps( - response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) + return json.dumps(response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) - def _create_error_response(self, error: Exception, operation: str, - params: Dict) -> str: + def _create_error_response(self, error: Exception, operation: str, params: Dict) -> str: """ Create standardized error response. @@ -336,7 +309,7 @@ def _create_error_response(self, error: Exception, operation: str, 'operation': operation, 'error_type': error_type, 'error': error_msg, - 'parameters': params + 'parameters': params, } # Only log with traceback for unexpected errors @@ -344,12 +317,9 @@ def _create_error_response(self, error: Exception, operation: str, if isinstance(error, (DataSourceError, NoDataFoundError)): logger.warning(f'{operation}: {error_msg}') else: - logger.error( - f"Operation '{operation}' failed: {error_type} - {error_msg}", - exc_info=True) + logger.error(f"Operation '{operation}' failed: {error_type} - {error_msg}", exc_info=True) - return json.dumps( - response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) + return json.dumps(response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) async def _get_tools_inner(self) -> Dict[str, Any]: """Return tool definitions""" @@ -358,66 +328,56 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='get_historical_k_data', server_name='financial_data_fetcher', - description= - 'Get historical K-line data (daily, weekly, monthly, etc.)', + description='Get historical K-line data (daily, weekly, monthly, etc.)', parameters={ 'type': 'object', 'properties': { 'code': { - 'type': - 'string', - 'description': - ('Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' - 'hk.03690 (Hong Kong), us.AAPL (US)') - }, - 'start_date': { 'type': 'string', - 'description': 'Start date, format: YYYY-MM-DD' - }, - 'end_date': { - 'type': 'string', - 'description': 'End date, format: YYYY-MM-DD' + 'description': ( + 'Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' + 'hk.03690 (Hong Kong), us.AAPL (US)' + ), }, + 'start_date': {'type': 'string', 'description': 'Start date, format: YYYY-MM-DD'}, + 'end_date': {'type': 'string', 'description': 'End date, format: YYYY-MM-DD'}, 'frequency': { 'type': 'string', - 'description': - 'Data frequency: d(daily), w(weekly), m(monthly), 5/15/30/60(minutes)', - 'default': 'd' + 'description': 'Data frequency: d(daily), w(weekly), m(monthly), 5/15/30/60(minutes)', + 'default': 'd', }, 'adjust_flag': { - 'type': - 'string', - 'description': - ('Adjustment flag for historical data.' - 'Adjust type: 1(backward adjusted), 2(forward adjusted), 3(non-adjusted)' - ), - 'default': - '3' - } + 'type': 'string', + 'description': ( + 'Adjustment flag for historical data.' + 'Adjust type: 1(backward adjusted), 2(forward adjusted), 3(non-adjusted)' + ), + 'default': '3', + }, }, - 'required': - ['code', 'start_date', 'end_date', 'frequency'], - 'additionalProperties': False - }), + 'required': ['code', 'start_date', 'end_date', 'frequency'], + 'additionalProperties': False, + }, + ), Tool( tool_name='get_stock_basic_info', server_name='financial_data_fetcher', - description= - 'Get stock basic information (name, industry, listing date, etc.)', + description='Get stock basic information (name, industry, listing date, etc.)', parameters={ 'type': 'object', 'properties': { 'code': { - 'type': - 'string', - 'description': - ('Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' - 'hk.03690 (Hong Kong), us.AAPL (US)') + 'type': 'string', + 'description': ( + 'Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' + 'hk.03690 (Hong Kong), us.AAPL (US)' + ), } }, 'required': ['code'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='get_dividend_data', server_name='financial_data_fetcher', @@ -426,33 +386,30 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'object', 'properties': { 'code': { - 'type': - 'string', - 'description': - ('Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' - 'hk.03690 (Hong Kong), us.AAPL (US)') + 'type': 'string', + 'description': ( + 'Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' + 'hk.03690 (Hong Kong), us.AAPL (US)' + ), }, 'year': { - 'type': - 'string', - 'description': - 'Year, e.g. 2023. If not provided, the current year will be used' + 'type': 'string', + 'description': 'Year, e.g. 2023. If not provided, the current year will be used', }, 'year_type': { - 'type': - 'string', - 'description': - ('Year category, default is "report": Year of the preliminary ' - 'announcement, optional "operate": Year of ex-dividend and ex-rights' - ), - 'default': - 'report', - 'enum': ['report', 'operate'] - } + 'type': 'string', + 'description': ( + 'Year category, default is "report": Year of the preliminary ' + 'announcement, optional "operate": Year of ex-dividend and ex-rights' + ), + 'default': 'report', + 'enum': ['report', 'operate'], + }, }, 'required': ['code'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='get_adjust_factor_data', server_name='financial_data_fetcher', @@ -464,156 +421,124 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'string', 'description': 'Stock code', }, - 'start_date': { - 'type': 'string', - 'description': - 'Start date, format: YYYY-MM-DD.' - }, - 'end_date': { - 'type': 'string', - 'description': 'End date, format: YYYY-MM-DD.' - } + 'start_date': {'type': 'string', 'description': 'Start date, format: YYYY-MM-DD.'}, + 'end_date': {'type': 'string', 'description': 'End date, format: YYYY-MM-DD.'}, }, 'required': ['code', 'start_date', 'end_date'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='get_financial_data', server_name='financial_data_fetcher', - description= - ('Get quarterly financial data for a given stock.' - 'Supported data types: profit, operation, growth, balance, cash_flow, dupont.' - 'You can specify one or multiple data types to get the corresponding data.' - ), + description=( + 'Get quarterly financial data for a given stock.' + 'Supported data types: profit, operation, growth, balance, cash_flow, dupont.' + 'You can specify one or multiple data types to get the corresponding data.' + ), parameters={ 'type': 'object', 'properties': { 'code': { - 'type': - 'string', - 'description': - ('Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' - 'hk.03690 (Hong Kong), us.AAPL (US)') - }, - 'year': { 'type': 'string', - 'description': 'Year, e.g. 2023' + 'description': ( + 'Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' + 'hk.03690 (Hong Kong), us.AAPL (US)' + ), }, + 'year': {'type': 'string', 'description': 'Year, e.g. 2023'}, 'quarter': { - 'type': - 'integer', - 'description': - ('Quarter, 1-4, e.g. 1 for first quarter, 2 for second ' - 'quarter, 3 for third quarter, 4 for fourth quarter' - ) + 'type': 'integer', + 'description': ( + 'Quarter, 1-4, e.g. 1 for first quarter, 2 for second ' + 'quarter, 3 for third quarter, 4 for fourth quarter' + ), }, 'data_types': { 'type': 'array', 'description': 'Data types to get.', 'items': { - 'type': - 'string', - 'enum': [ - 'profit', 'operation', 'growth', - 'balance', 'cash_flow', 'dupont' - ] - } - } + 'type': 'string', + 'enum': ['profit', 'operation', 'growth', 'balance', 'cash_flow', 'dupont'], + }, + }, }, 'required': ['code', 'year', 'quarter'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='get_report', server_name='financial_data_fetcher', - description= - ('Get report data for a given stock. Support for performance express ' - 'reports and performance forecast reports'), + description=( + 'Get report data for a given stock. Support for performance express ' + 'reports and performance forecast reports' + ), parameters={ - 'type': - 'object', + 'type': 'object', 'properties': { 'code': { - 'type': - 'string', - 'description': - ('Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' - 'hk.03690 (Hong Kong), us.AAPL (US)') - }, - 'start_date': { 'type': 'string', - 'description': 'Start date, format: YYYY-MM-DD' + 'description': ( + 'Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' + 'hk.03690 (Hong Kong), us.AAPL (US)' + ), }, - 'end_date': { + 'start_date': {'type': 'string', 'description': 'Start date, format: YYYY-MM-DD'}, + 'end_date': {'type': 'string', 'description': 'End date, format: YYYY-MM-DD'}, + 'report_type': { 'type': 'string', - 'description': 'End date, format: YYYY-MM-DD' + 'description': 'Report type', + 'default': 'performance_express', + 'enum': ['performance_express', 'performance_forecast'], }, - 'report_type': { - 'type': - 'string', - 'description': - 'Report type', - 'default': - 'performance_express', - 'enum': [ - 'performance_express', - 'performance_forecast' - ] - } }, - 'required': - ['code', 'start_date', 'end_date', 'report_type'], - 'additionalProperties': - False - }), + 'required': ['code', 'start_date', 'end_date', 'report_type'], + 'additionalProperties': False, + }, + ), Tool( tool_name='get_stock_industry', server_name='financial_data_fetcher', - description= - 'Get industry classification for a given stock and date', + description='Get industry classification for a given stock and date', parameters={ 'type': 'object', 'properties': { 'code': { - 'type': - 'string', - 'description': - ('Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' - 'hk.03690 (Hong Kong), us.AAPL (US)') - }, - 'date': { 'type': 'string', - 'description': 'Query date, format: YYYY-MM-DD' - } + 'description': ( + 'Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' + 'hk.03690 (Hong Kong), us.AAPL (US)' + ), + }, + 'date': {'type': 'string', 'description': 'Query date, format: YYYY-MM-DD'}, }, 'required': ['code', 'date'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='get_stock_list', server_name='financial_data_fetcher', - description= - ('Get stock list for a given date, support for SSE 50 index constituents (sse50), ' - 'CSI 300 index constituents (hs300), CSI 500 index constituents (zz500) ' - 'and all a-share stocks (all_a_share)'), + description=( + 'Get stock list for a given date, support for SSE 50 index constituents (sse50), ' + 'CSI 300 index constituents (hs300), CSI 500 index constituents (zz500) ' + 'and all a-share stocks (all_a_share)' + ), parameters={ 'type': 'object', 'properties': { - 'date': { - 'type': 'string', - 'description': 'Query date, format: YYYY-MM-DD' - }, + 'date': {'type': 'string', 'description': 'Query date, format: YYYY-MM-DD'}, 'data_type': { 'type': 'string', - 'description': - 'Data type to get. Default is "all_a_share"', - 'enum': - ['sse50', 'hs300', 'zz500', 'all_a_share'] - } + 'description': 'Data type to get. Default is "all_a_share"', + 'enum': ['sse50', 'hs300', 'zz500', 'all_a_share'], + }, }, 'required': ['date'], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='get_trade_dates', server_name='financial_data_fetcher', @@ -621,104 +546,86 @@ async def _get_tools_inner(self) -> Dict[str, Any]: parameters={ 'type': 'object', 'properties': { - 'start_date': { - 'type': 'string', - 'description': 'Start date, format: YYYY-MM-DD' - }, - 'end_date': { - 'type': 'string', - 'description': 'End date, format: YYYY-MM-DD' - } + 'start_date': {'type': 'string', 'description': 'Start date, format: YYYY-MM-DD'}, + 'end_date': {'type': 'string', 'description': 'End date, format: YYYY-MM-DD'}, }, 'required': [], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), Tool( tool_name='get_macro_data', server_name='financial_data_fetcher', - description= - ('Get macro data for a given range of dates' - 'Supported data types: deposit_rate, loan_rate, required_reserve_ratio, money_supply_month, ' - 'money_supply_year'), + description=( + 'Get macro data for a given range of dates' + 'Supported data types: deposit_rate, loan_rate, required_reserve_ratio, money_supply_month, ' + 'money_supply_year' + ), parameters={ 'type': 'object', 'properties': { - 'start_date': { - 'type': 'string', - 'description': 'Start date, format: YYYY-MM-DD' - }, - 'end_date': { - 'type': 'string', - 'description': 'End date, format: YYYY-MM-DD' - }, + 'start_date': {'type': 'string', 'description': 'Start date, format: YYYY-MM-DD'}, + 'end_date': {'type': 'string', 'description': 'End date, format: YYYY-MM-DD'}, 'data_types': { 'type': 'array', - 'description': - 'Data types to get. Default is all data types', + 'description': 'Data types to get. Default is all data types', 'items': { - 'type': - 'string', + 'type': 'string', 'enum': [ - 'deposit_rate', 'loan_rate', + 'deposit_rate', + 'loan_rate', 'required_reserve_ratio', 'money_supply_month', - 'money_supply_year' - ] + 'money_supply_year', + ], }, }, 'extra_kwargs': { 'type': 'object', - 'description': - 'Extra keyword arguments for the macro data', + 'description': 'Extra keyword arguments for the macro data', 'properties': { 'yearType': { - 'type': - 'string', - 'description': - ('Year Type, default value 0 means "announcement date," ' - 'and 1 means "effective date".'), - 'default': - '0' + 'type': 'string', + 'description': ( + 'Year Type, default value 0 means "announcement date," ' + 'and 1 means "effective date".' + ), + 'default': '0', } }, 'required': [], # yearType is optional - 'additionalProperties': False - } + 'additionalProperties': False, + }, }, 'required': ['start_date', 'end_date', 'data_types'], - 'additionalProperties': False + 'additionalProperties': False, }, - ) + ), ] } # Update tools by source type - if self.data_source is not None and hasattr(self.data_source, - 'get_extra_tools'): + if self.data_source is not None and hasattr(self.data_source, 'get_extra_tools'): tools.update(self.data_source.get_extra_tools()) return tools - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: """Call tool method""" if self.data_source is None: await self.connect() return await getattr(self, tool_name)(**tool_args) - async def get_historical_k_data(self, - code: str, - start_date: str, - end_date: str, - frequency: str = 'd', - adjust_flag: str = '3') -> str: + async def get_historical_k_data( + self, code: str, start_date: str, end_date: str, frequency: str = 'd', adjust_flag: str = '3' + ) -> str: """Get historical K-line data""" params = { 'code': code, 'start_date': start_date, 'end_date': end_date, 'frequency': frequency, - 'adjust_flag': adjust_flag + 'adjust_flag': adjust_flag, } try: @@ -728,7 +635,8 @@ async def get_historical_k_data(self, start_date=start_date, end_date=end_date, frequency=frequency, - adjust_flag=adjust_flag) + adjust_flag=adjust_flag, + ) # Generate filename with key parameters clean_code = code.replace('.', '_') @@ -742,21 +650,19 @@ async def get_historical_k_data(self, 'code': code, 'date_range': f'{start_date} to {end_date}', 'frequency': frequency, - 'adjust_flag': adjust_flag + 'adjust_flag': adjust_flag, } return self._create_success_response(df, saved_path, metadata) except Exception as e: - return self._create_error_response(e, 'get_historical_k_data', - params) + return self._create_error_response(e, 'get_historical_k_data', params) async def get_stock_basic_info(self, code: str) -> str: """Get stock basic information""" params = {'code': code} try: - df = await self._execute_with_rate_limit( - self.data_source.get_stock_basic_info, code=code) + df = await self._execute_with_rate_limit(self.data_source.get_stock_basic_info, code=code) # Generate filename clean_code = code.replace('.', '_') @@ -770,22 +676,16 @@ async def get_stock_basic_info(self, code: str) -> str: return self._create_success_response(df, saved_path, metadata) except Exception as e: - return self._create_error_response(e, 'get_stock_basic_info', - params) + return self._create_error_response(e, 'get_stock_basic_info', params) - async def get_dividend_data(self, - code: str, - year: Optional[str] = None, - year_type: str = 'report') -> str: + async def get_dividend_data(self, code: str, year: Optional[str] = None, year_type: str = 'report') -> str: """Get dividend information (BaoStock).""" params = {'code': code, 'year': year, 'year_type': year_type} try: df = await self._execute_with_rate_limit( - self.data_source.get_dividend_data, - code=code, - year=year, - year_type=year_type) + self.data_source.get_dividend_data, code=code, year=year, year_type=year_type + ) # Generate filename clean_code = code.replace('.', '_') @@ -802,17 +702,14 @@ async def get_dividend_data(self, except Exception as e: return self._create_error_response(e, 'get_dividend_data', params) - async def get_adjust_factor_data(self, code: str, start_date: str, - end_date: str) -> str: + async def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> str: """Get adjustment factor data (BaoStock).""" params = {'code': code, 'start_date': start_date, 'end_date': end_date} try: df = await self._execute_with_rate_limit( - self.data_source.get_adjust_factor_data, - code=code, - start_date=start_date, - end_date=end_date) + self.data_source.get_adjust_factor_data, code=code, start_date=start_date, end_date=end_date + ) # Generate filename clean_code = code.replace('.', '_') @@ -822,39 +719,21 @@ async def get_adjust_factor_data(self, code: str, start_date: str, saved_path = self._save_dataframe(df, filename) # Return response with sample data - metadata = { - 'code': code, - 'date_range': f'{start_date} to {end_date}' - } + metadata = {'code': code, 'date_range': f'{start_date} to {end_date}'} return self._create_success_response(df, saved_path, metadata) except Exception as e: - return self._create_error_response(e, 'get_adjust_factor_data', - params) - - async def get_financial_data(self, - code: str, - year: str, - quarter: int, - data_types: Optional[list] = None) -> str: + return self._create_error_response(e, 'get_adjust_factor_data', params) + + async def get_financial_data(self, code: str, year: str, quarter: int, data_types: Optional[list] = None) -> str: """Get multiple categories of financial data in one call.""" - data_types = data_types or [ - 'profit', 'operation', 'growth', 'balance', 'cash_flow', 'dupont' - ] - params = { - 'code': code, - 'year': year, - 'quarter': quarter, - 'data_types': data_types - } + data_types = data_types or ['profit', 'operation', 'growth', 'balance', 'cash_flow', 'dupont'] + params = {'code': code, 'year': year, 'quarter': quarter, 'data_types': data_types} try: result = await self._execute_with_rate_limit( - self.data_source.get_financial_data, - code=code, - year=year, - quarter=quarter, - data_types=data_types) + self.data_source.get_financial_data, code=code, year=year, quarter=quarter, data_types=data_types + ) # Save each data type and prepare response clean_code = code.replace('.', '_') @@ -875,42 +754,26 @@ async def get_financial_data(self, example_data[key] = value response = { - 'success': - True, - 'code': - code, - 'year': - year, - 'quarter': - quarter, - 'data_types': - list(result.keys()), - 'saved_files': - saved_files, - 'example_data': - example_data, - 'note': - 'Financial data saved to separate files. Showing sample rows for each data type.' + 'success': True, + 'code': code, + 'year': year, + 'quarter': quarter, + 'data_types': list(result.keys()), + 'saved_files': saved_files, + 'example_data': example_data, + 'note': 'Financial data saved to separate files. Showing sample rows for each data type.', } - return json.dumps( - response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) + return json.dumps(response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) except Exception as e: return self._create_error_response(e, 'get_financial_data', params) - async def get_report(self, - code: str, - start_date: str, - end_date: str, - report_type: str = 'performance_express') -> str: + async def get_report( + self, code: str, start_date: str, end_date: str, report_type: str = 'performance_express' + ) -> str: """Get performance express/forecast reports (BaoStock).""" - params = { - 'code': code, - 'start_date': start_date, - 'end_date': end_date, - 'report_type': report_type - } + params = {'code': code, 'start_date': start_date, 'end_date': end_date, 'report_type': report_type} try: df = await self._execute_with_rate_limit( @@ -918,7 +781,8 @@ async def get_report(self, code=code, start_date=start_date, end_date=end_date, - report_type=report_type) + report_type=report_type, + ) # Generate filename clean_code = code.replace('.', '_') @@ -928,11 +792,7 @@ async def get_report(self, saved_path = self._save_dataframe(df, filename) # Return response with sample data - metadata = { - 'code': code, - 'date_range': f'{start_date} to {end_date}', - 'report_type': report_type - } + metadata = {'code': code, 'date_range': f'{start_date} to {end_date}', 'report_type': report_type} return self._create_success_response(df, saved_path, metadata) except Exception as e: @@ -944,9 +804,8 @@ async def get_trade_dates(self, start_date: str, end_date: str) -> str: try: df = await self._execute_with_rate_limit( - self.data_source.get_trade_dates, - start_date=start_date, - end_date=end_date) + self.data_source.get_trade_dates, start_date=start_date, end_date=end_date + ) # Generate filename filename = f'trade_dates_{start_date}_{end_date}' @@ -966,8 +825,7 @@ async def get_stock_industry(self, code: str, date: str) -> str: params = {'code': code, 'date': date} try: - df = await self._execute_with_rate_limit( - self.data_source.get_stock_industry, code=code, date=date) + df = await self._execute_with_rate_limit(self.data_source.get_stock_industry, code=code, date=date) # Generate filename clean_code = code.replace('.', '_') @@ -983,17 +841,12 @@ async def get_stock_industry(self, code: str, date: str) -> str: except Exception as e: return self._create_error_response(e, 'get_stock_industry', params) - async def get_stock_list(self, - date: str, - data_type: str = 'all_a_share') -> str: + async def get_stock_list(self, date: str, data_type: str = 'all_a_share') -> str: """Get index constituents or all stocks.""" params = {'date': date, 'data_type': data_type} try: - df = await self._execute_with_rate_limit( - self.data_source.get_stock_list, - date=date, - data_type=data_type) + df = await self._execute_with_rate_limit(self.data_source.get_stock_list, date=date, data_type=data_type) # Generate filename filename = f'stock_list_{data_type}_{date}' @@ -1002,31 +855,28 @@ async def get_stock_list(self, saved_path = self._save_dataframe(df, filename) # Return response with sample data - metadata = { - 'date': date, - 'data_type': data_type, - 'total_stocks': len(df) - } + metadata = {'date': date, 'data_type': data_type, 'total_stocks': len(df)} return self._create_success_response(df, saved_path, metadata) except Exception as e: return self._create_error_response(e, 'get_stock_list', params) - async def get_macro_data(self, - start_date: str, - end_date: str, - data_types: Optional[list] = None, - extra_kwargs: Optional[dict] = None) -> str: + async def get_macro_data( + self, start_date: str, end_date: str, data_types: Optional[list] = None, extra_kwargs: Optional[dict] = None + ) -> str: """Get macroeconomic data (BaoStock).""" data_types = data_types or [ - 'deposit_rate', 'loan_rate', 'required_reserve_ratio', - 'money_supply_month', 'money_supply_year' + 'deposit_rate', + 'loan_rate', + 'required_reserve_ratio', + 'money_supply_month', + 'money_supply_year', ] params = { 'start_date': start_date, 'end_date': end_date, 'data_types': data_types, - 'extra_kwargs': extra_kwargs + 'extra_kwargs': extra_kwargs, } try: @@ -1035,7 +885,8 @@ async def get_macro_data(self, start_date=start_date, end_date=end_date, data_types=data_types, - extra_kwargs=extra_kwargs) + extra_kwargs=extra_kwargs, + ) # Save each data type and prepare response saved_files = {} @@ -1055,22 +906,15 @@ async def get_macro_data(self, example_data[key] = value response = { - 'success': - True, - 'date_range': - f'{start_date} to {end_date}', - 'data_types': - list(result.keys()), - 'saved_files': - saved_files, - 'example_data': - example_data, - 'note': - 'Macro data saved to separate files. Showing sample rows for each data type.' + 'success': True, + 'date_range': f'{start_date} to {end_date}', + 'data_types': list(result.keys()), + 'saved_files': saved_files, + 'example_data': example_data, + 'note': 'Macro data saved to separate files. Showing sample rows for each data type.', } - return json.dumps( - response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) + return json.dumps(response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) except Exception as e: return self._create_error_response(e, 'get_macro_data', params) diff --git a/ms_agent/tools/findata/hybrid_source.py b/ms_agent/tools/findata/hybrid_source.py index 79380fd14..57ad3c374 100644 --- a/ms_agent/tools/findata/hybrid_source.py +++ b/ms_agent/tools/findata/hybrid_source.py @@ -3,10 +3,10 @@ from typing import Any, Callable, Dict, List, Optional import pandas as pd + from ms_agent.tools.findata.akshare_source import AKShareDataSource from ms_agent.tools.findata.baostock_source import BaoStockDataSource -from ms_agent.tools.findata.data_source_base import (DataSourceError, - FinancialDataSource) +from ms_agent.tools.findata.data_source_base import DataSourceError, FinancialDataSource from ms_agent.utils import get_logger logger = get_logger() @@ -27,9 +27,7 @@ def __init__(self): logger.info('Initializing Hybrid data source') self.baostock = BaoStockDataSource() self.akshare = AKShareDataSource() - logger.info( - 'Hybrid data source initialized (A-shares: BaoStock, Others: AKShare)' - ) + logger.info('Hybrid data source initialized (A-shares: BaoStock, Others: AKShare)') def _detect_market(self, code: str) -> str: """ @@ -56,9 +54,7 @@ def _detect_market(self, code: str) -> str: logger.warning(f'Unknown market type for code: {code}') return 'unknown' - def _get_source(self, - code: str, - market: str = None) -> List[FinancialDataSource]: + def _get_source(self, code: str, market: str = None) -> List[FinancialDataSource]: """Select data source based on stock code""" market = market if market else self._detect_market(code) @@ -69,8 +65,7 @@ def _get_source(self, logger.debug(f'Using AKShare for {market}: {code}') return [self.akshare] - def _call_sources(self, sources: List[FinancialDataSource], - query_func: Callable) -> pd.DataFrame: + def _call_sources(self, sources: List[FinancialDataSource], query_func: Callable) -> pd.DataFrame: """Call query function for multiple data sources""" for source in sources: try: @@ -80,9 +75,7 @@ def _call_sources(self, sources: List[FinancialDataSource], if isinstance(result, dict) and result: return result except Exception as e: - logger.warning( - f'Data source {source.__class__.__name__} failed, continue to next source: {e}' - ) + logger.warning(f'Data source {source.__class__.__name__} failed, continue to next source: {e}') continue source_names = [s.__class__.__name__ for s in sources] @@ -100,78 +93,58 @@ def get_historical_k_data( """Get historical K-line data""" sources = self._get_source(code) return self._call_sources( - sources, lambda s: s.get_historical_k_data( - code, start_date, end_date, frequency, adjust_flag, fields)) + sources, lambda s: s.get_historical_k_data(code, start_date, end_date, frequency, adjust_flag, fields) + ) def get_stock_basic_info(self, code: str) -> pd.DataFrame: """Get stock basic information""" sources = self._get_source(code) - return self._call_sources(sources, - lambda s: s.get_stock_basic_info(code)) + return self._call_sources(sources, lambda s: s.get_stock_basic_info(code)) - def get_dividend_data(self, - code: str, - year: Optional[str] = None, - year_type: str = 'report') -> pd.DataFrame: + def get_dividend_data(self, code: str, year: Optional[str] = None, year_type: str = 'report') -> pd.DataFrame: """Get dividend data (BaoStock only)""" sources = self._get_source(code) - return self._call_sources( - sources, lambda s: s.get_dividend_data(code, year, year_type)) + return self._call_sources(sources, lambda s: s.get_dividend_data(code, year, year_type)) - def get_adjust_factor_data(self, code: str, start_date: str, - end_date: str) -> pd.DataFrame: + def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> pd.DataFrame: """Get adjustment factor data (BaoStock only)""" sources = self._get_source(code) - return self._call_sources( - sources, - lambda s: s.get_adjust_factor_data(code, start_date, end_date)) + return self._call_sources(sources, lambda s: s.get_adjust_factor_data(code, start_date, end_date)) - def get_financial_data(self, code: str, year: str, quarter: int, - data_types: List[str]) -> Dict[str, pd.DataFrame]: + def get_financial_data(self, code: str, year: str, quarter: int, data_types: List[str]) -> Dict[str, pd.DataFrame]: """Get financial data for multiple categories in one call""" sources = self._get_source(code) - return self._call_sources( - sources, - lambda s: s.get_financial_data(code, year, quarter, data_types)) - - def get_report(self, - code: str, - start_date: str, - end_date: str, - report_type: str = 'performance_express') -> pd.DataFrame: + return self._call_sources(sources, lambda s: s.get_financial_data(code, year, quarter, data_types)) + + def get_report( + self, code: str, start_date: str, end_date: str, report_type: str = 'performance_express' + ) -> pd.DataFrame: """Get report data (BaoStock only)""" sources = self._get_source(code) - return self._call_sources( - sources, - lambda s: s.get_report(code, start_date, end_date, report_type)) + return self._call_sources(sources, lambda s: s.get_report(code, start_date, end_date, report_type)) def get_stock_industry(self, code: str, date: str) -> pd.DataFrame: """Get industry classification (BaoStock only)""" sources = self._get_source(code) - return self._call_sources(sources, - lambda s: s.get_stock_industry(code, date)) + return self._call_sources(sources, lambda s: s.get_stock_industry(code, date)) - def get_stock_list(self, - date: str, - data_type: str = 'all_a_share') -> pd.DataFrame: + def get_stock_list(self, date: str, data_type: str = 'all_a_share') -> pd.DataFrame: """Get stock list or index constituents (BaoStock only)""" sources = self._get_source('', market='a_share') - return self._call_sources(sources, - lambda s: s.get_stock_list(date, data_type)) + return self._call_sources(sources, lambda s: s.get_stock_list(date, data_type)) def get_trade_dates(self, start_date: str, end_date: str) -> pd.DataFrame: """Get trading calendar (BaoStock only)""" sources = self._get_source('', market='a_share') - return self._call_sources( - sources, lambda s: s.get_trade_dates(start_date, end_date)) + return self._call_sources(sources, lambda s: s.get_trade_dates(start_date, end_date)) def get_macro_data( self, start_date: str, end_date: str, data_types: Optional[List[str]] = None, - extra_kwargs: Optional[Dict[str, - Any]] = None) -> Dict[str, pd.DataFrame]: + extra_kwargs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, pd.DataFrame]: """Get macroeconomic data for multiple categories in one call (BaoStock only)""" if data_types is None: data_types = [] @@ -179,6 +152,4 @@ def get_macro_data( extra_kwargs = {} sources = self._get_source('', market='a_share') - return self._call_sources( - sources, lambda s: s.get_macro_data(start_date, end_date, - data_types, extra_kwargs)) + return self._call_sources(sources, lambda s: s.get_macro_data(start_date, end_date, data_types, extra_kwargs)) diff --git a/ms_agent/tools/image_generator/ds_image_gen.py b/ms_agent/tools/image_generator/ds_image_gen.py index b7286b218..0a3f482d9 100644 --- a/ms_agent/tools/image_generator/ds_image_gen.py +++ b/ms_agent/tools/image_generator/ds_image_gen.py @@ -6,23 +6,18 @@ class DSImageGenerator: - def __init__(self, config, temp_dir): self.config = config self.temp_dir = temp_dir os.makedirs(self.temp_dir, exist_ok=True) - async def generate_image(self, - positive_prompt, - negative_prompt=None, - size=None, - ratio=None, - **kwargs): + async def generate_image(self, positive_prompt, negative_prompt=None, size=None, ratio=None, **kwargs): import aiohttp + image_generator = self.config.tools.image_generator base_url = ( - getattr(image_generator, 'base_url', None) - or 'https://dashscope.aliyuncs.com/compatible-mode').strip('/') + getattr(image_generator, 'base_url', None) or 'https://dashscope.aliyuncs.com/compatible-mode' + ).strip('/') api_key = image_generator.api_key model_id = image_generator.model assert api_key is not None @@ -38,34 +33,24 @@ async def generate_image(self, request_body = { 'model': model_id, - 'dashscope_extend_params': { - 'provider': 'b', - 'using_native_protocol': True - }, + 'dashscope_extend_params': {'provider': 'b', 'using_native_protocol': True}, 'stream': False, - 'contents': { - 'role': 'USER', - 'parts': { - 'text': positive_prompt - } - }, + 'contents': {'role': 'USER', 'parts': {'text': positive_prompt}}, 'generationConfig': { 'responseModalities': ['TEXT', 'IMAGE'], 'image_config': { 'aspect_ratio': ratio, }, - } + }, } async with aiohttp.ClientSession() as session: - async with session.post( - base_url, headers=headers, json=request_body) as resp: + async with session.post(base_url, headers=headers, json=request_body) as resp: resp.raise_for_status() data = await resp.json() try: - image_url = data['candidates'][0]['content']['parts'][-1][ - 'inlineData']['data'] + image_url = data['candidates'][0]['content']['parts'][-1]['inlineData']['data'] async with session.get(image_url) as img_resp: img_content = await img_resp.read() image = Image.open(BytesIO(img_content)) diff --git a/ms_agent/tools/image_generator/google_image_gen.py b/ms_agent/tools/image_generator/google_image_gen.py index fe3f6f011..103d52cc2 100644 --- a/ms_agent/tools/image_generator/google_image_gen.py +++ b/ms_agent/tools/image_generator/google_image_gen.py @@ -3,18 +3,15 @@ class GoogleImageGenerator: - def __init__(self, config, temp_dir): self.config = config self.temp_dir = temp_dir os.makedirs(self.temp_dir, exist_ok=True) - async def generate_image(self, - positive_prompt, - negative_prompt=None, - **kwargs): + async def generate_image(self, positive_prompt, negative_prompt=None, **kwargs): # TODO not tested from google import genai + image_generator = self.config.tools.image_generator api_key = image_generator.api_key model_id = image_generator.model diff --git a/ms_agent/tools/image_generator/image_gen.py b/ms_agent/tools/image_generator/image_gen.py index 3ab8a71df..c9e5f12b4 100644 --- a/ms_agent/tools/image_generator/image_gen.py +++ b/ms_agent/tools/image_generator/image_gen.py @@ -6,21 +6,22 @@ class ImageGenerator(ToolBase): - def __init__(self, config): super().__init__(config) - self.temp_dir = os.path.join(self.output_dir, '.temp', - 'image_generator') + self.temp_dir = os.path.join(self.output_dir, '.temp', 'image_generator') os.makedirs(self.temp_dir, exist_ok=True) image_generator = self.config.image_generator if image_generator.type == 'modelscope': from .ms_image_gen import MSImageGenerator + self.generator = MSImageGenerator(self.config, self.temp_dir) elif image_generator.type == 'dashscope': from .ds_image_gen import DSImageGenerator + self.generator = DSImageGenerator(self.config, self.temp_dir) elif image_generator.type == 'google': from .google_image_gen import GoogleImageGenerator + self.generator = GoogleImageGenerator(self.config, self.temp_dir) else: raise NotImplementedError() @@ -34,30 +35,21 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='generate_image', server_name='image_generator', - description= - 'Generate an image with a positive prompt, and return the image file path.', + description='Generate an image with a positive prompt, and return the image file path.', parameters={ 'type': 'object', 'properties': { - 'positive_prompt': { - 'type': 'string', - 'description': - 'The prompt to generate the image.' - } + 'positive_prompt': {'type': 'string', 'description': 'The prompt to generate the image.'} }, 'required': ['positive_prompt'], - 'additionalProperties': False - }) + 'additionalProperties': False, + }, + ) ] } - async def generate_image(self, - positive_prompt, - negative_prompt=None, - **kwargs): - return await self.generator.generate_image(positive_prompt, - negative_prompt, **kwargs) + async def generate_image(self, positive_prompt, negative_prompt=None, **kwargs): + return await self.generator.generate_image(positive_prompt, negative_prompt, **kwargs) - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await self.generate_image(**tool_args) diff --git a/ms_agent/tools/image_generator/ms_image_gen.py b/ms_agent/tools/image_generator/ms_image_gen.py index b121458e6..9000e26b4 100644 --- a/ms_agent/tools/image_generator/ms_image_gen.py +++ b/ms_agent/tools/image_generator/ms_image_gen.py @@ -1,33 +1,27 @@ import asyncio +import json import os import uuid from io import BytesIO -import json from PIL import Image class MSImageGenerator: - def __init__(self, config, temp_dir): self.config = config self.temp_dir = temp_dir os.makedirs(self.temp_dir, exist_ok=True) - async def generate_image(self, - positive_prompt, - negative_prompt=None, - size=None, - **kwargs): + async def generate_image(self, positive_prompt, negative_prompt=None, size=None, **kwargs): import aiohttp + image_generator = self.config.tools.image_generator - base_url = (getattr(image_generator, 'base_url', None) - or 'https://api-inference.modelscope.cn').strip('/') + base_url = (getattr(image_generator, 'base_url', None) or 'https://api-inference.modelscope.cn').strip('/') api_key = image_generator.api_key model_id = image_generator.model assert api_key is not None - output_file = os.path.join(self.temp_dir, - f'{str(uuid.uuid4())[:8]}.png') + output_file = os.path.join(self.temp_dir, f'{str(uuid.uuid4())[:8]}.png') headers = { 'Authorization': f'Bearer {api_key}', @@ -36,18 +30,18 @@ async def generate_image(self, async with aiohttp.ClientSession() as session: async with session.post( - f'{base_url}/v1/images/generations', - headers={ - **headers, 'X-ModelScope-Async-Mode': 'true' + f'{base_url}/v1/images/generations', + headers={**headers, 'X-ModelScope-Async-Mode': 'true'}, + data=json.dumps( + { + 'model': model_id, + 'prompt': positive_prompt, + 'negative_prompt': negative_prompt or '', + 'size': size or '', }, - data=json.dumps( - { - 'model': model_id, - 'prompt': positive_prompt, - 'negative_prompt': negative_prompt or '', - 'size': size or '', - }, - ensure_ascii=False)) as resp: + ensure_ascii=False, + ), + ) as resp: resp.raise_for_status() task_id = (await resp.json())['task_id'] @@ -61,11 +55,8 @@ async def generate_image(self, elapsed_time += poll_interval async with session.get( - f'{base_url}/v1/tasks/{task_id}', - headers={ - **headers, 'X-ModelScope-Task-Type': - 'image_generation' - }) as result: + f'{base_url}/v1/tasks/{task_id}', headers={**headers, 'X-ModelScope-Task-Type': 'image_generation'} + ) as result: result.raise_for_status() data = await result.json() @@ -82,5 +73,5 @@ async def generate_image(self, poll_interval = min(poll_interval * 1.5, max_poll_interval) return ( - f'Retrieval timeout, consider retry the task, or waiting for ' - f'longer time(current is {max_wait_time}s).') + f'Retrieval timeout, consider retry the task, or waiting for longer time(current is {max_wait_time}s).' + ) diff --git a/ms_agent/tools/jina_reader.py b/ms_agent/tools/jina_reader.py index b3f663971..acb9a7336 100644 --- a/ms_agent/tools/jina_reader.py +++ b/ms_agent/tools/jina_reader.py @@ -10,15 +10,13 @@ from urllib.parse import quote, urlparse from urllib.request import Request, urlopen -from ms_agent.tools.fetch_playwright_fallback import (looks_like_spa_shell_html, - try_playwright_inner_text) +from ms_agent.tools.fetch_playwright_fallback import looks_like_spa_shell_html, try_playwright_inner_text from ms_agent.utils.logger import get_logger logger = get_logger() DEFAULT_HEADERS: Dict[str, str] = { - 'User-Agent': - 'Mozilla/5.0 (compatible; ms-agent/1.0; +https://example.com)', + 'User-Agent': 'Mozilla/5.0 (compatible; ms-agent/1.0; +https://example.com)', 'Accept': 'text/plain; charset=utf-8', 'Accept-Language': 'en-US,en;q=0.9', } @@ -28,8 +26,7 @@ _DIRECT_FETCH_HEADERS: Dict[str, str] = { 'User-Agent': DEFAULT_HEADERS['User-Agent'], - 'Accept': - 'text/html,application/xhtml+xml,application/xml;q=0.9,text/plain;q=0.8,*/*;q=0.7', + 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,text/plain;q=0.8,*/*;q=0.7', 'Accept-Language': DEFAULT_HEADERS['Accept-Language'], } @@ -41,8 +38,7 @@ class JinaReaderConfig: retries: int = 3 backoff_base: float = 0.8 backoff_max: float = 8.0 - headers: Dict[str, - str] = field(default_factory=lambda: DEFAULT_HEADERS.copy()) + headers: Dict[str, str] = field(default_factory=lambda: DEFAULT_HEADERS.copy()) # When Jina Reader returns empty after retries, try HTTP GET on the target URL. direct_fetch_fallback: bool = True # Tier 2 (urllib): shorter than Jina timeout — fail fast on slow origins. @@ -57,8 +53,7 @@ class JinaReaderConfig: def _build_reader_url(target_url: str, base_endpoint: str) -> str: encoded_target = quote(target_url, safe=":/?&=%#@!$'*+,;[]()") - base = base_endpoint if base_endpoint.endswith( - '/') else f'{base_endpoint}/' + base = base_endpoint if base_endpoint.endswith('/') else f'{base_endpoint}/' return f'{base}{encoded_target}' @@ -132,8 +127,11 @@ def _fetch_direct_http_pair(url: str, timeout: float) -> Tuple[str, str]: content_type = (resp.headers.get('Content-Type') or '').lower() content_type_main = content_type.split(';')[0].strip() text = raw.decode(charset, errors='replace') - if 'html' in content_type_main or text.lstrip().lower().startswith( - ' Tuple[str, str]: return '', '' -def _should_try_playwright_after_direct(plain: str, raw_html: str, - min_chars: int) -> bool: +def _should_try_playwright_after_direct(plain: str, raw_html: str, min_chars: int) -> bool: """Whether tier-3 headless fetch is worth attempting.""" p = plain.strip() if raw_html: @@ -168,34 +165,29 @@ def _fetch_via_jina(url: str, config: JinaReaderConfig) -> str: return data.decode('utf-8', errors='replace') except HTTPError as e: status = getattr(e, 'code', None) - if status in (429, 500, 502, 503, - 504) and attempt <= config.retries: - sleep_s = min(config.backoff_max, - config.backoff_base * (2**(attempt - 1))) + if status in (429, 500, 502, 503, 504) and attempt <= config.retries: + sleep_s = min(config.backoff_max, config.backoff_base * (2 ** (attempt - 1))) sleep_s *= random.uniform(0.7, 1.4) time.sleep(sleep_s) continue return '' except URLError: if attempt <= config.retries: - sleep_s = min(config.backoff_max, - config.backoff_base * (2**(attempt - 1))) + sleep_s = min(config.backoff_max, config.backoff_base * (2 ** (attempt - 1))) sleep_s *= random.uniform(0.7, 1.4) time.sleep(sleep_s) continue return '' except Exception: if attempt <= config.retries: - sleep_s = min(config.backoff_max, - config.backoff_base * (2**(attempt - 1))) + sleep_s = min(config.backoff_max, config.backoff_base * (2 ** (attempt - 1))) sleep_s *= random.uniform(0.7, 1.4) time.sleep(sleep_s) continue return '' -def fetch_single_text_with_meta(url: str, - config: JinaReaderConfig) -> Tuple[str, Dict[str, Any]]: +def fetch_single_text_with_meta(url: str, config: JinaReaderConfig) -> Tuple[str, Dict[str, Any]]: """ Tiered fetch: Jina Reader → direct HTTP → optional Playwright (empty / short / SPA shell). @@ -209,15 +201,17 @@ def fetch_single_text_with_meta(url: str, return jina_text, {'content_source': 'jina_reader'} if not config.direct_fetch_fallback: return '', {'content_source': 'none'} - d_timeout = (float(config.timeout) if float(config.direct_fetch_timeout or 0) - <= 0 else float(config.direct_fetch_timeout)) + d_timeout = ( + float(config.timeout) if float(config.direct_fetch_timeout or 0) <= 0 else float(config.direct_fetch_timeout) + ) direct_plain, raw_html = _fetch_direct_http_pair(url, d_timeout) direct_text = _postprocess_text(direct_plain) try_playwright = ( - bool(config.playwright_fetch_fallback) and _is_direct_http_allowed(url) - and _should_try_playwright_after_direct(direct_text, raw_html, - config.playwright_retry_min_chars)) + bool(config.playwright_fetch_fallback) + and _is_direct_http_allowed(url) + and _should_try_playwright_after_direct(direct_text, raw_html, config.playwright_retry_min_chars) + ) if try_playwright: pw_text = _postprocess_text( @@ -225,17 +219,14 @@ def fetch_single_text_with_meta(url: str, url, int(config.playwright_timeout_ms), settle_ms=int(config.playwright_settle_ms), - )) + ) + ) if pw_text.strip(): - logger.info( - 'Using headless Chromium fallback after Jina/direct HTTP ' - f'(url prefix): {url[:80]}') + logger.info(f'Using headless Chromium fallback after Jina/direct HTTP (url prefix): {url[:80]}') return pw_text, {'content_source': 'playwright_fallback'} if direct_text: - logger.info( - 'Jina Reader returned no body for URL; using direct HTTP fallback ' - f'(url prefix): {url[:80]}') + logger.info(f'Jina Reader returned no body for URL; using direct HTTP fallback (url prefix): {url[:80]}') return direct_text, {'content_source': 'direct_http_fallback'} return '', {'content_source': 'none'} @@ -250,10 +241,11 @@ def fetch_single_text(url: str, config: JinaReaderConfig) -> str: async def fetch_texts_via_jina( - urls: List[str], - config: Optional[JinaReaderConfig] = None, - semaphore: Optional[asyncio.Semaphore] = None, - executor: Optional[ThreadPoolExecutor] = None) -> List[str]: + urls: List[str], + config: Optional[JinaReaderConfig] = None, + semaphore: Optional[asyncio.Semaphore] = None, + executor: Optional[ThreadPoolExecutor] = None, +) -> List[str]: """ Asynchronously fetch a list of URLs via Jina Reader. Allows caller-provided concurrency controls (semaphore/executor) to integrate with pipeline resource management. @@ -267,8 +259,7 @@ async def fetch_texts_via_jina( async def _bound(u: str) -> str: async with local_sem: - return await loop.run_in_executor(executor, fetch_single_text, u, - cfg) + return await loop.run_in_executor(executor, fetch_single_text, u, cfg) tasks = [_bound(u) for u in urls] results = await asyncio.gather(*tasks, return_exceptions=True) diff --git a/ms_agent/tools/mcp_client.py b/ms_agent/tools/mcp_client.py index d6b971fbb..e6074fe2a 100644 --- a/ms_agent/tools/mcp_client.py +++ b/ms_agent/tools/mcp_client.py @@ -8,12 +8,13 @@ from mcp import ClientSession, ListToolsResult, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client +from omegaconf import DictConfig + from ms_agent.config import Config from ms_agent.config.env import Env from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import enhance_error, get_logger -from omegaconf import DictConfig logger = get_logger() @@ -51,18 +52,14 @@ def __init__( self.mcp_config: Dict[str, Dict[str, Any]] = {'mcpServers': {}} if config is not None: config_from_file = Config.convert_mcp_servers_to_json(config) - self.mcp_config['mcpServers'].update( - config_from_file.get('mcpServers', {})) + self.mcp_config['mcpServers'].update(config_from_file.get('mcpServers', {})) self.exclude_functions = {} self.include_functions = {} if mcp_config is not None: - self.mcp_config['mcpServers'].update( - mcp_config.get('mcpServers', {})) + self.mcp_config['mcpServers'].update(mcp_config.get('mcpServers', {})) - async def call_tool(self, server_name: str, tool_name: str, - tool_args: dict): - response = await self.sessions[server_name].call_tool( - tool_name, tool_args) + async def call_tool(self, server_name: str, tool_name: str, tool_args: dict): + response = await self.sessions[server_name].call_tool(tool_name, tool_args) texts = [] resources = [] @@ -80,6 +77,7 @@ async def call_tool(self, server_name: str, tool_name: str, texts.append(content.text) elif content.type == 'resource': import json5 + json_str = content.resource.model_dump_json(by_alias=True) texts.append(json_str) resources.append(json5.loads(json_str)) @@ -96,8 +94,7 @@ async def get_tools(self) -> Dict: try: response = await session.list_tools() except Exception as e: - new_eg = enhance_error( - e, f'MCP `{key}` list tool failed, details: ') + new_eg = enhance_error(e, f'MCP `{key}` list tool failed, details: ') raise new_eg from e _session_tools = response.tools exclude = [] @@ -108,19 +105,12 @@ async def get_tools(self) -> Dict: elif self.exclude_functions: if key in self.exclude_functions: exclude = self.exclude_functions[key] - _session_tools = [ - t for t in _session_tools if t.name not in exclude - ] + _session_tools = [t for t in _session_tools if t.name not in exclude] if include: - _session_tools = [ - t for t in _session_tools if t.name in include - ] + _session_tools = [t for t in _session_tools if t.name in include] _session_tools = [ - Tool( - tool_name=t.name, - server_name=key, - description=t.description, - parameters=t.inputSchema) for t in _session_tools + Tool(tool_name=t.name, server_name=key, description=t.description, parameters=t.inputSchema) + for t in _session_tools ] tools[key].extend(_session_tools) return tools @@ -132,18 +122,13 @@ def print_tools(server_name: str, tools: ListToolsResult): if len(tools) > 10: tools = [tool.name for tool in tools][:10] logger.info( - f'\nConnected to server "{server_name}" ' - f'with tools: \n{sep.join(tools)}\nOnly list first 10 of them.' + f'\nConnected to server "{server_name}" with tools: \n{sep.join(tools)}\nOnly list first 10 of them.' ) else: tools = [tool.name for tool in tools] - logger.info(f'\nConnected to server "{server_name}" ' - f'with tools: \n{sep.join(tools)}.') + logger.info(f'\nConnected to server "{server_name}" with tools: \n{sep.join(tools)}.') - async def connect_to_server(self, - server_name: str, - timeout: int = CONNECTION_TIMEOUT, - **kwargs): + async def connect_to_server(self, server_name: str, timeout: int = CONNECTION_TIMEOUT, **kwargs): logger.info(f'connect to {server_name}') # transport: stdio, sse, streamable_http, websocket transport = kwargs.get('transport') or kwargs.get('type') @@ -152,21 +137,19 @@ async def connect_to_server(self, session_kwargs = kwargs.get('session_kwargs') if url: if transport and transport.lower() == 'sse': - logger.info( - '`transport` or `type` is configured as "sse", using sse transport.' - ) + logger.info('`transport` or `type` is configured as "sse", using sse transport.') sse_transport = await self.exit_stack.enter_async_context( sse_client( - url, kwargs.get('headers'), + url, + kwargs.get('headers'), kwargs.get('timeout', DEFAULT_HTTP_TIMEOUT), - kwargs.get('sse_read_timeout', - DEFAULT_SSE_READ_TIMEOUT))) + kwargs.get('sse_read_timeout', DEFAULT_SSE_READ_TIMEOUT), + ) + ) read, write = sse_transport elif transport and transport.lower() == 'websocket': - logger.info( - '`transport` or `type` is configured as "websocket", using websocket transport.' - ) + logger.info('`transport` or `type` is configured as "websocket", using websocket transport.') try: from mcp.client.websocket import websocket_client except ImportError: @@ -175,21 +158,22 @@ async def connect_to_server(self, 'To use Websocket connections, please install the required dependency with: ' "'pip install mcp[ws]' or 'pip install websockets'" ) from None - websocket_transport = await self.exit_stack.enter_async_context( - websocket_client(url)) + websocket_transport = await self.exit_stack.enter_async_context(websocket_client(url)) read, write = websocket_transport else: logger.info( 'Using streamable_http transport. To configure a different transport such as sse, please' - 'set the `type` or `transport` variable to "sse".') + 'set the `type` or `transport` variable to "sse".' + ) try: from mcp.client.streamable_http import streamablehttp_client except ImportError: raise ImportError( 'Could not import streamablehttp_client. ' 'To use streamable http connections, please upgrade to the latest version of mcp with: ' - "'pip install -U mcp'") from None + "'pip install -U mcp'" + ) from None httpx_client_factory = kwargs.get('httpx_client_factory') other_kwargs = {} if httpx_client_factory is not None: @@ -198,46 +182,36 @@ async def connect_to_server(self, streamablehttp_client( url, headers=kwargs.get('headers'), - timeout=kwargs.get('timeout', - DEFAULT_STREAMABLE_HTTP_TIMEOUT), - sse_read_timeout=kwargs.get( - 'sse_read_timeout', - DEFAULT_STREAMABLE_HTTP_SSE_READ_TIMEOUT), - **other_kwargs)) + timeout=kwargs.get('timeout', DEFAULT_STREAMABLE_HTTP_TIMEOUT), + sse_read_timeout=kwargs.get('sse_read_timeout', DEFAULT_STREAMABLE_HTTP_SSE_READ_TIMEOUT), + **other_kwargs, + ) + ) read, write, _ = streamable_transport session_kwargs = session_kwargs or {} - timeout = max( - session_kwargs.pop('read_timeout_seconds', timeout), 1) + timeout = max(session_kwargs.pop('read_timeout_seconds', timeout), 1) session = await self.exit_stack.enter_async_context( - ClientSession( - read, - write, - read_timeout_seconds=timedelta(seconds=timeout), - **session_kwargs)) + ClientSession(read, write, read_timeout_seconds=timedelta(seconds=timeout), **session_kwargs) + ) elif command: # transport: 'stdio' args = kwargs.get('args') if not args: - raise ValueError( - "'args' parameter is required for stdio connection") + raise ValueError("'args' parameter is required for stdio connection") server_params = StdioServerParameters( command=command, args=args, env=kwargs.get('env'), encoding=kwargs.get('encoding', DEFAULT_ENCODING), - encoding_error_handler=kwargs.get( - 'encoding_error_handler', DEFAULT_ENCODING_ERROR_HANDLER), + encoding_error_handler=kwargs.get('encoding_error_handler', DEFAULT_ENCODING_ERROR_HANDLER), ) - stdio, write = await self.exit_stack.enter_async_context( - stdio_client(server_params)) - session = await self.exit_stack.enter_async_context( - ClientSession(stdio, write)) + stdio, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) + session = await self.exit_stack.enter_async_context(ClientSession(stdio, write)) else: - raise ValueError( - "'url' or 'command' parameter is required for connection") + raise ValueError("'url' or 'command' parameter is required for connection") await session.initialize() # Store session @@ -252,20 +226,16 @@ async def connect(self, timeout: int = CONNECTION_TIMEOUT): for name, server in mcp_config.items(): try: env_dict = server.pop('env', {}) - env_dict = { - key: value if value else envs.get(key, '') - for key, value in env_dict.items() - } + env_dict = {key: value if value else envs.get(key, '') for key, value in env_dict.items()} if 'exclude' in server: self.exclude_functions[name] = server.pop('exclude') if 'include' in server: self.include_functions[name] = server.pop('include') - assert (not self.include_functions.get(name)) or ( - not self.exclude_functions.get(name) - ), 'Set either `include` or `exclude` in tools config.' + assert (not self.include_functions.get(name)) or (not self.exclude_functions.get(name)), ( + 'Set either `include` or `exclude` in tools config.' + ) timeout = server.pop('timeout', timeout) - await self.connect_to_server( - server_name=name, env=env_dict, timeout=timeout, **server) + await self.connect_to_server(server_name=name, env=env_dict, timeout=timeout, **server) except Exception as e: new_eg = enhance_error(e, f'Connect `{name}` failed, details:') raise new_eg from e @@ -282,14 +252,10 @@ async def add_mcp_config(self, mcp_config: Dict[str, Dict[str, Any]]): else: servers[name] = server env_dict = server.pop('env', {}) - env_dict = { - key: value if value else envs.get(key, '') - for key, value in env_dict.items() - } + env_dict = {key: value if value else envs.get(key, '') for key, value in env_dict.items()} if 'exclude' in server: self.exclude_functions[name] = server.pop('exclude') - await self.connect_to_server( - server_name=name, env=env_dict, **server) + await self.connect_to_server(server_name=name, env=env_dict, **server) self.mcp_config['mcpServers'].update(new_mcp_config) async def cleanup(self): diff --git a/ms_agent/tools/mineru/pdf_parser.py b/ms_agent/tools/mineru/pdf_parser.py index d210fd9b1..a352fe396 100644 --- a/ms_agent/tools/mineru/pdf_parser.py +++ b/ms_agent/tools/mineru/pdf_parser.py @@ -1,16 +1,13 @@ import os from magic_pdf.config.enums import SupportedPdfParseMethod -from magic_pdf.data.data_reader_writer import (FileBasedDataReader, - FileBasedDataWriter) +from magic_pdf.data.data_reader_writer import FileBasedDataReader, FileBasedDataWriter from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze class PdfParser: - def __init__(self, parser_workdir: str): - # e.g. "your_workdir/resources/mineru" self._workdir = parser_workdir os.makedirs(self._workdir, exist_ok=True) @@ -18,8 +15,7 @@ def __init__(self, parser_workdir: str): self.relative_image_dir = 'images' self.markdown_dir = self._workdir - self.img_writer = FileBasedDataWriter( - os.path.join(self._workdir, self.relative_image_dir)) + self.img_writer = FileBasedDataWriter(os.path.join(self._workdir, self.relative_image_dir)) self.md_writer = FileBasedDataWriter(self.markdown_dir) self.data_reader = FileBasedDataReader('') @@ -43,8 +39,7 @@ def parse(self, f_path: str, reuse: bool = True) -> str: print(f'Processing file: {f_path}') file_name_no_suffix = os.path.splitext(os.path.basename(f_path))[0] - entry_md_file = os.path.join(self.markdown_dir, - f'{file_name_no_suffix}.md') + entry_md_file = os.path.join(self.markdown_dir, f'{file_name_no_suffix}.md') if reuse and os.path.exists(entry_md_file): print(f'File {entry_md_file} already exists. Skipping processing.') @@ -68,32 +63,24 @@ def parse(self, f_path: str, reuse: bool = True) -> str: pipe_result = infer_result.pipe_txt_mode(self.img_writer) # draw model result on each page - infer_result.draw_model( - os.path.join(self.markdown_dir, - f'{file_name_no_suffix}_model.pdf')) + infer_result.draw_model(os.path.join(self.markdown_dir, f'{file_name_no_suffix}_model.pdf')) # draw layout result on each page - pipe_result.draw_layout( - os.path.join(self.markdown_dir, - f'{file_name_no_suffix}_layout.pdf')) + pipe_result.draw_layout(os.path.join(self.markdown_dir, f'{file_name_no_suffix}_layout.pdf')) # draw spans result on each page - pipe_result.draw_span( - os.path.join(self.markdown_dir, - f'{file_name_no_suffix}_spans.pdf')) + pipe_result.draw_span(os.path.join(self.markdown_dir, f'{file_name_no_suffix}_spans.pdf')) # dump markdown - pipe_result.dump_md(self.md_writer, f'{file_name_no_suffix}.md', - self.relative_image_dir) + pipe_result.dump_md(self.md_writer, f'{file_name_no_suffix}.md', self.relative_image_dir) # dump content list pipe_result.dump_content_list( - self.md_writer, f'{file_name_no_suffix}_content_list.json', - self.relative_image_dir) + self.md_writer, f'{file_name_no_suffix}_content_list.json', self.relative_image_dir + ) # dump middle json - pipe_result.dump_middle_json(self.md_writer, - f'{file_name_no_suffix}_middle.json') + pipe_result.dump_middle_json(self.md_writer, f'{file_name_no_suffix}_middle.json') print(f'Finished processing file: {f_path}') diff --git a/ms_agent/tools/search/arxiv/__init__.py b/ms_agent/tools/search/arxiv/__init__.py index 2cbf5d6fe..d4308d5f7 100644 --- a/ms_agent/tools/search/arxiv/__init__.py +++ b/ms_agent/tools/search/arxiv/__init__.py @@ -1,4 +1,3 @@ # flake8: noqa -from ms_agent.tools.search.arxiv.schema import (ArxivSearchRequest, - ArxivSearchResult) +from ms_agent.tools.search.arxiv.schema import ArxivSearchRequest, ArxivSearchResult from ms_agent.tools.search.arxiv.search import ArxivSearch diff --git a/ms_agent/tools/search/arxiv/schema.py b/ms_agent/tools/search/arxiv/schema.py index 154c603ae..f0bc2aed4 100644 --- a/ms_agent/tools/search/arxiv/schema.py +++ b/ms_agent/tools/search/arxiv/schema.py @@ -1,12 +1,11 @@ # flake8: noqa -from dataclasses import dataclass, field -from typing import Any, Dict, Generator, List, Optional - import arxiv import json from arxiv import SortCriterion, SortOrder -from ms_agent.tools.search.search_base import (BaseResult, SearchRequest, - SearchResponse, SearchResult) +from dataclasses import dataclass, field +from typing import Any, Dict, Generator, List, Optional + +from ms_agent.tools.search.search_base import BaseResult, SearchRequest, SearchResponse, SearchResult from ms_agent.utils.logger import get_logger logger = get_logger() @@ -17,15 +16,17 @@ class ArxivSearchRequest(SearchRequest): A class representing a search request to ArXiv. """ - def __init__(self, - query: str = None, - num_results: Optional[int] = 10, - sort_strategy: SortCriterion = SortCriterion.Relevance, - sort_order: SortOrder = SortOrder.Descending, - categories: Optional[List[str]] = None, - date_from: Optional[str] = None, - date_to: Optional[str] = None, - **kwargs: Any): + def __init__( + self, + query: str = None, + num_results: Optional[int] = 10, + sort_strategy: SortCriterion = SortCriterion.Relevance, + sort_order: SortOrder = SortOrder.Descending, + categories: Optional[List[str]] = None, + date_from: Optional[str] = None, + date_to: Optional[str] = None, + **kwargs: Any, + ): """ Initialize ArxivSearchRequest with search parameters. @@ -47,12 +48,9 @@ def __init__(self, self.sort_strategy_map = { 'relevance': SortCriterion.Relevance, 'lastUpdatedDate': SortCriterion.LastUpdatedDate, - 'submittedDate': SortCriterion.SubmittedDate - } - self.sort_order_map = { - 'descending': SortOrder.Descending, - 'ascending': SortOrder.Ascending + 'submittedDate': SortCriterion.SubmittedDate, } + self.sort_order_map = {'descending': SortOrder.Descending, 'ascending': SortOrder.Ascending} def to_dict(self) -> Dict[str, Any]: """ @@ -61,18 +59,16 @@ def to_dict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The parameters as a dictionary """ - if isinstance(self.sort_strategy, str) and self.sort_strategy_map.get( - self.sort_strategy): + if isinstance(self.sort_strategy, str) and self.sort_strategy_map.get(self.sort_strategy): self.sort_strategy = self.sort_strategy_map[self.sort_strategy] - if isinstance(self.sort_order, str) and self.sort_order_map.get( - self.sort_order): + if isinstance(self.sort_order, str) and self.sort_order_map.get(self.sort_order): self.sort_order = self.sort_order_map[self.sort_order] return { 'query': self.query, 'max_results': self.num_results, 'sort_by': self.sort_strategy, - 'sort_order': self.sort_order + 'sort_order': self.sort_order, } def to_json(self) -> Dict[str, Any]: @@ -87,18 +83,16 @@ def to_json(self) -> Dict[str, Any]: 'query': self.query, 'max_results': self.num_results, 'sort_strategy': self.sort_strategy.value, - 'sort_order': self.sort_order.value + 'sort_order': self.sort_order.value, }, - ensure_ascii=False) + ensure_ascii=False, + ) class ArxivSearchResult(SearchResult): """ArXiv search result implementation.""" - def __init__(self, - query: str, - arguments: Dict[str, Any] = None, - response: List['arxiv.Result'] = None): + def __init__(self, query: str, arguments: Dict[str, Any] = None, response: List['arxiv.Result'] = None): """ Initialize ArxivSearchResult. @@ -140,21 +134,20 @@ def _process_results(self) -> SearchResponse: processed = [] for res in self.raw_response: if not isinstance(res, arxiv.Result): - print( - f'***Warning: Result {res} is not an instance of arxiv.Result.' - ) + print(f'***Warning: Result {res} is not an instance of arxiv.Result.') continue processed.append( BaseResult( - url=getattr(res, 'pdf_url', None) - or getattr(res, 'entry_id', None), + url=getattr(res, 'pdf_url', None) or getattr(res, 'entry_id', None), id=getattr(res, 'entry_id', None), title=getattr(res, 'title', None), highlights=None, highlight_scores=None, summary=getattr(res, 'summary', None), - markdown=None)) + markdown=None, + ) + ) return SearchResponse(results=processed) @@ -162,8 +155,7 @@ def _process_arguments(self) -> Dict[str, Any]: """Process the search arguments to be JSON serializable.""" sort_strategy = self.arguments.get('sort_strategy', None) if sort_strategy is None: - sort_strategy = self.arguments.get('sort_by', - SortCriterion.Relevance) + sort_strategy = self.arguments.get('sort_by', SortCriterion.Relevance) sort_order = self.arguments.get('sort_order', SortOrder.Descending) if isinstance(sort_strategy, SortCriterion): @@ -228,31 +220,21 @@ def to_list(self) -> List[Dict[str, Any]]: categories = getattr(res, 'categories', None) or [] - res_list.append({ - 'url': (getattr(res, 'pdf_url', None) - or getattr(res, 'entry_id', None) or ''), - 'id': - getattr(res, 'entry_id', None) or '', - 'title': - getattr(res, 'title', None) or '', - 'published_date': - published_date, - 'summary': - getattr(res, 'summary', None) or '', - 'highlights': - None, - 'highlight_scores': - None, - 'markdown': - None, - 'authors': - authors, - 'categories': - categories, - 'arxiv_id': - short_id or '', - 'resource_uri': - f'arxiv://{short_id}' if short_id else '', - }) + res_list.append( + { + 'url': (getattr(res, 'pdf_url', None) or getattr(res, 'entry_id', None) or ''), + 'id': getattr(res, 'entry_id', None) or '', + 'title': getattr(res, 'title', None) or '', + 'published_date': published_date, + 'summary': getattr(res, 'summary', None) or '', + 'highlights': None, + 'highlight_scores': None, + 'markdown': None, + 'authors': authors, + 'categories': categories, + 'arxiv_id': short_id or '', + 'resource_uri': f'arxiv://{short_id}' if short_id else '', + } + ) return res_list diff --git a/ms_agent/tools/search/arxiv/search.py b/ms_agent/tools/search/arxiv/search.py index 8f407d454..0761b2509 100644 --- a/ms_agent/tools/search/arxiv/search.py +++ b/ms_agent/tools/search/arxiv/search.py @@ -1,12 +1,11 @@ # flake8: noqa +import arxiv import os +from arxiv import SortCriterion, SortOrder from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING -import arxiv -from arxiv import SortCriterion, SortOrder -from ms_agent.tools.search.arxiv.schema import (ArxivSearchRequest, - ArxivSearchResult) +from ms_agent.tools.search.arxiv.schema import ArxivSearchRequest, ArxivSearchResult from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType from ms_agent.utils.logger import get_logger @@ -132,20 +131,16 @@ def _parse_yyyy_mm_dd(s: str, *, end_of_day: bool) -> datetime: date_from_dt = None date_to_dt = None if getattr(search_request, 'date_from', None): - date_from_dt = _parse_yyyy_mm_dd( - search_request.date_from, end_of_day=False) + date_from_dt = _parse_yyyy_mm_dd(search_request.date_from, end_of_day=False) if getattr(search_request, 'date_to', None): - date_to_dt = _parse_yyyy_mm_dd( - search_request.date_to, end_of_day=True) + date_to_dt = _parse_yyyy_mm_dd(search_request.date_to, end_of_day=True) if date_from_dt or date_to_dt: desired = int(search_request.num_results or 10) - search_args['max_results'] = min( - max(desired + 10, desired), 50) + search_args['max_results'] = min(max(desired + 10, desired), 50) response = [] - for paper in self.client.results( - search=arxiv.Search(**search_args)): + for paper in self.client.results(search=arxiv.Search(**search_args)): if date_from_dt or date_to_dt: paper_date = getattr(paper, 'published', None) if paper_date is None: @@ -175,7 +170,8 @@ def _parse_yyyy_mm_dd(s: str, *, end_of_day: bool) -> datetime: **search_args, **extra_args, }, - response=response) + response=response, + ) except Exception as e: raise RuntimeError(f'Failed to perform search: {e}') from e @@ -185,6 +181,7 @@ def _parse_yyyy_mm_dd(s: str, *, end_of_day: bool) -> datetime: def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': """Return the tool definition for arXiv search engine.""" from ms_agent.llm.utils import Tool + return Tool( tool_name=cls.get_tool_name(), server_name=server_name, @@ -193,62 +190,52 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'type': 'object', 'properties': { 'query': { - 'type': - 'string', - 'description': - ('Search query using quoted phrases for exact matches ' - '(e.g., \'"machine learning" OR "deep learning"\') or ' - 'specific technical terms. Avoid overly broad or generic terms.' - ), + 'type': 'string', + 'description': ( + 'Search query using quoted phrases for exact matches ' + '(e.g., \'"machine learning" OR "deep learning"\') or ' + 'specific technical terms. Avoid overly broad or generic terms.' + ), }, 'num_results': { - 'type': - 'integer', - 'minimum': - 1, - 'maximum': - 15, - 'description': - ('Maximum number of results to return. Default is 5.' - 'Use 5-15 for comprehensive searches.'), + 'type': 'integer', + 'minimum': 1, + 'maximum': 15, + 'description': ( + 'Maximum number of results to return. Default is 5.Use 5-15 for comprehensive searches.' + ), }, 'date_from': { - 'type': - 'string', - 'description': - ('Start date for papers (YYYY-MM-DD format). ' - 'Use to find recent work, e.g., "2023-01-01".'), + 'type': 'string', + 'description': ( + 'Start date for papers (YYYY-MM-DD format). Use to find recent work, e.g., "2023-01-01".' + ), }, 'date_to': { - 'type': - 'string', - 'description': - ('End date for papers (YYYY-MM-DD format). ' - 'Use with date_from for historical windows, e.g., "2020-12-31".' - ), + 'type': 'string', + 'description': ( + 'End date for papers (YYYY-MM-DD format). ' + 'Use with date_from for historical windows, e.g., "2020-12-31".' + ), }, 'categories': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - ('Strongly recommended: arXiv categories to focus search ' - '(e.g., ["cs.AI", "cs.MA"] for agent research, ["cs.LG"] for ML, ' - '["cs.CL"] for NLP, ["cs.CV"] for computer vision). ' - 'Greatly improves relevance.'), + 'type': 'array', + 'items': {'type': 'string'}, + 'description': ( + 'Strongly recommended: arXiv categories to focus search ' + '(e.g., ["cs.AI", "cs.MA"] for agent research, ["cs.LG"] for ML, ' + '["cs.CL"] for NLP, ["cs.CV"] for computer vision). ' + 'Greatly improves relevance.' + ), }, 'sort_by': { - 'type': - 'string', - 'enum': - ['relevance', 'submittedDate', 'lastUpdatedDate'], - 'description': - ('How to sort results. "relevance" for best match, ' - '"submittedDate" for newest submissions, ' - '"lastUpdatedDate" for recently updated. Default is "relevance".' - ), + 'type': 'string', + 'enum': ['relevance', 'submittedDate', 'lastUpdatedDate'], + 'description': ( + 'How to sort results. "relevance" for best match, ' + '"submittedDate" for newest submissions, ' + '"lastUpdatedDate" for recently updated. Default is "relevance".' + ), }, 'sort_order': { 'type': 'string', @@ -270,8 +257,8 @@ def build_request_from_args(cls, **kwargs) -> ArxivSearchRequest: categories = [str(c).strip() for c in categories if str(c).strip()] if not _validate_categories(categories): logger.warning( - f"Invalid arXiv categories provided: {kwargs.get('categories')}. " - 'Ignoring categories filter.') + f"Invalid arXiv categories provided: {kwargs.get('categories')}. Ignoring categories filter." + ) categories = None # Build final query by AND-ing base query with category filter (OR across categories) diff --git a/ms_agent/tools/search/content_optimizer.py b/ms_agent/tools/search/content_optimizer.py index 998b13b66..c316ac7e0 100644 --- a/ms_agent/tools/search/content_optimizer.py +++ b/ms_agent/tools/search/content_optimizer.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio +import json import os import re from concurrent.futures import ThreadPoolExecutor @@ -8,12 +9,12 @@ from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse -import json +from omegaconf import DictConfig, OmegaConf + from ms_agent.llm.openai_llm import OpenAI from ms_agent.llm.utils import Message from ms_agent.utils.logger import get_logger from ms_agent.utils.thread_util import DaemonThreadPoolExecutor -from omegaconf import DictConfig, OmegaConf logger = get_logger() @@ -143,6 +144,7 @@ @dataclass class SearchResultMeta: """Metadata for a search result used in reranking.""" + url: str title: str snippet: str = '' @@ -155,6 +157,7 @@ class SearchResultMeta: @dataclass class SummaryResult: """Result of content summarization.""" + summary: str key_excerpts: str original_length: int @@ -178,6 +181,7 @@ def total_tokens(self) -> int: @dataclass class ContentOptimizerConfig: """Configuration for content optimization.""" + # Summarization settings summarizer_model: str = 'qwen-flash' summarizer_base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1' @@ -200,7 +204,8 @@ class ContentOptimizerConfig: 'blog': 0.6, # Technical blogs 'forum': 0.4, # Forums, Q&A sites 'unknown': 0.5, - }) + } + ) # Domain patterns for source classification @@ -306,8 +311,7 @@ def classify_source(url: str) -> str: return 'paper' # Check for documentation indicators - if any(doc_pattern in domain - for doc_pattern in ['docs.', 'documentation.', 'developer.']): + if any(doc_pattern in domain for doc_pattern in ['docs.', 'documentation.', 'developer.']): return 'official' for news_domain in NEWS_DOMAINS: @@ -375,11 +379,7 @@ def _build_llm_config(self) -> DictConfig: 'openai_base_url': self.config.summarizer_base_url, 'openai_api_key': self.config.summarizer_api_key, }, - 'generation_config': { - 'extra_body': { - 'enable_thinking': False - } - }, + 'generation_config': {'extra_body': {'enable_thinking': False}}, } return OmegaConf.create(config_dict) @@ -396,9 +396,7 @@ async def initialize(self) -> None: thread_name_prefix='content_summarizer_', ) self._initialized = True - logger.info( - f'ContentSummarizer initialized with model: {self.config.summarizer_model}' - ) + logger.info(f'ContentSummarizer initialized with model: {self.config.summarizer_model}') except Exception as e: logger.error(f'Failed to initialize ContentSummarizer: {e}') raise @@ -430,8 +428,7 @@ def _parse_summary_response(self, response_text: str) -> Tuple[str, str]: Tuple of (summary, key_excerpts) """ # Try to find JSON in the response - json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', - response_text, re.DOTALL) + json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', response_text, re.DOTALL) if json_match: try: data = json.loads(json_match.group(1)) @@ -445,7 +442,7 @@ def _parse_summary_response(self, response_text: str) -> Tuple[str, str]: start_idx = response_text.find('{') end_idx = response_text.rfind('}') if start_idx != -1 and end_idx != -1: - json_str = response_text[start_idx:end_idx + 1] + json_str = response_text[start_idx : end_idx + 1] data = json.loads(json_str) return data.get('summary', ''), data.get('key_excerpts', '') except json.JSONDecodeError: @@ -468,10 +465,7 @@ def _call_llm_sync(self, prompt: str) -> Message: response = self._llm.generate(messages) return response - async def summarize(self, - content: str, - task_context: str = '', - language: str = 'auto') -> SummaryResult: + async def summarize(self, content: str, task_context: str = '', language: str = 'auto') -> SummaryResult: """ Summarize webpage content using the configured LLM. @@ -500,13 +494,12 @@ async def summarize(self, ) # Truncate content if too long - content_to_summarize = content[:self.config.max_content_chars] + content_to_summarize = content[: self.config.max_content_chars] # Detect language and select prompt if language == 'auto': # Simple heuristic: check for Chinese characters - chinese_chars = len( - re.findall(r'[\u4e00-\u9fff]', content_to_summarize[:1000])) + chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', content_to_summarize[:1000])) language = 'zh' if chinese_chars > 30 else 'en' prompt_template = SUMMARIZE_WEBPAGE_PROMPT if language == 'zh' else SUMMARIZE_WEBPAGE_PROMPT_EN @@ -524,8 +517,7 @@ async def summarize(self, # Run synchronous LLM call in executor with timeout loop = asyncio.get_event_loop() response_msg: Message = await asyncio.wait_for( - loop.run_in_executor(self._executor, self._call_llm_sync, - prompt), + loop.run_in_executor(self._executor, self._call_llm_sync, prompt), timeout=self.config.summarization_timeout, ) @@ -540,8 +532,8 @@ async def summarize(self, compression_ratio = compressed_length / original_length if original_length > 0 else 1.0 logger.debug( - f'Content summarized: {original_length} -> {compressed_length} chars ' - f'(ratio: {compression_ratio:.2%})') + f'Content summarized: {original_length} -> {compressed_length} chars (ratio: {compression_ratio:.2%})' + ) return SummaryResult( summary=summary, @@ -550,25 +542,19 @@ async def summarize(self, compressed_length=compressed_length, compression_ratio=compression_ratio, success=True, - model=str( - getattr(self._llm, 'model', '') - or self.config.summarizer_model), - prompt_tokens=int( - getattr(response_msg, 'prompt_tokens', 0) or 0), - completion_tokens=int( - getattr(response_msg, 'completion_tokens', 0) or 0), - cached_tokens=int( - getattr(response_msg, 'cached_tokens', 0) or 0), - cache_creation_input_tokens=int( - getattr(response_msg, 'cache_creation_input_tokens', 0) - or 0), + model=str(getattr(self._llm, 'model', '') or self.config.summarizer_model), + prompt_tokens=int(getattr(response_msg, 'prompt_tokens', 0) or 0), + completion_tokens=int(getattr(response_msg, 'completion_tokens', 0) or 0), + cached_tokens=int(getattr(response_msg, 'cached_tokens', 0) or 0), + cache_creation_input_tokens=int(getattr(response_msg, 'cache_creation_input_tokens', 0) or 0), api_calls=int(getattr(response_msg, 'api_calls', 0) or 0), ) except asyncio.TimeoutError: logger.warning( f'Summarization timed out after {self.config.summarization_timeout}s, ' - 'returning truncated original content') + 'returning truncated original content' + ) # Return truncated original content truncated = content_to_summarize[:100000] return SummaryResult( @@ -583,9 +569,7 @@ async def summarize(self, ) except Exception as e: - logger.warning( - f'Summarization failed: {e}, returning truncated original content' - ) + logger.warning(f'Summarization failed: {e}, returning truncated original content') truncated = content_to_summarize[:100000] return SummaryResult( summary=truncated, @@ -620,8 +604,7 @@ async def summarize_batch( semaphore = asyncio.Semaphore(max_concurrent) - async def _bounded_summarize( - url: str, content: str) -> Tuple[str, SummaryResult]: + async def _bounded_summarize(url: str, content: str) -> Tuple[str, SummaryResult]: async with semaphore: result = await self.summarize(content, task_context) return url, result @@ -744,8 +727,7 @@ def _compute_recency_score(self, published_at: str) -> float: # Calculate months difference if month: - months_diff = (current_year - year) * 12 + ( - current_month - month) + months_diff = (current_year - year) * 12 + (current_month - month) else: months_diff = (current_year - year) * 12 @@ -786,8 +768,7 @@ def _build_result_meta( url = result.get('url', '') title = result.get('title', '') snippet = result.get('summary', '') or result.get('snippet', '') - published_at = result.get('published_date', '') or result.get( - 'published_at', '') + published_at = result.get('published_date', '') or result.get('published_at', '') source_type = classify_source(url) @@ -799,9 +780,7 @@ def _build_result_meta( # Weighted combination # Title relevance: 40%, Source type: 30%, Recency: 20%, Snippet: 10% - relevance_score = ( - title_relevance * 0.4 + source_weight * 0.3 + recency_score * 0.2 - + snippet_relevance * 0.1) + relevance_score = title_relevance * 0.4 + source_weight * 0.3 + recency_score * 0.2 + snippet_relevance * 0.1 return SearchResultMeta( url=url, @@ -839,10 +818,7 @@ def rerank( return results[:k] # Build metadata for all results - metas = [ - self._build_result_meta(result, idx, query) - for idx, result in enumerate(results) - ] + metas = [self._build_result_meta(result, idx, query) for idx, result in enumerate(results)] # Sort by relevance score (descending) sorted_pairs = sorted( @@ -867,8 +843,7 @@ def rerank( return top_results @staticmethod - def deduplicate_by_url( - results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def deduplicate_by_url(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Remove duplicate results based on URL. @@ -990,16 +965,15 @@ async def summarize_contents( if not self._initialized: await self.initialize() - results = await self.summarizer.summarize_batch( - contents, task_context, max_concurrent) + results = await self.summarizer.summarize_batch(contents, task_context, max_concurrent) # Convert SummaryResult to formatted strings formatted = {} for url, result in results.items(): if result.key_excerpts: formatted[url] = ( - f'\n{result.summary}\n\n\n' - f'\n{result.key_excerpts}\n') + f'\n{result.summary}\n\n\n\n{result.key_excerpts}\n' + ) else: formatted[url] = result.summary @@ -1020,8 +994,7 @@ async def summarize_contents_with_usage( if not self._initialized: await self.initialize() - results = await self.summarizer.summarize_batch( - contents, task_context, max_concurrent) + results = await self.summarizer.summarize_batch(contents, task_context, max_concurrent) formatted: Dict[str, str] = {} # Aggregate usage across results (best-effort; failures may have 0 usage) @@ -1035,8 +1008,8 @@ async def summarize_contents_with_usage( for url, result in results.items(): if result.key_excerpts: formatted[url] = ( - f'\n{result.summary}\n\n\n' - f'\n{result.key_excerpts}\n') + f'\n{result.summary}\n\n\n\n{result.key_excerpts}\n' + ) else: formatted[url] = result.summary @@ -1065,8 +1038,7 @@ async def summarize_contents_with_usage( def create_content_optimizer( summarizer_model: str = 'qwen-flash', - summarizer_base_url: - str = 'https://dashscope.aliyuncs.com/compatible-mode/v1', + summarizer_base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1', summarizer_api_key: Optional[str] = None, max_content_chars: int = 500000, enable_rerank: bool = False, diff --git a/ms_agent/tools/search/exa/schema.py b/ms_agent/tools/search/exa/schema.py index a80a1c401..b306fb8e8 100644 --- a/ms_agent/tools/search/exa/schema.py +++ b/ms_agent/tools/search/exa/schema.py @@ -1,14 +1,12 @@ # flake8: noqa -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - import json +from dataclasses import dataclass, field from exa_py.api import SearchResponse +from typing import Any, Dict, List, Optional @dataclass class ExaSearchRequest: - # The search query string query: str @@ -44,7 +42,7 @@ def to_dict(self) -> Dict[str, Any]: 'start_published_date': self.start_published_date, 'end_published_date': self.end_published_date, 'start_crawl_date': self.start_crawl_date, - 'end_crawl_date': self.end_crawl_date + 'end_crawl_date': self.end_crawl_date, } def to_json(self) -> str: @@ -56,7 +54,6 @@ def to_json(self) -> str: @dataclass class ExaSearchResult: - # The original search query string query: str @@ -81,22 +78,19 @@ def to_list(self): res_list: List[Any] = [] for res in self.response.results: - res_list.append({ - 'url': - getattr(res, 'url', ''), - 'id': - getattr(res, 'id', ''), - 'title': - getattr(res, 'title'), - 'published_date': - getattr(res, 'published_date', ''), - 'summary': - getattr(res, 'summary', ''), - # 'text': getattr(res, 'text', ''), - # 'highlights': getattr(res, 'highlights', ''), - # 'highlight_scores': getattr(res, 'highlight_scores', ''), - # 'markdown': getattr(res, 'markdown', ''), - }) + res_list.append( + { + 'url': getattr(res, 'url', ''), + 'id': getattr(res, 'id', ''), + 'title': getattr(res, 'title'), + 'published_date': getattr(res, 'published_date', ''), + 'summary': getattr(res, 'summary', ''), + # 'text': getattr(res, 'text', ''), + # 'highlights': getattr(res, 'highlights', ''), + # 'highlight_scores': getattr(res, 'highlight_scores', ''), + # 'markdown': getattr(res, 'markdown', ''), + } + ) return res_list @@ -140,18 +134,19 @@ def load_from_disk(file_path: str) -> List[Dict[str, Any]]: return data -def dump_batch_search_results(results: List[ExaSearchResult], - file_path: str) -> None: +def dump_batch_search_results(results: List[ExaSearchResult], file_path: str) -> None: """ Dump a batch of search results to a local file. """ out_list: List[Dict[str, Any]] = [] for res in results: - out_list.append({ - 'query': res.query, - 'arguments': res.arguments, - 'results': res.to_list(), - }) + out_list.append( + { + 'query': res.query, + 'arguments': res.arguments, + 'results': res.to_list(), + } + ) with open(file_path, 'w', encoding='utf-8') as f: json.dump(out_list, f, ensure_ascii=False, indent=2) diff --git a/ms_agent/tools/search/exa/search.py b/ms_agent/tools/search/exa/search.py index 08fa6fb71..c3eba0bd7 100644 --- a/ms_agent/tools/search/exa/search.py +++ b/ms_agent/tools/search/exa/search.py @@ -1,9 +1,9 @@ # flake8: noqa import os import threading +from exa_py import Exa from typing import TYPE_CHECKING, List, Optional, Set, Union -from exa_py import Exa from ms_agent.tools.search.exa.schema import ExaSearchRequest, ExaSearchResult from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType from ms_agent.utils.logger import get_logger @@ -36,13 +36,9 @@ class ExaSearch(SearchEngine): _global_exhausted_keys: Set[str] = set() _global_lock = threading.Lock() - def __init__(self, - api_key: Union[str, list, None] = None, - api_keys: Union[str, list, None] = None): + def __init__(self, api_key: Union[str, list, None] = None, api_keys: Union[str, list, None] = None): all_keys = self._collect_keys(api_key, api_keys) - assert all_keys, ( - 'EXA_API_KEY or EXA_API_KEYS must be set either as arguments ' - 'or as environment variables') + assert all_keys, 'EXA_API_KEY or EXA_API_KEYS must be set either as arguments or as environment variables' self._api_keys: List[str] = all_keys self._lock = threading.Lock() @@ -60,11 +56,12 @@ def __init__(self, if len(all_keys) > 1: with ExaSearch._global_lock: - n_exhausted = sum(1 for k in all_keys - if k in ExaSearch._global_exhausted_keys) - logger.info(f'Exa key pool: {len(all_keys)} keys, ' - f'{n_exhausted} previously exhausted, ' - f'starting at key {start_idx + 1}/{len(all_keys)}') + n_exhausted = sum(1 for k in all_keys if k in ExaSearch._global_exhausted_keys) + logger.info( + f'Exa key pool: {len(all_keys)} keys, ' + f'{n_exhausted} previously exhausted, ' + f'starting at key {start_idx + 1}/{len(all_keys)}' + ) @staticmethod def _collect_keys( @@ -119,8 +116,7 @@ def _add_source(value): def _is_credits_exhausted(error: Exception) -> bool: """Detect Exa 402 / NO_MORE_CREDITS errors.""" msg = str(error) - return ('402' in msg - and ('credits' in msg.lower() or 'NO_MORE_CREDITS' in msg)) + return '402' in msg and ('credits' in msg.lower() or 'NO_MORE_CREDITS' in msg) @staticmethod def _mask_key(key: str) -> str: @@ -162,8 +158,7 @@ def search(self, search_request: ExaSearchRequest) -> ExaSearchResult: key_idx = self._current_key_idx try: - search_result.response = client.search_and_contents( - **search_args) + search_result.response = client.search_and_contents(**search_args) return search_result except Exception as e: if not self._is_credits_exhausted(e): @@ -181,28 +176,30 @@ def search(self, search_request: ExaSearchRequest) -> ExaSearchResult: ) rotated = False for i in range(len(self._api_keys)): - if i not in instance_exhausted and not self._is_key_exhausted( - i): + if i not in instance_exhausted and not self._is_key_exhausted(i): self._current_key_idx = i self.client = Exa(api_key=self._api_keys[i]) - logger.info(f'Rotated to Exa API key ' - f'{self._mask_key(self._api_keys[i])} ' - f'({i + 1}/{len(self._api_keys)})') + logger.info( + f'Rotated to Exa API key ' + f'{self._mask_key(self._api_keys[i])} ' + f'({i + 1}/{len(self._api_keys)})' + ) rotated = True break if not rotated: raise RuntimeError( - f'All {len(self._api_keys)} Exa API keys have ' - f'been exhausted. Last error: {e}') from e + f'All {len(self._api_keys)} Exa API keys have been exhausted. Last error: {e}' + ) from e raise RuntimeError( - f'All {len(self._api_keys)} Exa API keys have been exhausted. ' - f'Last error: {last_error}') from last_error + f'All {len(self._api_keys)} Exa API keys have been exhausted. Last error: {last_error}' + ) from last_error @classmethod def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': """Return the tool definition for Exa search engine.""" from ms_agent.llm.utils import Tool + return Tool( tool_name=cls.get_tool_name(), server_name=server_name, @@ -210,50 +207,44 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'Search the web using Exa neural search engine. ' 'Best for: semantic understanding, finding relevant content, ' 'recent web pages with date filtering. ' - 'Supports neural search (meaning-based) and keyword search.'), + 'Supports neural search (meaning-based) and keyword search.' + ), parameters={ 'type': 'object', 'properties': { 'query': { - 'type': - 'string', - 'description': - ('The search query. For neural search, use natural language ' - 'descriptions. For keyword search, use Google-style queries.' - ), + 'type': 'string', + 'description': ( + 'The search query. For neural search, use natural language ' + 'descriptions. For keyword search, use Google-style queries.' + ), }, 'num_results': { - 'type': - 'integer', - 'minimum': - 1, - 'maximum': - 10, - 'description': - 'Number of results to return. Default is 5.', + 'type': 'integer', + 'minimum': 1, + 'maximum': 10, + 'description': 'Number of results to return. Default is 5.', }, 'type': { - 'type': - 'string', + 'type': 'string', 'enum': ['auto', 'neural', 'keyword'], - 'description': - ('Search type. "neural" for semantic similarity, ' - '"keyword" for exact matching, "auto" to let Exa decide. ' - 'Default is "auto".'), + 'description': ( + 'Search type. "neural" for semantic similarity, ' + '"keyword" for exact matching, "auto" to let Exa decide. ' + 'Default is "auto".' + ), }, 'start_published_date': { - 'type': - 'string', - 'description': - ('Filter results published on/after this date. ' - 'Format: YYYY-MM-DD (e.g., "2024-01-01").'), + 'type': 'string', + 'description': ( + 'Filter results published on/after this date. Format: YYYY-MM-DD (e.g., "2024-01-01").' + ), }, 'end_published_date': { - 'type': - 'string', - 'description': - ('Filter results published on/before this date. ' - 'Format: YYYY-MM-DD (e.g., "2024-12-31").'), + 'type': 'string', + 'description': ( + 'Filter results published on/before this date. Format: YYYY-MM-DD (e.g., "2024-12-31").' + ), }, }, 'required': ['query'], diff --git a/ms_agent/tools/search/localsearch_tool.py b/ms_agent/tools/search/localsearch_tool.py index d8ab34ef7..7ec628dbf 100644 --- a/ms_agent/tools/search/localsearch_tool.py +++ b/ms_agent/tools/search/localsearch_tool.py @@ -5,12 +5,9 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from ms_agent.tools.search.sirchmunk_search import ( - SirchmunkSearch, - effective_localsearch_settings, -) from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase +from ms_agent.tools.search.sirchmunk_search import SirchmunkSearch, effective_localsearch_settings from ms_agent.utils.logger import get_logger logger = get_logger() @@ -60,9 +57,7 @@ def _resolved_localsearch_paths_from_config(config) -> List[str]: def _format_configured_roots(paths: List[str]) -> str: if not paths: - return ( - '(none — set tools.localsearch.paths in agent config, ' - 'or legacy knowledge_search.paths)') + return '(none — set tools.localsearch.paths in agent config, or legacy knowledge_search.paths)' return '\n'.join(f'- {p}' for p in paths) @@ -91,12 +86,10 @@ def __init__(self, config, **kwargs): if tool_cfg is not None: self.exclude_func(tool_cfg) self._searcher: Optional[SirchmunkSearch] = None - self._configured_roots: List[str] = ( - _resolved_localsearch_paths_from_config(config)) + self._configured_roots: List[str] = _resolved_localsearch_paths_from_config(config) def _tool_description(self) -> str: - return _LOCALSEARCH_DESCRIPTION.format( - configured_roots=_format_configured_roots(self._configured_roots)) + return _LOCALSEARCH_DESCRIPTION.format(configured_roots=_format_configured_roots(self._configured_roots)) def _paths_param_description(self) -> str: roots = _format_configured_roots(self._configured_roots) @@ -104,7 +97,8 @@ def _paths_param_description(self) -> str: 'Optional. Narrow search to specific files or directories under the ' 'configured roots below. Each path must exist on disk and lie under ' 'one of these roots (or be exactly one of them).\n' - f'Configured roots:\n{roots}') + f'Configured roots:\n{roots}' + ) def _ensure_searcher(self) -> SirchmunkSearch: if self._searcher is None: @@ -122,60 +116,43 @@ async def _get_tools_inner(self) -> Dict[str, List[Tool]]: server_name=_SERVER, description=self._tool_description(), parameters={ - 'type': - 'object', + 'type': 'object', 'properties': { 'query': { - 'type': - 'string', - 'description': - 'Search keywords or natural-language question about local content.', + 'type': 'string', + 'description': 'Search keywords or natural-language question about local content.', }, 'paths': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - self._paths_param_description(), + 'type': 'array', + 'items': {'type': 'string'}, + 'description': self._paths_param_description(), }, 'mode': { - 'type': - 'string', + 'type': 'string', 'enum': ['FAST', 'DEEP', 'FILENAME_ONLY'], - 'description': - 'Search mode; omit to use agent default (usually FAST).', + 'description': 'Search mode; omit to use agent default (usually FAST).', }, 'max_depth': { 'type': 'integer', 'minimum': 1, 'maximum': 20, - 'description': - 'Max directory depth for filesystem search.', + 'description': 'Max directory depth for filesystem search.', }, 'top_k_files': { 'type': 'integer', 'minimum': 1, 'maximum': 20, - 'description': - 'Max files for evidence / filename hits.', + 'description': 'Max files for evidence / filename hits.', }, 'include': { 'type': 'array', - 'items': { - 'type': 'string' - }, - 'description': - 'Glob patterns to include (e.g. *.py, *.md).', + 'items': {'type': 'string'}, + 'description': 'Glob patterns to include (e.g. *.py, *.md).', }, 'exclude': { 'type': 'array', - 'items': { - 'type': 'string' - }, - 'description': - 'Glob patterns to exclude (e.g. *.pyc).', + 'items': {'type': 'string'}, + 'description': 'Glob patterns to exclude (e.g. *.pyc).', }, }, 'required': ['query'], @@ -184,8 +161,7 @@ async def _get_tools_inner(self) -> Dict[str, List[Tool]]: ] } - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict): + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict): del server_name if tool_name != _TOOL: return f'Unknown tool: {tool_name}' @@ -219,11 +195,11 @@ async def call_tool(self, server_name: str, *, tool_name: str, if paths_arg: resolved_paths = searcher.resolve_tool_paths(paths_arg) if not resolved_paths: - roots = _format_configured_roots( - self._configured_roots) + roots = _format_configured_roots(self._configured_roots) return ( 'Error: `paths` are invalid. Each path must exist on disk and lie ' - 'under one of these configured roots:\n' + roots) + 'under one of these configured roots:\n' + roots + ) answer = await searcher.query( query, @@ -266,8 +242,7 @@ async def call_tool(self, server_name: str, *, tool_name: str, result_parts.append('\nSource paths:') for item in excerpts[:12]: meta = item.get('metadata') or {} - result_parts.append( - f'- {meta.get("source", "?")}') + result_parts.append(f'- {meta.get("source", "?")}') result_text = '\n'.join(result_parts) return { @@ -279,4 +254,3 @@ async def call_tool(self, server_name: str, *, tool_name: str, except Exception as exc: logger.warning(f'localsearch failed: {exc}') return f'Local search failed: {exc}' - diff --git a/ms_agent/tools/search/search_base.py b/ms_agent/tools/search/search_base.py index 8a10c9d20..443d4c79b 100644 --- a/ms_agent/tools/search/search_base.py +++ b/ms_agent/tools/search/search_base.py @@ -1,12 +1,11 @@ # flake8: noqa import enum +import json import os from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar -import json - if TYPE_CHECKING: from ms_agent.llm.utils import Tool @@ -63,10 +62,7 @@ class SearchResponse(Generic[T]): class SearchRequest(ABC): """Abstract base class for search requests.""" - def __init__(self, - query: str, - num_results: Optional[int] = 10, - **kwargs: Any): + def __init__(self, query: str, num_results: Optional[int] = 10, **kwargs: Any): """ Initialize SearchRequest with search parameters. @@ -96,10 +92,7 @@ def to_json(self) -> str: class SearchResult(ABC): """Base class for search results.""" - def __init__(self, - query: str, - arguments: Optional[Dict[str, Any]] = None, - response: Any = None): + def __init__(self, query: str, arguments: Optional[Dict[str, Any]] = None, response: Any = None): """ Initialize SearchResult. @@ -137,15 +130,17 @@ def to_list(self) -> List[Dict[str, Any]]: res_list: List[Dict[str, Any]] = [] for res in self.response.results: - res_list.append({ - 'url': res.url, - 'id': res.id, - 'title': res.title, - 'highlights': res.highlights, - 'highlight_scores': res.highlight_scores, - 'summary': res.summary, - 'markdown': res.markdown, - }) + res_list.append( + { + 'url': res.url, + 'id': res.id, + 'title': res.title, + 'highlights': res.highlights, + 'highlight_scores': res.highlight_scores, + 'summary': res.summary, + 'markdown': res.markdown, + } + ) return res_list @@ -201,6 +196,7 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': Tool definition dict """ from ms_agent.llm.utils import Tool + return Tool( tool_name=cls.get_tool_name(), server_name=server_name, @@ -235,5 +231,4 @@ def build_request_from_args(cls, **kwargs) -> SearchRequest: Returns: SearchRequest instance """ - raise NotImplementedError( - f'{cls.__name__} must implement build_request_from_args') + raise NotImplementedError(f'{cls.__name__} must implement build_request_from_args') diff --git a/ms_agent/tools/search/search_request.py b/ms_agent/tools/search/search_request.py index e2aed6666..81c8e13d7 100644 --- a/ms_agent/tools/search/search_request.py +++ b/ms_agent/tools/search/search_request.py @@ -114,7 +114,7 @@ def get_rewrite_prompt(self) -> str: f'\n2. 其中,query参数的值直接使用用户原始输入,即:{self.user_prompt}' f'\n3. 参数需要符合搜索引擎的要求,num_results需要根据实际问题的复杂程度来估算,最大25,最小1,对于复杂的问题,num_results的值需要尽量大;' f'\n4. start_published_date和end_published_date需要根据实际问题的时间范围来估算,默认均为None。' - f'当前日期为:{datetime.now().strftime("%Y-%m-%d")}') + f'当前日期为:{datetime.now().strftime('%Y-%m-%d')}') def create_request(self, search_request_d: Dict[str, Any]) -> ExaSearchRequest: diff --git a/ms_agent/tools/search/serpapi/__init__.py b/ms_agent/tools/search/serpapi/__init__.py index 8a46380ac..cf16db488 100644 --- a/ms_agent/tools/search/serpapi/__init__.py +++ b/ms_agent/tools/search/serpapi/__init__.py @@ -1,4 +1,3 @@ # flake8: noqa -from ms_agent.tools.search.serpapi.schema import (SerpApiSearchRequest, - SerpApiSearchResult) +from ms_agent.tools.search.serpapi.schema import SerpApiSearchRequest, SerpApiSearchResult from ms_agent.tools.search.serpapi.search import SerpApiSearch diff --git a/ms_agent/tools/search/serpapi/schema.py b/ms_agent/tools/search/serpapi/schema.py index 633250473..5643ac4b7 100644 --- a/ms_agent/tools/search/serpapi/schema.py +++ b/ms_agent/tools/search/serpapi/schema.py @@ -2,8 +2,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional -from ms_agent.tools.search.search_base import (BaseResult, SearchRequest, - SearchResponse, SearchResult) +from ms_agent.tools.search.search_base import BaseResult, SearchRequest, SearchResponse, SearchResult class SerpApiSearchRequest(SearchRequest): @@ -11,11 +10,7 @@ class SerpApiSearchRequest(SearchRequest): A class representing a search request to SerpApi. """ - def __init__(self, - query: str, - num_results: Optional[int] = 5, - location: Optional[str] = None, - **kwargs: Any): + def __init__(self, query: str, num_results: Optional[int] = 5, location: Optional[str] = None, **kwargs: Any): """ Initialize SerpApiSearchRequest with search parameters. @@ -34,21 +29,13 @@ def to_dict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The parameters as a dictionary """ - return { - 'q': self.query, - 'num': self.num_results, - 'location': self.location - } + return {'q': self.query, 'num': self.num_results, 'location': self.location} class SerpApiSearchResult(SearchResult): """SerpApi search result implementation.""" - def __init__(self, - provider: str, - query: str, - arguments: Dict[str, Any] = None, - response: Dict[str, Any] = None): + def __init__(self, provider: str, query: str, arguments: Dict[str, Any] = None, response: Dict[str, Any] = None): """ Initialize SerpApiSearchResult. @@ -76,8 +63,7 @@ def _process_results(self) -> SearchResponse: processed = [] if self.provider.lower() in ['google', 'bing', 'baidu']: # Extract organic results - organic_results: List[Dict[str, Any]] = self.response.get( - 'organic_results', []) + organic_results: List[Dict[str, Any]] = self.response.get('organic_results', []) for res in organic_results: processed.append( BaseResult( @@ -87,9 +73,10 @@ def _process_results(self) -> SearchResponse: highlights=res.get('snippet_highlighted_words'), highlight_scores=None, summary=None, - markdown=None)) + markdown=None, + ) + ) else: - raise NotImplementedError( - f"Provider '{self.provider}' is not supported yet.") + raise NotImplementedError(f"Provider '{self.provider}' is not supported yet.") return SearchResponse(results=processed) diff --git a/ms_agent/tools/search/serpapi/search.py b/ms_agent/tools/search/serpapi/search.py index d42fef413..adebbe044 100644 --- a/ms_agent/tools/search/serpapi/search.py +++ b/ms_agent/tools/search/serpapi/search.py @@ -3,8 +3,7 @@ from typing import TYPE_CHECKING from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType -from ms_agent.tools.search.serpapi.schema import (SerpApiSearchRequest, - SerpApiSearchResult) +from ms_agent.tools.search.serpapi.schema import SerpApiSearchRequest, SerpApiSearchResult if TYPE_CHECKING: from ms_agent.llm.utils import Tool @@ -21,16 +20,13 @@ class SerpApiSearch(SearchEngine): engine_type = SearchEngineType.SERPAPI def __init__(self, api_key: str = None, provider: str = None): - api_key = api_key or os.getenv('SERPAPI_API_KEY') assert api_key, 'SERPAPI_API_KEY must be set either as an argument or as an environment variable' self.provider = (provider or 'google').lower() - self.client = self._get_search_client( - provider=self.provider, api_key=api_key) + self.client = self._get_search_client(provider=self.provider, api_key=api_key) - def search(self, - search_request: SerpApiSearchRequest) -> SerpApiSearchResult: + def search(self, search_request: SerpApiSearchRequest) -> SerpApiSearchResult: """ Perform a search using SerpApi and return the results. @@ -46,10 +42,8 @@ def search(self, self.client.params_dict.update(search_args) response = self.client.get_dict() search_result = SerpApiSearchResult( - provider=self.provider, - query=search_request.query, - arguments=search_args, - response=response) + provider=self.provider, query=search_request.query, arguments=search_args, response=response + ) except Exception as e: raise RuntimeError(f'Failed to perform search: {e}') from e @@ -59,6 +53,7 @@ def search(self, def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': """Return the tool definition for SerpApi search engine.""" from ms_agent.llm.utils import Tool + return Tool( tool_name=cls.get_tool_name(), server_name=server_name, @@ -67,33 +62,28 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'Default provider is Google. ' 'Best for: general web search, current events, news, ' 'real-time information, and location-specific results. ' - 'Supports Google search operators.'), + 'Supports Google search operators.' + ), parameters={ 'type': 'object', 'properties': { 'query': { - 'type': - 'string', - 'description': - ('Google-style search query. Use operators as needed: ' - 'quotes for exact phrases ("..."), OR, -term to exclude. ' - 'Date limits: before:YYYY-MM-DD, after:YYYY-MM-DD.'), + 'type': 'string', + 'description': ( + 'Google-style search query. Use operators as needed: ' + 'quotes for exact phrases ("..."), OR, -term to exclude. ' + 'Date limits: before:YYYY-MM-DD, after:YYYY-MM-DD.' + ), }, 'num_results': { - 'type': - 'integer', - 'minimum': - 1, - 'maximum': - 10, - 'description': - 'Number of results to return. Default is 5.', + 'type': 'integer', + 'minimum': 1, + 'maximum': 10, + 'description': 'Number of results to return. Default is 5.', }, 'location': { - 'type': - 'string', - 'description': - ('Geographic location filter. Default is null'), + 'type': 'string', + 'description': ('Geographic location filter. Default is null'), }, }, 'required': ['query'], diff --git a/ms_agent/tools/search/sirchmunk_search.py b/ms_agent/tools/search/sirchmunk_search.py index cd86f9d63..a49cae982 100644 --- a/ms_agent/tools/search/sirchmunk_search.py +++ b/ms_agent/tools/search/sirchmunk_search.py @@ -38,8 +38,7 @@ def effective_localsearch_settings(config: DictConfig) -> Optional[Any]: tools = getattr(config, 'tools', None) tl = None if tools is not None: - tl = tools.get('localsearch') if hasattr(tools, 'get') else getattr( - tools, 'localsearch', None) + tl = tools.get('localsearch') if hasattr(tools, 'get') else getattr(tools, 'localsearch', None) ks = getattr(config, 'knowledge_search', None) if tl is not None and _paths_from_block(tl): @@ -81,16 +80,13 @@ def __init__(self, config: DictConfig): paths = rag_config.get('paths', []) if isinstance(paths, str): paths = [paths] - self.search_paths: List[str] = [ - str(Path(p).expanduser().resolve()) for p in paths - ] + self.search_paths: List[str] = [str(Path(p).expanduser().resolve()) for p in paths] _work_path = rag_config.get('work_path', './.sirchmunk') self.work_path: Path = Path(_work_path).expanduser().resolve() self.reuse_knowledge = rag_config.get('reuse_knowledge', True) - self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', - 0.85) + self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', 0.85) self.cluster_sim_top_k = rag_config.get('cluster_sim_top_k', 3) self.search_mode = rag_config.get('mode', 'FAST') self.max_loops = rag_config.get('max_loops', 10) @@ -100,23 +96,19 @@ def __init__(self, config: DictConfig): self.llm_base_url = rag_config.get('llm_base_url', None) self.llm_model_name = rag_config.get('llm_model_name', None) - if (self.llm_api_key is None or self.llm_base_url is None - or self.llm_model_name is None): + if self.llm_api_key is None or self.llm_base_url is None or self.llm_model_name is None: llm_config = config.get('llm', {}) if llm_config: service = getattr(llm_config, 'service', 'dashscope') if self.llm_api_key is None: - self.llm_api_key = getattr(llm_config, - f'{service}_api_key', None) + self.llm_api_key = getattr(llm_config, f'{service}_api_key', None) if self.llm_base_url is None: - self.llm_base_url = getattr(llm_config, - f'{service}_base_url', None) + self.llm_base_url = getattr(llm_config, f'{service}_base_url', None) if self.llm_model_name is None: self.llm_model_name = getattr(llm_config, 'model', None) self.embedding_model_id = rag_config.get('embedding_model', None) - self.embedding_model_cache_dir = rag_config.get( - 'embedding_model_cache_dir', None) + self.embedding_model_cache_dir = rag_config.get('embedding_model_cache_dir', None) self._searcher = None self._initialized = False @@ -135,15 +127,15 @@ def _validate_config(self, config: DictConfig): raise ValueError( 'Missing localsearch configuration. Add ' '`tools.localsearch` with non-empty `paths` (or legacy ' - '`knowledge_search.paths`).') + '`knowledge_search.paths`).' + ) paths = _paths_from_block(block) if not paths: raise ValueError( - 'tools.localsearch.paths (or legacy knowledge_search.paths) ' - 'must be specified and non-empty') + 'tools.localsearch.paths (or legacy knowledge_search.paths) must be specified and non-empty' + ) - def resolve_tool_paths( - self, paths: Optional[List[str]]) -> Optional[List[str]]: + def resolve_tool_paths(self, paths: Optional[List[str]]) -> Optional[List[str]]: """Restrict per-call paths to configured search roots.""" if not paths: return None @@ -156,12 +148,9 @@ def resolve_tool_paths( if not p.exists(): logger.warning(f'localsearch: path does not exist, skipped: {p}') continue - allowed = any( - p == r or p.is_relative_to(r) for r in roots) + allowed = any(p == r or p.is_relative_to(r) for r in roots) if not allowed: - logger.warning( - f'localsearch: path outside configured search roots, ' - f'skipped: {p}') + logger.warning(f'localsearch: path outside configured search roots, skipped: {p}') continue cleaned.append(str(p)) return cleaned or None @@ -184,13 +173,9 @@ def _initialize_searcher(self): log_callback=self._log_callback_wrapper(), ) - embedding_model_id = ( - self.embedding_model_id if self.embedding_model_id else None) - embedding_cache_dir = ( - self.embedding_model_cache_dir - if self.embedding_model_cache_dir else None) - embedding = EmbeddingUtil( - model_id=embedding_model_id, cache_dir=embedding_cache_dir) + embedding_model_id = self.embedding_model_id if self.embedding_model_id else None + embedding_cache_dir = self.embedding_model_cache_dir if self.embedding_model_cache_dir else None + embedding = EmbeddingUtil(model_id=embedding_model_id, cache_dir=embedding_cache_dir) self._searcher = AgenticSearch( llm=llm, @@ -205,14 +190,10 @@ def _initialize_searcher(self): ) self._initialized = True - logger.info( - f'SirschmunkSearch initialized with paths: {self.search_paths}' - ) + logger.info(f'SirschmunkSearch initialized with paths: {self.search_paths}') except ImportError as e: - raise ImportError( - f'Failed to import sirchmunk: {e}. ' - 'Please install sirchmunk: pip install sirchmunk') + raise ImportError(f'Failed to import sirchmunk: {e}. Please install sirchmunk: pip install sirchmunk') except Exception as e: raise RuntimeError(f'Failed to initialize SirchmunkSearch: {e}') @@ -279,19 +260,16 @@ async def add_documents_from_files(self, file_paths: List[str]) -> bool: try: for file_path in file_paths: if Path(file_path).exists(): - await self._searcher.scan_directory( - str(Path(file_path).parent)) + await self._searcher.scan_directory(str(Path(file_path).parent)) return True except Exception as e: logger.error(f'Failed to scan files: {e}') return False return True - async def retrieve(self, - query: str, - limit: int = 5, - score_threshold: float = 0.7, - **filters) -> List[Dict[str, Any]]: + async def retrieve( + self, query: str, limit: int = 5, score_threshold: float = 0.7, **filters + ) -> List[Dict[str, Any]]: """Retrieve relevant documents using sirchmunk. Args: @@ -310,8 +288,7 @@ async def retrieve(self, try: mode = filters.get('mode', self.search_mode) max_loops = filters.get('max_loops', self.max_loops) - max_token_budget = filters.get('max_token_budget', - self.max_token_budget) + max_token_budget = filters.get('max_token_budget', self.max_token_budget) result = await self._searcher.search( query=query, @@ -324,11 +301,9 @@ async def retrieve(self, self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - self._cluster_cache_hit = getattr(result.cluster, - '_reused_from_cache', False) + self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) if hasattr(result.cluster, 'updated_at'): - self._cluster_cache_hit_time = getattr( - result.cluster, 'updated_at', None) + self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) return self._parse_search_result(result, score_threshold, limit) @@ -374,8 +349,7 @@ async def query( mode_eff = mode_eff.strip().upper() allowed_modes = ('FAST', 'DEEP', 'FILENAME_ONLY') if mode_eff not in allowed_modes: - return ( - f'Invalid mode {mode_eff!r}; use one of {allowed_modes}.') + return f'Invalid mode {mode_eff!r}; use one of {allowed_modes}.' kw: Dict[str, Any] = dict( query=query, @@ -402,34 +376,29 @@ async def query( self._last_search_result = [] for item in result[:20]: if isinstance(item, dict): - src = (item.get('path') or item.get('file_path') - or item.get('file') or '') - self._last_search_result.append({ - 'text': - json.dumps(item, ensure_ascii=False), - 'score': - 1.0, - 'metadata': { - 'source': str(src), - 'type': 'filename_match', - }, - }) + src = item.get('path') or item.get('file_path') or item.get('file') or '' + self._last_search_result.append( + { + 'text': json.dumps(item, ensure_ascii=False), + 'score': 1.0, + 'metadata': { + 'source': str(src), + 'type': 'filename_match', + }, + } + ) return json.dumps(result, ensure_ascii=False, indent=2) self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - self._cluster_cache_hit = getattr(result.cluster, - '_reused_from_cache', False) + self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) if hasattr(result.cluster, 'updated_at'): - self._cluster_cache_hit_time = getattr( - result.cluster, 'updated_at', None) + self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) - self._last_search_result = self._parse_search_result( - result, score_threshold=0.7, limit=5) + self._last_search_result = self._parse_search_result(result, score_threshold=0.7, limit=5) - if hasattr(result, 'answer') and getattr(result, 'answer', - None) is not None: + if hasattr(result, 'answer') and getattr(result, 'answer', None) is not None: return result.answer if isinstance(result, str): @@ -441,8 +410,7 @@ async def query( logger.error(f'SirschmunkSearch query failed: {e}') return f'Query failed: {e}' - def _parse_search_result(self, result: Any, score_threshold: float, - limit: int) -> List[Dict[str, Any]]: + def _parse_search_result(self, result: Any, score_threshold: float, limit: int) -> List[Dict[str, Any]]: """Parse sirchmunk search result into standard format. Args: @@ -468,62 +436,57 @@ def _parse_search_result(self, result: Any, score_threshold: float, else: text_parts.append(str(snippet)) - results.append({ - 'text': - '\n'.join(text_parts) if text_parts else getattr( - unit, 'summary', ''), - 'score': - score, - 'metadata': { - 'source': - source, - 'type': - getattr(unit, 'abstraction_level', 'text') - if hasattr(unit, 'abstraction_level') else 'text', - }, - }) + results.append( + { + 'text': '\n'.join(text_parts) if text_parts else getattr(unit, 'summary', ''), + 'score': score, + 'metadata': { + 'source': source, + 'type': getattr(unit, 'abstraction_level', 'text') + if hasattr(unit, 'abstraction_level') + else 'text', + }, + } + ) elif hasattr(result, 'evidence_units'): for unit in result.evidence_units: score = getattr(unit, 'confidence', 1.0) if score >= score_threshold: - results.append({ - 'text': - str(unit.content) - if hasattr(unit, 'content') else str(unit), - 'score': - score, - 'metadata': { - 'source': getattr(unit, 'source_file', 'unknown'), - 'type': getattr(unit, 'abstraction_level', 'text'), - }, - }) + results.append( + { + 'text': str(unit.content) if hasattr(unit, 'content') else str(unit), + 'score': score, + 'metadata': { + 'source': getattr(unit, 'source_file', 'unknown'), + 'type': getattr(unit, 'abstraction_level', 'text'), + }, + } + ) elif isinstance(result, list): for item in result: if isinstance(item, dict): score = item.get('score', item.get('confidence', 1.0)) if score >= score_threshold: - results.append({ - 'text': - item.get('content', item.get('text', str(item))), - 'score': - score, - 'metadata': - item.get('metadata', {}), - }) + results.append( + { + 'text': item.get('content', item.get('text', str(item))), + 'score': score, + 'metadata': item.get('metadata', {}), + } + ) elif isinstance(result, dict): score = result.get('score', result.get('confidence', 1.0)) if score >= score_threshold: - results.append({ - 'text': - result.get('content', result.get('text', str(result))), - 'score': - score, - 'metadata': - result.get('metadata', {}), - }) + results.append( + { + 'text': result.get('content', result.get('text', str(result))), + 'score': score, + 'metadata': result.get('metadata', {}), + } + ) results.sort(key=lambda x: x.get('score', 0), reverse=True) return results[:limit] diff --git a/ms_agent/tools/search/tavily/fetcher.py b/ms_agent/tools/search/tavily/fetcher.py index 4082907a2..b38e2bf96 100644 --- a/ms_agent/tools/search/tavily/fetcher.py +++ b/ms_agent/tools/search/tavily/fetcher.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """Tavily Extract API as ContentFetcher (replaces Jina for fetch_page / URL fetch).""" + import os import time from typing import Any, Dict, Optional, Tuple @@ -34,8 +35,7 @@ def __init__( ): key = api_key or os.getenv('TAVILY_API_KEY') if not key: - raise ValueError( - 'TAVILY_API_KEY required for tavily_extract fetcher') + raise ValueError('TAVILY_API_KEY required for tavily_extract fetcher') self._api_key = key self._extract_depth = extract_depth self._format = format diff --git a/ms_agent/tools/search/tavily/http.py b/ms_agent/tools/search/tavily/http.py index d4916d271..7c1d3981b 100644 --- a/ms_agent/tools/search/tavily/http.py +++ b/ms_agent/tools/search/tavily/http.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """Minimal HTTP JSON client for Tavily REST API (stdlib only).""" + import json from typing import Any, Dict from urllib.error import HTTPError, URLError @@ -44,7 +45,6 @@ def post_json( detail = json.loads(err_body) if err_body else {} except json.JSONDecodeError: detail = {'raw': err_body} - raise RuntimeError( - f'Tavily HTTP {e.code}: {detail}') from e + raise RuntimeError(f'Tavily HTTP {e.code}: {detail}') from e except URLError as e: raise RuntimeError(f'Tavily network error: {e}') from e diff --git a/ms_agent/tools/search/tavily/schema.py b/ms_agent/tools/search/tavily/schema.py index 75f3f0aed..9b92d812f 100644 --- a/ms_agent/tools/search/tavily/schema.py +++ b/ms_agent/tools/search/tavily/schema.py @@ -86,21 +86,23 @@ def to_list(self) -> List[Dict[str, Any]]: raw = (r.get('raw_content') or '').strip() # Prefer full page text for downstream summarization; fallback to snippets body = raw if raw else snippet - rows.append({ - 'url': url, - 'id': url, - 'title': title, - 'highlights': None, - 'highlight_scores': None, - 'summary': snippet, - 'markdown': raw if raw else None, - # Pipeline uses these keys: - 'content': body, - 'fetch_success': bool(raw), - 'score': r.get('score'), - 'tavily_images': r.get('images') or [], - 'favicon': r.get('favicon'), - }) + rows.append( + { + 'url': url, + 'id': url, + 'title': title, + 'highlights': None, + 'highlight_scores': None, + 'summary': snippet, + 'markdown': raw if raw else None, + # Pipeline uses these keys: + 'content': body, + 'fetch_success': bool(raw), + 'score': r.get('score'), + 'tavily_images': r.get('images') or [], + 'favicon': r.get('favicon'), + } + ) return rows def extra_response_fields(self) -> Dict[str, Any]: diff --git a/ms_agent/tools/search/tavily/search.py b/ms_agent/tools/search/tavily/search.py index b4b7d3f3b..aa42cb556 100644 --- a/ms_agent/tools/search/tavily/search.py +++ b/ms_agent/tools/search/tavily/search.py @@ -33,17 +33,14 @@ def __init__( ): key = api_key or os.getenv('TAVILY_API_KEY') if not key: - raise ValueError( - 'TAVILY_API_KEY must be set in environment or web_search.tavily_api_key' - ) + raise ValueError('TAVILY_API_KEY must be set in environment or web_search.tavily_api_key') self._api_key = key self._request_timeout = float(request_timeout) def search(self, search_request: TavilySearchRequest) -> TavilySearchResult: body = search_request.to_api_body(self._api_key) try: - data = post_json( - TAVILY_SEARCH_URL, body, timeout=self._request_timeout) + data = post_json(TAVILY_SEARCH_URL, body, timeout=self._request_timeout) except Exception as e: raise RuntimeError(f'Tavily search failed: {e}') from e safe_args = {k: v for k, v in body.items() if k != 'api_key'} @@ -57,6 +54,7 @@ def search(self, search_request: TavilySearchRequest) -> TavilySearchResult: @classmethod def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': from ms_agent.llm.utils import Tool + return Tool( tool_name=cls.get_tool_name(), server_name=server_name, @@ -64,7 +62,8 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'Search the web using Tavily (built for AI agents). ' 'Returns ranked results with optional full-page markdown via ' '`include_raw_content`. Use `search_depth` advanced for best ' - 'relevance and richer `content` chunks (higher API credit use).'), + 'relevance and richer `content` chunks (higher API credit use).' + ), parameters={ 'type': 'object', 'properties': { @@ -81,20 +80,18 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'search_depth': { 'type': 'string', 'enum': ['advanced', 'basic', 'fast', 'ultra-fast'], - 'description': - ('advanced: best quality, 2 credits; ' - 'basic/fast/ultra-fast: 1 credit (see Tavily docs).'), + 'description': ( + 'advanced: best quality, 2 credits; basic/fast/ultra-fast: 1 credit (see Tavily docs).' + ), }, 'topic': { 'type': 'string', 'enum': ['general', 'news', 'finance'], - 'description': - 'Search category (`news` / `finance` for focused verticals).', + 'description': 'Search category (`news` / `finance` for focused verticals).', }, 'time_range': { 'type': 'string', - 'description': - ('Filter by recency: day, week, month, year or d,w,m,y.'), + 'description': ('Filter by recency: day, week, month, year or d,w,m,y.'), }, 'start_date': { 'type': 'string', @@ -107,50 +104,39 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'include_answer': { 'type': 'string', 'enum': ['false', 'true', 'basic', 'advanced'], - 'description': - ('LLM answer: true/basic for short, advanced for detailed. ' - 'Use false to skip.'), + 'description': ('LLM answer: true/basic for short, advanced for detailed. Use false to skip.'), }, 'include_raw_content': { 'type': 'string', - 'enum': - ['false', 'true', 'markdown', 'text'], - 'description': - ('full page text: markdown (recommended) or text; ' - 'false to skip raw content.'), + 'enum': ['false', 'true', 'markdown', 'text'], + 'description': ('full page text: markdown (recommended) or text; false to skip raw content.'), }, 'chunks_per_source': { 'type': 'integer', 'minimum': 1, 'maximum': 3, - 'description': - ('Relevant chunks per URL when search_depth=advanced. ' - 'Each chunk up to ~500 chars in `content` field.'), + 'description': ( + 'Relevant chunks per URL when search_depth=advanced. ' + 'Each chunk up to ~500 chars in `content` field.' + ), }, 'include_domains': { 'type': 'array', - 'items': { - 'type': 'string' - }, + 'items': {'type': 'string'}, 'description': 'Only include these domains (max 300).', }, 'exclude_domains': { 'type': 'array', - 'items': { - 'type': 'string' - }, + 'items': {'type': 'string'}, 'description': 'Exclude domains (max 150).', }, 'country': { 'type': 'string', - 'description': - ('Boost results from country (e.g. united states). ' - 'See Tavily docs for enum.'), + 'description': ('Boost results from country (e.g. united states). See Tavily docs for enum.'), }, 'exact_match': { 'type': 'boolean', - 'description': - 'Only results with exact quoted phrases in query.', + 'description': 'Only results with exact quoted phrases in query.', }, }, 'required': ['query'], @@ -204,8 +190,7 @@ def _boolish(name: str, default: Any) -> Any: include_answer=inc_ans, include_raw_content=inc_raw, include_images=bool(_boolish('include_images', False)), - include_image_descriptions=bool( - _boolish('include_image_descriptions', False)), + include_image_descriptions=bool(_boolish('include_image_descriptions', False)), include_favicon=bool(_boolish('include_favicon', False)), include_domains=list(kwargs.get('include_domains') or []), exclude_domains=list(kwargs.get('exclude_domains') or []), diff --git a/ms_agent/tools/search/web_search_spill.py b/ms_agent/tools/search/web_search_spill.py index 9c8da73c3..f1b92ebb1 100644 --- a/ms_agent/tools/search/web_search_spill.py +++ b/ms_agent/tools/search/web_search_spill.py @@ -27,6 +27,7 @@ JSON gains ``spill`` with ``digest`` (instructions + quick index) and paths relative to ``output_dir`` so ``read_file`` can open them. """ + from __future__ import annotations import copy @@ -111,18 +112,18 @@ def _build_spill_markdown(item: Dict[str, Any]) -> str: return ''.join(lines) -def _shrink_item_after_spill(item: Dict[str, Any], - spill_preview_chars: int) -> Dict[str, Any]: +def _shrink_item_after_spill(item: Dict[str, Any], spill_preview_chars: int) -> Dict[str, Any]: """Replace heavy fields with short previews + pointers.""" out = dict(item) note = ( 'Full text spilled to disk; see content_path / manifest_path in parent ' - 'JSON spill block. Use read_file on content_path for this row.') + 'JSON spill block. Use read_file on content_path for this row.' + ) sm = out.get('summary') if isinstance(sm, str) and sm.strip(): out['summary'] = _preview(sm, spill_preview_chars) out.setdefault('content_note', note) - main = (out.get('content') or '') + main = out.get('content') or '' if isinstance(main, str) and main.strip(): out['content'] = _preview(main, spill_preview_chars) out['content_note'] = note @@ -131,12 +132,14 @@ def _shrink_item_after_spill(item: Dict[str, Any], out['abstract'] = _preview(ab, min(800, spill_preview_chars)) ch = out.get('chunks') if isinstance(ch, list) and ch: - out['chunks'] = [{ - 'chunk_id': - c.get('chunk_id', ''), - 'content': - _preview(str(c.get('content', '')), min(400, spill_preview_chars)), - } for c in ch if isinstance(c, dict)] + out['chunks'] = [ + { + 'chunk_id': c.get('chunk_id', ''), + 'content': _preview(str(c.get('content', '')), min(400, spill_preview_chars)), + } + for c in ch + if isinstance(c, dict) + ] out['chunks_note'] = 'Full chunk bodies are in the spilled markdown file.' return out @@ -189,13 +192,10 @@ def order_by_size() -> List[int]: if _item_inline_chars(item) == 0: break full_md = _build_spill_markdown(item) - rel_body = os.path.join(spill_subdir, run_key, 'bodies', - f'{idx:03d}.md').replace('\\', '/') - abs_body = os.path.normpath( - os.path.join(output_dir, rel_body.replace('/', os.sep))) + rel_body = os.path.join(spill_subdir, run_key, 'bodies', f'{idx:03d}.md').replace('\\', '/') + abs_body = os.path.normpath(os.path.join(output_dir, rel_body.replace('/', os.sep))) os.makedirs(os.path.dirname(abs_body), exist_ok=True) - header = ( - f'\n') + header = f'\n' with open(abs_body, 'w', encoding='utf-8') as bf: bf.write(header + full_md) @@ -206,55 +206,37 @@ def order_by_size() -> List[int]: work[idx]['content_path'] = rel_body work[idx]['content_chars_spilled'] = before_chars - preview_src = ( - item.get('content') or item.get('summary') or item.get('abstract') - or '')[:4000] - manifest_rows.append({ - 'index': - idx, - 'url': - item.get('url', ''), - 'title': - item.get('title', ''), - 'body_file': - f'bodies/{idx:03d}.md', - 'content_path': - rel_body, - 'chars_spilled': - before_chars, - 'preview': - _preview(preview_src, min(500, spill_preview_chars)), - }) + preview_src = (item.get('content') or item.get('summary') or item.get('abstract') or '')[:4000] + manifest_rows.append( + { + 'index': idx, + 'url': item.get('url', ''), + 'title': item.get('title', ''), + 'body_file': f'bodies/{idx:03d}.md', + 'content_path': rel_body, + 'chars_spilled': before_chars, + 'preview': _preview(preview_src, min(500, spill_preview_chars)), + } + ) manifest: Dict[str, Any] = { - 'version': - 1, - 'created_at_utc': - time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()), - 'query': - query, - 'engine': - engine, - 'run_key': - run_key, - 'lifecycle': - ('Ephemeral: lives under this task output_dir; delete the task directory ' - 'to remove. ms-agent does not auto-prune.'), - 'inline_chars_before': - total, - 'inline_chars_after': - _total_inline_chars(work), - 'spill_threshold_chars': - spill_max_inline_chars, - 'spilled_row_indices': - spilled_indices, - 'rows': - manifest_rows, + 'version': 1, + 'created_at_utc': time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()), + 'query': query, + 'engine': engine, + 'run_key': run_key, + 'lifecycle': ( + 'Ephemeral: lives under this task output_dir; delete the task directory ' + 'to remove. ms-agent does not auto-prune.' + ), + 'inline_chars_before': total, + 'inline_chars_after': _total_inline_chars(work), + 'spill_threshold_chars': spill_max_inline_chars, + 'spilled_row_indices': spilled_indices, + 'rows': manifest_rows, } - rel_manifest = os.path.join(spill_subdir, run_key, 'manifest.json').replace( - '\\', '/') - abs_manifest = os.path.normpath( - os.path.join(output_dir, rel_manifest.replace('/', os.sep))) + rel_manifest = os.path.join(spill_subdir, run_key, 'manifest.json').replace('\\', '/') + abs_manifest = os.path.normpath(os.path.join(output_dir, rel_manifest.replace('/', os.sep))) with open(abs_manifest, 'w', encoding='utf-8') as mf: json.dump(manifest, mf, ensure_ascii=False, indent=2) @@ -262,31 +244,24 @@ def order_by_size() -> List[int]: 'Large web_search payload was written to disk under this task output_dir.', f'- **Manifest (map of rows → files, sizes)**: `{rel_manifest}`', f'- **Bodies**: `{spill_subdir}/{run_key}/bodies/`', - 'Read **manifest.json** first, then **read_file** on specific ' - '`bodies/NNN.md` files as needed.', + 'Read **manifest.json** first, then **read_file** on specific `bodies/NNN.md` files as needed.', '', '**Quick index**', ] for row in manifest_rows: lines.append( f'{row["index"]}. {row.get("title") or "(no title)"} — ' - f'`{row["content_path"]}` ({row.get("chars_spilled", 0)} chars)') + f'`{row["content_path"]}` ({row.get("chars_spilled", 0)} chars)' + ) digest = '\n'.join(lines) spill_meta = { - 'spilled': - True, - 'run_key': - run_key, - 'artifact_dir': - f'{spill_subdir}/{run_key}'.replace('\\', '/'), - 'manifest_path': - rel_manifest, - 'digest': - digest, - 'inline_chars_before_spill': - total, - 'inline_chars_after_spill': - _total_inline_chars(work), + 'spilled': True, + 'run_key': run_key, + 'artifact_dir': f'{spill_subdir}/{run_key}'.replace('\\', '/'), + 'manifest_path': rel_manifest, + 'digest': digest, + 'inline_chars_before_spill': total, + 'inline_chars_after_spill': _total_inline_chars(work), } return work, spill_meta diff --git a/ms_agent/tools/search/websearch_tool.py b/ms_agent/tools/search/websearch_tool.py index 16d6005d2..410ba0a60 100644 --- a/ms_agent/tools/search/websearch_tool.py +++ b/ms_agent/tools/search/websearch_tool.py @@ -10,11 +10,8 @@ from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase -from ms_agent.tools.jina_reader import (JinaReaderConfig, - fetch_single_text_with_meta) -from ms_agent.tools.search.content_optimizer import (ContentOptimizer, - ContentOptimizerConfig, - SearchResultReranker) +from ms_agent.tools.jina_reader import JinaReaderConfig, fetch_single_text_with_meta +from ms_agent.tools.search.content_optimizer import ContentOptimizer, ContentOptimizerConfig, SearchResultReranker from ms_agent.tools.search.search_base import ENGINE_TOOL_NAMES, SearchEngine from ms_agent.tools.search.web_search_spill import maybe_spill_web_search_payload from ms_agent.utils.logger import get_logger @@ -42,14 +39,14 @@ def default_per_url_fetch_timeout_s( retries = max(0, int(fetch_retries)) # Up to (retries+1) attempts each up to ``ft``; 1.35 leaves slack for urllib backoff. jina_budget = ft * float(retries + 1) * 1.35 - tail = max(10.0, float(direct_fetch_timeout)) + ( - float(playwright_timeout_ms) / 1000.0) + 30.0 + tail = max(10.0, float(direct_fetch_timeout)) + (float(playwright_timeout_ms) / 1000.0) + 30.0 raw = jina_budget + tail return max(210.0, min(720.0, raw)) def _json_dumps(data: Any) -> str: import json + return json.dumps(data, ensure_ascii=False, indent=2) @@ -79,10 +76,7 @@ class TextChunk: end_pos: int -def chunk_text_simple(text: str, - chunk_size: int = 1500, - overlap: int = 200, - prefix: str = '') -> List[TextChunk]: +def chunk_text_simple(text: str, chunk_size: int = 1500, overlap: int = 200, prefix: str = '') -> List[TextChunk]: """ Simple text chunking by character count with overlap. Tries to break at paragraph or sentence boundaries when possible. @@ -101,13 +95,7 @@ def chunk_text_simple(text: str, text = text.strip() if len(text) <= chunk_size: - return [ - TextChunk( - chunk_id=f'{prefix}0' if prefix else '0', - content=text, - start_pos=0, - end_pos=len(text)) - ] + return [TextChunk(chunk_id=f'{prefix}0' if prefix else '0', content=text, start_pos=0, end_pos=len(text))] chunks: List[TextChunk] = [] start = 0 @@ -134,11 +122,12 @@ def chunk_text_simple(text: str, if chunk_content: chunks.append( TextChunk( - chunk_id=f'{prefix}{chunk_idx}' - if prefix else str(chunk_idx), + chunk_id=f'{prefix}{chunk_idx}' if prefix else str(chunk_idx), content=chunk_content, start_pos=start, - end_pos=end)) + end_pos=end, + ) + ) chunk_idx += 1 # Move start with overlap @@ -168,11 +157,7 @@ class JinaContentFetcher(ContentFetcher): def __init__(self, config: Optional[JinaReaderConfig] = None): self.config = config or JinaReaderConfig() - def fetch( - self, - url: str, - max_chars: Optional[int] = MAX_FETCH_CHARS - ) -> Tuple[str, Dict[str, Any]]: + def fetch(self, url: str, max_chars: Optional[int] = MAX_FETCH_CHARS) -> Tuple[str, Dict[str, Any]]: content, source_meta = fetch_single_text_with_meta(url, self.config) metadata: Dict[str, Any] = { 'fetcher': 'jina_reader', @@ -192,8 +177,7 @@ def fetch( # pass -def get_content_fetcher(fetcher_type: str = 'jina_reader', - **kwargs) -> ContentFetcher: +def get_content_fetcher(fetcher_type: str = 'jina_reader', **kwargs) -> ContentFetcher: """Factory function to get content fetcher by type.""" if fetcher_type == 'jina_reader': config = JinaReaderConfig( @@ -201,17 +185,15 @@ def get_content_fetcher(fetcher_type: str = 'jina_reader', retries=kwargs.get('retries', 3), direct_fetch_fallback=bool(kwargs.get('direct_fetch_fallback', True)), direct_fetch_timeout=float(kwargs.get('direct_fetch_timeout', 15.0)), - playwright_fetch_fallback=bool( - kwargs.get('playwright_fetch_fallback', True)), - playwright_retry_min_chars=int( - kwargs.get('playwright_retry_min_chars', 400) or 400), - playwright_timeout_ms=int( - kwargs.get('playwright_timeout_ms', 30_000) or 30_000), + playwright_fetch_fallback=bool(kwargs.get('playwright_fetch_fallback', True)), + playwright_retry_min_chars=int(kwargs.get('playwright_retry_min_chars', 400) or 400), + playwright_timeout_ms=int(kwargs.get('playwright_timeout_ms', 30_000) or 30_000), playwright_settle_ms=int(kwargs.get('playwright_settle_ms', 350)), ) return JinaContentFetcher(config) if fetcher_type == 'tavily_extract': from ms_agent.tools.search.tavily.fetcher import TavilyExtractFetcher + return TavilyExtractFetcher( api_key=kwargs.get('tavily_api_key'), extract_depth=str(kwargs.get('tavily_extract_depth', 'advanced')), @@ -226,26 +208,19 @@ def get_content_fetcher(fetcher_type: str = 'jina_reader', # elif fetcher_type == 'docling': # return DoclingContentFetcher(**kwargs) else: - logger.warning( - f"Unknown fetcher type '{fetcher_type}', falling back to jina_reader" - ) + logger.warning(f"Unknown fetcher type '{fetcher_type}', falling back to jina_reader") return JinaContentFetcher( JinaReaderConfig( timeout=kwargs.get('timeout', 45.0), retries=kwargs.get('retries', 3), - direct_fetch_fallback=bool(kwargs.get('direct_fetch_fallback', - True)), - direct_fetch_timeout=float( - kwargs.get('direct_fetch_timeout', 15.0)), - playwright_fetch_fallback=bool( - kwargs.get('playwright_fetch_fallback', True)), - playwright_retry_min_chars=int( - kwargs.get('playwright_retry_min_chars', 400) or 400), - playwright_timeout_ms=int( - kwargs.get('playwright_timeout_ms', 30_000) or 30_000), - playwright_settle_ms=int( - kwargs.get('playwright_settle_ms', 350)), - )) + direct_fetch_fallback=bool(kwargs.get('direct_fetch_fallback', True)), + direct_fetch_timeout=float(kwargs.get('direct_fetch_timeout', 15.0)), + playwright_fetch_fallback=bool(kwargs.get('playwright_fetch_fallback', True)), + playwright_retry_min_chars=int(kwargs.get('playwright_retry_min_chars', 400) or 400), + playwright_timeout_ms=int(kwargs.get('playwright_timeout_ms', 30_000) or 30_000), + playwright_settle_ms=int(kwargs.get('playwright_settle_ms', 350)), + ) + ) def get_search_engine_class(engine_type: str) -> Type[SearchEngine]: @@ -262,26 +237,28 @@ def get_search_engine_class(engine_type: str) -> Type[SearchEngine]: if engine_type == 'exa': from ms_agent.tools.search.exa import ExaSearch + return ExaSearch elif engine_type in ('serpapi', 'serp', 'google', 'bing', 'baidu'): from ms_agent.tools.search.serpapi import SerpApiSearch + return SerpApiSearch elif engine_type == 'arxiv': from ms_agent.tools.search.arxiv import ArxivSearch + return ArxivSearch elif engine_type == 'tavily': from ms_agent.tools.search.tavily import TavilySearch + return TavilySearch else: - logger.warning( - f"Unknown search engine '{engine_type}', falling back to arxiv") + logger.warning(f"Unknown search engine '{engine_type}', falling back to arxiv") from ms_agent.tools.search.arxiv import ArxivSearch + return ArxivSearch -def get_search_engine(engine_type: str, - api_key: Optional[str] = None, - **kwargs) -> SearchEngine: +def get_search_engine(engine_type: str, api_key: Optional[str] = None, **kwargs) -> SearchEngine: """ Get search engine instance by type. @@ -298,46 +275,45 @@ def get_search_engine(engine_type: str, if engine_type == 'exa': from ms_agent.tools.search.exa import ExaSearch + return ExaSearch( api_key=api_key or os.getenv('EXA_API_KEY'), api_keys=kwargs.get('api_keys') or os.getenv('EXA_API_KEYS'), ) elif engine_type in ('serpapi', 'serp', 'google', 'bing', 'baidu'): from ms_agent.tools.search.serpapi import SerpApiSearch - default_provider = ('google' if engine_type in ('serpapi', 'serp') else - engine_type) + + default_provider = 'google' if engine_type in ('serpapi', 'serp') else engine_type return SerpApiSearch( api_key=api_key or os.getenv('SERPAPI_API_KEY'), provider=kwargs.get('provider', default_provider), ) elif engine_type == 'arxiv': from ms_agent.tools.search.arxiv import ArxivSearch + return ArxivSearch() elif engine_type == 'tavily': from ms_agent.tools.search.tavily import TavilySearch + return TavilySearch( api_key=api_key or os.getenv('TAVILY_API_KEY'), request_timeout=float(kwargs.get('request_timeout', 120.0)), ) else: - logger.warning( - f"Unknown search engine '{engine_type}', falling back to arxiv") + logger.warning(f"Unknown search engine '{engine_type}', falling back to arxiv") from ms_agent.tools.search.arxiv import ArxivSearch + return ArxivSearch() # Kept for backward compatibility -def build_search_request(engine_type: str, - query: str, - num_results: int = 5, - **kwargs): +def build_search_request(engine_type: str, query: str, num_results: int = 5, **kwargs): """Build a search request for the specified engine. DEPRECATED: Use SearchEngine.build_request_from_args() instead. """ engine_cls = get_search_engine_class(engine_type) - return engine_cls.build_request_from_args( - query=query, num_results=num_results, **kwargs) + return engine_cls.build_request_from_args(query=query, num_results=num_results, **kwargs) class WebSearchTool(ToolBase): @@ -396,12 +372,8 @@ def get_global_summarization_usage(cls) -> Dict[str, Any]: """Get process-wide summarization usage totals (best-effort).""" with cls._GLOBAL_SUMMARY_USAGE_LOCK: total = dict(cls._GLOBAL_SUMMARY_USAGE_TOTAL) - by_model = { - k: dict(v) - for k, v in cls._GLOBAL_SUMMARY_USAGE_BY_MODEL.items() - } - total['total_tokens'] = total.get('prompt_tokens', 0) + total.get( - 'completion_tokens', 0) + by_model = {k: dict(v) for k, v in cls._GLOBAL_SUMMARY_USAGE_BY_MODEL.items()} + total['total_tokens'] = total.get('prompt_tokens', 0) + total.get('completion_tokens', 0) return { 'total': total, 'by_model': by_model, @@ -412,8 +384,7 @@ def log_global_summarization_usage(cls) -> None: """Log process-wide summarization totals once at end-of-run.""" usage = cls.get_global_summarization_usage() total = usage.get('total', {}) or {} - if not (total.get('prompt_tokens', 0) or total.get( - 'completion_tokens', 0) or total.get('api_calls', 0)): + if not (total.get('prompt_tokens', 0) or total.get('completion_tokens', 0) or total.get('api_calls', 0)): return logger.info( '[web_search_summarization_usage_process_total] ' @@ -447,31 +418,23 @@ def __init__(self, config, **kwargs): self.exclude_func(tool_cfg) # Parse engine configuration - support both single and multi-engine modes - engines_config = getattr(tool_cfg, 'engines', - None) if tool_cfg else None + engines_config = getattr(tool_cfg, 'engines', None) if tool_cfg else None if engines_config: # Multi-engine mode: engines: [exa, arxiv] # Note: OmegaConf ListConfig is iterable but not isinstance of list/tuple - if hasattr(engines_config, - '__iter__') and not isinstance(engines_config, str): - self._engine_types = [ - str(e).lower().strip() for e in engines_config - ] + if hasattr(engines_config, '__iter__') and not isinstance(engines_config, str): + self._engine_types = [str(e).lower().strip() for e in engines_config] else: self._engine_types = [str(engines_config).lower().strip()] else: # Single engine mode (backward compatible): engine: exa - single_engine = getattr(tool_cfg, 'engine', - 'arxiv') if tool_cfg else 'arxiv' + single_engine = getattr(tool_cfg, 'engine', 'arxiv') if tool_cfg else 'arxiv' self._engine_types = [single_engine.lower().strip()] # Validate engine types - self._engine_types = [ - e for e in self._engine_types if e in self.SUPPORTED_ENGINES - ] + self._engine_types = [e for e in self._engine_types if e in self.SUPPORTED_ENGINES] if not self._engine_types: - logger.warning( - 'No valid engines configured, falling back to arxiv') + logger.warning('No valid engines configured, falling back to arxiv') self._engine_types = ['arxiv'] # API keys for each engine. @@ -483,15 +446,17 @@ def __init__(self, config, **kwargs): getattr(tool_cfg, 'exa_api_keys', None) or getattr(tool_cfg, 'exa_api_key', None) or getattr(tool_cfg, 'api_key', None) # backward compat - or os.getenv('EXA_API_KEYS') or os.getenv('EXA_API_KEY')) - if tool_cfg else - (os.getenv('EXA_API_KEYS') or os.getenv('EXA_API_KEY')), - 'serpapi': (getattr(tool_cfg, 'serpapi_api_key', None) - or os.getenv('SERPAPI_API_KEY')) - if tool_cfg else os.getenv('SERPAPI_API_KEY'), - 'tavily': (getattr(tool_cfg, 'tavily_api_key', None) - or os.getenv('TAVILY_API_KEY')) if tool_cfg else - os.getenv('TAVILY_API_KEY'), + or os.getenv('EXA_API_KEYS') + or os.getenv('EXA_API_KEY') + ) + if tool_cfg + else (os.getenv('EXA_API_KEYS') or os.getenv('EXA_API_KEY')), + 'serpapi': (getattr(tool_cfg, 'serpapi_api_key', None) or os.getenv('SERPAPI_API_KEY')) + if tool_cfg + else os.getenv('SERPAPI_API_KEY'), + 'tavily': (getattr(tool_cfg, 'tavily_api_key', None) or os.getenv('TAVILY_API_KEY')) + if tool_cfg + else os.getenv('TAVILY_API_KEY'), } # Tavily search defaults from optional `tavily:` sub-block in YAML @@ -501,79 +466,64 @@ def __init__(self, config, **kwargs): if tv is not None: try: from omegaconf import OmegaConf + if OmegaConf.is_config(tv): - self._tavily_defaults = dict( - OmegaConf.to_container(tv, resolve=True)) + self._tavily_defaults = dict(OmegaConf.to_container(tv, resolve=True)) elif isinstance(tv, dict): self._tavily_defaults = dict(tv) except Exception: if isinstance(tv, dict): self._tavily_defaults = dict(tv) - self._tavily_request_timeout = float( - getattr(tool_cfg, 'tavily_request_timeout', 120.0) - or 120.0) if tool_cfg else 120.0 + self._tavily_request_timeout = ( + float(getattr(tool_cfg, 'tavily_request_timeout', 120.0) or 120.0) if tool_cfg else 120.0 + ) # SerpApi provider (google, bing, baidu) - self._serpapi_provider = getattr(tool_cfg, 'serpapi_provider', - 'google') if tool_cfg else 'google' + self._serpapi_provider = getattr(tool_cfg, 'serpapi_provider', 'google') if tool_cfg else 'google' # Default result count - self._max_results = int(getattr(tool_cfg, 'max_results', 5) - or 5) if tool_cfg else 5 + self._max_results = int(getattr(tool_cfg, 'max_results', 5) or 5) if tool_cfg else 5 # Content fetcher config - self._fetcher_type = getattr( - tool_cfg, 'fetcher', 'jina_reader') if tool_cfg else 'jina_reader' - self._fetch_timeout = float( - getattr(tool_cfg, 'fetch_timeout', 45) or 45) if tool_cfg else 45.0 - self._fetch_retries = int(getattr(tool_cfg, 'fetch_retries', 3) - or 3) if tool_cfg else 3 - self._jina_direct_fetch_fallback = bool( - getattr(tool_cfg, 'jina_direct_fetch_fallback', True) - ) if tool_cfg else True + self._fetcher_type = getattr(tool_cfg, 'fetcher', 'jina_reader') if tool_cfg else 'jina_reader' + self._fetch_timeout = float(getattr(tool_cfg, 'fetch_timeout', 45) or 45) if tool_cfg else 45.0 + self._fetch_retries = int(getattr(tool_cfg, 'fetch_retries', 3) or 3) if tool_cfg else 3 + self._jina_direct_fetch_fallback = ( + bool(getattr(tool_cfg, 'jina_direct_fetch_fallback', True)) if tool_cfg else True + ) if tool_cfg is not None and hasattr(tool_cfg, 'jina_direct_fetch_timeout'): - self._jina_direct_fetch_timeout = float( - tool_cfg.jina_direct_fetch_timeout) + self._jina_direct_fetch_timeout = float(tool_cfg.jina_direct_fetch_timeout) else: self._jina_direct_fetch_timeout = 15.0 - self._jina_playwright_fetch_fallback = bool( - getattr(tool_cfg, 'jina_playwright_fetch_fallback', True) - ) if tool_cfg else True - self._jina_playwright_retry_min_chars = int( - getattr(tool_cfg, 'jina_playwright_retry_min_chars', 400) or 400 - ) if tool_cfg else 400 - self._jina_playwright_timeout_ms = int( - getattr(tool_cfg, 'jina_playwright_timeout_ms', 30000) or 30000 - ) if tool_cfg else 30000 + self._jina_playwright_fetch_fallback = ( + bool(getattr(tool_cfg, 'jina_playwright_fetch_fallback', True)) if tool_cfg else True + ) + self._jina_playwright_retry_min_chars = ( + int(getattr(tool_cfg, 'jina_playwright_retry_min_chars', 400) or 400) if tool_cfg else 400 + ) + self._jina_playwright_timeout_ms = ( + int(getattr(tool_cfg, 'jina_playwright_timeout_ms', 30000) or 30000) if tool_cfg else 30000 + ) if tool_cfg is not None and hasattr(tool_cfg, 'jina_playwright_settle_ms'): - self._jina_playwright_settle_ms = int( - tool_cfg.jina_playwright_settle_ms) + self._jina_playwright_settle_ms = int(tool_cfg.jina_playwright_settle_ms) else: self._jina_playwright_settle_ms = 350 - self._fetch_content_default = bool( - getattr(tool_cfg, 'fetch_content', True)) if tool_cfg else True + self._fetch_content_default = bool(getattr(tool_cfg, 'fetch_content', True)) if tool_cfg else True # Chunking config - self._enable_chunking = bool( - getattr(tool_cfg, 'enable_chunking', False)) if tool_cfg else False - self._chunk_size = int(getattr(tool_cfg, 'chunk_size', 2000) - or 2000) if tool_cfg else 2000 - self._chunk_overlap = int( - getattr(tool_cfg, 'chunk_overlap', 200) - or 200) if tool_cfg else 200 + self._enable_chunking = bool(getattr(tool_cfg, 'enable_chunking', False)) if tool_cfg else False + self._chunk_size = int(getattr(tool_cfg, 'chunk_size', 2000) or 2000) if tool_cfg else 2000 + self._chunk_overlap = int(getattr(tool_cfg, 'chunk_overlap', 200) or 200) if tool_cfg else 200 # Concurrency - self._max_concurrent_fetch = int( - getattr(tool_cfg, 'max_concurrent_fetch', 3) - or 3) if tool_cfg else 3 + self._max_concurrent_fetch = int(getattr(tool_cfg, 'max_concurrent_fetch', 3) or 3) if tool_cfg else 3 # Hard cap (seconds) per URL for asyncio.wait_for around run_in_executor. # When hit, this URL gets empty content + fetch_error; other URLs in the # same web_search call keep their already-fetched bodies. Set 0 to disable # the asyncio cap (underlying urllib/Jina timeouts still apply). if tool_cfg is not None and hasattr(tool_cfg, 'per_url_fetch_timeout'): - self._per_url_fetch_timeout_s = float( - tool_cfg.per_url_fetch_timeout) + self._per_url_fetch_timeout_s = float(tool_cfg.per_url_fetch_timeout) else: self._per_url_fetch_timeout_s = default_per_url_fetch_timeout_s( self._fetch_timeout, @@ -581,61 +531,47 @@ def __init__(self, config, **kwargs): self._jina_direct_fetch_timeout, self._jina_playwright_timeout_ms, ) - self._max_concurrent_summarization = int( - getattr(tool_cfg, 'max_concurrent_summarization', 5) - or 5) if tool_cfg else 5 + self._max_concurrent_summarization = ( + int(getattr(tool_cfg, 'max_concurrent_summarization', 5) or 5) if tool_cfg else 5 + ) # Content optimization config (summarization & reranking) - self._enable_summarization = bool( - getattr(tool_cfg, 'enable_summarization', - False)) if tool_cfg else False - self._summarizer_model = getattr( - tool_cfg, 'summarizer_model', - 'qwen-flash') if tool_cfg else 'qwen-flash' - self._summarizer_base_url = getattr( - tool_cfg, 'summarizer_base_url', - 'https://dashscope.aliyuncs.com/compatible-mode/v1' - ) if tool_cfg else 'https://dashscope.aliyuncs.com/compatible-mode/v1' - self._summarizer_api_key = getattr(tool_cfg, 'summarizer_api_key', - None) if tool_cfg else None - self._max_content_chars = int( - getattr(tool_cfg, 'max_content_chars', 500000) - or 500000) if tool_cfg else 500000 - self._summarizer_max_workers = int( - getattr(tool_cfg, 'summarizer_max_workers', 5) - or 5) if tool_cfg else 5 - self._summarization_timeout = float( - getattr(tool_cfg, 'summarization_timeout', 90.0) - or 90.0) if tool_cfg else 90.0 + self._enable_summarization = bool(getattr(tool_cfg, 'enable_summarization', False)) if tool_cfg else False + self._summarizer_model = getattr(tool_cfg, 'summarizer_model', 'qwen-flash') if tool_cfg else 'qwen-flash' + self._summarizer_base_url = ( + getattr(tool_cfg, 'summarizer_base_url', 'https://dashscope.aliyuncs.com/compatible-mode/v1') + if tool_cfg + else 'https://dashscope.aliyuncs.com/compatible-mode/v1' + ) + self._summarizer_api_key = getattr(tool_cfg, 'summarizer_api_key', None) if tool_cfg else None + self._max_content_chars = int(getattr(tool_cfg, 'max_content_chars', 500000) or 500000) if tool_cfg else 500000 + self._summarizer_max_workers = int(getattr(tool_cfg, 'summarizer_max_workers', 5) or 5) if tool_cfg else 5 + self._summarization_timeout = ( + float(getattr(tool_cfg, 'summarization_timeout', 90.0) or 90.0) if tool_cfg else 90.0 + ) # Large payload spill (write bodies to disk; keep JSON small) - self._spill_enabled = bool( - getattr(tool_cfg, 'spill_large_results', True)) if tool_cfg else True - self._spill_max_inline_chars = int( - getattr(tool_cfg, 'spill_max_inline_chars', 120000) - or 120000) if tool_cfg else 120000 - self._spill_subdir = str( - getattr(tool_cfg, 'spill_subdir', 'web_search_artifacts') - or 'web_search_artifacts') if tool_cfg else 'web_search_artifacts' - self._spill_preview_chars = int( - getattr(tool_cfg, 'spill_preview_chars', 600) - or 600) if tool_cfg else 600 + self._spill_enabled = bool(getattr(tool_cfg, 'spill_large_results', True)) if tool_cfg else True + self._spill_max_inline_chars = ( + int(getattr(tool_cfg, 'spill_max_inline_chars', 120000) or 120000) if tool_cfg else 120000 + ) + self._spill_subdir = ( + str(getattr(tool_cfg, 'spill_subdir', 'web_search_artifacts') or 'web_search_artifacts') + if tool_cfg + else 'web_search_artifacts' + ) + self._spill_preview_chars = int(getattr(tool_cfg, 'spill_preview_chars', 600) or 600) if tool_cfg else 600 # Reranking config - self._enable_rerank = bool(getattr(tool_cfg, 'enable_rerank', - False)) if tool_cfg else False - self._rerank_top_k = int(getattr(tool_cfg, 'rerank_top_k', 3) - or 3) if tool_cfg else 3 + self._enable_rerank = bool(getattr(tool_cfg, 'enable_rerank', False)) if tool_cfg else False + self._rerank_top_k = int(getattr(tool_cfg, 'rerank_top_k', 3) or 3) if tool_cfg else 3 # Task context for summarization (can be set dynamically) - self._task_context = getattr(tool_cfg, 'task_context', - '') if tool_cfg else '' + self._task_context = getattr(tool_cfg, 'task_context', '') if tool_cfg else '' # Runtime instances (lazy init) - self._engines: Dict[str, SearchEngine] = { - } # engine_type -> engine instance - self._engine_classes: Dict[str, Type[SearchEngine]] = { - } # engine_type -> engine class + self._engines: Dict[str, SearchEngine] = {} # engine_type -> engine instance + self._engine_classes: Dict[str, Type[SearchEngine]] = {} # engine_type -> engine class self._content_fetcher: Optional[ContentFetcher] = None self._content_optimizer: Optional[ContentOptimizer] = None self._executor: Optional[ThreadPoolExecutor] = None @@ -658,8 +594,7 @@ async def connect(self) -> None: # Create engine instance if engine_type == 'exa': - self._engines[engine_type] = engine_cls( - api_key=self._api_keys.get('exa')) + self._engines[engine_type] = engine_cls(api_key=self._api_keys.get('exa')) elif engine_type == 'serpapi': self._engines[engine_type] = engine_cls( api_key=self._api_keys.get('serpapi'), @@ -675,8 +610,7 @@ async def connect(self) -> None: logger.info(f'Initialized search engine: {engine_type}') except Exception as e: - logger.warning( - f'Failed to initialize {engine_type} engine: {e}') + logger.warning(f'Failed to initialize {engine_type} engine: {e}') if not self._engines: raise RuntimeError('No search engines could be initialized') @@ -694,20 +628,16 @@ async def connect(self) -> None: 'playwright_settle_ms': self._jina_playwright_settle_ms, } if wcfg is not None: - _fk.update({ - 'tavily_extract_depth': - getattr(wcfg, 'tavily_extract_depth', 'advanced'), - 'tavily_extract_format': - getattr(wcfg, 'tavily_extract_format', 'markdown'), - 'tavily_extract_chunks_per_source': - int(getattr(wcfg, 'tavily_extract_chunks_per_source', 3) or 3), - 'tavily_extract_include_images': - bool(getattr(wcfg, 'tavily_extract_include_images', False)), - 'tavily_extract_include_favicon': - bool(getattr(wcfg, 'tavily_extract_include_favicon', False)), - 'tavily_extract_include_usage': - bool(getattr(wcfg, 'tavily_extract_include_usage', False)), - }) + _fk.update( + { + 'tavily_extract_depth': getattr(wcfg, 'tavily_extract_depth', 'advanced'), + 'tavily_extract_format': getattr(wcfg, 'tavily_extract_format', 'markdown'), + 'tavily_extract_chunks_per_source': int(getattr(wcfg, 'tavily_extract_chunks_per_source', 3) or 3), + 'tavily_extract_include_images': bool(getattr(wcfg, 'tavily_extract_include_images', False)), + 'tavily_extract_include_favicon': bool(getattr(wcfg, 'tavily_extract_include_favicon', False)), + 'tavily_extract_include_usage': bool(getattr(wcfg, 'tavily_extract_include_usage', False)), + } + ) self._content_fetcher = get_content_fetcher(self._fetcher_type, **_fk) # Use daemon threads: tool-call timeouts can cancel the awaiting coroutine, # but not the underlying sync network calls running in executor threads. @@ -721,9 +651,9 @@ async def connect(self) -> None: optimizer_config = ContentOptimizerConfig( summarizer_model=self._summarizer_model, summarizer_base_url=self._summarizer_base_url, - summarizer_api_key=(self._summarizer_api_key - or os.getenv('DASHSCOPE_API_KEY') - or os.getenv('OPENAI_API_KEY')), + summarizer_api_key=( + self._summarizer_api_key or os.getenv('DASHSCOPE_API_KEY') or os.getenv('OPENAI_API_KEY') + ), max_content_chars=self._max_content_chars, summarizer_max_workers=self._summarizer_max_workers, summarization_timeout=self._summarization_timeout, @@ -733,12 +663,9 @@ async def connect(self) -> None: self._content_optimizer = ContentOptimizer(optimizer_config) if self._enable_summarization: await self._content_optimizer.initialize() - logger.info( - f'Content optimizer initialized with model: {self._summarizer_model}' - ) + logger.info(f'Content optimizer initialized with model: {self._summarizer_model}') else: - logger.info( - 'Content reranking enabled (summarization disabled)') + logger.info('Content reranking enabled (summarization disabled)') async def cleanup(self) -> None: """Cleanup resources.""" @@ -756,11 +683,12 @@ async def cleanup(self) -> None: self._engine_classes.clear() # Optional: instance-level totals can be noisy when multiple sub-agents # create their own WebSearchTool instances. Default off; use env var to enable. - if os.getenv('MS_AGENT_WEB_SEARCH_LOG_INSTANCE_SUMMARY_USAGE', - '').lower() in ('1', 'true', 'yes'): - if (self._summary_usage_total.get('prompt_tokens', 0) - or self._summary_usage_total.get('completion_tokens', 0) - or self._summary_usage_total.get('api_calls', 0)): + if os.getenv('MS_AGENT_WEB_SEARCH_LOG_INSTANCE_SUMMARY_USAGE', '').lower() in ('1', 'true', 'yes'): + if ( + self._summary_usage_total.get('prompt_tokens', 0) + or self._summary_usage_total.get('completion_tokens', 0) + or self._summary_usage_total.get('api_calls', 0) + ): model = self._summary_usage_model or self._summarizer_model logger.info( '[web_search_summarization_usage_total] ' @@ -795,21 +723,19 @@ async def _get_tools_inner(self) -> Dict[str, Any]: continue # Get engine's tool definition - tool_def = engine_cls.get_tool_definition( - server_name=self.SERVER_NAME) + tool_def = engine_cls.get_tool_definition(server_name=self.SERVER_NAME) # Add fetch_content parameter if content fetcher is available if self._content_fetcher: tool_params = dict(tool_def.get('parameters', {})) tool_props = dict(tool_params.get('properties', {})) tool_props['fetch_content'] = { - 'type': - 'boolean', - 'description': - ('Whether to fetch and parse full page content. ' - 'Set to false for faster results with only titles/snippets. ' - f'Default is {self._fetch_content_default}. Suggested to set to True.' - ), + 'type': 'boolean', + 'description': ( + 'Whether to fetch and parse full page content. ' + 'Set to false for faster results with only titles/snippets. ' + f'Default is {self._fetch_content_default}. Suggested to set to True.' + ), } tool_params['properties'] = tool_props tool_def['parameters'] = tool_params @@ -821,8 +747,9 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='fetch_page', server_name=self.SERVER_NAME, - description=('Fetch and parse a single web page by URL. ' - 'Use this when you have a specific URL to read.'), + description=( + 'Fetch and parse a single web page by URL. Use this when you have a specific URL to read.' + ), parameters={ 'type': 'object', 'properties': { @@ -833,12 +760,12 @@ async def _get_tools_inner(self) -> Dict[str, Any]: }, 'required': ['url'], }, - )) + ) + ) return {self.SERVER_NAME: tools} - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: """Route tool calls to appropriate handler.""" if tool_name == 'fetch_page': return await self.fetch_page(**(tool_args or {})) @@ -848,12 +775,9 @@ async def call_tool(self, server_name: str, *, tool_name: str, engine_type = tool_to_engine.get(tool_name) if not engine_type or engine_type not in self._engines: - return _json_dumps({ - 'status': - 'error', - 'message': - f'Unknown tool: {tool_name}. Available: {list(tool_to_engine.keys())}' - }) + return _json_dumps( + {'status': 'error', 'message': f'Unknown tool: {tool_name}. Available: {list(tool_to_engine.keys())}'} + ) return await self._execute_search(engine_type, tool_args or {}) @@ -879,11 +803,9 @@ def _fetch_content_sync(self, url: str) -> Dict[str, Any]: content, chunk_size=self._chunk_size, overlap=self._chunk_overlap, - prefix=f'{hash(url) & 0xFFFFFF:06x}_') - result['chunks'] = [{ - 'chunk_id': c.chunk_id, - 'content': c.content - } for c in chunks] + prefix=f'{hash(url) & 0xFFFFFF:06x}_', + ) + result['chunks'] = [{'chunk_id': c.chunk_id, 'content': c.content} for c in chunks] return result except Exception as e: @@ -898,8 +820,7 @@ def _fetch_content_sync(self, url: str) -> Dict[str, Any]: async def _fetch_content_async(self, url: str) -> Dict[str, Any]: """Async wrapper for content fetching.""" loop = asyncio.get_event_loop() - return await loop.run_in_executor(self._executor, - self._fetch_content_sync, url) + return await loop.run_in_executor(self._executor, self._fetch_content_sync, url) def _url_log_preview(self, url: str, max_len: int = 220) -> str: u = (url or '').strip() @@ -932,7 +853,10 @@ async def _fetch_content_async_bounded(self, url: str) -> Dict[str, Any]: logger.warning( '[web_search] fetch TIMEOUT url=%s elapsed=%.1fs cap=%.1fs — ' 'this URL is dropped for this response; others are unchanged', - preview, elapsed, cap) + preview, + elapsed, + cap, + ) return { 'url': u, 'content': '', @@ -942,20 +866,20 @@ async def _fetch_content_async_bounded(self, url: str) -> Dict[str, Any]: 'fetch_timed_out': True, } elapsed = time.perf_counter() - t0 - src = (out or {}).get('content_source') or (out or {}).get( - 'fetcher', '') or '' + src = (out or {}).get('content_source') or (out or {}).get('fetcher', '') or '' ok = bool((out or {}).get('fetch_success')) - logger.info( - '[web_search] fetch done url=%s elapsed=%.2fs ok=%s source=%s', - preview, elapsed, ok, src) - return out if out is not None else { - 'url': u, - 'content': '', - 'fetch_success': False, - } + logger.info('[web_search] fetch done url=%s elapsed=%.2fs ok=%s source=%s', preview, elapsed, ok, src) + return ( + out + if out is not None + else { + 'url': u, + 'content': '', + 'fetch_success': False, + } + ) - async def _fetch_multiple_async(self, - urls: List[str]) -> List[Dict[str, Any]]: + async def _fetch_multiple_async(self, urls: List[str]) -> List[Dict[str, Any]]: """Fetch multiple URLs concurrently with semaphore.""" semaphore = asyncio.Semaphore(self._max_concurrent_fetch) @@ -967,29 +891,30 @@ async def _bounded_fetch(url: str) -> Dict[str, Any]: return await asyncio.gather(*tasks) def _do_search( - self, engine_type: str, engine: SearchEngine, - engine_cls: Type[SearchEngine], - tool_args: Dict[str, Any] + self, engine_type: str, engine: SearchEngine, engine_cls: Type[SearchEngine], tool_args: Dict[str, Any] ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """Perform search; returns (result rows, extra top-level metadata e.g. Tavily).""" try: merged = dict(tool_args) - if engine_type == 'tavily' and getattr(self, '_tavily_defaults', - None): + if engine_type == 'tavily' and getattr(self, '_tavily_defaults', None): merged = {**self._tavily_defaults, **merged} # Keys only for engine / fetcher YAML, not TavilySearchRequest - for _k in ('request_timeout', 'tavily_extract_depth', - 'tavily_extract_format', - 'tavily_extract_chunks_per_source', - 'tavily_extract_include_images', - 'tavily_extract_include_favicon', - 'tavily_extract_include_usage'): + for _k in ( + 'request_timeout', + 'tavily_extract_depth', + 'tavily_extract_format', + 'tavily_extract_chunks_per_source', + 'tavily_extract_include_images', + 'tavily_extract_include_favicon', + 'tavily_extract_include_usage', + ): merged.pop(_k, None) request = engine_cls.build_request_from_args(**merged) result = engine.search(request) rows = result.to_list() extra: Dict[str, Any] = {} from ms_agent.tools.search.tavily.schema import TavilySearchResult + if isinstance(result, TavilySearchResult): extra = result.extra_response_fields() return rows, extra @@ -997,8 +922,7 @@ def _do_search( logger.error(f'Search failed ({engine_type}): {e}') return [], {} - async def _execute_search(self, engine_type: str, - tool_args: Dict[str, Any]) -> str: + async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> str: """ Execute search using the specified engine. The search pipeline with optimization: @@ -1014,40 +938,30 @@ async def _execute_search(self, engine_type: str, """ query = tool_args.get('query', '').strip() if not query: - return _json_dumps({ - 'status': 'error', - 'message': 'Query is required.' - }) + return _json_dumps({'status': 'error', 'message': 'Query is required.'}) call_id_for_spill = str(tool_args.pop('__call_id', '') or '') # Get fetch_content preference, default to configured value - fetch_content = tool_args.pop('fetch_content', - self._fetch_content_default) + fetch_content = tool_args.pop('fetch_content', self._fetch_content_default) # Get task context for summarization (can be passed in tool_args) task_context = tool_args.pop('task_context', self._task_context) if 'num_results' not in tool_args or tool_args['num_results'] is None: - tool_args[ - 'num_results'] = 10 if engine_type == 'arxiv' else self._max_results + tool_args['num_results'] = 10 if engine_type == 'arxiv' else self._max_results engine = self._engines.get(engine_type) engine_cls = self._engine_classes.get(engine_type) if not engine or not engine_cls: - return _json_dumps({ - 'status': - 'error', - 'message': - f'Engine {engine_type} not initialized.' - }) + return _json_dumps({'status': 'error', 'message': f'Engine {engine_type} not initialized.'}) # Perform search loop = asyncio.get_event_loop() search_results, tavily_extra = await loop.run_in_executor( - self._executor, self._do_search, engine_type, engine, engine_cls, - tool_args) + self._executor, self._do_search, engine_type, engine, engine_cls, tool_args + ) if not search_results: out_empty: Dict[str, Any] = { @@ -1072,16 +986,13 @@ async def _execute_search(self, engine_type: str, query, top_k=self._rerank_top_k, ) - logger.info( - f'Reranked {original_count} results to top {len(search_results)} ' - f'for query: {query[:50]}...') + logger.info(f'Reranked {original_count} results to top {len(search_results)} for query: {query[:50]}...') # Step 3: Fetch content for (filtered) results (skip URLs already filled e.g. Tavily raw_content) fetch_attempts = 0 fetch_timeouts = 0 if fetch_content and self._content_fetcher: - search_results = SearchResultReranker.deduplicate_by_url( - search_results) + search_results = SearchResultReranker.deduplicate_by_url(search_results) urls: List[str] = [] for r in search_results: u = r.get('url') @@ -1093,8 +1004,7 @@ async def _execute_search(self, engine_type: str, if urls: fetch_attempts = len(urls) fetch_results = await self._fetch_multiple_async(urls) - fetch_timeouts = sum( - 1 for r in fetch_results if r.get('fetch_timed_out')) + fetch_timeouts = sum(1 for r in fetch_results if r.get('fetch_timed_out')) # Merge search metadata with fetched content url_to_fetch = {r['url']: r for r in fetch_results} @@ -1103,8 +1013,7 @@ async def _execute_search(self, engine_type: str, if url and url in url_to_fetch: fetched = url_to_fetch[url] sr['content'] = fetched.get('content', '') - sr['fetch_success'] = fetched.get( - 'fetch_success', False) + sr['fetch_success'] = fetched.get('fetch_success', False) if fetched.get('fetch_error'): sr['fetch_error'] = fetched['fetch_error'] else: @@ -1115,8 +1024,7 @@ async def _execute_search(self, engine_type: str, sr.pop('fetch_timed_out', None) if fetched.get('content_source'): sr['content_source'] = fetched['content_source'] - if fetched.get('published_at' - ) and not sr.get('published_date'): + if fetched.get('published_at') and not sr.get('published_date'): sr['published_at'] = fetched['published_at'] if self._enable_chunking and fetched.get('chunks'): sr['chunks'] = fetched['chunks'] @@ -1132,89 +1040,74 @@ async def _execute_search(self, engine_type: str, ] if contents_to_summarize: - logger.info( - f'Summarizing {len(contents_to_summarize)} pages...') + logger.info(f'Summarizing {len(contents_to_summarize)} pages...') # Summarize all contents in parallel + collect usage summaries, summarization_usage = await self._content_optimizer.summarize_contents_with_usage( contents_to_summarize, task_context=task_context, - max_concurrent=min(self._max_concurrent_summarization, - len(contents_to_summarize)), + max_concurrent=min(self._max_concurrent_summarization, len(contents_to_summarize)), ) # Update global usage totals for this tool instance (independent from LLMAgent) try: if summarization_usage: self._summary_usage_model = str( - summarization_usage.get('model') - or self._summary_usage_model or '') - self._summary_usage_total['api_calls'] += int( - summarization_usage.get('api_calls', 0) or 0) + summarization_usage.get('model') or self._summary_usage_model or '' + ) + self._summary_usage_total['api_calls'] += int(summarization_usage.get('api_calls', 0) or 0) self._summary_usage_total['prompt_tokens'] += int( - summarization_usage.get('prompt_tokens', 0) or 0) + summarization_usage.get('prompt_tokens', 0) or 0 + ) self._summary_usage_total['completion_tokens'] += int( - summarization_usage.get('completion_tokens', 0) - or 0) + summarization_usage.get('completion_tokens', 0) or 0 + ) self._summary_usage_total['cached_tokens'] += int( - summarization_usage.get('cached_tokens', 0) or 0) - self._summary_usage_total[ - 'cache_creation_input_tokens'] += int( - summarization_usage.get( - 'cache_creation_input_tokens', 0) or 0) + summarization_usage.get('cached_tokens', 0) or 0 + ) + self._summary_usage_total['cache_creation_input_tokens'] += int( + summarization_usage.get('cache_creation_input_tokens', 0) or 0 + ) # Process-wide totals (thread-safe; sub-agents may run in background threads) - model = str( - summarization_usage.get('model') - or self._summarizer_model) + model = str(summarization_usage.get('model') or self._summarizer_model) with WebSearchTool._GLOBAL_SUMMARY_USAGE_LOCK: - WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL[ - 'pages'] += int( - summarization_usage.get('pages', 0) or 0) - WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL[ - 'api_calls'] += int( - summarization_usage.get('api_calls', 0) - or 0) - WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL[ - 'prompt_tokens'] += int( - summarization_usage.get( - 'prompt_tokens', 0) or 0) - WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL[ - 'completion_tokens'] += int( - summarization_usage.get( - 'completion_tokens', 0) or 0) - WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL[ - 'cached_tokens'] += int( - summarization_usage.get( - 'cached_tokens', 0) or 0) - WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL[ - 'cache_creation_input_tokens'] += int( - summarization_usage.get( - 'cache_creation_input_tokens', 0) or 0) + WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL['pages'] += int( + summarization_usage.get('pages', 0) or 0 + ) + WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL['api_calls'] += int( + summarization_usage.get('api_calls', 0) or 0 + ) + WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL['prompt_tokens'] += int( + summarization_usage.get('prompt_tokens', 0) or 0 + ) + WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL['completion_tokens'] += int( + summarization_usage.get('completion_tokens', 0) or 0 + ) + WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL['cached_tokens'] += int( + summarization_usage.get('cached_tokens', 0) or 0 + ) + WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL['cache_creation_input_tokens'] += int( + summarization_usage.get('cache_creation_input_tokens', 0) or 0 + ) m = WebSearchTool._GLOBAL_SUMMARY_USAGE_BY_MODEL.setdefault( - model, { + model, + { 'pages': 0, 'api_calls': 0, 'prompt_tokens': 0, 'completion_tokens': 0, 'cached_tokens': 0, 'cache_creation_input_tokens': 0, - }) - m['pages'] += int( - summarization_usage.get('pages', 0) or 0) - m['api_calls'] += int( - summarization_usage.get('api_calls', 0) or 0) - m['prompt_tokens'] += int( - summarization_usage.get('prompt_tokens', 0) - or 0) - m['completion_tokens'] += int( - summarization_usage.get( - 'completion_tokens', 0) or 0) - m['cached_tokens'] += int( - summarization_usage.get('cached_tokens', 0) - or 0) + }, + ) + m['pages'] += int(summarization_usage.get('pages', 0) or 0) + m['api_calls'] += int(summarization_usage.get('api_calls', 0) or 0) + m['prompt_tokens'] += int(summarization_usage.get('prompt_tokens', 0) or 0) + m['completion_tokens'] += int(summarization_usage.get('completion_tokens', 0) or 0) + m['cached_tokens'] += int(summarization_usage.get('cached_tokens', 0) or 0) m['cache_creation_input_tokens'] += int( - summarization_usage.get( - 'cache_creation_input_tokens', 0) or 0) + summarization_usage.get('cache_creation_input_tokens', 0) or 0 + ) logger.info( '[web_search_summarization_usage] ' f"model={summarization_usage.get('model', self._summarizer_model)} " @@ -1227,8 +1120,7 @@ async def _execute_search(self, engine_type: str, f"cache_creation_input_tokens={summarization_usage.get('cache_creation_input_tokens', 0)}" ) except Exception as e: - logger.warning( - f'Failed to record summarization usage: {e}') + logger.warning(f'Failed to record summarization usage: {e}') # Replace original content with summaries for sr in search_results: @@ -1238,42 +1130,31 @@ async def _execute_search(self, engine_type: str, sr['content'] = summaries[url] sr['content_summarized'] = True sr['original_content_length'] = original_len - logger.debug( - f'Summarized content for {url[:50]}: ' - f"{original_len} -> {len(sr['content'])} chars") + logger.debug(f'Summarized content for {url[:50]}: {original_len} -> {len(sr["content"])} chars') # Format output output_results = [] for sr in search_results: item = { - 'url': - sr.get('url', ''), - 'title': - sr.get('title', ''), - 'published_at': - sr.get('published_date') or sr.get('published_at', ''), + 'url': sr.get('url', ''), + 'title': sr.get('title', ''), + 'published_at': sr.get('published_date') or sr.get('published_at', ''), } # Preserve arXiv-specific metadata (aligned with arxiv-mcp-server) if engine_type == 'arxiv': - item.update({ - 'id': - sr.get('arxiv_id', '') or '', # arXiv short id - 'abs_url': - sr.get('id', '') or '', # entry_id (abstract page) - 'pdf_url': - sr.get('pdf_url', '') or '', - 'abstract': - sr.get('summary', '') or '', - 'authors': - sr.get('authors', []) or [], - 'categories': - sr.get('categories', []) or [], - 'resource_uri': - sr.get('resource_uri', '') or '', - 'published': - sr.get('published_date') or sr.get('published_at', ''), - }) + item.update( + { + 'id': sr.get('arxiv_id', '') or '', # arXiv short id + 'abs_url': sr.get('id', '') or '', # entry_id (abstract page) + 'pdf_url': sr.get('pdf_url', '') or '', + 'abstract': sr.get('summary', '') or '', + 'authors': sr.get('authors', []) or [], + 'categories': sr.get('categories', []) or [], + 'resource_uri': sr.get('resource_uri', '') or '', + 'published': sr.get('published_date') or sr.get('published_at', ''), + } + ) if fetch_content: item['content'] = sr.get('content', '') @@ -1287,8 +1168,7 @@ async def _execute_search(self, engine_type: str, # Add summarization metadata if applicable if sr.get('content_summarized'): item['content_summarized'] = True - item['original_content_length'] = sr.get( - 'original_content_length', 0) + item['original_content_length'] = sr.get('original_content_length', 0) if self._enable_chunking and sr.get('chunks'): item['chunks'] = sr['chunks'] @@ -1317,12 +1197,9 @@ async def _execute_search(self, engine_type: str, } if fetch_content and self._content_fetcher: response['fetch_stats'] = { - 'per_url_timeout_s': - self._per_url_fetch_timeout_s, - 'urls_fetched_this_call': - fetch_attempts, - 'urls_timed_out': - fetch_timeouts, + 'per_url_timeout_s': self._per_url_fetch_timeout_s, + 'urls_fetched_this_call': fetch_attempts, + 'urls_timed_out': fetch_timeouts, } if tavily_extra: response.update(tavily_extra) @@ -1334,34 +1211,25 @@ async def _execute_search(self, engine_type: str, 'summarization_enabled': self._enable_summarization, } if self._enable_rerank: - response['optimization'][ - 'original_result_count'] = original_count + response['optimization']['original_result_count'] = original_count response['optimization']['filtered_to'] = len(output_results) if self._enable_summarization: - summarized_count = sum(1 for r in output_results - if r.get('content_summarized')) + summarized_count = sum(1 for r in output_results if r.get('content_summarized')) response['optimization']['pages_summarized'] = summarized_count # Include per-call usage + cumulative totals (separate from LLMAgent usage) if summarization_usage: - response['optimization'][ - 'summarization_usage'] = summarization_usage + response['optimization']['summarization_usage'] = summarization_usage response['optimization']['summarization_usage_total'] = { - 'model': - self._summary_usage_model or self._summarizer_model, - 'api_calls': - self._summary_usage_total.get('api_calls', 0), - 'prompt_tokens': - self._summary_usage_total.get('prompt_tokens', 0), - 'completion_tokens': - self._summary_usage_total.get('completion_tokens', 0), - 'total_tokens': - (self._summary_usage_total.get('prompt_tokens', 0) - + self._summary_usage_total.get('completion_tokens', 0)), - 'cached_tokens': - self._summary_usage_total.get('cached_tokens', 0), - 'cache_creation_input_tokens': - self._summary_usage_total.get( - 'cache_creation_input_tokens', 0), + 'model': self._summary_usage_model or self._summarizer_model, + 'api_calls': self._summary_usage_total.get('api_calls', 0), + 'prompt_tokens': self._summary_usage_total.get('prompt_tokens', 0), + 'completion_tokens': self._summary_usage_total.get('completion_tokens', 0), + 'total_tokens': ( + self._summary_usage_total.get('prompt_tokens', 0) + + self._summary_usage_total.get('completion_tokens', 0) + ), + 'cached_tokens': self._summary_usage_total.get('cached_tokens', 0), + 'cache_creation_input_tokens': self._summary_usage_total.get('cache_creation_input_tokens', 0), } # Process-wide totals so far (across all WebSearchTool instances) response['optimization'][ @@ -1369,8 +1237,7 @@ async def _execute_search(self, engine_type: str, ] = WebSearchTool.get_global_summarization_usage() # yapf: disable if self._spill_enabled: - od = getattr(self, 'output_dir', None) or getattr( - getattr(self, 'config', None), 'output_dir', '') or '' + od = getattr(self, 'output_dir', None) or getattr(getattr(self, 'config', None), 'output_dir', '') or '' if od: try: new_results, spill_meta = maybe_spill_web_search_payload( @@ -1387,47 +1254,34 @@ async def _execute_search(self, engine_type: str, response['results'] = new_results response['spill'] = spill_meta except Exception as e: - logger.warning( - f'web_search spill failed (returning full inline JSON): {e}' - ) + logger.warning(f'web_search spill failed (returning full inline JSON): {e}') return _json_dumps(response) async def fetch_page(self, url: str) -> str: """Fetch and parse a single web page.""" if not url or not url.strip(): - return _json_dumps({ - 'status': 'error', - 'message': 'URL is required.' - }) + return _json_dumps({'status': 'error', 'message': 'URL is required.'}) result = await self._fetch_content_async_bounded(url.strip()) - return _json_dumps({ - 'status': - 'ok' if result.get('fetch_success') else 'error', - 'url': - url, - 'content': - result.get('content', ''), - 'published_at': - result.get('published_at', ''), - 'fetch_success': - result.get('fetch_success', False), - 'fetch_error': - result.get('fetch_error', ''), - 'fetch_timed_out': - bool(result.get('fetch_timed_out')), - 'chunks': - result.get('chunks') if self._enable_chunking else None, - }) + return _json_dumps( + { + 'status': 'ok' if result.get('fetch_success') else 'error', + 'url': url, + 'content': result.get('content', ''), + 'published_at': result.get('published_at', ''), + 'fetch_success': result.get('fetch_success', False), + 'fetch_error': result.get('fetch_error', ''), + 'fetch_timed_out': bool(result.get('fetch_timed_out')), + 'chunks': result.get('chunks') if self._enable_chunking else None, + } + ) # Backward compatibility aliases - async def web_search(self, - query: str, - num_results: Optional[int] = None, - fetch_content: bool = True, - **kwargs) -> str: + async def web_search( + self, query: str, num_results: Optional[int] = None, fetch_content: bool = True, **kwargs + ) -> str: """ Search the web and optionally fetch page content. @@ -1437,12 +1291,7 @@ async def web_search(self, # Use first engine as default engine_type = self._engine_types[0] if self._engine_types else 'arxiv' - tool_args = { - 'query': query, - 'num_results': num_results, - 'fetch_content': fetch_content, - **kwargs - } + tool_args = {'query': query, 'num_results': num_results, 'fetch_content': fetch_content, **kwargs} return await self._execute_search(engine_type, tool_args) diff --git a/ms_agent/tools/search_engine.py b/ms_agent/tools/search_engine.py index 6bb0be4d6..20ad080d2 100644 --- a/ms_agent/tools/search_engine.py +++ b/ms_agent/tools/search_engine.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional from dotenv import load_dotenv + from ms_agent.config.env import Env from ms_agent.tools.search.arxiv import ArxivSearch from ms_agent.tools.search.exa import ExaSearch @@ -29,10 +30,7 @@ def set_search_env_overrides(env_overrides: Optional[Dict[str, str]]) -> None: if hasattr(_search_env_local, 'overrides'): delattr(_search_env_local, 'overrides') return - _search_env_local.overrides = { - k: v - for k, v in env_overrides.items() if v is not None - } + _search_env_local.overrides = {k: v for k, v in env_overrides.items() if v is not None} def get_search_env_overrides() -> Dict[str, str]: @@ -62,12 +60,11 @@ def load_base_config(file_path: str) -> Dict[str, Any]: Env.load_env() if not os.path.exists(file_path): - logger.warning( - f'Config file {file_path} does not exist. Using default config (ArxivSearch).' - ) + logger.warning(f'Config file {file_path} does not exist. Using default config (ArxivSearch).') return {} import yaml + with open(file_path, 'r') as file: config = yaml.safe_load(file) @@ -130,12 +127,16 @@ def get_web_search_tool(config_file: str): # Engine override precedence: # 1) Thread-local override (per-request, e.g. FinResearch UI) # 2) Global environment variable (shared default) - engine_override = ((local_env.get(SEARCH_ENGINE_OVERRIDE_ENV, '') or '') - or (os.getenv(SEARCH_ENGINE_OVERRIDE_ENV, '') - or '')).strip().lower() - if engine_override and engine_override in (SearchEngineType.EXA.value, - SearchEngineType.SERPAPI.value, - SearchEngineType.ARXIV.value): + engine_override = ( + ((local_env.get(SEARCH_ENGINE_OVERRIDE_ENV, '') or '') or (os.getenv(SEARCH_ENGINE_OVERRIDE_ENV, '') or '')) + .strip() + .lower() + ) + if engine_override and engine_override in ( + SearchEngineType.EXA.value, + SearchEngineType.SERPAPI.value, + SearchEngineType.ARXIV.value, + ): search_config['engine'] = engine_override engine_name = (search_config.get('engine', '') or '').lower() @@ -145,14 +146,12 @@ def get_web_search_tool(config_file: str): override_serp_key = local_env.get('SERPAPI_API_KEY') if engine_name == SearchEngineType.EXA.value: - return ExaSearch( - api_key=override_exa_key or search_config.get( - 'exa_api_key', os.getenv('EXA_API_KEY', None))) + return ExaSearch(api_key=override_exa_key or search_config.get('exa_api_key', os.getenv('EXA_API_KEY', None))) elif engine_name == SearchEngineType.SERPAPI.value: return SerpApiSearch( - api_key=override_serp_key or search_config.get( - 'serpapi_api_key', os.getenv('SERPAPI_API_KEY', None)), - provider=search_config.get('provider', 'google').lower()) + api_key=override_serp_key or search_config.get('serpapi_api_key', os.getenv('SERPAPI_API_KEY', None)), + provider=search_config.get('provider', 'google').lower(), + ) elif engine_name == SearchEngineType.ARXIV.value: return ArxivSearch() else: diff --git a/ms_agent/tools/task_control_tool.py b/ms_agent/tools/task_control_tool.py index 01bd75f64..29ae693c8 100644 --- a/ms_agent/tools/task_control_tool.py +++ b/ms_agent/tools/task_control_tool.py @@ -2,10 +2,11 @@ import json from typing import Any, Dict, Optional +from omegaconf import DictConfig + from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.logger import get_logger -from omegaconf import DictConfig logger = get_logger() @@ -47,7 +48,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: server_name=_SERVER, description=( 'List all background tasks and their current status. ' - 'Returns task_id, tool_name, description, status, and duration.'), + 'Returns task_id, tool_name, description, status, and duration.' + ), parameters={ 'type': 'object', 'properties': {}, @@ -74,8 +76,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: ] } - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: if self._task_manager is None: return 'TaskManager not available.' @@ -90,14 +91,17 @@ async def call_tool(self, server_name: str, *, tool_name: str, duration = f'{t.ended_at - t.started_at:.1f}s' elif t.status == 'running': import time + duration = f'{time.monotonic() - t.started_at:.1f}s (running)' - rows.append({ - 'task_id': t.task_id, - 'tool_name': t.tool_name, - 'description': t.description, - 'status': t.status, - 'duration': duration, - }) + rows.append( + { + 'task_id': t.task_id, + 'tool_name': t.tool_name, + 'description': t.description, + 'status': t.status, + 'duration': duration, + } + ) return json.dumps(rows, ensure_ascii=False, indent=2) if tool_name == 'cancel_task': diff --git a/ms_agent/tools/todolist_tool.py b/ms_agent/tools/todolist_tool.py index aee860134..9ab5cbced 100644 --- a/ms_agent/tools/todolist_tool.py +++ b/ms_agent/tools/todolist_tool.py @@ -1,9 +1,9 @@ +import json import os import time from dataclasses import dataclass from typing import Any, Dict, List, Optional -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.utils import file_lock, render_markdown_todo @@ -40,18 +40,14 @@ def _write_text(path: str, content: str) -> None: def _validate_status(status: str) -> str: allowed = {'pending', 'in_progress', 'completed', 'cancelled'} if status not in allowed: - raise ValueError( - f'Invalid todo status "{status}", must be one of {sorted(allowed)}.' - ) + raise ValueError(f'Invalid todo status "{status}", must be one of {sorted(allowed)}.') return status def _validate_priority(priority: str) -> str: allowed = {'high', 'medium', 'low'} if priority not in allowed: - raise ValueError( - f'Invalid todo priority "{priority}", must be one of {sorted(allowed)}.' - ) + raise ValueError(f'Invalid todo priority "{priority}", must be one of {sorted(allowed)}.') return priority @@ -82,14 +78,10 @@ def __init__(self, config, **kwargs): tool_cfg = getattr(getattr(config, 'tools'), 'todo_list') self.exclude_func(tool_cfg) - self._plan_filename = getattr(tool_cfg, 'plan_filename', - 'plan.json') if tool_cfg else 'plan.json' - self._plan_md_filename = getattr(tool_cfg, 'plan_md_filename', - 'plan.md') if tool_cfg else 'plan.md' - self._lock_subdir = getattr(tool_cfg, 'lock_subdir', - '.locks') if tool_cfg else '.locks' - self._auto_render_md = bool(getattr(tool_cfg, 'auto_render_md', - True)) if tool_cfg else True + self._plan_filename = getattr(tool_cfg, 'plan_filename', 'plan.json') if tool_cfg else 'plan.json' + self._plan_md_filename = getattr(tool_cfg, 'plan_md_filename', 'plan.md') if tool_cfg else 'plan.md' + self._lock_subdir = getattr(tool_cfg, 'lock_subdir', '.locks') if tool_cfg else '.locks' + self._auto_render_md = bool(getattr(tool_cfg, 'auto_render_md', True)) if tool_cfg else True async def connect(self) -> None: # Nothing to connect; file-based tool. @@ -109,57 +101,47 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='todo_write', server_name=self.SERVER_NAME, - description= - ('Create or update the structured todo list (plan.json) for this session/workdir. ' - 'Use merge=true to merge by id (partial updates allowed for existing ids); ' - 'merge=false replaces the list (full items required).'), + description=( + 'Create or update the structured todo list (plan.json) for this session/workdir. ' + 'Use merge=true to merge by id (partial updates allowed for existing ids); ' + 'merge=false replaces the list (full items required).' + ), parameters={ 'type': 'object', 'properties': { 'merge': { - 'type': - 'boolean', - 'description': - ('If true, merge todo items into existing list by id (partial updates allowed). ' - 'If false, replace the list entirely.'), - 'default': - True, + 'type': 'boolean', + 'description': ( + 'If true, merge todo items into existing list by id (partial updates allowed). ' + 'If false, replace the list entirely.' + ), + 'default': True, }, 'todos': { 'type': 'array', - 'description': - 'The updated/created todo list.', + 'description': 'The updated/created todo list.', 'items': { 'type': 'object', 'properties': { 'id': { - 'type': - 'string', - 'description': - ('Unique identifier for the todo item. ' - 'e.g. "T_1", "T_2", ...'), + 'type': 'string', + 'description': ( + 'Unique identifier for the todo item. e.g. "T_1", "T_2", ...' + ), }, 'content': { - 'type': - 'string', - 'description': - 'Brief description of the task', + 'type': 'string', + 'description': 'Brief description of the task', }, 'status': { - 'type': - 'string', - 'enum': [ - 'pending', 'in_progress', - 'completed', 'cancelled' - ], - 'description': - 'Current status of the task', + 'type': 'string', + 'enum': ['pending', 'in_progress', 'completed', 'cancelled'], + 'description': 'Current status of the task', }, 'priority': { 'type': 'string', 'enum': ['high', 'medium', 'low'], - 'description': - 'Priority level of the task', + 'description': 'Priority level of the task', 'default': 'medium', }, }, @@ -178,8 +160,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='todo_read', server_name=self.SERVER_NAME, - description= - 'Read the current todo list for this session/workdir.', + description='Read the current todo list for this session/workdir.', parameters={ 'type': 'object', 'properties': {}, @@ -190,17 +171,16 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='todo_render_md', server_name=self.SERVER_NAME, - description= - 'Render plan.md from plan.json (checkbox view).', + description='Render plan.md from plan.json (checkbox view).', parameters={ 'type': 'object', 'properties': { 'path': { - 'type': - 'string', - 'description': - ('Optional relative output path for the markdown file. ' - 'Defaults to plan.md in the workdir.'), + 'type': 'string', + 'description': ( + 'Optional relative output path for the markdown file. ' + 'Defaults to plan.md in the workdir.' + ), } }, 'required': [], @@ -211,8 +191,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: } return tools - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await getattr(self, tool_name)(**(tool_args or {})) def _load_plan_locked(self, paths: _PlanPaths) -> Dict[str, Any]: @@ -244,15 +223,13 @@ def _load_plan_locked(self, paths: _PlanPaths) -> Dict[str, Any]: data['updated_at'] = _now_iso() return data - def _save_plan_locked(self, paths: _PlanPaths, plan: Dict[str, - Any]) -> None: + def _save_plan_locked(self, paths: _PlanPaths, plan: Dict[str, Any]) -> None: plan = dict(plan or {}) plan['schema_version'] = int(plan.get('schema_version', 1) or 1) plan['updated_at'] = _now_iso() _write_text(paths.plan_json, _json_dumps(plan)) - def _normalize_todos(self, todos: List[Dict[str, - Any]]) -> List[Dict[str, Any]]: + def _normalize_todos(self, todos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: normalized: List[Dict[str, Any]] = [] for idx, item in enumerate(todos or []): if not isinstance(item, dict): @@ -262,11 +239,9 @@ def _normalize_todos(self, todos: List[Dict[str, status = str(item.get('status', '')).strip() priority = str(item.get('priority', 'medium') or 'medium').strip() if not todo_id: - raise ValueError( - f'todos[{idx}].id is required and must be non-empty.') + raise ValueError(f'todos[{idx}].id is required and must be non-empty.') if not content: - raise ValueError( - f'todos[{idx}].content is required and must be non-empty.') + raise ValueError(f'todos[{idx}].content is required and must be non-empty.') _validate_status(status) _validate_priority(priority) # Keep extra fields as-is @@ -300,8 +275,7 @@ def _normalize_todo_updates( todo_id = str(item.get('id', '')).strip() if not todo_id: - raise ValueError( - f'todos[{idx}].id is required and must be non-empty.') + raise ValueError(f'todos[{idx}].id is required and must be non-empty.') is_new = todo_id not in existing_ids @@ -312,27 +286,20 @@ def _normalize_todo_updates( if 'content' in item: content = str(item.get('content', '')).strip() if not content: - raise ValueError( - f'todos[{idx}].content is required and must be non-empty.' - ) + raise ValueError(f'todos[{idx}].content is required and must be non-empty.') upd['content'] = content elif is_new: - raise ValueError( - f'todos[{idx}] is a new id "{todo_id}" so content is required.' - ) + raise ValueError(f'todos[{idx}] is a new id "{todo_id}" so content is required.') if 'status' in item: status = str(item.get('status', '')).strip() _validate_status(status) upd['status'] = status elif is_new: - raise ValueError( - f'todos[{idx}] is a new id "{todo_id}" so status is required.' - ) + raise ValueError(f'todos[{idx}] is a new id "{todo_id}" so status is required.') if 'priority' in item: - priority = str(item.get('priority', 'medium') - or 'medium').strip() + priority = str(item.get('priority', 'medium') or 'medium').strip() _validate_priority(priority) upd['priority'] = priority @@ -340,16 +307,11 @@ def _normalize_todo_updates( return normalized - def _merge_todos(self, base: List[Dict[str, Any]], - updates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _merge_todos(self, base: List[Dict[str, Any]], updates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: base_by_id: Dict[str, Dict[str, Any]] = { - str(t.get('id')): dict(t) - for t in (base or []) if isinstance(t, dict) and t.get('id') + str(t.get('id')): dict(t) for t in (base or []) if isinstance(t, dict) and t.get('id') } - order: List[str] = [ - str(t.get('id')) for t in (base or []) - if isinstance(t, dict) and t.get('id') - ] + order: List[str] = [str(t.get('id')) for t in (base or []) if isinstance(t, dict) and t.get('id')] for upd in updates or []: tid = str(upd.get('id')) if tid in base_by_id: @@ -386,9 +348,7 @@ def _render_plan_md_text(self, plan: Dict[str, Any]) -> str: lines.append('') return '\n'.join(lines) - async def todo_write(self, - todos: List[Dict[str, Any]], - merge: bool = True) -> str: + async def todo_write(self, todos: List[Dict[str, Any]], merge: bool = True) -> str: paths = self._paths() _ensure_dir(self.output_dir) _ensure_dir(paths.lock_dir) @@ -400,8 +360,7 @@ async def todo_write(self, # For merge=true, allow partial updates for existing ids. existing_full = self._normalize_todos(existing) existing_ids = {str(t.get('id')) for t in existing_full} - updates = self._normalize_todo_updates( - todos, existing_ids=existing_ids) + updates = self._normalize_todo_updates(todos, existing_ids=existing_ids) merged = self._merge_todos(existing_full, updates) plan['todos'] = self._normalize_todos(merged) else: @@ -413,18 +372,16 @@ async def todo_write(self, md_text = self._render_plan_md_text(plan) _write_text(paths.plan_md, md_text) - render_markdown_todo( - paths.plan_md, title='CURRENT PLAN', use_pager=False) + render_markdown_todo(paths.plan_md, title='CURRENT PLAN', use_pager=False) # Return a JSON list (opencode-style) so the model can easily read it. - return _json_dumps({ - 'status': - 'ok', - 'plan_path': - os.path.relpath(paths.plan_json, self.output_dir), - 'todos': - plan.get('todos', []), - }) + return _json_dumps( + { + 'status': 'ok', + 'plan_path': os.path.relpath(paths.plan_json, self.output_dir), + 'todos': plan.get('todos', []), + } + ) async def todo_read(self) -> str: paths = self._paths() @@ -433,8 +390,7 @@ async def todo_read(self) -> str: with file_lock(paths.lock_dir, self._plan_filename): plan = self._load_plan_locked(paths) if self._auto_render_md: - render_markdown_todo( - paths.plan_md, title='CURRENT PLAN', use_pager=False) + render_markdown_todo(paths.plan_md, title='CURRENT PLAN', use_pager=False) return _json_dumps(plan.get('todos', [])) @@ -442,8 +398,7 @@ async def todo_render_md(self, path: Optional[str] = None) -> str: paths = self._paths() _ensure_dir(self.output_dir) _ensure_dir(paths.lock_dir) - out_path = paths.plan_md if not path else os.path.join( - self.output_dir, path) + out_path = paths.plan_md if not path else os.path.join(self.output_dir, path) with file_lock(paths.lock_dir, self._plan_filename): plan = self._load_plan_locked(paths) diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 70703fbfb..f108b6b87 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -2,6 +2,7 @@ import asyncio import importlib import inspect +import json import os import sys import uuid @@ -9,7 +10,6 @@ from types import TracebackType from typing import Any, Dict, List, Optional -import json from ms_agent.llm.utils import Tool, ToolCall from ms_agent.tools.agent_tool import AgentTool from ms_agent.tools.base import ToolBase @@ -33,55 +33,42 @@ class ToolManager: - """Interacting with Agent class, hold all tools - """ + """Interacting with Agent class, hold all tools""" TOOL_SPLITER = '---' - def __init__(self, - config, - mcp_config: Optional[Dict[str, Any]] = None, - mcp_client: Optional[MCPClient] = None, - **kwargs): + def __init__( + self, config, mcp_config: Optional[Dict[str, Any]] = None, mcp_client: Optional[MCPClient] = None, **kwargs + ): self.config = config self.trust_remote_code = kwargs.get('trust_remote_code', False) self.extra_tools: List[ToolBase] = [] self.has_split_task_tool = False - if hasattr(config, 'tools') and hasattr(config.tools, - 'image_generator'): + if hasattr(config, 'tools') and hasattr(config.tools, 'image_generator'): self.extra_tools.append(ImageGenerator(config)) - if hasattr(config, 'tools') and hasattr(config.tools, - 'video_generator'): + if hasattr(config, 'tools') and hasattr(config.tools, 'video_generator'): self.extra_tools.append(VideoGenerator(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'file_system'): - self.extra_tools.append( - FileSystemTool( - config, trust_remote_code=self.trust_remote_code)) + self.extra_tools.append(FileSystemTool(config, trust_remote_code=self.trust_remote_code)) if hasattr(config, 'tools') and hasattr(config.tools, 'code_executor'): code_exec_cfg = getattr(config.tools, 'code_executor') - implementation = getattr(code_exec_cfg, 'implementation', - 'sandbox') - if isinstance(implementation, - str) and implementation.lower() == 'python_env': + implementation = getattr(code_exec_cfg, 'implementation', 'sandbox') + if isinstance(implementation, str) and implementation.lower() == 'python_env': self.extra_tools.append(LocalCodeExecutionTool(config)) - elif isinstance(implementation, - str) and implementation.lower() == 'sandbox': + elif isinstance(implementation, str) and implementation.lower() == 'sandbox': self.extra_tools.append(CodeExecutionTool(config)) else: - logger.warning( - f'Unknown code execution implementation: {implementation},' - f'using sandbox instead.') + logger.warning(f'Unknown code execution implementation: {implementation},using sandbox instead.') self.extra_tools.append(CodeExecutionTool(config)) - if hasattr(config, 'tools') and hasattr(config.tools, - 'financial_data_fetcher'): + if hasattr(config, 'tools') and hasattr(config.tools, 'financial_data_fetcher'): from ms_agent.tools.findata.findata_fetcher import FinancialDataFetcher + self.extra_tools.append(FinancialDataFetcher(config)) if hasattr(config, 'tools') and ( - getattr(config.tools, 'agent_tools', None) - or hasattr(config.tools, 'split_task')): - agent_tool = AgentTool( - config, trust_remote_code=self.trust_remote_code) + getattr(config.tools, 'agent_tools', None) or hasattr(config.tools, 'split_task') + ): + agent_tool = AgentTool(config, trust_remote_code=self.trust_remote_code) if agent_tool.enabled: self.extra_tools.append(agent_tool) if hasattr(config, 'tools') and hasattr(config.tools, 'todo_list'): @@ -92,13 +79,11 @@ def __init__(self, self.extra_tools.append(LocalSearchTool(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'task_control'): from ms_agent.tools.task_control_tool import TaskControlTool + self.extra_tools.append(TaskControlTool(config)) - self.tool_call_timeout = getattr(config, 'tool_call_timeout', - TOOL_CALL_TIMEOUT) - local_dir = self.config.local_dir if hasattr(self.config, - 'local_dir') else None - if hasattr(config, 'tools') and hasattr(config.tools, - TOOL_PLUGIN_NAME): + self.tool_call_timeout = getattr(config, 'tool_call_timeout', TOOL_CALL_TIMEOUT) + local_dir = self.config.local_dir if hasattr(self.config, 'local_dir') else None + if hasattr(config, 'tools') and hasattr(config.tools, TOOL_PLUGIN_NAME): plugins = getattr(config.tools, TOOL_PLUGIN_NAME) for plugin in plugins: subdir = os.path.dirname(plugin) @@ -119,11 +104,7 @@ def __init__(self, if _plugin.endswith('.py'): _plugin = _plugin[:-3] plugin_file = importlib.import_module(_plugin) - module_classes = { - name: cls - for name, cls in inspect.getmembers( - plugin_file, inspect.isclass) - } + module_classes = {name: cls for name, cls in inspect.getmembers(plugin_file, inspect.isclass)} for name, cls in module_classes.items(): # Find cls which base class is `ToolBase` if issubclass(cls, ToolBase) and cls.__module__ == _plugin: @@ -173,15 +154,12 @@ async def cleanup(self): pass async def reindex_tool(self): - - def extend_tool(tool_ins: ToolBase, server_name: str, - tool_list: List[Tool]): + def extend_tool(tool_ins: ToolBase, server_name: str, tool_list: List[Tool]): for tool in tool_list: # Subtract the length of the tool name splitter - max_server_len = MAX_TOOL_NAME_LEN - len( - tool['tool_name']) - len(self.TOOL_SPLITER) + max_server_len = MAX_TOOL_NAME_LEN - len(tool['tool_name']) - len(self.TOOL_SPLITER) if len(server_name) > max_server_len: - key = f"{server_name[:max(0, max_server_len)]}{self.TOOL_SPLITER}{tool['tool_name']}" + key = f"{server_name[: max(0, max_server_len)]}{self.TOOL_SPLITER}{tool['tool_name']}" else: key = f"{server_name}{self.TOOL_SPLITER}{tool['tool_name']}" assert key not in self._tool_index, f'Tool name duplicated {tool["tool_name"]}' @@ -201,7 +179,7 @@ async def get_tools(self): # Return tools in deterministic order to improve prompt/prefix cache hit rate # across process restarts and across different MCP tool listing orders. tools = [value[2] for value in self._tool_index.values()] - return sorted(tools, key=lambda t: (t.get('tool_name', ''), )) + return sorted(tools, key=lambda t: (t.get('tool_name', ''),)) async def single_call_tool(self, tool_info: ToolCall): if self._concurrent_limiter is None: @@ -209,8 +187,7 @@ async def single_call_tool(self, tool_info: ToolCall): self._init_lock = asyncio.Lock() async with self._init_lock: if self._concurrent_limiter is None: - self._concurrent_limiter = asyncio.Semaphore( - MAX_CONCURRENT_TOOLS) + self._concurrent_limiter = asyncio.Semaphore(MAX_CONCURRENT_TOOLS) async with self._concurrent_limiter: brief_info = json.dumps(tool_info, ensure_ascii=False) @@ -231,27 +208,27 @@ async def single_call_tool(self, tool_info: ToolCall): call_args = dict(tool_args or {}) call_id = tool_info.get('id') or str(uuid.uuid4()) call_args['__call_id'] = call_id - elif isinstance( - tool_ins, - LocalCodeExecutionTool) and tool_name.endswith( - f'{self.TOOL_SPLITER}shell_executor'): + elif isinstance(tool_ins, LocalCodeExecutionTool) and tool_name.endswith( + f'{self.TOOL_SPLITER}shell_executor' + ): call_args = dict(tool_args or {}) - call_args['__call_id'] = tool_info.get('id') or str( - uuid.uuid4()) + call_args['__call_id'] = tool_info.get('id') or str(uuid.uuid4()) response = await asyncio.wait_for( tool_ins.call_tool( - server_name, - tool_name=tool_name.split(self.TOOL_SPLITER)[1], - tool_args=call_args), - timeout=self.tool_call_timeout) + server_name, tool_name=tool_name.split(self.TOOL_SPLITER)[1], tool_args=call_args + ), + timeout=self.tool_call_timeout, + ) return response except asyncio.TimeoutError: import traceback + logger.warning(traceback.format_exc()) # TODO: How to get the information printed by the tool before hanging to return to the model? return f'Execute tool call timeout: {brief_info}' except Exception as e: import traceback + logger.warning(traceback.format_exc()) return f'Tool calling failed: {brief_info}, details: {str(e)}' @@ -261,7 +238,6 @@ async def parallel_call_tool(self, tool_list: List[ToolCall]): return result async def __aenter__(self) -> 'ToolManager': - return self async def __aexit__( diff --git a/ms_agent/tools/video_generator/ds_video_gen.py b/ms_agent/tools/video_generator/ds_video_gen.py index d757038d3..48db015df 100644 --- a/ms_agent/tools/video_generator/ds_video_gen.py +++ b/ms_agent/tools/video_generator/ds_video_gen.py @@ -8,32 +8,27 @@ class DSVideoGenerator: - def __init__(self, config, temp_dir): self.config = config self.temp_dir = temp_dir os.makedirs(self.temp_dir, exist_ok=True) - async def generate_video(self, - positive_prompt, - size='1280x720', - seconds=4): + async def generate_video(self, positive_prompt, size='1280x720', seconds=4): video_generator = self.config.tools.video_generator - base_url = (getattr(video_generator, 'base_url', None) - or 'https://dashscope.aliyuncs.com').strip('/') + base_url = (getattr(video_generator, 'base_url', None) or 'https://dashscope.aliyuncs.com').strip('/') api_key = video_generator.api_key model_id = video_generator.model assert api_key is not None task_id = str(uuid.uuid4())[:8] output_file = os.path.join(self.temp_dir, f'{task_id}.mp4') - video_url = await self._generate_video(base_url, api_key, model_id, - positive_prompt, size, seconds) + video_url = await self._generate_video(base_url, api_key, model_id, positive_prompt, size, seconds) await self.download_video(video_url, output_file) return output_file @staticmethod async def download_video(video_url, output_file): import aiohttp + max_retries = 3 retry_count = 0 @@ -41,8 +36,7 @@ async def download_video(video_url, output_file): headers = {} while retry_count < max_retries: try: - async with session.get( - video_url, headers=headers) as video_resp: + async with session.get(video_url, headers=headers) as video_resp: video_resp.raise_for_status() video_content = await video_resp.read() with open(output_file, 'wb') as f: @@ -58,6 +52,7 @@ async def download_video(video_url, output_file): @staticmethod async def _generate_video(base_url, api_key, model, prompt, size, seconds): import aiohttp + base_url = base_url.strip('/') create_endpoint = '/api/v1/services/aigc/model-evaluation/async-inference/' @@ -73,23 +68,19 @@ async def _generate_video(base_url, api_key, model, prompt, size, seconds): 'size': size, 'seconds': seconds, }, - 'parameters': {} + 'parameters': {}, } async with aiohttp.ClientSession() as session: - async with session.post( - f'{base_url}{create_endpoint}', headers=headers, - json=payload) as resp: + async with session.post(f'{base_url}{create_endpoint}', headers=headers, json=payload) as resp: resp.raise_for_status() response_data = await resp.json() task_id = response_data['output']['task_id'] if not task_id: - raise RuntimeError( - f'No task ID in response: {response_data}') + raise RuntimeError(f'No task ID in response: {response_data}') - return await DSVideoGenerator._poll_video_task( - session, base_url, task_id, headers) + return await DSVideoGenerator._poll_video_task(session, base_url, task_id, headers) @staticmethod async def _poll_video_task(session, base_url, task_id, headers): @@ -106,28 +97,21 @@ async def _poll_video_task(session, base_url, task_id, headers): await asyncio.sleep(poll_interval) elapsed_time += poll_interval - async with session.get( - f'{base_url}{poll_endpoint}', headers=headers) as result: + async with session.get(f'{base_url}{poll_endpoint}', headers=headers) as result: result.raise_for_status() data = await result.json() status = data['output']['task_status'] - logger.info( - f'Task {task_id} status: {status}, detailed message: {str(data)}' - ) + logger.info(f'Task {task_id} status: {status}, detailed message: {str(data)}') if status in success_statuses: video_url = data['output']['video_url'] if not video_url: - raise RuntimeError( - f'Video URL not found in response: {data}') + raise RuntimeError(f'Video URL not found in response: {data}') return video_url elif status in failed_statuses: - error_msg = data['output'].get( - 'message') or 'Unknown error' + error_msg = data['output'].get('message') or 'Unknown error' raise RuntimeError(f'Video generation failed: {error_msg}') poll_interval = min(poll_interval * 1.2, max_poll_interval) - raise TimeoutError( - f'Video generation task {task_id} timed out after {max_wait_time} seconds' - ) + raise TimeoutError(f'Video generation task {task_id} timed out after {max_wait_time} seconds') diff --git a/ms_agent/tools/video_generator/video_gen.py b/ms_agent/tools/video_generator/video_gen.py index 7578ebf59..9af5e2599 100644 --- a/ms_agent/tools/video_generator/video_gen.py +++ b/ms_agent/tools/video_generator/video_gen.py @@ -6,15 +6,14 @@ class VideoGenerator(ToolBase): - def __init__(self, config): super().__init__(config) - self.temp_dir = os.path.join(self.output_dir, '.temp', - 'video_generator') + self.temp_dir = os.path.join(self.output_dir, '.temp', 'video_generator') os.makedirs(self.temp_dir, exist_ok=True) video_generator = self.config.video_generator if video_generator.type == 'dashscope': from .ds_video_gen import DSVideoGenerator + self.generator = DSVideoGenerator(self.config, self.temp_dir) else: raise NotImplementedError() @@ -28,32 +27,25 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='generate_video', server_name='video_generator', - description= - 'Generate a video with a positive prompt, and return the video file path.', + description='Generate a video with a positive prompt, and return the video file path.', parameters={ 'type': 'object', 'properties': { - 'positive_prompt': { - 'type': 'string', - 'description': - 'The prompt to generate the image.' - }, + 'positive_prompt': {'type': 'string', 'description': 'The prompt to generate the image.'}, 'seconds': { - 'type': - 'integer', - 'description': - 'The generated video seconds, supported is 4/8/12' - } + 'type': 'integer', + 'description': 'The generated video seconds, supported is 4/8/12', + }, }, 'required': ['positive_prompt'], - 'additionalProperties': False - }) + 'additionalProperties': False, + }, + ) ] } async def generate_video(self, positive_prompt, **kwargs): return await self.generator.generate_video(positive_prompt, **kwargs) - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await self.generate_video(**tool_args) diff --git a/ms_agent/utils/__init__.py b/ms_agent/utils/__init__.py index 32655bd93..09061daab 100644 --- a/ms_agent/utils/__init__.py +++ b/ms_agent/utils/__init__.py @@ -2,7 +2,6 @@ from .llm_utils import async_retry, retry from .logger import get_logger from .prompt import get_fact_retrieval_prompt -from .utils import (assert_package_exist, enhance_error, read_history, - save_history, strtobool) +from .utils import assert_package_exist, enhance_error, read_history, save_history, strtobool MAX_CONTINUE_RUNS = 3 diff --git a/ms_agent/utils/artifact_manager.py b/ms_agent/utils/artifact_manager.py index 9953f25c0..92ac60604 100644 --- a/ms_agent/utils/artifact_manager.py +++ b/ms_agent/utils/artifact_manager.py @@ -49,8 +49,7 @@ def pack_text_result( out.update(extra) return out - safe_id = ''.join(c if c.isalnum() or c in '-_' else '_' for c in call_id - )[:120] or 'call' + safe_id = ''.join(c if c.isalnum() or c in '-_' else '_' for c in call_id)[:120] or 'call' rel_dir = Path(tool_name) / safe_id out_dir = self._artifact_root / rel_dir out_dir.mkdir(parents=True, exist_ok=True) @@ -60,21 +59,14 @@ def pack_text_result( fpath = out_dir / fname fpath.write_text(body, encoding='utf-8', errors='replace') rel = fpath.relative_to(self._root).as_posix() - preview = _make_preview(body, self.preview_head_chars, - self.preview_tail_chars) + preview = _make_preview(body, self.preview_head_chars, self.preview_tail_chars) result = { - 'output': stdout[:self.preview_head_chars] - if len(stdout) > self.preview_head_chars else stdout, - 'error': - (stderr[:self.preview_head_chars] if stderr else None), - 'truncated': - True, - 'artifact_path': - rel, - 'preview': - preview, - 'artifact_bytes': - len(enc), + 'output': stdout[: self.preview_head_chars] if len(stdout) > self.preview_head_chars else stdout, + 'error': (stderr[: self.preview_head_chars] if stderr else None), + 'truncated': True, + 'artifact_path': rel, + 'preview': preview, + 'artifact_bytes': len(enc), } if extra: result.update(extra) @@ -95,23 +87,15 @@ def pack_json_shell_result( call_id=call_id, stdout=stdout, stderr=stderr, - extra={ - k: v - for k, v in payload.items() if k not in ('output', 'error') - }, + extra={k: v for k, v in payload.items() if k not in ('output', 'error')}, ) # pack_text_result merged extra into top level; rebuild standard shell shape out = { - 'success': - payload.get('success'), - 'output': - packed.get('output'), - 'error': - packed.get('error'), - 'return_code': - payload.get('return_code'), - 'truncated': - packed.get('truncated', False), + 'success': payload.get('success'), + 'output': packed.get('output'), + 'error': packed.get('error'), + 'return_code': payload.get('return_code'), + 'truncated': packed.get('truncated', False), } if packed.get('artifact_path'): out['artifact_path'] = packed['artifact_path'] @@ -123,4 +107,4 @@ def pack_json_shell_result( def _make_preview(text: str, head: int, tail: int) -> str: if len(text) <= head + tail: return text - return (text[:head] + '\n... [truncated] ...\n' + text[-tail:]) + return text[:head] + '\n... [truncated] ...\n' + text[-tail:] diff --git a/ms_agent/utils/constants.py b/ms_agent/utils/constants.py index e068e654e..9929a26e4 100644 --- a/ms_agent/utils/constants.py +++ b/ms_agent/utils/constants.py @@ -61,36 +61,30 @@ class ServiceConfig: @dataclass class ModelscopeConfig(ServiceConfig): - def __init__(self): super().__init__(base_url='https://api-inference.modelscope.cn/v1') @dataclass class DashscopeConfig(ServiceConfig): - def __init__(self): - super().__init__( - base_url='https://dashscope.aliyuncs.com/compatible-mode/v1') + super().__init__(base_url='https://dashscope.aliyuncs.com/compatible-mode/v1') @dataclass class DeepseekConfig(ServiceConfig): - def __init__(self): super().__init__(base_url='https://api.deepseek.com/v1') @dataclass class AnthropicConfig(ServiceConfig): - def __init__(self): # without /v1, using Anthropic API super().__init__(base_url='https://api.anthropic.com') class OpenaiConfig(ServiceConfig): - def __init__(self): super().__init__(base_url='https://api.openai.com/v1') diff --git a/ms_agent/utils/llm_utils.py b/ms_agent/utils/llm_utils.py index 4f800bd45..c9fc50036 100644 --- a/ms_agent/utils/llm_utils.py +++ b/ms_agent/utils/llm_utils.py @@ -11,15 +11,15 @@ T = TypeVar('T') -def retry(max_attempts: int = 3, - delay: float = 1.0, - backoff_factor: float = 2.0, - exceptions: Union[Type[Exception], Tuple[Type[Exception], - ...]] = Exception): +def retry( + max_attempts: int = 3, + delay: float = 1.0, + backoff_factor: float = 2.0, + exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception, +): """Retry doing something""" def decorator(func: Callable[..., T]) -> Callable[..., T]: - @functools.wraps(func) def wrapper(*args, **kwargs) -> T: current_delay = delay @@ -30,6 +30,7 @@ def wrapper(*args, **kwargs) -> T: return func(*args, **kwargs) except exceptions as e: import traceback + logger.warning(traceback.format_exc()) last_exception = e if attempt < max_attempts: @@ -42,7 +43,8 @@ def wrapper(*args, **kwargs) -> T: else: logger.error( f'Attempt to call {func.__name__} over {max_attempts} times. ' - f'The last exception message: {e}') + f'The last exception message: {e}' + ) raise last_exception return wrapper @@ -50,15 +52,15 @@ def wrapper(*args, **kwargs) -> T: return decorator -def async_retry(max_attempts: int = 3, - delay: float = 1.0, - backoff_factor: float = 2.0, - exceptions: Union[Type[Exception], Tuple[Type[Exception], - ...]] = Exception): +def async_retry( + max_attempts: int = 3, + delay: float = 1.0, + backoff_factor: float = 2.0, + exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception, +): """Retry doing something""" def decorator(func: Callable[..., T]) -> Callable[..., T]: - @functools.wraps(func) async def wrapper(*args, **kwargs) -> AsyncGenerator[T, Any]: current_delay = delay @@ -71,6 +73,7 @@ async def wrapper(*args, **kwargs) -> AsyncGenerator[T, Any]: return except exceptions as e: import traceback + logger.warning(traceback.format_exc()) last_exception = e if attempt < max_attempts: @@ -83,7 +86,8 @@ async def wrapper(*args, **kwargs) -> AsyncGenerator[T, Any]: else: logger.error( f'Attempt to call {func.__name__} over {max_attempts} times. ' - f'The last exception message: {e}') + f'The last exception message: {e}' + ) raise last_exception return wrapper diff --git a/ms_agent/utils/logger.py b/ms_agent/utils/logger.py index 0ac9d48a8..b005d1ab7 100644 --- a/ms_agent/utils/logger.py +++ b/ms_agent/utils/logger.py @@ -30,10 +30,8 @@ def warning_once(self, msg, *args, **kwargs): self.warning(msg) -def get_logger(log_file: Optional[str] = None, - log_level: Optional[int] = None, - file_mode: str = 'w'): - """ Get logging logger +def get_logger(log_file: Optional[str] = None, log_level: Optional[int] = None, file_mode: str = 'w'): + """Get logging logger Args: log_file: Log filename, if specified, file handler will be added to diff --git a/ms_agent/utils/parser_utils.py b/ms_agent/utils/parser_utils.py index 5455dae52..ae2f6ba46 100644 --- a/ms_agent/utils/parser_utils.py +++ b/ms_agent/utils/parser_utils.py @@ -1,15 +1,15 @@ +import json import os import re from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Dict, List, Optional -import json - @dataclass class ImportInfo: """Detailed information about an import statement""" + # Source file path (resolved path) source_file: str # Original import statement @@ -24,8 +24,7 @@ class ImportInfo: is_type_only: bool = False def __repr__(self): - items_str = ', '.join( - self.imported_items) if self.imported_items else 'all' + items_str = ', '.join(self.imported_items) if self.imported_items else 'all' alias_str = f' as {self.alias}' if self.alias else '' return f"ImportInfo(file='{self.source_file}', items=[{items_str}]{alias_str})" @@ -64,8 +63,7 @@ def parse(self, code_content: str) -> List[ImportInfo]: # Pattern 1: from ... import ... from_pattern = r'^\s*from\s+([\w.]+)\s+import\s+(?:\(([^)]+)\)|([^\n]+))' - for match in re.finditer(from_pattern, code_content, - re.MULTILINE | re.DOTALL): + for match in re.finditer(from_pattern, code_content, re.MULTILINE | re.DOTALL): info = self._extract_from_import(match, code_content) if info: imports.append(info) @@ -78,8 +76,7 @@ def parse(self, code_content: str) -> List[ImportInfo]: return imports - def _extract_from_import(self, match, - code_content) -> Optional[ImportInfo]: + def _extract_from_import(self, match, code_content) -> Optional[ImportInfo]: """Extract 'from ... import ...' statement""" module_path = match.group(1) # Group 2 is parenthesized multi-line imports, group 3 is single-line imports @@ -90,7 +87,7 @@ def _extract_from_import(self, match, cleaned_items = [] for line in lines: if '#' in line: - line = line[:line.index('#')] + line = line[: line.index('#')] cleaned_items.append(line.strip()) imports_str = ','.join(cleaned_items) @@ -118,7 +115,8 @@ def _extract_from_import(self, match, source_file=file_path, raw_statement=match.group(0), imported_items=imported_items, - import_type='namespace' if '*' in imported_items else 'named') + import_type='namespace' if '*' in imported_items else 'named', + ) def _extract_simple_import(self, match) -> List[ImportInfo]: """Extract 'import ...' statement""" @@ -147,7 +145,9 @@ def _extract_simple_import(self, match) -> List[ImportInfo]: raw_statement=f'import {module}', imported_items=[module.split('.')[-1]], import_type='default', - alias=alias)) + alias=alias, + ) + ) return results @@ -206,8 +206,7 @@ def safe_relpath(path): target_dir = os.path.join(target_dir, module_file_path) # Try as package - package_init = os.path.normpath( - os.path.join(target_dir, '__init__.py')) + package_init = os.path.normpath(os.path.join(target_dir, '__init__.py')) if os.path.exists(package_init): return safe_relpath(package_init) @@ -227,26 +226,22 @@ def safe_relpath(path): module_file_path = module_path.replace('.', os.sep) # Try as package (relative to current file) - package_init = os.path.normpath( - os.path.join(self.current_dir, module_file_path, '__init__.py')) + package_init = os.path.normpath(os.path.join(self.current_dir, module_file_path, '__init__.py')) if os.path.exists(package_init): return safe_relpath(package_init) # Try as module (relative to current file) - module_file = os.path.normpath( - os.path.join(self.current_dir, module_file_path + '.py')) + module_file = os.path.normpath(os.path.join(self.current_dir, module_file_path + '.py')) if os.path.exists(module_file): return safe_relpath(module_file) # Try from output_dir (absolute import) if self.output_dir: - package_init_abs = os.path.normpath( - os.path.join(self.output_dir, module_file_path, '__init__.py')) + package_init_abs = os.path.normpath(os.path.join(self.output_dir, module_file_path, '__init__.py')) if os.path.exists(package_init_abs): return os.path.join(module_file_path, '__init__.py') - module_file_abs = os.path.normpath( - os.path.join(self.output_dir, module_file_path + '.py')) + module_file_abs = os.path.normpath(os.path.join(self.output_dir, module_file_path + '.py')) if os.path.exists(module_file_abs): return module_file_path + '.py' @@ -269,16 +264,14 @@ def parse(self, code_content: str) -> List[ImportInfo]: # Pattern 1: Mixed import - import Default, { Named } from 'path' # Must come BEFORE Pattern 2 and 3 to avoid partial matches mixed_pattern = r"^\s*import\s+(type\s+)?(\w+)\s*,\s*\{([^}]+)\}\s*from\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(mixed_pattern, code_content, - re.MULTILINE | re.DOTALL): + for match in re.finditer(mixed_pattern, code_content, re.MULTILINE | re.DOTALL): infos = self._extract_mixed_import(match) if infos: imports.extend(infos) # Pattern 2: Named import - import { A, B } from 'path' (supports multiline) named_pattern = r"^\s*import\s+(type\s+)?\{([^}]+)\}\s*from\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(named_pattern, code_content, - re.MULTILINE | re.DOTALL): + for match in re.finditer(named_pattern, code_content, re.MULTILINE | re.DOTALL): info = self._extract_named_import(match) if info: imports.append(info) @@ -292,40 +285,35 @@ def parse(self, code_content: str) -> List[ImportInfo]: # Pattern 4: Namespace import - import * as name from 'path' namespace_pattern = r"^\s*import\s+(type\s+)?\*\s+as\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(namespace_pattern, code_content, - re.MULTILINE): + for match in re.finditer(namespace_pattern, code_content, re.MULTILINE): info = self._extract_namespace_import(match) if info: imports.append(info) # Pattern 5: Side-effect import - import 'path' side_effect_pattern = r"^\s*import\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(side_effect_pattern, code_content, - re.MULTILINE): + for match in re.finditer(side_effect_pattern, code_content, re.MULTILINE): info = self._extract_side_effect_import(match) if info: imports.append(info) # Pattern 6: Named re-export - export { A, B } from 'path' (supports multiline) export_named_pattern = r"^\s*export\s+(type\s+)?\{([^}]+)\}\s+from\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(export_named_pattern, code_content, - re.MULTILINE | re.DOTALL): + for match in re.finditer(export_named_pattern, code_content, re.MULTILINE | re.DOTALL): info = self._extract_export_named(match) if info: imports.append(info) # Pattern 7: Wildcard re-export - export * from 'path' export_wildcard_pattern = r"^\s*export\s+(type\s+)?\*\s+from\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(export_wildcard_pattern, code_content, - re.MULTILINE): + for match in re.finditer(export_wildcard_pattern, code_content, re.MULTILINE): info = self._extract_export_wildcard(match) if info: imports.append(info) # Pattern 8: Named wildcard re-export - export * as name from 'path' export_named_wildcard_pattern = r"^\s*export\s+(type\s+)?\*\s+as\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(export_named_wildcard_pattern, code_content, - re.MULTILINE): + for match in re.finditer(export_named_wildcard_pattern, code_content, re.MULTILINE): info = self._extract_export_named_wildcard(match) if info: imports.append(info) @@ -371,7 +359,9 @@ def _extract_mixed_import(self, match) -> List[ImportInfo]: raw_statement=match.group(0), imported_items=[default_name], import_type='default', - is_type_only=is_type)) + is_type_only=is_type, + ) + ) # Create named import info results.append( @@ -380,7 +370,9 @@ def _extract_mixed_import(self, match) -> List[ImportInfo]: raw_statement=match.group(0), imported_items=named_items, import_type='named', - is_type_only=is_type)) + is_type_only=is_type, + ) + ) return results @@ -413,7 +405,8 @@ def _extract_named_import(self, match) -> Optional[ImportInfo]: raw_statement=match.group(0), imported_items=items, import_type='named', - is_type_only=is_type) + is_type_only=is_type, + ) def _extract_default_import(self, match) -> Optional[ImportInfo]: """Extract: import React from 'path'""" @@ -431,7 +424,8 @@ def _extract_default_import(self, match) -> Optional[ImportInfo]: raw_statement=match.group(0), imported_items=[name], import_type='default', - is_type_only=is_type) + is_type_only=is_type, + ) def _extract_namespace_import(self, match) -> Optional[ImportInfo]: """Extract: import * as name from 'path'""" @@ -450,7 +444,8 @@ def _extract_namespace_import(self, match) -> Optional[ImportInfo]: imported_items=['*'], import_type='namespace', alias=name, - is_type_only=is_type) + is_type_only=is_type, + ) def _extract_side_effect_import(self, match) -> Optional[ImportInfo]: """Extract: import 'path'""" @@ -461,10 +456,8 @@ def _extract_side_effect_import(self, match) -> Optional[ImportInfo]: resolved_path = import_path return ImportInfo( - source_file=resolved_path, - raw_statement=match.group(0), - imported_items=[], - import_type='side-effect') + source_file=resolved_path, raw_statement=match.group(0), imported_items=[], import_type='side-effect' + ) def _extract_export_named(self, match) -> Optional[ImportInfo]: """Extract: export { A, B } from 'path'""" @@ -495,7 +488,8 @@ def _extract_export_named(self, match) -> Optional[ImportInfo]: raw_statement=match.group(0), imported_items=items, import_type='named', - is_type_only=is_type) + is_type_only=is_type, + ) def _extract_export_wildcard(self, match) -> Optional[ImportInfo]: """Extract: export * from 'path'""" @@ -512,7 +506,8 @@ def _extract_export_wildcard(self, match) -> Optional[ImportInfo]: raw_statement=match.group(0), imported_items=['*'], import_type='namespace', - is_type_only=is_type) + is_type_only=is_type, + ) def _extract_export_named_wildcard(self, match) -> Optional[ImportInfo]: """Extract: export * as name from 'path'""" @@ -531,7 +526,8 @@ def _extract_export_named_wildcard(self, match) -> Optional[ImportInfo]: imported_items=['*'], import_type='namespace', alias=name, - is_type_only=is_type) + is_type_only=is_type, + ) def _resolve_js_path(self, import_path: str) -> Optional[str]: """Resolve JavaScript/TypeScript import path to file @@ -540,8 +536,7 @@ def _resolve_js_path(self, import_path: str) -> Optional[str]: Returns None for external packages. """ # Check if it's an external package (doesn't start with . or /) - is_external = not import_path.startswith( - '.') and not import_path.startswith('/') + is_external = not import_path.startswith('.') and not import_path.startswith('/') # External packages return None early if is_external: @@ -586,14 +581,11 @@ def to_relative(path): abs_resolved = os.path.join(self.output_dir, resolved) else: # Both are relative, make absolute from current working directory - abs_resolved = os.path.abspath( - os.path.join(self.output_dir, resolved)) + abs_resolved = os.path.abspath(os.path.join(self.output_dir, resolved)) # Try as directory with index file first if os.path.isdir(abs_resolved): - for index_file in [ - 'index.ts', 'index.tsx', 'index.js', 'index.jsx' - ]: + for index_file in ['index.ts', 'index.tsx', 'index.js', 'index.jsx']: index_path = os.path.join(abs_resolved, index_file) if os.path.exists(index_path): # Return relative path with index file @@ -603,8 +595,19 @@ def to_relative(path): # Try different extensions extensions = [ - '.ts', '.tsx', '.js', '.jsx', '.mjs', '.cjs', '.json', '.css', - '.scss', '.sass', '.less', '.module.css', '.module.scss' + '.ts', + '.tsx', + '.js', + '.jsx', + '.mjs', + '.cjs', + '.json', + '.css', + '.scss', + '.sass', + '.less', + '.module.css', + '.module.scss', ] for ext in extensions: @@ -648,9 +651,7 @@ def to_relative(path): def _load_path_aliases(self) -> Dict[str, str]: """Load path aliases from tsconfig.json and vite.config""" aliases = {} - excluded_dirs = { - 'node_modules', 'dist', 'build', '.git', '__pycache__' - } + excluded_dirs = {'node_modules', 'dist', 'build', '.git', '__pycache__'} # Search for config files for root, dirs, files in os.walk(self.output_dir): @@ -658,16 +659,12 @@ def _load_path_aliases(self) -> Dict[str, str]: # tsconfig.json if 'tsconfig.json' in files: - self._parse_tsconfig_aliases( - os.path.join(root, 'tsconfig.json'), root, aliases) + self._parse_tsconfig_aliases(os.path.join(root, 'tsconfig.json'), root, aliases) # vite.config.* - for config_file in [ - 'vite.config.js', 'vite.config.ts', 'vite.config.mjs' - ]: + for config_file in ['vite.config.js', 'vite.config.ts', 'vite.config.mjs']: if config_file in files: - self._parse_vite_config_aliases( - os.path.join(root, config_file), root, aliases) + self._parse_vite_config_aliases(os.path.join(root, config_file), root, aliases) # Default aliases if not aliases: @@ -680,41 +677,36 @@ def _load_path_aliases(self) -> Dict[str, str]: return aliases - def _parse_tsconfig_aliases(self, tsconfig_path: str, base_dir: str, - aliases: Dict[str, str]): + def _parse_tsconfig_aliases(self, tsconfig_path: str, base_dir: str, aliases: Dict[str, str]): """Parse tsconfig.json and extract path aliases""" try: with open(tsconfig_path, 'r', encoding='utf-8') as f: content = f.read() # Remove comments - content = re.sub( - r'//.*?\n|/\*.*?\*/', '', content, flags=re.DOTALL) + content = re.sub(r'//.*?\n|/\*.*?\*/', '', content, flags=re.DOTALL) tsconfig = json.loads(content) - if 'compilerOptions' in tsconfig and 'paths' in tsconfig[ - 'compilerOptions']: + if 'compilerOptions' in tsconfig and 'paths' in tsconfig['compilerOptions']: base_url = tsconfig['compilerOptions'].get('baseUrl', '.') - for alias, paths in tsconfig['compilerOptions'][ - 'paths'].items(): + for alias, paths in tsconfig['compilerOptions']['paths'].items(): clean_alias = alias.rstrip('/*') if paths and len(paths) > 0: target = paths[0].rstrip('/*') - resolved_target = os.path.normpath( - os.path.join(base_dir, base_url, target)) + resolved_target = os.path.normpath(os.path.join(base_dir, base_url, target)) if clean_alias not in aliases: aliases[clean_alias] = resolved_target except (json.JSONDecodeError, IOError, KeyError): pass - def _parse_vite_config_aliases(self, config_path: str, base_dir: str, - aliases: Dict[str, str]): + def _parse_vite_config_aliases(self, config_path: str, base_dir: str, aliases: Dict[str, str]): """Parse vite.config and extract path aliases""" try: with open(config_path, 'r', encoding='utf-8') as f: content = f.read() alias_pattern = ( r"['\"]([^'\"]+)['\"]\s*:\s*(?:path\.resolve\([^,]+,\s*['\"]" - r"([^'\"]+)['\"]\)|['\"]([^'\"]+)['\"])") + r"([^'\"]+)['\"]\)|['\"]([^'\"]+)['\"])" + ) for match in re.finditer(alias_pattern, content): alias_key = match.group(1) target = match.group(2) or match.group(3) @@ -732,7 +724,7 @@ def _resolve_alias_path(self, import_path: str) -> Optional[str]: if import_path == alias: return target elif import_path.startswith(alias + '/'): - remainder = import_path[len(alias) + 1:] + remainder = import_path[len(alias) + 1 :] return os.path.join(target, remainder) return None @@ -774,10 +766,8 @@ def _extract_java_import(self, match) -> Optional[ImportInfo]: items = [import_path.split('.')[-1]] return ImportInfo( - source_file=file_path, - raw_statement=match.group(0), - imported_items=items, - import_type=import_type) + source_file=file_path, raw_statement=match.group(0), imported_items=items, import_type=import_type + ) def _resolve_java_path(self, import_path: str) -> Optional[str]: """Resolve Java import to file path""" @@ -799,8 +789,7 @@ class ImportParserFactory: """Factory to get appropriate parser for file type""" @staticmethod - def get_parser(file_ext: str, output_dir: str, current_file: str, - current_dir: str) -> Optional[BaseImportParser]: + def get_parser(file_ext: str, output_dir: str, current_file: str, current_dir: str) -> Optional[BaseImportParser]: """Get parser instance for given file extension""" parsers = [ PythonImportParser, @@ -816,8 +805,7 @@ def get_parser(file_ext: str, output_dir: str, current_file: str, return None -def parse_imports(current_file: str, code_content: str, - output_dir: str) -> List[ImportInfo]: +def parse_imports(current_file: str, code_content: str, output_dir: str) -> List[ImportInfo]: """ Parse imports from code content (main entry point for backward compatibility) @@ -833,13 +821,11 @@ def parse_imports(current_file: str, code_content: str, List of ImportInfo objects for project files only (external packages are excluded) """ # Detect file extension - file_ext = os.path.splitext(current_file)[1].lstrip( - '.').lower() if current_file else '' + file_ext = os.path.splitext(current_file)[1].lstrip('.').lower() if current_file else '' current_dir = os.path.dirname(current_file) if current_file else '.' # Get appropriate parser - parser = ImportParserFactory.get_parser(file_ext, output_dir, current_file, - current_dir) + parser = ImportParserFactory.get_parser(file_ext, output_dir, current_file, current_dir) if not parser: return [] @@ -872,28 +858,41 @@ def parse_imports(current_file: str, code_content: str, # They don't start with '.', '/', or contain path separators (except scoped packages) # Check if it's a scoped package (starts with @ but file doesn't exist) - is_scoped_package = source.startswith('@') and not os.path.exists( - os.path.join(output_dir, source)) + is_scoped_package = source.startswith('@') and not os.path.exists(os.path.join(output_dir, source)) # Check if it's a project file (exists in output_dir) - full_path = os.path.join( - output_dir, source) if not os.path.isabs(source) else source + full_path = os.path.join(output_dir, source) if not os.path.isabs(source) else source is_project_file = os.path.exists(full_path) # Check if source has common code file extension # This helps identify resolved file paths vs package names - common_extensions = ('.js', '.jsx', '.ts', '.tsx', '.mjs', '.cjs', - '.java', '.py', '.pyw', '.css', '.scss', '.json') + common_extensions = ( + '.js', + '.jsx', + '.ts', + '.tsx', + '.mjs', + '.cjs', + '.java', + '.py', + '.pyw', + '.css', + '.scss', + '.json', + ) has_code_extension = source.endswith(common_extensions) # Check if it's an external package (package name without path separators) # For Java: java.util.List has dots but no file extension, so it's external # For JS: utils.js has extension, so it's a file - is_external = ( - is_scoped_package - or (not is_project_file and not has_code_extension - and not source.startswith('.') and not source.startswith('/') - and '/' not in source and os.sep not in source)) + is_external = is_scoped_package or ( + not is_project_file + and not has_code_extension + and not source.startswith('.') + and not source.startswith('/') + and '/' not in source + and os.sep not in source + ) if not is_external: project_imports.append(imp) diff --git a/ms_agent/utils/patcher.py b/ms_agent/utils/patcher.py index 4d721f266..b0602d435 100644 --- a/ms_agent/utils/patcher.py +++ b/ms_agent/utils/patcher.py @@ -4,8 +4,7 @@ T = TypeVar('T') -def patch(target_object: Any, attribute_name: str, - patch_value: Any) -> Callable[[Callable[..., T]], Callable[..., T]]: +def patch(target_object: Any, attribute_name: str, patch_value: Any) -> Callable[[Callable[..., T]], Callable[..., T]]: """ A decorator factory that patches an attribute of an object for the duration of a function's execution. @@ -30,9 +29,7 @@ def wrapper(*args: Any, **kwargs: Any) -> T: """ # Check if the target attribute exists if not hasattr(target_object, attribute_name): - raise AttributeError( - f'{target_object} does not have attribute {attribute_name}' - ) + raise AttributeError(f'{target_object} does not have attribute {attribute_name}') # 1. Save the original value (similar to __enter__) original_value = getattr(target_object, attribute_name) diff --git a/ms_agent/utils/push_to_hub.py b/ms_agent/utils/push_to_hub.py index a164603dd..3df0f9e56 100644 --- a/ms_agent/utils/push_to_hub.py +++ b/ms_agent/utils/push_to_hub.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 +import json import mimetypes import os import re @@ -8,11 +9,10 @@ from pathlib import Path from typing import List, Optional, Tuple -import json import requests + from ms_agent.utils.logger import get_logger -from ms_agent.utils.utils import (get_files_from_dir, is_package_installed, - text_hash) +from ms_agent.utils.utils import get_files_from_dir, is_package_installed, text_hash logger = get_logger() @@ -22,8 +22,7 @@ class PushToHub(ABC): The abstract base class for pushing files to a remote hub (e.g., GitHub). """ - def __init__(self, *args, **kwargs): - ... + def __init__(self, *args, **kwargs): ... @abstractmethod def push(self, *args, **kwargs): @@ -32,15 +31,16 @@ def push(self, *args, **kwargs): class PushToGitHub(PushToHub): - GITHUB_API_URL = 'https://api.github.com' - def __init__(self, - user_name: str, - repo_name: str, - token: str, - visibility: Optional[str] = 'public', - description: Optional[str] = None): + def __init__( + self, + user_name: str, + repo_name: str, + token: str, + visibility: Optional[str] = 'public', + description: Optional[str] = None, + ): """ Initialize the `PushToGitHub` class with authentication. @@ -72,9 +72,7 @@ def __init__(self, super().__init__() if not all([user_name, repo_name, token]): - raise ValueError( - 'GitHub username, repository name, and token must be provided.' - ) + raise ValueError('GitHub username, repository name, and token must be provided.') self.user_name = user_name self.repo_name = repo_name @@ -84,10 +82,12 @@ def __init__(self, # Create a session and set authentication headers self.session = requests.Session() - self.session.headers.update({ - 'Authorization': f'token {self.token}', - 'Accept': 'application/vnd.github.v3+json', - }) + self.session.headers.update( + { + 'Authorization': f'token {self.token}', + 'Accept': 'application/vnd.github.v3+json', + } + ) self._check_auth() self._create_github_repo( @@ -103,11 +103,8 @@ def _check_auth(self): RuntimeError: If authentication fails. """ user_data_resp = self.session.get(f'{self.GITHUB_API_URL}/user') - if user_data_resp.status_code != 200 or user_data_resp.json( - )['login'] != self.user_name: - raise RuntimeError( - 'Authentication failed! Please check your username and Personal Access Token.' - ) + if user_data_resp.status_code != 200 or user_data_resp.json()['login'] != self.user_name: + raise RuntimeError('Authentication failed! Please check your username and Personal Access Token.') def _create_github_repo( self, @@ -132,49 +129,37 @@ def _create_github_repo( raise ValueError('Repository name cannot be empty.') if visibility not in ['public', 'private']: - raise ValueError( - "Visibility must be either 'public' or 'private'.") + raise ValueError("Visibility must be either 'public' or 'private'.") if description is None: description = f'Repository - `{repo_name}` created by MS-Agent.' # Create the first commit with README url = f'{self.GITHUB_API_URL}/user/repos' - payload = { - 'name': repo_name, - 'description': description, - 'private': visibility == 'private', - 'auto_init': True - } + payload = {'name': repo_name, 'description': description, 'private': visibility == 'private', 'auto_init': True} response = self.session.post(url, json=payload) if response.status_code == 201: - logger.info( - f"Successfully created and initialized repository: {response.json()['html_url']}" - ) + logger.info(f"Successfully created and initialized repository: {response.json()['html_url']}") return response.json() elif response.status_code == 422: - error_message = response.json().get('errors', - [{}])[0].get('message', '') + error_message = response.json().get('errors', [{}])[0].get('message', '') if 'name already exists' in error_message: - logger.info( - f"Repository '{repo_name}' already exists. Will attempt to upload files to it." - ) + logger.info(f"Repository '{repo_name}' already exists. Will attempt to upload files to it.") return None else: - raise ValueError( - f'Validation error (422) while creating repository: {response.json()}' - ) + raise ValueError(f'Validation error (422) while creating repository: {response.json()}') else: logger.error(response.json()) - raise RuntimeError( - f'Failed to create repository: {response.status_code}') - - def _upload_files(self, - files_to_upload: List[Path], - work_dir: Path, - path_in_repo: Optional[str] = None, - branch: Optional[str] = 'main', - commit_message: Optional[str] = None) -> None: + raise RuntimeError(f'Failed to create repository: {response.status_code}') + + def _upload_files( + self, + files_to_upload: List[Path], + work_dir: Path, + path_in_repo: Optional[str] = None, + branch: Optional[str] = 'main', + commit_message: Optional[str] = None, + ) -> None: """ Upload multiple files to a GitHub repository in a single commit. @@ -202,8 +187,7 @@ def _upload_files(self, commit_response.raise_for_status() base_tree_sha = commit_response.json()['tree']['sha'] - logger.info( - f"Found '{branch}' branch, latest commit: {latest_commit_sha[:7]}") + logger.info(f"Found '{branch}' branch, latest commit: {latest_commit_sha[:7]}") # 2. Create a blob for each file blobs = [] @@ -211,13 +195,10 @@ def _upload_files(self, repo_base_path = Path(path_in_repo or '') for full_path in files_to_upload: - - file_relative_path: str = str( - full_path.relative_to(work_dir)).replace('\\', '/') + file_relative_path: str = str(full_path.relative_to(work_dir)).replace('\\', '/') mime_type, _ = mimetypes.guess_type(full_path) - is_binary = not (mime_type and mime_type.startswith('text/') - ) if mime_type else False + is_binary = not (mime_type and mime_type.startswith('text/')) if mime_type else False with open(full_path, 'rb') as f: content_bytes = f.read() @@ -236,22 +217,14 @@ def _upload_files(self, blob_url = f'{self.GITHUB_API_URL}/repos/{self.user_name}/{self.repo_name}/git/blobs' blob_payload = {'content': content, 'encoding': encoding} - response = self.session.post( - blob_url, data=json.dumps(blob_payload)) + response = self.session.post(blob_url, data=json.dumps(blob_payload)) response.raise_for_status() remote_path = repo_base_path / file_relative_path remote_path_str = str(remote_path).replace('\\', '/') - blobs.append({ - 'path': remote_path_str, - 'mode': '100644', - 'type': 'blob', - 'sha': response.json()['sha'] - }) - logger.info( - f" - Local: '{str(full_path)}' -> Remote: '{remote_path_str}'" - ) + blobs.append({'path': remote_path_str, 'mode': '100644', 'type': 'blob', 'sha': response.json()['sha']}) + logger.info(f" - Local: '{str(full_path)}' -> Remote: '{remote_path_str}'") # 3. Create a tree object tree_url = f'{self.GITHUB_API_URL}/repos/{self.user_name}/{self.repo_name}/git/trees' @@ -264,13 +237,11 @@ def _upload_files(self, # 4. Create a commit commit_url = f'{self.GITHUB_API_URL}/repos/{self.user_name}/{self.repo_name}/git/commits' commit_payload = { - 'message': commit_message - or f"Upload files to '{path_in_repo or '/'}'", + 'message': commit_message or f"Upload files to '{path_in_repo or '/'}'", 'tree': tree_sha, - 'parents': [latest_commit_sha] + 'parents': [latest_commit_sha], } - response = self.session.post( - commit_url, data=json.dumps(commit_payload)) + response = self.session.post(commit_url, data=json.dumps(commit_payload)) response.raise_for_status() new_commit_sha = response.json()['sha'] logger.info(f'Commit created: {new_commit_sha[:7]}') @@ -282,13 +253,15 @@ def _upload_files(self, logger.info(f"Branch '{branch}' successfully points to the new commit") - def push(self, - folder_path: str, - path_in_repo: Optional[str] = None, - branch: Optional[str] = 'main', - commit_message: Optional[str] = None, - exclude: Optional[List[str]] = None, - **kwargs) -> None: + def push( + self, + folder_path: str, + path_in_repo: Optional[str] = None, + branch: Optional[str] = 'main', + commit_message: Optional[str] = None, + exclude: Optional[List[str]] = None, + **kwargs, + ) -> None: """ Push files from a local directory to the GitHub repository. @@ -349,12 +322,9 @@ def __init__( """ if not is_package_installed('modelscope'): - raise ImportError( - 'ModelScope package is not installed. Please install it with `pip install modelscope`.' - ) + raise ImportError('ModelScope package is not installed. Please install it with `pip install modelscope`.') - from modelscope.hub.api import HubApi - from modelscope.hub.api import get_endpoint + from modelscope.hub.api import HubApi, get_endpoint self.api = HubApi() self.token = token @@ -363,10 +333,7 @@ def __init__( super().__init__() @staticmethod - def _preprocess(folder_path: str, - path_in_repo_url: str, - add_powered_by: bool = True) -> 'Tuple[str, str]': - + def _preprocess(folder_path: str, path_in_repo_url: str, add_powered_by: bool = True) -> 'Tuple[str, str]': report_filename = 'report.md' file_path = os.path.join(folder_path, report_filename) file_path_hash: str = text_hash(text=file_path, keep_n_chars=8) @@ -376,9 +343,7 @@ def _preprocess(folder_path: str, new_file_path = os.path.join(current_cache_path, new_report_filename) if not os.path.exists(file_path): - logger.warning( - f'The report file: {file_path} does not exist. Skipping preprocessing.' - ) + logger.warning(f'The report file: {file_path} does not exist. Skipping preprocessing.') return '', '' try: @@ -390,12 +355,15 @@ def _preprocess(folder_path: str, try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() - if add_powered_by and not content.lstrip().startswith( - '""" + + '\n\n' + + content + ) pattern = r'!\[(.*?)\]\((resources/.*?)\)' replacement = rf'![\1]({path_in_repo_url}\2)' @@ -404,35 +372,26 @@ def _preprocess(folder_path: str, if count > 0: with open(file_path, 'w', encoding='utf-8') as f: f.write(new_content) - logger.info( - f"Preprocessed {count} 'resources/' links in {file_path}.") + logger.info(f"Preprocessed {count} 'resources/' links in {file_path}.") else: - logger.info( - f'No "resources/" links found in {file_path}. No changes made.' - ) + logger.info(f'No "resources/" links found in {file_path}. No changes made.') except IOError as e: logger.error(f'Error reading or writing the report file: {e}') return '', '' except Exception as e: - logger.error( - f'Unexpected error during preprocessing of PushToModelScope: {e}' - ) + logger.error(f'Unexpected error during preprocessing of PushToModelScope: {e}') return '', '' return file_path, new_file_path @staticmethod def _postprocess(report_file_path: str, report_file_path_in_cache: str): - try: shutil.move(report_file_path_in_cache, report_file_path) - shutil.rmtree( - os.path.dirname(report_file_path_in_cache), ignore_errors=True) + shutil.rmtree(os.path.dirname(report_file_path_in_cache), ignore_errors=True) except FileNotFoundError: - logger.warning( - f'The backup file of report: {report_file_path_in_cache} does not exist.' - ) + logger.warning(f'The backup file of report: {report_file_path_in_cache} does not exist.') def push( self, @@ -477,8 +436,7 @@ def push( revision='master', ) target_url: str = f'{self.endpoint}/{repo_type}s/{repo_id}/files' - logger.info( - f'Successfully pushed files to ModelScope: {target_url}') + logger.info(f'Successfully pushed files to ModelScope: {target_url}') except Exception as e: logger.error(f'Failed to push files to ModelScope: {e}') finally: @@ -490,7 +448,6 @@ def push( class PushToHuggingFace(PushToHub): - def __init__(self, token: str): """ Initialize the `PushToHuggingFace` with authentication. @@ -541,8 +498,7 @@ def push( raise ValueError('Repository ID cannot be empty.') if repo_type not in ['model', 'dataset']: - raise ValueError( - "Repository type must be either 'model' or 'dataset'.") + raise ValueError("Repository type must be either 'model' or 'dataset'.") try: self.api.upload_folder( @@ -562,6 +518,5 @@ def push( f'https://huggingface.co/{repo_type_in_url}{repo_id}/tree/main/{path_in_repo or ""}' ) except Exception as e: - logger.error( - f'Failed to push files to {repo_id} on HuggingFace: {e}') + logger.error(f'Failed to push files to {repo_id} on HuggingFace: {e}') raise e diff --git a/ms_agent/utils/rate_limiter.py b/ms_agent/utils/rate_limiter.py index 3bb94ec52..5273dbf75 100644 --- a/ms_agent/utils/rate_limiter.py +++ b/ms_agent/utils/rate_limiter.py @@ -73,8 +73,7 @@ async def _wait_if_needed(self): elapsed = now - self._last_request_time if elapsed < self.min_request_interval: wait_time = self.min_request_interval - elapsed - logger.debug( - f'Enforcing min interval: waiting {wait_time:.3f}s') + logger.debug(f'Enforcing min interval: waiting {wait_time:.3f}s') await asyncio.sleep(wait_time) now = time.time() @@ -88,18 +87,14 @@ async def _wait_if_needed(self): # If rate limit reached, wait until oldest request expires if len(self._request_times) >= self.max_requests_per_second: oldest_request = self._request_times[0] - wait_time = 1.0 - (now - - oldest_request) + 0.01 # Add 10ms margin + wait_time = 1.0 - (now - oldest_request) + 0.01 # Add 10ms margin if wait_time > 0: - logger.debug( - f'Rate limit reached ({self.max_requests_per_second} req/s): ' - f'waiting {wait_time:.3f}s') + logger.debug(f'Rate limit reached ({self.max_requests_per_second} req/s): waiting {wait_time:.3f}s') await asyncio.sleep(wait_time) now = time.time() # Clean up expired records cutoff_time = now - 1.0 - while self._request_times and self._request_times[ - 0] < cutoff_time: + while self._request_times and self._request_times[0] < cutoff_time: self._request_times.popleft() # Record this request time @@ -140,23 +135,15 @@ def get_stats(self) -> dict: with self._lock: now = time.time() cutoff_time = now - 1.0 - recent_requests = sum(1 for t in self._request_times - if t >= cutoff_time) + recent_requests = sum(1 for t in self._request_times if t >= cutoff_time) return { - 'max_requests_per_second': - self.max_requests_per_second, - 'min_request_interval': - self.min_request_interval, - 'max_concurrent': - self.max_concurrent, - 'recent_requests_count': - recent_requests, - 'available_concurrent_slots': - self._semaphore._value, - 'last_request_ago': - now - self._last_request_time - if self._last_request_time > 0 else None, + 'max_requests_per_second': self.max_requests_per_second, + 'min_request_interval': self.min_request_interval, + 'max_concurrent': self.max_concurrent, + 'recent_requests_count': recent_requests, + 'available_concurrent_slots': self._semaphore._value, + 'last_request_ago': now - self._last_request_time if self._last_request_time > 0 else None, } def reset(self): @@ -215,7 +202,8 @@ def __init__( logger.info( f'AdaptiveRateLimiter initialized: {initial_requests_per_second} req/s ' - f'(range: {min_requests_per_second}-{max_requests_per_second})') + f'(range: {min_requests_per_second}-{max_requests_per_second})' + ) def record_success(self): """Record successful request""" @@ -227,13 +215,13 @@ def record_success(self): # Consecutive successes reached threshold, attempt to increase rate if self._consecutive_successes >= self._success_threshold: old_rps = self.max_requests_per_second - new_rps = min( - round(old_rps * self._recovery_factor), self._max_rps) + new_rps = min(round(old_rps * self._recovery_factor), self._max_rps) if new_rps > old_rps: self.max_requests_per_second = new_rps logger.info( f'Rate limit increased: {old_rps} → {new_rps} req/s ' - f'(after {self._consecutive_successes} successes)') + f'(after {self._consecutive_successes} successes)' + ) self._consecutive_successes = 0 def record_error(self, is_rate_limit_error: bool = False): @@ -252,46 +240,41 @@ def record_error(self, is_rate_limit_error: bool = False): # If rate limit error, immediately reduce rate if is_rate_limit_error: old_rps = self.max_requests_per_second - new_rps = max( - int(old_rps * self._backoff_factor), self._min_rps) + new_rps = max(int(old_rps * self._backoff_factor), self._min_rps) if new_rps < old_rps: self.max_requests_per_second = new_rps # Also increase minimum request interval - self.min_request_interval = min( - self.min_request_interval * 1.5, 2.0) + self.min_request_interval = min(self.min_request_interval * 1.5, 2.0) logger.warning( f'Rate limit error detected! Reducing rate: {old_rps} → {new_rps} req/s, ' - f'min_interval → {self.min_request_interval:.2f}s') + f'min_interval → {self.min_request_interval:.2f}s' + ) self._consecutive_errors = 0 # Consecutive errors reached threshold, reduce rate elif self._consecutive_errors >= self._error_threshold: old_rps = self.max_requests_per_second - new_rps = max( - int(old_rps * self._backoff_factor), self._min_rps) + new_rps = max(int(old_rps * self._backoff_factor), self._min_rps) if new_rps < old_rps: self.max_requests_per_second = new_rps logger.warning( f'Multiple errors detected! Reducing rate: {old_rps} → {new_rps} req/s ' - f'(after {self._consecutive_errors} errors)') + f'(after {self._consecutive_errors} errors)' + ) self._consecutive_errors = 0 def get_stats(self) -> dict: """Get extended statistics""" stats = super().get_stats() with self._lock: - stats.update({ - 'total_requests': - self._total_requests, - 'total_errors': - self._total_errors, - 'error_rate': - self._total_errors / max(self._total_requests, 1), - 'consecutive_successes': - self._consecutive_successes, - 'consecutive_errors': - self._consecutive_errors, - 'current_requests_per_second': - self.max_requests_per_second, - }) + stats.update( + { + 'total_requests': self._total_requests, + 'total_errors': self._total_errors, + 'error_rate': self._total_errors / max(self._total_requests, 1), + 'consecutive_successes': self._consecutive_successes, + 'consecutive_errors': self._consecutive_errors, + 'current_requests_per_second': self.max_requests_per_second, + } + ) return stats diff --git a/ms_agent/utils/snapshot.py b/ms_agent/utils/snapshot.py index 2014b49da..649acab26 100644 --- a/ms_agent/utils/snapshot.py +++ b/ms_agent/utils/snapshot.py @@ -8,8 +8,9 @@ All git commands are run with GIT_DIR and GIT_WORK_TREE explicitly set, so the snapshot repo is fully isolated from any surrounding repository. """ -import os + import json +import os import subprocess from typing import Optional @@ -21,8 +22,7 @@ _META_FILE = 'snapshot_meta.json' -def _git(args: list[str], work_tree: str, git_dir: str, - check: bool = True) -> subprocess.CompletedProcess: +def _git(args: list[str], work_tree: str, git_dir: str, check: bool = True) -> subprocess.CompletedProcess: env = os.environ.copy() env['GIT_DIR'] = git_dir env['GIT_WORK_TREE'] = work_tree @@ -51,10 +51,7 @@ def _configure_snapshot_repo_for_automation(work_tree: str, git_dir: str) -> Non Git-supported way to disable hooks (POSIX ``/dev/null``, Windows ``nul``). """ try: - _git(['config', 'core.hooksPath', os.devnull], - work_tree=work_tree, - git_dir=git_dir, - check=False) + _git(['config', 'core.hooksPath', os.devnull], work_tree=work_tree, git_dir=git_dir, check=False) except Exception: pass @@ -67,10 +64,8 @@ def _ensure_repo(output_dir: str) -> str: # Use non-bare init with explicit GIT_DIR — no --bare so work tree is supported. # Do NOT pass a path argument; GIT_DIR env var points git at our custom dir. _git(['init'], work_tree=output_dir, git_dir=git_dir) - _git(['config', 'user.email', 'ms-agent@snapshot'], - work_tree=output_dir, git_dir=git_dir) - _git(['config', 'user.name', 'ms-agent'], - work_tree=output_dir, git_dir=git_dir) + _git(['config', 'user.email', 'ms-agent@snapshot'], work_tree=output_dir, git_dir=git_dir) + _git(['config', 'user.name', 'ms-agent'], work_tree=output_dir, git_dir=git_dir) # Exclude the snapshot dir itself from tracking info_dir = os.path.join(git_dir, 'info') os.makedirs(info_dir, exist_ok=True) @@ -103,8 +98,7 @@ def _save_meta(output_dir: str, meta: dict) -> None: json.dump(meta, f, indent=2) -def take_snapshot(output_dir: str, message: str, - message_count: int = 0) -> Optional[str]: +def take_snapshot(output_dir: str, message: str, message_count: int = 0) -> Optional[str]: """ Stage all changes in output_dir and create a snapshot commit. @@ -127,15 +121,13 @@ def take_snapshot(output_dir: str, message: str, _git(['add', '-A'], work_tree=output_dir, git_dir=git_dir) # Check if there's anything to commit - status = _git(['status', '--porcelain'], - work_tree=output_dir, git_dir=git_dir) + status = _git(['status', '--porcelain'], work_tree=output_dir, git_dir=git_dir) if not status.stdout.strip(): return None # Nothing changed # Truncate message to keep commit subject readable subject = message.strip().replace('\n', ' ')[:120] - result = _git(['commit', '--no-verify', '-m', subject], - work_tree=output_dir, git_dir=git_dir) + result = _git(['commit', '--no-verify', '-m', subject], work_tree=output_dir, git_dir=git_dir) commit_hash = None for line in result.stdout.splitlines(): @@ -154,8 +146,7 @@ def take_snapshot(output_dir: str, message: str, return commit_hash except FileNotFoundError: - logger.warning_once( - '[snapshot] git not found — snapshots disabled.') + logger.warning_once('[snapshot] git not found — snapshots disabled.') return None except subprocess.CalledProcessError as e: logger.warning(f'[snapshot] git error: {e.stderr.strip()}') @@ -188,19 +179,20 @@ def list_snapshots(output_dir: str) -> list[dict]: parts = line.split('\t', 2) if len(parts) == 3: h = parts[0] - snapshots.append({ - 'hash': h, - 'date': parts[1], - 'message': parts[2], - 'message_count': meta.get(h, {}).get('message_count', 0), - }) + snapshots.append( + { + 'hash': h, + 'date': parts[1], + 'message': parts[2], + 'message_count': meta.get(h, {}).get('message_count', 0), + } + ) return snapshots except Exception: return [] -def restore_snapshot(output_dir: str, - commit_hash: str) -> tuple[bool, int]: +def restore_snapshot(output_dir: str, commit_hash: str) -> tuple[bool, int]: """ Restore output_dir to the state at commit_hash. @@ -212,8 +204,7 @@ def restore_snapshot(output_dir: str, logger.warning('[snapshot] No snapshot repo found.') return False, 0 try: - _git(['checkout', commit_hash, '--', '.'], - work_tree=output_dir, git_dir=git_dir) + _git(['checkout', commit_hash, '--', '.'], work_tree=output_dir, git_dir=git_dir) logger.info(f'[snapshot] Restored to {commit_hash}') meta = _load_meta(output_dir) message_count = meta.get(commit_hash, {}).get('message_count', 0) diff --git a/ms_agent/utils/stats.py b/ms_agent/utils/stats.py index 7ed705a1d..def45af31 100644 --- a/ms_agent/utils/stats.py +++ b/ms_agent/utils/stats.py @@ -1,11 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio +import json import os import time from datetime import datetime from typing import Any, Dict, Iterable, Optional -import json from ms_agent.llm.utils import Message from .logger import get_logger @@ -23,8 +23,7 @@ def _get_lock(path: str) -> asyncio.Lock: return lock -def get_stats_path(config: Any, - default_filename: str = 'workflow_stats.json') -> str: +def get_stats_path(config: Any, default_filename: str = 'workflow_stats.json') -> str: stats_file = getattr(config, 'stats_file', None) output_dir = getattr(config, 'output_dir', './output') if stats_file: @@ -47,8 +46,7 @@ def summarize_usage(messages: Optional[Iterable[Message]]) -> Dict[str, int]: prompt_tokens += int(getattr(msg, 'prompt_tokens', 0) or 0) completion_tokens += int(getattr(msg, 'completion_tokens', 0) or 0) cached_tokens += int(getattr(msg, 'cached_tokens', 0) or 0) - cache_creation_input_tokens += int( - getattr(msg, 'cache_creation_input_tokens', 0) or 0) + cache_creation_input_tokens += int(getattr(msg, 'cache_creation_input_tokens', 0) or 0) api_calls += int(getattr(msg, 'api_calls', 0) or 0) return { 'prompt_tokens': prompt_tokens, @@ -72,8 +70,7 @@ async def append_stats(path: str, record: Dict[str, Any]) -> None: with open(path, 'r', encoding='utf-8') as f: data = json.load(f) or [] except Exception as exc: - logger.warning( - f'Failed to read stats file {path}, resetting: {exc}') + logger.warning(f'Failed to read stats file {path}, resetting: {exc}') data = [] if not isinstance(data, list): data = [] @@ -85,16 +82,17 @@ async def append_stats(path: str, record: Dict[str, Any]) -> None: def build_timing_record( - *, - event: str, - agent_tag: Optional[str], - agent_type: Optional[str], - started_at: str, - ended_at: str, - duration_s: float, - status: str, - usage: Optional[Dict[str, int]] = None, - extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + *, + event: str, + agent_tag: Optional[str], + agent_type: Optional[str], + started_at: str, + ended_at: str, + duration_s: float, + status: str, + usage: Optional[Dict[str, int]] = None, + extra: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: record = { 'event': event, 'agent_tag': agent_tag, diff --git a/ms_agent/utils/stream_writer.py b/ms_agent/utils/stream_writer.py index 46c5cf95b..045c32e1c 100644 --- a/ms_agent/utils/stream_writer.py +++ b/ms_agent/utils/stream_writer.py @@ -21,6 +21,7 @@ File path: ``{output_dir}/subagents/{call_id}.stream.jsonl`` """ + import json import os import threading @@ -76,16 +77,17 @@ def on_start(self, agent_tag: Optional[str]) -> None: self._agent_tag = agent_tag try: self._file = open(self._path, 'w', encoding='utf-8') - self._write_line({ - 'type': 'header', - 'call_id': self._call_id, - 'tool_name': self._tool_name, - 'agent_tag': agent_tag or '', - 'ts': _now_iso(), - }) + self._write_line( + { + 'type': 'header', + 'call_id': self._call_id, + 'tool_name': self._tool_name, + 'agent_tag': agent_tag or '', + 'ts': _now_iso(), + } + ) except Exception as exc: - logger.warning( - 'SubAgentStreamWriter: failed to open %s: %s', self._path, exc) + logger.warning('SubAgentStreamWriter: failed to open %s: %s', self._path, exc) self._file = None def on_chunk(self, history: Any) -> None: @@ -102,13 +104,15 @@ def on_chunk(self, history: Any) -> None: with self._lock: if self._closed or self._file is None: return - for msg in messages[self._last_written_count:]: - self._write_line({ - 'type': 'message', - 'index': self._last_written_count, - 'message': _msg_to_dict(msg), - 'ts': _now_iso(), - }) + for msg in messages[self._last_written_count :]: + self._write_line( + { + 'type': 'message', + 'index': self._last_written_count, + 'message': _msg_to_dict(msg), + 'ts': _now_iso(), + } + ) self._last_written_count += 1 def on_end(self, history: Any) -> None: @@ -125,19 +129,20 @@ def on_end(self, history: Any) -> None: self._closed = True if self._file is not None: try: - self._write_line({ - 'type': 'footer', - 'call_id': self._call_id, - 'agent_tag': self._agent_tag or '', - 'status': 'complete', - 'total_messages': self._last_written_count, - 'ts': _now_iso(), - }) + self._write_line( + { + 'type': 'footer', + 'call_id': self._call_id, + 'agent_tag': self._agent_tag or '', + 'status': 'complete', + 'total_messages': self._last_written_count, + 'ts': _now_iso(), + } + ) self._file.flush() self._file.close() except Exception as exc: - logger.warning( - 'SubAgentStreamWriter: close error on %s: %s', self._path, exc) + logger.warning('SubAgentStreamWriter: close error on %s: %s', self._path, exc) finally: self._file = None @@ -153,15 +158,17 @@ def on_error(self, error: str) -> None: self._closed = True if self._file is not None: try: - self._write_line({ - 'type': 'footer', - 'call_id': self._call_id, - 'agent_tag': self._agent_tag or '', - 'status': 'error', - 'error': error, - 'total_messages': self._last_written_count, - 'ts': _now_iso(), - }) + self._write_line( + { + 'type': 'footer', + 'call_id': self._call_id, + 'agent_tag': self._agent_tag or '', + 'status': 'error', + 'error': error, + 'total_messages': self._last_written_count, + 'ts': _now_iso(), + } + ) self._file.flush() self._file.close() except Exception: diff --git a/ms_agent/utils/task_manager.py b/ms_agent/utils/task_manager.py index a897e2fa9..cff5b5691 100644 --- a/ms_agent/utils/task_manager.py +++ b/ms_agent/utils/task_manager.py @@ -14,8 +14,8 @@ @dataclass class BackgroundTask: task_id: str - task_type: str # 'agent' | 'shell' - tool_name: str # which tool spawned this + task_type: str # 'agent' | 'shell' + tool_name: str # which tool spawned this description: str status: str = 'running' # 'running' | 'completed' | 'failed' | 'killed' proc: Optional[Any] = field(default=None, repr=False) # mp.Process or asyncio.Task diff --git a/ms_agent/utils/thread_util.py b/ms_agent/utils/thread_util.py index 16e46eba9..815407b48 100644 --- a/ms_agent/utils/thread_util.py +++ b/ms_agent/utils/thread_util.py @@ -5,19 +5,16 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from functools import wraps -from ms_agent.utils.logger import get_logger from tqdm.auto import tqdm +from ms_agent.utils.logger import get_logger + logger = get_logger() -DEFAULT_MAX_WORKERS = int( - os.getenv('DEFAULT_MAX_WORKERS', min(8, - os.cpu_count() + 4))) +DEFAULT_MAX_WORKERS = int(os.getenv('DEFAULT_MAX_WORKERS', min(8, os.cpu_count() + 4))) -def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS, - disable_tqdm: bool = False, - tqdm_desc: str = None): +def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS, disable_tqdm: bool = False, tqdm_desc: str = None): """ A decorator to execute a function in a threaded manner using ThreadPoolExecutor. @@ -43,26 +40,22 @@ def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS, """ def decorator(func): - @wraps(func) def wrapper(iterable, *args, **kwargs): results = [] # Create a tqdm progress bar with the total number of items to process with tqdm( - unit_scale=True, - unit_divisor=1024, - initial=0, - total=len(iterable), - desc=tqdm_desc or f'Processing {len(iterable)} items', - disable=disable_tqdm, + unit_scale=True, + unit_divisor=1024, + initial=0, + total=len(iterable), + desc=tqdm_desc or f'Processing {len(iterable)} items', + disable=disable_tqdm, ) as pbar: # Define a wrapper function to update the progress bar with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all tasks - futures = { - executor.submit(func, item, *args, **kwargs): item - for item in iterable - } + futures = {executor.submit(func, item, *args, **kwargs): item for item in iterable} # Update the progress bar as tasks complete for future in as_completed(futures): @@ -99,16 +92,14 @@ def weakref_cb(_, q=self._work_queue): num_threads = len(self._threads) if num_threads < self._max_workers: - thread_name = '%s_%d' % (self._thread_name_prefix - or self, num_threads) + thread_name = '%s_%d' % (self._thread_name_prefix or self, num_threads) # Import internal helpers from stdlib to keep behavior consistent. - from concurrent.futures.thread import _worker, _threads_queues # type: ignore + from concurrent.futures.thread import _threads_queues, _worker # type: ignore t = threading.Thread( name=thread_name, target=_worker, - args=(weakref.ref(self, weakref_cb), self._work_queue, - self._initializer, self._initargs), + args=(weakref.ref(self, weakref_cb), self._work_queue, self._initializer, self._initargs), ) t.daemon = True t.start() diff --git a/ms_agent/utils/tokenizer_util.py b/ms_agent/utils/tokenizer_util.py index b26e5228f..423b4084f 100644 --- a/ms_agent/utils/tokenizer_util.py +++ b/ms_agent/utils/tokenizer_util.py @@ -59,14 +59,10 @@ def segment(self, content: str) -> List[str]: token_ids = self.encode(content) # Decode each token ID individually to get its string representation - token_strings = [ - self.tokenizer.decode([tid], skip_special_tokens=True) - for tid in token_ids - ] + token_strings = [self.tokenizer.decode([tid], skip_special_tokens=True) for tid in token_ids] return token_strings - def count_tokens(self, - contents: Union[str, List[str]]) -> Union[int, List[int]]: + def count_tokens(self, contents: Union[str, List[str]]) -> Union[int, List[int]]: """ Batch count tokens for multiple texts. diff --git a/ms_agent/utils/utils.py b/ms_agent/utils/utils.py index a6d0bc87d..0a2e1e1a7 100644 --- a/ms_agent/utils/utils.py +++ b/ms_agent/utils/utils.py @@ -5,6 +5,7 @@ import html import importlib import importlib.util +import json import os.path import re import subprocess @@ -15,7 +16,6 @@ from pathlib import Path from typing import List, Optional, Tuple, Union -import json import requests import yaml from omegaconf import DictConfig, OmegaConf @@ -30,11 +30,10 @@ else: # Define a placeholder class for type-checking compatibility class BuiltInExceptionGroup(BaseException): - def __init__(self, message, exceptions): self.message = message self.exceptions = exceptions - self.args = (message, ) + self.args = (message,) def __str__(self): return f'{self.message}: {self.exceptions}' @@ -169,8 +168,7 @@ def escape_yaml_string(text: str) -> str: return text -def save_history(output_dir: str, task: str, config: DictConfig, - messages: List['Message']): +def save_history(output_dir: str, task: str, config: DictConfig, messages: List['Message']): """ Saves the specified configuration and conversation history to a cache directory for later retrieval or restoration. @@ -205,10 +203,7 @@ def save_history(output_dir: str, task: str, config: DictConfig, with open(config_file, 'w') as f: OmegaConf.save(config, f) with open(message_file, 'w') as f: - json.dump([message.to_dict() for message in messages], - f, - indent=4, - ensure_ascii=False) + json.dump([message.to_dict() for message in messages], f, indent=4, ensure_ascii=False) def read_history(output_dir: str, task: str): @@ -240,8 +235,9 @@ def read_history(output_dir: str, task: str): TypeError / AttributeError: If the deserialized JSON data lacks expected keys or structure for Message objects. """ - from ms_agent.llm import Message from ms_agent.config import Config + from ms_agent.llm import Message + cache_dir = os.path.join(output_dir, DEFAULT_MEMORY_DIR) os.makedirs(cache_dir, exist_ok=True) config_file = os.path.join(cache_dir, f'{task}.yaml') @@ -309,6 +305,7 @@ def json_loads(text: str) -> dict: JSON decoding error is raised. """ import json5 + text = text.strip('\n') if text.startswith('```') and text.endswith('\n```'): text = '\n'.join(text.split('\n')[1:-1]) @@ -332,14 +329,12 @@ def download_pdf(url: str, out_file_path: str, reuse: bool = True): """ if reuse and os.path.exists(out_file_path): - logger.info( - f"File '{out_file_path}' already exists. Skipping download.") + logger.info(f"File '{out_file_path}' already exists. Skipping download.") return try: response = requests.get(url, stream=True) - response.raise_for_status( - ) # Raise an exception for bad status codes (4xx or 5xx) + response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) with open(out_file_path, 'wb') as pdf_file: for chunk in response.iter_content(chunk_size=8192): @@ -377,6 +372,7 @@ def load_image_from_url_to_pil(url: str) -> 'Image.Image': A PIL Image object if successful, None otherwise. """ from PIL import Image + try: response = requests.get(url, timeout=(10, 25)) # Raise an HTTPError for bad responses (4xx or 5xx) @@ -404,6 +400,7 @@ def load_image_from_uri_to_pil(uri: str) -> 'Image.Image': tuple: (PIL Image object, file extension string) or None if failed """ from PIL import Image + try: header, encoded = uri.split(',', 1) if ';base64' in header: @@ -422,14 +419,11 @@ def load_image_from_uri_to_pil(uri: str) -> 'Image.Image': logger.error(f'Error opening image with PIL for uri_to_pil: {e}') return None except Exception as e: - logger.error( - f'Unexpected error loading image from URI for uri_to_pil: {e}') + logger.error(f'Unexpected error loading image from URI for uri_to_pil: {e}') return None -def validate_url( - img_url: str, - backend: 'docling.backend.html_backend.HTMLDocumentBackend') -> str: +def validate_url(img_url: str, backend: 'docling.backend.html_backend.HTMLDocumentBackend') -> str: """ Validates and resolves a relative image URL using the base URL from the HTML document's metadata. @@ -448,23 +442,23 @@ def validate_url( from urllib.parse import urljoin, urlparse # Check if we have a valid soup object in the backend - if not backend or not hasattr( - backend, 'soup') or not backend.soup or not backend.soup.head: + if not backend or not hasattr(backend, 'soup') or not backend.soup or not backend.soup.head: return None # Potential sources of base URLs to try sources = [ # Try base tag lambda: backend.soup.head.find('base', href=True)['href'] - if backend.soup.head.find('base', href=True) else None, + if backend.soup.head.find('base', href=True) + else None, # Try canonical link - lambda: backend.soup.head.find('link', rel='canonical', href=True)[ - 'href'] if backend.soup.head.find( - 'link', rel='canonical', href=True) else None, + lambda: backend.soup.head.find('link', rel='canonical', href=True)['href'] + if backend.soup.head.find('link', rel='canonical', href=True) + else None, # Try OG URL meta tag - lambda: backend.soup.head.find( - 'meta', property='og:url', content=True)['content'] if backend.soup - .head.find('meta', property='og:url', content=True) else None + lambda: backend.soup.head.find('meta', property='og:url', content=True)['content'] + if backend.soup.head.find('meta', property='og:url', content=True) + else None, ] # Try each source until we find a valid base URL @@ -508,7 +502,9 @@ def get_default_config(): os.path.dirname(__file__), # ms_agent/utils/ '..', # ↑ up to ms_agent/ 'agent', # → agent/ - 'agent.yaml')) + 'agent.yaml', + ) + ) with open(config_path, 'r', encoding='utf-8') as file: return yaml.safe_load(file) @@ -578,8 +574,7 @@ def txt_to_html(txt_path: str, html_path: Optional[str] = None) -> str: return html_path -def get_files_from_dir(folder_path: Union[str, Path], - exclude: Optional[List[str]] = None) -> List[Path]: +def get_files_from_dir(folder_path: Union[str, Path], exclude: Optional[List[str]] = None) -> List[Path]: """ Get all files in the target directory recursively, excluding files that match any of the given regex patterns. @@ -607,11 +602,12 @@ def get_files_from_dir(folder_path: Union[str, Path], # Filter files based on exclusion patterns file_list = [ - file_path for file_path in files if not any( - pattern.search( - str(file_path.resolve().relative_to( - folder_path.resolve())).replace('\\', '/')) - for pattern in exclude_patterns) + file_path + for file_path in files + if not any( + pattern.search(str(file_path.resolve().relative_to(folder_path.resolve())).replace('\\', '/')) + for pattern in exclude_patterns + ) ] return file_list @@ -630,9 +626,7 @@ def is_package_installed(package_or_import_name: str) -> bool: return importlib.util.find_spec(package_or_import_name) is not None -def install_package(package_name: str, - import_name: Optional[str] = None, - extend_module: str = None): +def install_package(package_name: str, import_name: Optional[str] = None, extend_module: str = None): """ Check and install a package using pip. @@ -652,8 +646,7 @@ def install_package(package_name: str, package_name = f'{package_name}[{extend_module}]' if not is_package_installed(import_name): - subprocess.check_call( - [sys.executable, '-m', 'pip', 'install', package_name]) + subprocess.check_call([sys.executable, '-m', 'pip', 'install', package_name]) logger.info(f'Package {package_name} installed successfully.') else: logger.info(f'Package {import_name} is already installed.') @@ -670,7 +663,7 @@ def extract_by_tag(text: str, tag: str) -> str: Returns: str: The content found between the specified tags, or an empty string if not found. """ - pattern = fr'<{tag}>(.*?)' + pattern = rf'<{tag}>(.*?)' match = re.search(pattern, text, re.DOTALL) if match: return match.group(1).strip() @@ -753,15 +746,12 @@ def file_lock(lock_dir: str, filename: str, timeout: float = 120.0): while True: try: - lock_fd = os.open(lock_file_path, - os.O_CREAT | os.O_EXCL | os.O_WRONLY) + lock_fd = os.open(lock_file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) os.write(lock_fd, f'{os.getpid()}'.encode()) break except FileExistsError: if time.time() - start_time >= timeout: - raise TimeoutError( - f'Failed to acquire lock for {filename} after {timeout} seconds' - ) + raise TimeoutError(f'Failed to acquire lock for {filename} after {timeout} seconds') time.sleep(0.1) # Wait 100ms before retry try: @@ -776,10 +766,7 @@ def file_lock(lock_dir: str, filename: str, timeout: float = 120.0): pass -def render_markdown_todo(md_path: str, - *, - title: str = ' CURRENT PLAN', - use_pager: bool = False) -> None: +def render_markdown_todo(md_path: str, *, title: str = ' CURRENT PLAN', use_pager: bool = False) -> None: """ Render a Markdown todo list nicely in terminal using Rich. - Cross-platform (Windows/Linux/macOS) @@ -790,15 +777,17 @@ def render_markdown_todo(md_path: str, from rich.panel import Panel from rich.theme import Theme - theme = Theme({ - 'markdown.code': 'bold', - 'markdown.code_block': 'dim', - 'markdown.h1': 'bold', - 'markdown.h2': 'bold', - 'markdown.h3': 'bold', - 'markdown.link': 'underline', - 'markdown.list': '', - }) + theme = Theme( + { + 'markdown.code': 'bold', + 'markdown.code_block': 'dim', + 'markdown.h1': 'bold', + 'markdown.h2': 'bold', + 'markdown.h3': 'bold', + 'markdown.link': 'underline', + 'markdown.list': '', + } + ) console = Console(theme=theme, soft_wrap=True, highlight=False) try: diff --git a/ms_agent/utils/workspace_policy.py b/ms_agent/utils/workspace_policy.py index a8380b710..dbc485321 100644 --- a/ms_agent/utils/workspace_policy.py +++ b/ms_agent/utils/workspace_policy.py @@ -68,11 +68,9 @@ def resolve_under_roots(self, user_path: str | Path) -> Path: except ValueError: continue else: - raise WorkspacePolicyError( - f'Path is outside allowed workspace roots: {resolved}') + raise WorkspacePolicyError(f'Path is outside allowed workspace roots: {resolved}') if self._is_denied(resolved): - raise WorkspacePolicyError( - f'Path matches a deny_globs pattern: {resolved}') + raise WorkspacePolicyError(f'Path matches a deny_globs pattern: {resolved}') return resolved def _is_denied(self, path: Path) -> bool: @@ -108,16 +106,12 @@ def assert_shell_command_allowed(self, command: str) -> None: if not command or not command.strip(): raise WorkspacePolicyError('Empty shell command') if len(command) > self.max_command_chars: - raise WorkspacePolicyError( - f'Shell command exceeds max length ({self.max_command_chars})') + raise WorkspacePolicyError(f'Shell command exceeds max length ({self.max_command_chars})') mode = self.shell_default_mode if mode == 'read_only': - if _shell_looks_mutating_or_network(command, - allow_network=False): - raise WorkspacePolicyError( - 'Shell is in read_only mode: mutating or network commands are not allowed' - ) + if _shell_looks_mutating_or_network(command, allow_network=False): + raise WorkspacePolicyError('Shell is in read_only mode: mutating or network commands are not allowed') elif mode == 'workspace_write': if not self.shell_network_enabled and _shell_looks_network(command): raise WorkspacePolicyError( @@ -146,15 +140,13 @@ def _shell_looks_network(command: str) -> bool: return any(t in lowered for t in tokens) -def _shell_looks_mutating_or_network(command: str, *, - allow_network: bool) -> bool: +def _shell_looks_mutating_or_network(command: str, *, allow_network: bool) -> bool: if not allow_network and _shell_looks_network(command): return True # redirection that creates/overwrites files if re.search(r'[>]{1,2}\s*[^\s]', command): return True - if re.search(r'\b(rm|rmdir|mv|cp|chmod|chown|chgrp|mkdir|touch|tee)\b', - command): + if re.search(r'\b(rm|rmdir|mv|cp|chmod|chown|chgrp|mkdir|touch|tee)\b', command): return True return False @@ -180,14 +172,12 @@ def dir_skipped(dirpath: Path) -> bool: return True parts = rel.split('/') for i in range(len(parts)): - sub = '/'.join(parts[:i + 1]) - if fnmatch.fnmatch(sub, pat.rstrip('/')) or fnmatch.fnmatch( - sub + '/', pat): + sub = '/'.join(parts[: i + 1]) + if fnmatch.fnmatch(sub, pat.rstrip('/')) or fnmatch.fnmatch(sub + '/', pat): return True return False - for dirpath, dirnames, filenames in os.walk( - root, topdown=True, followlinks=False): + for dirpath, dirnames, filenames in os.walk(root, topdown=True, followlinks=False): dp = Path(dirpath) if dir_skipped(dp): dirnames[:] = [] diff --git a/ms_agent/workflow/base.py b/ms_agent/workflow/base.py index 9d484118c..2c21d7bcb 100644 --- a/ms_agent/workflow/base.py +++ b/ms_agent/workflow/base.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod from typing import Dict, Optional -from ms_agent.config import Config from omegaconf import DictConfig +from ms_agent.config import Config + class Workflow(ABC): """Base class for workflows that define a sequence of agent-based processing steps. @@ -22,12 +23,14 @@ class Workflow(ABC): - mcp_server_file (Optional[str]): Path to an MCP server file if needed. Default is None. """ - def __init__(self, - config_dir_or_id: Optional[str] = None, - config: Optional[DictConfig] = None, - env: Optional[Dict[str, str]] = None, - trust_remote_code: bool = False, - **kwargs): + def __init__( + self, + config_dir_or_id: Optional[str] = None, + config: Optional[DictConfig] = None, + env: Optional[Dict[str, str]] = None, + trust_remote_code: bool = False, + **kwargs, + ): if config_dir_or_id is None: self.config = config else: diff --git a/ms_agent/workflow/chain_workflow.py b/ms_agent/workflow/chain_workflow.py index 2b2b739e1..7ea8cd55c 100644 --- a/ms_agent/workflow/chain_workflow.py +++ b/ms_agent/workflow/chain_workflow.py @@ -1,10 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os +from omegaconf import DictConfig + from ms_agent.agent.loader import AgentLoader from ms_agent.utils import get_logger from ms_agent.workflow.base import Workflow -from omegaconf import DictConfig logger = get_logger() @@ -26,9 +27,7 @@ def build_workflow(self): if isinstance(next_tasks, str): has_next.add(next_tasks) else: - assert len( - next_tasks - ) == 1, 'ChainWorkflow only supports one next task' + assert len(next_tasks) == 1, 'ChainWorkflow only supports one next task' has_next.update(next_tasks) for task_name in self.config.keys(): @@ -89,8 +88,7 @@ async def run(self, inputs, **kwargs): init_args['task'] = task init_args['load_cache'] = self.load_cache if isinstance(config, str): - init_args['config_dir_or_id'] = os.path.join( - self.config.local_dir, config) + init_args['config_dir_or_id'] = os.path.join(self.config.local_dir, config) else: init_args['config'] = config init_args['env'] = self.env diff --git a/ms_agent/workflow/dag_workflow.py b/ms_agent/workflow/dag_workflow.py index 1419ccf46..97e3ac65c 100644 --- a/ms_agent/workflow/dag_workflow.py +++ b/ms_agent/workflow/dag_workflow.py @@ -3,10 +3,11 @@ from collections import defaultdict, deque from typing import Any, Dict, List, Set +from omegaconf import DictConfig + from ms_agent.agent.loader import AgentLoader from ms_agent.utils import get_logger from ms_agent.workflow.base import Workflow -from omegaconf import DictConfig logger = get_logger() @@ -37,9 +38,7 @@ def build_workflow(self): self.nodes = set(list(self.graph.keys()) + list(indegree.keys())) # Find root tasks (indegree==0) - self.roots = [ - t for t in tasks if 'next' in self.config[t] and indegree[t] == 0 - ] + self.roots = [t for t in tasks if 'next' in self.config[t] and indegree[t] == 0] if not self.roots: raise ValueError('No root task found for DagWorkflow') @@ -76,12 +75,10 @@ async def run(self, inputs: Any, **kwargs): task_input = inputs else: parent_outs = [outputs[p] for p in self.parents[task]] - task_input = parent_outs if len( - parent_outs) > 1 else parent_outs[0] + task_input = parent_outs if len(parent_outs) > 1 else parent_outs[0] task_info: DictConfig = getattr(self.config, task) - agent_cfg_path = os.path.join(self.config.local_dir, - task_info.agent_config) + agent_cfg_path = os.path.join(self.config.local_dir, task_info.agent_config) if not hasattr(task_info, 'agent'): task_info.agent = DictConfig({}) init_args = getattr(task_info.agent, 'kwargs', {}) @@ -98,8 +95,5 @@ async def run(self, inputs: Any, **kwargs): outputs[task] = result # Return results of terminal nodes (no outgoing edges) - terminals = [ - t for t in self.config.keys() - if t not in self.graph and t in self.nodes - ] + terminals = [t for t in self.config.keys() if t not in self.graph and t in self.nodes] return {t: outputs[t] for t in terminals} diff --git a/ms_agent/workflow/deep_research/__init__.py b/ms_agent/workflow/deep_research/__init__.py index df00c463b..f79565769 100644 --- a/ms_agent/workflow/deep_research/__init__.py +++ b/ms_agent/workflow/deep_research/__init__.py @@ -1,6 +1,12 @@ -from .principle import (BSGMatrixPrinciple, MECEPrinciple, ParetoPrinciple, - Principle, PyramidPrinciple, SWOTPrinciple, - ValueChainPrinciple) +from .principle import ( + BSGMatrixPrinciple, + MECEPrinciple, + ParetoPrinciple, + Principle, + PyramidPrinciple, + SWOTPrinciple, + ValueChainPrinciple, +) from .research_workflow import ResearchWorkflow from .research_workflow_beta import ResearchWorkflowBeta diff --git a/ms_agent/workflow/deep_research/principle.py b/ms_agent/workflow/deep_research/principle.py index 818a89092..89dd929cb 100644 --- a/ms_agent/workflow/deep_research/principle.py +++ b/ms_agent/workflow/deep_research/principle.py @@ -2,11 +2,12 @@ class Principle: - def __init__(self, breakdown_prompt: str = None): - - self.breakdown_prompt: str = breakdown_prompt or """\n首先生成一份系统性的分析方案,自上而下breakdown,输出markdown格式:\n + self.breakdown_prompt: str = ( + breakdown_prompt + or """\n首先生成一份系统性的分析方案,自上而下breakdown,输出markdown格式:\n """ + ) self.todo_prompt: str = """"\n基于上述breakdown,生成todo list,输出markdown格式,形式必须遵循:\n # Title @@ -56,48 +57,56 @@ def __init__(self, breakdown_prompt: str = None): class BSGMatrixPrinciple(Principle): - def __init__(self, breakdown_prompt: str = None): super().__init__(breakdown_prompt=breakdown_prompt) - self.breakdown_prompt = breakdown_prompt or '\n首先使用Boston Matrix Analysis Principle(Boston Consulting Group matrix analysis)来拆解和分析上述问题,输出markdown格式:' + self.breakdown_prompt = ( + breakdown_prompt + or '\n首先使用Boston Matrix Analysis Principle(Boston Consulting Group matrix analysis)来拆解和分析上述问题,输出markdown格式:' + ) class ParetoPrinciple(Principle): - def __init__(self, breakdown_prompt: str = None): super().__init__(breakdown_prompt=breakdown_prompt) - self.breakdown_prompt = breakdown_prompt or '\n首先使用Pareto Principle(80/20 Rule)来拆解和分析上述问题,输出markdown格式:' + self.breakdown_prompt = ( + breakdown_prompt or '\n首先使用Pareto Principle(80/20 Rule)来拆解和分析上述问题,输出markdown格式:' + ) class MECEPrinciple(Principle): - def __init__(self, breakdown_prompt: str = None): super().__init__(breakdown_prompt=breakdown_prompt) - self.breakdown_prompt = breakdown_prompt or '\n首先使用MECE原则(Mutually Exclusive and Collectively Exhaustive)来拆解和分析上述问题,输出markdown格式:' + self.breakdown_prompt = ( + breakdown_prompt + or '\n首先使用MECE原则(Mutually Exclusive and Collectively Exhaustive)来拆解和分析上述问题,输出markdown格式:' + ) class PyramidPrinciple(Principle): - def __init__(self, breakdown_prompt: str = None): super().__init__(breakdown_prompt=breakdown_prompt) - self.breakdown_prompt = breakdown_prompt or '\n首先使用金字塔原理(Pyramid Principle)来拆解和分析上述问题,输出markdown格式:' + self.breakdown_prompt = ( + breakdown_prompt or '\n首先使用金字塔原理(Pyramid Principle)来拆解和分析上述问题,输出markdown格式:' + ) class SWOTPrinciple(Principle): - def __init__(self, breakdown_prompt: str = None): super().__init__(breakdown_prompt=breakdown_prompt) - self.breakdown_prompt = breakdown_prompt or '\n首先使用SWOT分析法(SWOT Analysis)来拆解和分析上述问题,输出markdown格式:' + self.breakdown_prompt = ( + breakdown_prompt or '\n首先使用SWOT分析法(SWOT Analysis)来拆解和分析上述问题,输出markdown格式:' + ) class ValueChainPrinciple(Principle): - def __init__(self, breakdown_prompt: str = None): super().__init__(breakdown_prompt=breakdown_prompt) - self.breakdown_prompt = breakdown_prompt or '\n首先使用价值链分析(Value Chain Analysis)来拆解和分析上述问题,输出markdown格式:' + self.breakdown_prompt = ( + breakdown_prompt or '\n首先使用价值链分析(Value Chain Analysis)来拆解和分析上述问题,输出markdown格式:' + ) diff --git a/ms_agent/workflow/deep_research/research_utils.py b/ms_agent/workflow/deep_research/research_utils.py index 7d5ccdb69..56dcb7952 100644 --- a/ms_agent/workflow/deep_research/research_utils.py +++ b/ms_agent/workflow/deep_research/research_utils.py @@ -32,8 +32,7 @@ class ResearchRequest(BaseModel): query: str = Field(..., description='Research query') depth: int = Field(default=2, ge=1, le=5, description='Research depth') - breadth: int = Field( - default=4, ge=1, le=10, description='Research breadth') + breadth: int = Field(default=4, ge=1, le=10, description='Research breadth') class ResearchResponse(BaseModel): @@ -48,10 +47,8 @@ class ResearchResponse(BaseModel): class LearningsResponse(BaseModel): """Response model for processed search results.""" - learnings: List[str] = Field( - ..., description='List of learnings extracted from search results') - follow_up_questions: List[str] = Field( - ..., description='List of follow-up questions for further research') + learnings: List[str] = Field(..., description='List of learnings extracted from search results') + follow_up_questions: List[str] = Field(..., description='List of follow-up questions for further research') class ProgressTracker: @@ -63,12 +60,10 @@ def __init__(self): def __enter__(self): self.progress = Progress( - SpinnerColumn(), - TextColumn('[progress.description]{task.description}'), - console=console) + SpinnerColumn(), TextColumn('[progress.description]{task.description}'), console=console + ) self.progress.__enter__() - self.task_id = self.progress.add_task( - 'Starting research...', total=None) + self.task_id = self.progress.add_task('Starting research...', total=None) return self def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/ms_agent/workflow/deep_research/research_workflow.py b/ms_agent/workflow/deep_research/research_workflow.py index 03d64bf80..d3dcc48a4 100644 --- a/ms_agent/workflow/deep_research/research_workflow.py +++ b/ms_agent/workflow/deep_research/research_workflow.py @@ -1,11 +1,11 @@ # flake8: noqa # yapf: disable import copy +import json import os import re from typing import Any, Dict, List, Optional, Union -import json from ms_agent.llm.openai import OpenAIChat from ms_agent.utils import get_logger @@ -484,4 +484,4 @@ def run(self, # Dump report to markdown file with open(self.workdir_structure['report_md'], 'w', encoding='utf-8') as f_report: f_report.write(resp_content) - logger.info(f'Report saved to {self.workdir_structure["report_md"]}') + logger.info(f'Report saved to {self.workdir_structure['report_md']}') diff --git a/ms_agent/workflow/deep_research/research_workflow_beta.py b/ms_agent/workflow/deep_research/research_workflow_beta.py index 24816ae52..6dd4003a2 100644 --- a/ms_agent/workflow/deep_research/research_workflow_beta.py +++ b/ms_agent/workflow/deep_research/research_workflow_beta.py @@ -8,6 +8,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import click +from rich.prompt import Confirm, Prompt + from ms_agent.llm.openai import OpenAIChat from ms_agent.rag.extraction_manager import extract_key_information from ms_agent.tools.search.exa.schema import dump_batch_search_results @@ -16,12 +18,13 @@ from ms_agent.utils.logger import get_logger from ms_agent.utils.utils import remove_resource_info, text_hash from ms_agent.workflow.deep_research.principle import MECEPrinciple, Principle -from ms_agent.workflow.deep_research.research_utils import (LearningsResponse, - ProgressTracker, - ResearchProgress, - ResearchResult) +from ms_agent.workflow.deep_research.research_utils import ( + LearningsResponse, + ProgressTracker, + ResearchProgress, + ResearchResult, +) from ms_agent.workflow.deep_research.research_workflow import ResearchWorkflow -from rich.prompt import Confirm, Prompt logger = get_logger() @@ -343,7 +346,7 @@ async def generate_search_queries( if learnings: learnings_prompt = ( f'\n\nHere are some learnings from previous research, ' - f'use them to generate more specific queries: {", ".join(learnings)}' + f'use them to generate more specific queries: {', '.join(learnings)}' ) rewrite_prompt = ( @@ -592,7 +595,7 @@ async def process_search_results( f'information dense as possible.\n' f'- Make sure to include any entities like people, places, companies, products, ' f'things, etc in the learnings, as well as any exact metrics, numbers, or dates.\n' - f'{multimodal_prompt if self._enable_multimodal else ""}' + f'{multimodal_prompt if self._enable_multimodal else ''}' f'- The learnings will be used to research the topic further.\n' f'- Do NOT repeat the query verbatim as a learning. ' f'Do NOT invent facts not present in .\n' @@ -620,7 +623,7 @@ async def process_search_results( response_data = response_data.get('learnings_extraction', {}) or response_data except Exception as e: logger.error(f'Error parsing JSON response: {e}') - logger.error(f'Raw response content: {response.get("content", "")}') + logger.error(f'Raw response content: {response.get('content', '')}') return LearningsResponse(learnings=[], follow_up_questions=[]) learnings = response_data.get('learnings', [])[:num_learnings] @@ -670,7 +673,7 @@ async def _process_single_query( if new_depth > 0 and len(processed_results.follow_up_questions) > 0: logger.info( f'Researching deeper, breadth: {new_breadth}, ' - f'depth: {progress_manager.get_current().current_depth if progress_manager else "N/A"}' + f'depth: {progress_manager.get_current().current_depth if progress_manager else 'N/A'}' ) # Use atomic increment to avoid race conditions if progress_manager is not None: @@ -679,8 +682,8 @@ async def _process_single_query( # Create next query from follow-up questions next_query = ( f'Previous Query: {search_request.query}\n' - f'Previous research goal: {getattr(search_request, "research_goal", "")}\n' - f'Follow-up research directions: {", ".join(processed_results.follow_up_questions)}' + f'Previous research goal: {getattr(search_request, 'research_goal', '')}\n' + f'Follow-up research directions: {', '.join(processed_results.follow_up_questions)}' ).strip() # Continue with deeper research, passing through the progress manager @@ -869,7 +872,7 @@ async def write_final_report(self, prompt: str, f'\n{learnings_text}\n' f'\n\nPlease respond with valid JSON that matches provided schema:\n{json_schema}\n' f'Please respond in the language of the . ' - f'{multimodal_prompt if self._enable_multimodal else ""}' + f'{multimodal_prompt if self._enable_multimodal else ''}' ) response = await self._chat_async( @@ -888,7 +891,7 @@ async def write_final_report(self, prompt: str, fix_prompt = ( f'The response is not valid JSON. Please fix it. ' f'You can only return the fixed JSON, no other text. ' - f'The response is: {response.get("content", "")}' + f'The response is: {response.get('content', '')}' ) response = await self._chat_async( messages=[ @@ -1082,7 +1085,7 @@ async def _run(self, encoding='utf-8') as f_report: f_report.write(report) logger.info( - f'Report saved to {self.workdir_structure["report_md"]}') + f'Report saved to {self.workdir_structure['report_md']}') else: # Generate and save answer answer = await self.write_final_answer( @@ -1097,7 +1100,7 @@ async def _run(self, encoding='utf-8') as f_answer: f_answer.write(answer) logger.info( - f'Answer saved to {self.workdir_structure["report_md"]}') + f'Answer saved to {self.workdir_structure['report_md']}') return self.workdir_structure['report_md'] diff --git a/ms_agent/workflow/loader.py b/ms_agent/workflow/loader.py index 7e0589cfa..3df2d6b57 100644 --- a/ms_agent/workflow/loader.py +++ b/ms_agent/workflow/loader.py @@ -1,19 +1,21 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import Dict, Optional -from ms_agent.config.config import Config from omegaconf import DictConfig, OmegaConf +from ms_agent.config.config import Config -class WorkflowLoader: +class WorkflowLoader: @classmethod - def build(cls, - config_dir_or_id: Optional[str] = None, - config: Optional[DictConfig] = None, - env: Optional[Dict[str, str]] = None, - trust_remote_code: bool = False, - **kwargs): + def build( + cls, + config_dir_or_id: Optional[str] = None, + config: Optional[DictConfig] = None, + env: Optional[Dict[str, str]] = None, + trust_remote_code: bool = False, + **kwargs, + ): wf_config: Optional[DictConfig] = None if config_dir_or_id is not None: wf_config: DictConfig = Config.from_task(config_dir_or_id, env) @@ -25,6 +27,7 @@ def build(cls, from ms_agent.workflow.chain_workflow import ChainWorkflow from ms_agent.workflow.dag_workflow import DagWorkflow + wf_type = ChainWorkflow.WORKFLOW_NAME.lower() wf_type = getattr(wf_config, 'type', '').lower() or wf_type @@ -35,7 +38,8 @@ def build(cls, env=env, mcp_server_file=kwargs.get('mcp_server_file'), load_cache=kwargs.get('load_cache', False), - trust_remote_code=trust_remote_code) + trust_remote_code=trust_remote_code, + ) elif wf_type == DagWorkflow.WORKFLOW_NAME.lower(): wf_instance = DagWorkflow( config_dir_or_id=config_dir_or_id, @@ -43,7 +47,8 @@ def build(cls, env=env, mcp_server_file=kwargs.get('mcp_server_file'), load_cache=kwargs.get('load_cache', False), - trust_remote_code=trust_remote_code) + trust_remote_code=trust_remote_code, + ) elif wf_type == 'ResearchWorkflow'.lower(): # TODO raise NotImplementedError() diff --git a/projects/code_genesis/tools/build_sandbox_image.py b/projects/code_genesis/tools/build_sandbox_image.py new file mode 100644 index 000000000..8995df386 --- /dev/null +++ b/projects/code_genesis/tools/build_sandbox_image.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +"""Build code_genesis sandbox image via Docker API (Colima / API-compatible daemon). + +Avoids requiring the standalone `docker` CLI binary; uses the PyPI `docker` package +(``pip install docker`` / requirements/code.txt) like the rest of ms-agent. +""" +from __future__ import annotations + +import sys +from pathlib import Path + +IMAGE_NAME = "code-genesis-sandbox" +IMAGE_TAG = "version1" + +DOCKERFILE = r"""FROM python:3.12-slim + +# Install system dependencies and Node.js +RUN sed -i 's|deb.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \ + && apt-get update -o Acquire::Retries=5 \ + && apt-get install -y --no-install-recommends \ + curl \ + git \ + build-essential \ + && curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \ + && apt-get install -y --no-install-recommends nodejs \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +# Configure npm to use a Chinese mirror. Comment out this line if not needed. +RUN npm config set registry https://registry.npmmirror.com/ + +# Install Jupyter kernel gateway (required by sandbox) +RUN pip install --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple --trusted-host mirrors.aliyun.com \ + jupyter_kernel_gateway \ + jupyter_client \ + ipykernel + +# Install Python kernel +RUN python -m ipykernel install --sys-prefix --name python3 --display-name "Python 3" + +WORKDIR /data + +EXPOSE 8888 +CMD ["jupyter", "kernelgateway", "--KernelGatewayApp.ip=0.0.0.0", "--KernelGatewayApp.port=8888", "--KernelGatewayApp.allow_origin=*"] +""" + + +def _repo_root() -> Path: + # projects/code_genesis/tools/thisfile -> parents[3] == repo root + return Path(__file__).resolve().parents[3] + + +def main() -> int: + try: + import docker + except ImportError: + print( + "Missing Python package 'docker'. Run: pip install docker\n" + "or: pip install -r requirements/code.txt", + file=sys.stderr, + ) + return 1 + + root = _repo_root() + dockerfile_path = root / "Dockerfile.sandbox" + dockerfile_path.write_text(DOCKERFILE, encoding="utf-8") + try: + client = docker.from_env() + client.ping() + print("Pulling python:3.12-slim ...") + client.images.pull("python:3.12-slim") + tag = f"{IMAGE_NAME}:{IMAGE_TAG}" + print(f"Building {tag} (context: {root}) ...") + stream = client.api.build( + path=str(root), + dockerfile="Dockerfile.sandbox", + tag=tag, + rm=True, + forcerm=True, + decode=True, + ) + for chunk in stream: + if not chunk: + continue + if "stream" in chunk and chunk["stream"]: + print(chunk["stream"], end="") + if "errorDetail" in chunk: + print(chunk.get("error", chunk["errorDetail"]), file=sys.stderr) + return 1 + print(f"Done: {tag}") + finally: + if dockerfile_path.is_file(): + dockerfile_path.unlink() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/projects/code_genesis/tools/build_sandbox_image.sh b/projects/code_genesis/tools/build_sandbox_image.sh index 8aae6859a..9c368011d 100755 --- a/projects/code_genesis/tools/build_sandbox_image.sh +++ b/projects/code_genesis/tools/build_sandbox_image.sh @@ -1,53 +1,10 @@ -#!/bin/bash - -# Build Docker sandbox image for code_genesis -# Includes Python + Node.js for full-stack project support - -set -e - -IMAGE_NAME="code-genesis-sandbox" -IMAGE_TAG="version1" - -echo "Building code-genesis sandbox Docker image..." - -docker pull python:3.12-slim - -cat > Dockerfile.sandbox << 'EOF' -FROM python:3.12-slim - -# Install system dependencies and Node.js -RUN sed -i 's|deb.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \ - && apt-get update -o Acquire::Retries=5 \ - && apt-get install -y --no-install-recommends \ - curl \ - git \ - build-essential \ - && curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \ - && apt-get install -y --no-install-recommends nodejs \ - && apt-get clean && rm -rf /var/lib/apt/lists/* - -# Configure npm to use a Chinese mirror. Comment out this line if not needed. -RUN npm config set registry https://registry.npmmirror.com/ - -# Install Jupyter kernel gateway (required by sandbox) -RUN pip install --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple --trusted-host mirrors.aliyun.com \ - jupyter_kernel_gateway \ - jupyter_client \ - ipykernel - -# Install Python kernel -RUN python -m ipykernel install --sys-prefix --name python3 --display-name "Python 3" - -WORKDIR /data - -EXPOSE 8888 -CMD ["jupyter", "kernelgateway", "--KernelGatewayApp.ip=0.0.0.0", "--KernelGatewayApp.port=8888", "--KernelGatewayApp.allow_origin=*"] -EOF - -echo "Building Docker image: ${IMAGE_NAME}:${IMAGE_TAG}" -docker build -f Dockerfile.sandbox -t "${IMAGE_NAME}:${IMAGE_TAG}" . - -rm Dockerfile.sandbox - -echo "Done: ${IMAGE_NAME}:${IMAGE_TAG}" -echo "Contains: Python 3.12, Node.js 20, npm, git, curl" +#!/usr/bin/env bash +# Build sandbox image using Docker HTTP API (PyPI `docker`); Colima supplies the daemon. +# No Docker Desktop and no standalone `docker` CLI required. + +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +if [[ -n "${VIRTUAL_ENV:-}" && -x "${VIRTUAL_ENV}/bin/python" ]]; then + exec "${VIRTUAL_ENV}/bin/python" "${SCRIPT_DIR}/build_sandbox_image.py" "$@" +fi +exec python3 "${SCRIPT_DIR}/build_sandbox_image.py" "$@" diff --git a/projects/code_genesis/workflow/api_search.py b/projects/code_genesis/workflow/api_search.py index 4e380d17a..60a635530 100644 --- a/projects/code_genesis/workflow/api_search.py +++ b/projects/code_genesis/workflow/api_search.py @@ -1,16 +1,15 @@ +import json import os import re from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.constants import DEFAULT_INDEX_DIR class ApiSearch(ToolBase): - def __init__(self, config): super().__init__(config) index_dir = getattr(config, 'index_cache_dir', DEFAULT_INDEX_DIR) @@ -43,21 +42,19 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'object', 'properties': { 'keywords': { - 'type': - 'string', - 'description': - 'The keywords/regex in the url to search api of.', + 'type': 'string', + 'description': 'The keywords/regex in the url to search api of.', } }, 'required': [], - 'additionalProperties': False - }), + 'additionalProperties': False, + }, + ), ] } return tools - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await self.url_search(**tool_args) async def url_search(self, keywords: str = None): @@ -84,9 +81,7 @@ async def url_search(self, keywords: str = None): use_regex = True except re.error: # Not a valid regex, treat as comma-separated keywords - keyword_list = [ - kw.strip() for kw in keywords.split(',') if kw.strip() - ] + keyword_list = [kw.strip() for kw in keywords.split(',') if kw.strip()] use_regex = False def search_in_file(file_path): @@ -107,12 +102,10 @@ def search_in_file(file_path): is_match = regex_pattern.search(url) is not None else: # Substring matching (any keyword matches) - is_match = any(keyword in url - for keyword in keyword_list) + is_match = any(keyword in url for keyword in keyword_list) if is_match: - matches.append( - json.dumps(protocol, ensure_ascii=False)) + matches.append(json.dumps(protocol, ensure_ascii=False)) except Exception: # noqa return [] if matches: @@ -120,7 +113,8 @@ def search_in_file(file_path): matches.insert( 0, f'API{" with keywords: " + str(keywords) + " " + match_mode if keywords else ""} defined ' - f'in {file_path}:') + f'in {file_path}:', + ) matches.append('\n') return matches @@ -132,10 +126,7 @@ def search_in_file(file_path): # Use thread pool to search files in parallel all_matches = [] with ThreadPoolExecutor(max_workers=8) as executor: - future_to_file = { - executor.submit(search_in_file, f): f - for f in files_to_search - } + future_to_file = {executor.submit(search_in_file, f): f for f in files_to_search} for future in as_completed(future_to_file): matches = future.result() all_matches.extend(matches) diff --git a/projects/code_genesis/workflow/architect.py b/projects/code_genesis/workflow/architect.py index b340ab78a..889b34384 100644 --- a/projects/code_genesis/workflow/architect.py +++ b/projects/code_genesis/workflow/architect.py @@ -6,7 +6,6 @@ class ArchitectureAgent(LLMAgent): - async def run(self, messages, **kwargs): with open(os.path.join(self.output_dir, 'topic.txt'), 'r') as f: topic = f.read() diff --git a/projects/code_genesis/workflow/coding.py b/projects/code_genesis/workflow/coding.py index c6cfc398f..70e6df08e 100644 --- a/projects/code_genesis/workflow/coding.py +++ b/projects/code_genesis/workflow/coding.py @@ -1,5 +1,6 @@ import asyncio import dataclasses +import json import os import re import shutil @@ -8,18 +9,17 @@ from pathlib import Path from typing import List, Optional, Set -import json +from omegaconf import DictConfig + from ms_agent import LLMAgent from ms_agent.agent import CodeAgent from ms_agent.llm import Message from ms_agent.memory.condenser.code_condenser import CodeCondenser from ms_agent.tools.code_server import LSPCodeServer from ms_agent.utils import get_logger -from ms_agent.utils.constants import (DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, - DEFAULT_TAG) +from ms_agent.utils.constants import DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, DEFAULT_TAG from ms_agent.utils.parser_utils import ImportInfo, parse_imports from ms_agent.utils.utils import extract_code_blocks, file_lock -from omegaconf import DictConfig logger = get_logger() @@ -54,13 +54,14 @@ class Programmer(LLMAgent): - - def __init__(self, - config: DictConfig = DictConfig({}), - tag: str = DEFAULT_TAG, - trust_remote_code: bool = False, - code_file: str = None, - **kwargs): + def __init__( + self, + config: DictConfig = DictConfig({}), + tag: str = DEFAULT_TAG, + trust_remote_code: bool = False, + code_file: str = None, + **kwargs, + ): # Validate and adjust config before passing to parent config = self._validate_config(config) super().__init__(config, tag, trust_remote_code, **kwargs) @@ -90,22 +91,19 @@ def _validate_config(self, config: DictConfig) -> DictConfig: # Check edit_file_config.api_key edit_file_api_key = None try: - edit_file_api_key = config.get('tools', {}).get( - 'file_system', {}).get('edit_file_config', {}).get('api_key') + edit_file_api_key = ( + config.get('tools', {}).get('file_system', {}).get('edit_file_config', {}).get('api_key') + ) except Exception: pass if not edit_file_api_key: # Remove edit_file from include list try: - include_list = config.get('tools', - {}).get('file_system', - {}).get('include', []) + include_list = config.get('tools', {}).get('file_system', {}).get('include', []) if include_list and 'edit_file' in include_list: include_list.remove('edit_file') - logger.warning( - '[coding] edit_file_config.api_key not set, removing edit_file from tools' - ) + logger.warning('[coding] edit_file_config.api_key not set, removing edit_file from tools') except Exception: pass else: @@ -133,8 +131,7 @@ def stop_nothing(self): self.llm.args['extra_body']['stop_sequences'] = self.stop_words[1] def is_stop_imports(self): - return self.llm.args['extra_body'][ - 'stop_sequences'] == self.stop_words[0] + return self.llm.args['extra_body']['stop_sequences'] == self.stop_words[0] def find_all_files(self): self.all_code_files = [] @@ -183,36 +180,28 @@ def read_file(path): contents = content.split('\n') comments = ['*', '#', '-', '%', '/'] - contents = [ - c for c in contents - if not any(c.strip().startswith(cm) for cm in comments) - ] - all_files = parse_imports(code_file, '\n'.join(contents), - self.output_dir) or [] + contents = [c for c in contents if not any(c.strip().startswith(cm) for cm in comments)] + all_files = parse_imports(code_file, '\n'.join(contents), self.output_dir) or [] all_read_files = find_all_read_files() all_notes = [] for file in all_files: if 'react' in file.source_file or 'vue' in file.source_file: continue if file.source_file == code_file: - all_notes.append( - f'You should not import the file itself: {code_file}') + all_notes.append(f'You should not import the file itself: {code_file}') continue - file.imported_items = [ - item for item in file.imported_items - if item not in ('*', 'default') - ] + file.imported_items = [item for item in file.imported_items if item not in ('*', 'default')] filename = os.path.join(self.output_dir, file.source_file) if not os.path.exists(filename): if file.source_file in self.all_code_files: all_notes.append( - f'The dependency you import: {file.source_file} does not exist, ' - f'the order may be incorrect.') + f'The dependency you import: {file.source_file} does not exist, the order may be incorrect.' + ) else: all_notes.append( - f'The dependency you import: {file.source_file} is not in the code plan, ' - f'stop importing it.') + f'The dependency you import: {file.source_file} is not in the code plan, stop importing it.' + ) elif os.path.isfile(filename): if file.source_file not in all_read_files: all_notes.append( @@ -221,8 +210,7 @@ def read_file(path): elif os.path.isdir(filename): index_file_path = self.find_index_file(filename) if index_file_path: - index_file_path = str( - Path(index_file_path).relative_to(self.output_dir)) + index_file_path = str(Path(index_file_path).relative_to(self.output_dir)) if index_file_path not in all_read_files: all_notes.append( f'Extra file {index_file_path} content in imports:\n{read_file(index_file_path)}' @@ -230,9 +218,9 @@ def read_file(path): if all_notes: all_notes = '\n'.join(all_notes) - user_content = (f'Problems found in your imports:\n' - f'\n{all_notes}\n' - f'Correct the errors and regenerate the code:\n') + user_content = ( + f'Problems found in your imports:\n\n{all_notes}\nCorrect the errors and regenerate the code:\n' + ) messages.append(Message(role='user', content=user_content)) else: messages.pop(-1) @@ -242,14 +230,12 @@ def read_file(path): async def _incremental_check(self, code_file: str, partial_code: str): if self.lsp_check: - lsp_result = await self._incremental_lsp_check( - code_file, partial_code) + lsp_result = await self._incremental_lsp_check(code_file, partial_code) else: lsp_result = None if self.post_import_check: - import_result = await self._after_import_check( - code_file, partial_code) + import_result = await self._after_import_check(code_file, partial_code) else: import_result = None return (lsp_result or '') + '\n' + (import_result or '') @@ -260,38 +246,27 @@ def find_index_file(full_path): return None else: result = None - for index_file in [ - 'index.ts', 'index.tsx', 'index.js', 'index.jsx', - 'index.vue', '__init__.py' - ]: + for index_file in ['index.ts', 'index.tsx', 'index.js', 'index.jsx', 'index.vue', '__init__.py']: index_path = os.path.join(full_path, index_file) if os.path.exists(index_path): result = index_path break return result - async def _after_import_check(self, code_file: str, - partial_code: str) -> Optional[str]: + async def _after_import_check(self, code_file: str, partial_code: str) -> Optional[str]: errors = [] partial_code = partial_code.split('\n') comments = ['*', '#', '-', '%', '/'] - contents = [ - c for c in partial_code - if not any(c.strip().startswith(cm) for cm in comments) - ] + contents = [c for c in partial_code if not any(c.strip().startswith(cm) for cm in comments)] partial_code = '\n'.join(contents) - all_imports: List[ImportInfo] = parse_imports(code_file, partial_code, - self.output_dir) + all_imports: List[ImportInfo] = parse_imports(code_file, partial_code, self.output_dir) for info in all_imports: source_file = info.source_file if not source_file or 'react' in source_file or 'vue' in source_file: continue - info.imported_items = [ - item for item in info.imported_items - if item not in ('*', 'default') - ] + info.imported_items = [item for item in info.imported_items if item not in ('*', 'default')] if not os.path.isabs(source_file): full_path = os.path.join(self.output_dir, source_file) @@ -308,14 +283,17 @@ async def _after_import_check(self, code_file: str, errors.append( f'Import error in {code_file}:\n' f" Directory '{source_file}' exists but has no index file (__init__.py, index.ts, etc.)\n" - f' Statement: {info.raw_statement}\n') + f' Statement: {info.raw_statement}\n' + ) continue else: full_path = index_file_path else: - errors.append(f'Import error in {code_file}:\n' - f" File '{source_file}' does not exist\n" - f' Statement: {info.raw_statement}\n') + errors.append( + f'Import error in {code_file}:\n' + f" File '{source_file}' does not exist\n" + f' Statement: {info.raw_statement}\n' + ) continue # 2. Check if imported symbols exist in the file @@ -337,12 +315,12 @@ async def _after_import_check(self, code_file: str, errors.append( f'Import error in {code_file}:\n' f" Items {missing_items} not found in '{source_file}'\n" - f' Statement: {info.raw_statement}\n') + f' Statement: {info.raw_statement}\n' + ) return '\n'.join(errors) if errors else None - async def _incremental_lsp_check(self, code_file: str, - partial_code: str) -> Optional[str]: + async def _incremental_lsp_check(self, code_file: str, partial_code: str) -> Optional[str]: lsp_servers = self.shared_lsp_context.get('lsp_servers', {}) if not lsp_servers: return None @@ -379,11 +357,8 @@ async def _incremental_lsp_check(self, code_file: str, return await lsp_server.call_tool( 'lsp_code_server', tool_name='update_and_check', - tool_args={ - 'file_path': code_file, - 'content': partial_code, - 'language': lang - }) + tool_args={'file_path': code_file, 'content': partial_code, 'language': lang}, + ) def filter_code_files(self): code_files = [] @@ -403,20 +378,20 @@ def increment_unchecked_file(self): self.unchecked_files.pop(key) logger.error( f"Unchecked file {key} still have problem:\n{self.unchecked_issues.get('key')}\n" - f'But the checking limit has reached.') + f'But the checking limit has reached.' + ) async def after_tool_call(self, messages: List[Message]): - is_prepare = len(messages[-1].tool_calls - or []) > 0 or messages[-1].role != 'assistant' - is_code_finish = '' in messages[ - -1].content and '' in messages[ - -1].content and not is_prepare + is_prepare = len(messages[-1].tool_calls or []) > 0 or messages[-1].role != 'assistant' + is_code_finish = '' in messages[-1].content and '' in messages[-1].content and not is_prepare is_import = ( - self.is_stop_imports() and not is_code_finish and not is_prepare + self.is_stop_imports() + and not is_code_finish + and not is_prepare and '' in messages[-1].content - and '' not in messages[-1].content) - is_check = messages[-1].role == 'assistant' and len( - messages[-1].tool_calls or []) == 0 and not is_import + and '' not in messages[-1].content + ) + is_check = messages[-1].role == 'assistant' and len(messages[-1].tool_calls or []) == 0 and not is_import message = messages[-1] all_issues = [] @@ -424,7 +399,6 @@ async def after_tool_call(self, messages: List[Message]): self._before_import_check(messages) if is_code_finish: - # Saving code result, remaining_text = extract_code_blocks(message.content) if result: @@ -456,15 +430,12 @@ async def after_tool_call(self, messages: List[Message]): if is_check: # After checking when fix ended or write ended for uncheck_file in list(self.unchecked_files.keys()): - with open(os.path.join(self.output_dir, uncheck_file), - 'r') as f: + with open(os.path.join(self.output_dir, uncheck_file), 'r') as f: _code = f.read() - lsp_feedback = await self._incremental_check( - uncheck_file, _code) + lsp_feedback = await self._incremental_check(uncheck_file, _code) lsp_feedback = lsp_feedback.strip() if lsp_feedback: - all_issues.append(f'❎Issues in {uncheck_file}:' - + lsp_feedback) + all_issues.append(f'❎Issues in {uncheck_file}:' + lsp_feedback) self.unchecked_issues[uncheck_file] = lsp_feedback else: logger.info(f'✅No issues found in {uncheck_file}.') @@ -506,16 +477,15 @@ async def after_tool_call(self, messages: List[Message]): if self.error_counter > 2: raise RuntimeError('The model does not output any response!') - new_task = is_code_finish and self.code_files and ( - not self.unchecked_files) + new_task = is_code_finish and self.code_files and (not self.unchecked_files) if new_task: last_file = self.code_files[-1] messages.append( Message( role='user', - content= - f'\nA code file in your imports not found, you should write it first: {last_file}\n' - )) + content=f'\nA code file in your imports not found, you should write it first: {last_file}\n', + ) + ) # Condense code block and prepare index files # await self.code_condenser.run(messages) @@ -523,7 +493,6 @@ async def after_tool_call(self, messages: List[Message]): @dataclasses.dataclass class FileRelation: - name: str description: str done: bool = False @@ -531,7 +500,6 @@ class FileRelation: class CodingAgent(CodeAgent): - def __init__(self, config, tag, trust_remote_code, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) # Shared LSP context across all Programmers @@ -549,31 +517,23 @@ async def _init_lsp_servers(self): # Detect all languages in the project detected_languages = set() - if any(kw in framework for kw in - ['typescript', 'javascript', 'react', 'node', 'npm', 'html']): + if any(kw in framework for kw in ['typescript', 'javascript', 'react', 'node', 'npm', 'html']): detected_languages.add('typescript') - if any(kw in framework - for kw in ['python', 'django', 'flask', 'fastapi']): + if any(kw in framework for kw in ['python', 'django', 'flask', 'fastapi']): detected_languages.add('python') - if any(kw in framework - for kw in ['java ', 'java\n', 'spring', 'maven', 'gradle']): + if any(kw in framework for kw in ['java ', 'java\n', 'spring', 'maven', 'gradle']): detected_languages.add('java') if not detected_languages: logger.info('No supported languages detected in framework.txt') return - logger.info( - f"Initializing LSP servers for languages: {', '.join(detected_languages)}" - ) + logger.info(f"Initializing LSP servers for languages: {', '.join(detected_languages)}") # Initialize LSP server for each detected language - lsp_config = DictConfig({ - 'workspace_dir': self.output_dir, - 'output_dir': self.output_dir - }) + lsp_config = DictConfig({'workspace_dir': self.output_dir, 'output_dir': self.output_dir}) lsp_servers = {} for lang in detected_languages: @@ -585,12 +545,8 @@ async def _init_lsp_servers(self): for lang, lsp_server in lsp_servers.items(): logger.info(f'Building LSP index for {lang}...') await lsp_server.call_tool( - 'lsp_code_server', - tool_name='check_directory', - tool_args={ - 'directory': '', - 'language': lang - }) + 'lsp_code_server', tool_name='check_directory', tool_args={'directory': '', 'language': lang} + ) logger.info(f'LSP index built for {lang}') self.shared_lsp_context['lsp_servers'] = lsp_servers @@ -607,9 +563,20 @@ async def _cleanup_lsp_servers(self): except Exception: # noqa pass - async def write_code(self, topic, user_story, framework, protocol, - file_order, name, description, index, last_batch, - siblings, next_batch): + async def write_code( + self, + topic, + user_story, + framework, + protocol, + file_order, + name, + description, + index, + last_batch, + siblings, + next_batch, + ): logger.info(f'Writing {name}') _config = deepcopy(self.config) messages = [ @@ -626,7 +593,8 @@ async def write_code(self, topic, user_story, framework, protocol, f'File description: {description}\n' f'Previous batch output:\n{last_batch}\n' f'Other workers writing in parallel:\n{siblings}\n' - f'Next batch planned:\n{next_batch}\n'), + f'Next batch planned:\n{next_batch}\n', + ), ] _config = deepcopy(self.config) @@ -637,7 +605,8 @@ async def write_code(self, topic, user_story, framework, protocol, tag=f'programmer-{name.replace(os.sep, "-")}', trust_remote_code=True, code_file=name, - shared_lsp_context=self.shared_lsp_context) # Pass shared context + shared_lsp_context=self.shared_lsp_context, + ) # Pass shared context await programmer.run(messages) async def execute_code(self, inputs, **kwargs): @@ -656,8 +625,7 @@ async def execute_code(self, inputs, **kwargs): file_orders = self.construct_file_orders() file_relation = OrderedDict() self.refresh_file_status(file_relation) - shutil.rmtree( - os.path.join(self.output_dir, 'locks'), ignore_errors=True) + shutil.rmtree(os.path.join(self.output_dir, 'locks'), ignore_errors=True) for idx, files in enumerate(file_orders): while True: @@ -689,7 +657,8 @@ async def execute_code(self, inputs, **kwargs): index=idx, last_batch=last_batch, siblings='\n'.join(set(files) - {name}), - next_batch=next_batch) + next_batch=next_batch, + ) for name, description in files.items() ] @@ -749,8 +718,7 @@ def refresh_file_status(self, file_relation): description = file['description'] file_path = os.path.join(self.output_dir, file_name) if file_name not in file_relation: - file_relation[file_name] = FileRelation( - name=file_name, description=description) + file_relation[file_name] = FileRelation(name=file_name, description=description) file_relation[file_name].done = os.path.exists(file_path) def construct_file_information(self, file_relation, add_output_dir=False): diff --git a/projects/code_genesis/workflow/file_design.py b/projects/code_genesis/workflow/file_design.py index b33f977a5..d1dae4756 100644 --- a/projects/code_genesis/workflow/file_design.py +++ b/projects/code_genesis/workflow/file_design.py @@ -1,13 +1,12 @@ +import json import os from typing import List -import json from ms_agent import LLMAgent from ms_agent.llm import Message class FileDesignAgent(LLMAgent): - async def run(self, messages, **kwargs): with open(os.path.join(self.output_dir, 'topic.txt'), 'r') as f: topic = f.read() @@ -32,15 +31,11 @@ async def after_tool_call(self, messages: List[Message]): if self.runtime.should_stop: query = None - if os.path.isfile( - os.path.join(self.output_dir, 'file_design.txt')): - with open( - os.path.join(self.output_dir, 'file_design.txt'), - 'r') as f: + if os.path.isfile(os.path.join(self.output_dir, 'file_design.txt')): + with open(os.path.join(self.output_dir, 'file_design.txt'), 'r') as f: file_design = json.load(f) - with open(os.path.join(self.output_dir, 'modules.txt'), - 'r') as f: + with open(os.path.join(self.output_dir, 'modules.txt'), 'r') as f: modules = f.readlines() files1 = set() @@ -63,8 +58,7 @@ async def after_tool_call(self, messages: List[Message]): f'please provide the correct file order without these files.' ) else: - query = ('The file design you provided is missing, ' - 'please provide the complete file design.') + query = 'The file design you provided is missing, please provide the complete file design.' if query: messages.append(Message(role='user', content=query)) diff --git a/projects/code_genesis/workflow/file_order.py b/projects/code_genesis/workflow/file_order.py index 592a68025..c380e926f 100644 --- a/projects/code_genesis/workflow/file_order.py +++ b/projects/code_genesis/workflow/file_order.py @@ -1,13 +1,12 @@ +import json import os from typing import List -import json from ms_agent import LLMAgent from ms_agent.llm import Message class FileOrderAgent(LLMAgent): - async def run(self, messages, **kwargs): with open(os.path.join(self.output_dir, 'topic.txt'), 'r') as f: topic = f.read() @@ -33,14 +32,10 @@ async def after_tool_call(self, messages: List[Message]): query = None if os.path.isfile(os.path.join(self.output_dir, 'file_order.txt')): - with open( - os.path.join(self.output_dir, 'file_order.txt'), - 'r') as f: + with open(os.path.join(self.output_dir, 'file_order.txt'), 'r') as f: file_order = json.load(f) - with open( - os.path.join(self.output_dir, 'file_design.txt'), - 'r') as f: + with open(os.path.join(self.output_dir, 'file_design.txt'), 'r') as f: file_design = json.load(f) files1 = set() @@ -63,8 +58,7 @@ async def after_tool_call(self, messages: List[Message]): f'please provide the correct file order without these files.' ) else: - query = ('The file order you provided is missing, ' - 'please provide the complete file order.') + query = 'The file order you provided is missing, please provide the complete file order.' if query: messages.append(Message(role='user', content=query)) diff --git a/projects/code_genesis/workflow/install.py b/projects/code_genesis/workflow/install.py index 82f091147..69ea4ee13 100644 --- a/projects/code_genesis/workflow/install.py +++ b/projects/code_genesis/workflow/install.py @@ -5,7 +5,6 @@ class InstallAgent(LLMAgent): - async def run(self, messages, **kwargs): with open(os.path.join(self.output_dir, 'topic.txt'), 'r') as f: topic = f.read() @@ -19,7 +18,8 @@ async def run(self, messages, **kwargs): query = ( f'Topic: {topic}\nFramework: {framework}\nFile Design: {file_design}\n' f'Your `workflow_dir` is "./", ' - 'Please write dependency files and install dependencies.') + 'Please write dependency files and install dependencies.' + ) messages = [ Message(role='system', content=self.config.prompt.system), diff --git a/projects/code_genesis/workflow/refine.py b/projects/code_genesis/workflow/refine.py index 17c9ffba4..6247f5369 100644 --- a/projects/code_genesis/workflow/refine.py +++ b/projects/code_genesis/workflow/refine.py @@ -1,26 +1,24 @@ +import json import os import sys from typing import List, OrderedDict -import json from coding import CodingAgent +from omegaconf import DictConfig + from ms_agent import LLMAgent from ms_agent.llm import Message from ms_agent.memory.condenser.refine_condenser import RefineCondenser from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG -from omegaconf import DictConfig logger = get_logger() class RefineAgent(LLMAgent): - - def __init__(self, - config: DictConfig = DictConfig({}), - tag: str = DEFAULT_TAG, - trust_remote_code: bool = False, - **kwargs): + def __init__( + self, config: DictConfig = DictConfig({}), tag: str = DEFAULT_TAG, trust_remote_code: bool = False, **kwargs + ): # Validate and adjust config before passing to parent config = self._validate_config(config) super().__init__(config, tag, trust_remote_code, **kwargs) @@ -36,22 +34,19 @@ def _validate_config(self, config: DictConfig) -> DictConfig: # Check edit_file_config.api_key edit_file_api_key = None try: - edit_file_api_key = config.get('tools', {}).get( - 'file_system', {}).get('edit_file_config', {}).get('api_key') + edit_file_api_key = ( + config.get('tools', {}).get('file_system', {}).get('edit_file_config', {}).get('api_key') + ) except Exception: pass if not edit_file_api_key: # Remove edit_file from include list try: - include_list = config.get('tools', - {}).get('file_system', - {}).get('include', []) + include_list = config.get('tools', {}).get('file_system', {}).get('include', []) if 'edit_file' in include_list: include_list.remove('edit_file') - logger.warning( - '[refine] edit_file_config.api_key not set, removing edit_file from tools' - ) + logger.warning('[refine] edit_file_config.api_key not set, removing edit_file from tools') except Exception: pass else: @@ -60,9 +55,9 @@ def _validate_config(self, config: DictConfig) -> DictConfig: # Check EDGEONE_PAGES_API_TOKEN edgeone_token = None try: - edgeone_token = config.get('tools', {}).get( - 'edgeone-pages-mcp', {}).get('env', - {}).get('EDGEONE_PAGES_API_TOKEN') + edgeone_token = ( + config.get('tools', {}).get('edgeone-pages-mcp', {}).get('env', {}).get('EDGEONE_PAGES_API_TOKEN') + ) except Exception: pass @@ -71,15 +66,11 @@ def _validate_config(self, config: DictConfig) -> DictConfig: try: if 'edgeone-pages-mcp' in config.get('tools', {}): del config['tools']['edgeone-pages-mcp'] - logger.warning( - '[refine] EDGEONE_PAGES_API_TOKEN not set, removing edgeone-pages-mcp from tools' - ) + logger.warning('[refine] EDGEONE_PAGES_API_TOKEN not set, removing edgeone-pages-mcp from tools') except Exception: pass else: - logger.info( - f'[refine] EDGEONE_PAGES_API_TOKEN is configured: {edgeone_token[:10]}...' - ) + logger.info(f'[refine] EDGEONE_PAGES_API_TOKEN is configured: {edgeone_token[:10]}...') return OmegaConf.create(config) @@ -111,7 +102,8 @@ async def run(self, messages, **kwargs): f'Project files are at the current working directory (/data). ' f'All relative paths work directly.\n' f'When creating the deployment zip file, name it workspace.zip.\n' - f'Please refine the project and deploy it to EdgeOne Pages:'), + f'Please refine the project and deploy it to EdgeOne Pages:', + ), ] return await super().run(messages, **kwargs) @@ -121,9 +113,7 @@ async def after_tool_call(self, messages: List[Message]): if self.runtime.should_stop: if not sys.stdin.isatty(): # Running in WebUI - notify user that agent is waiting for input - logger.info( - '[refine] Agent completed initial refinement. Waiting for user feedback.' - ) + logger.info('[refine] Agent completed initial refinement. Waiting for user feedback.') # # Add a system message to notify the user # messages.append( @@ -137,15 +127,12 @@ async def after_tool_call(self, messages: List[Message]): try: query = sys.stdin.readline().strip() if query: - logger.info( - f'[refine] Received input from WebUI: {query}') + logger.info(f'[refine] Received input from WebUI: {query}') messages.append(Message(role='user', content=query)) self.runtime.should_stop = False return else: - logger.warning( - '[refine] Received empty input, continuing to wait...' - ) + logger.warning('[refine] Received empty input, continuing to wait...') return except (EOFError, OSError, ValueError) as e: logger.error(f'[refine] Error reading from stdin: {e}') diff --git a/projects/code_genesis/workflow/user_story.py b/projects/code_genesis/workflow/user_story.py index 28c62f984..f92ee0380 100644 --- a/projects/code_genesis/workflow/user_story.py +++ b/projects/code_genesis/workflow/user_story.py @@ -6,7 +6,6 @@ class SplitModuleAgent(LLMAgent): - async def on_task_end(self, messages: List[Message]): assert os.path.isfile(os.path.join(self.output_dir, 'user_story.txt')) topic = '' diff --git a/projects/deep_research/run.py b/projects/deep_research/run.py index d2094fdc2..2cf8ef18c 100644 --- a/projects/deep_research/run.py +++ b/projects/deep_research/run.py @@ -6,16 +6,17 @@ from ms_agent.tools.search_engine import get_web_search_tool from ms_agent.workflow.deep_research.principle import MECEPrinciple from ms_agent.workflow.deep_research.research_workflow import ResearchWorkflow -from ms_agent.workflow.deep_research.research_workflow_beta import \ - ResearchWorkflowBeta +from ms_agent.workflow.deep_research.research_workflow_beta import ResearchWorkflowBeta -def run_workflow(user_prompt: str, - task_dir: str, - chat_client: OpenAIChat, - search_engine: SearchEngine, - reuse: bool, - use_ray: bool = False): +def run_workflow( + user_prompt: str, + task_dir: str, + chat_client: OpenAIChat, + search_engine: SearchEngine, + reuse: bool, + use_ray: bool = False, +): """ Run the deep research workflow, which follows a lightweight and efficient pipeline: 1. Receive a user prompt and generate search queries. @@ -43,15 +44,17 @@ def run_workflow(user_prompt: str, research_workflow.run(user_prompt=user_prompt) -def run_deep_workflow(user_prompt: str, - task_dir: str, - chat_client: OpenAIChat, - search_engine: SearchEngine, - breadth: int = 4, - depth: int = 2, - is_report: bool = True, - show_progress: bool = True, - use_ray: bool = False): +def run_deep_workflow( + user_prompt: str, + task_dir: str, + chat_client: OpenAIChat, + search_engine: SearchEngine, + breadth: int = 4, + depth: int = 2, + is_report: bool = True, + show_progress: bool = True, + use_ray: bool = False, +): """ Run the expandable deep research workflow (beta version). This version is more flexible and scalable than the original deep research workflow. @@ -78,23 +81,17 @@ def run_deep_workflow(user_prompt: str, """ research_workflow = ResearchWorkflowBeta( - client=chat_client, - search_engine=search_engine, - workdir=task_dir, - use_ray=use_ray, - enable_multimodal=True) + client=chat_client, search_engine=search_engine, workdir=task_dir, use_ray=use_ray, enable_multimodal=True + ) asyncio.run( research_workflow.run( - user_prompt=user_prompt, - breadth=breadth, - depth=depth, - is_report=is_report, - show_progress=show_progress)) + user_prompt=user_prompt, breadth=breadth, depth=depth, is_report=is_report, show_progress=show_progress + ) + ) if __name__ == '__main__': - query: str = 'Survey of the AI Agent within the recent 3 month, including the latest research papers, open-source projects, and industry applications.' # noqa task_workdir: str = '/path/to/your_workdir' # Specify your task work directory here reuse: bool = False @@ -110,9 +107,8 @@ def run_deep_workflow(user_prompt: str, api_key='xxx-xxx', base_url='https://api-inference.modelscope.cn/v1/', model='Qwen/Qwen3-235B-A22B-Instruct-2507', - generation_config={'extra_body': { - 'enable_thinking': False - }}) + generation_config={'extra_body': {'enable_thinking': False}}, + ) # Get web-search engine client # For the ExaSearch, you can get your API key from https://exa.ai diff --git a/projects/deep_research/v2/callbacks/quality_checker.py b/projects/deep_research/v2/callbacks/quality_checker.py index 36fadf902..a406face4 100644 --- a/projects/deep_research/v2/callbacks/quality_checker.py +++ b/projects/deep_research/v2/callbacks/quality_checker.py @@ -1,12 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json from abc import ABC, abstractmethod from typing import List, Optional -import json +from omegaconf import DictConfig, OmegaConf + from ms_agent.llm.openai_llm import OpenAI as OpenAILLM from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig, OmegaConf logger = get_logger() @@ -45,47 +46,49 @@ class ModelQualityChecker(ReportQualityChecker): """ _SYSTEM_PROMPTS = { - 'en': - ('You are a strict report quality auditor. Your ONLY job is to detect whether a research report violates any of the rules listed below.\n' - 'You MUST check ONLY against these rules — do NOT invent additional criteria or penalize anything not explicitly listed here.\n' - 'If a problem is NOT described by rules below, you MUST ignore it and return {"pass": true}. ' - 'Specifically: duplicate/repeated content, heading numbering gaps, structural ordering issues, stylistic choices, ' - 'and the density of inline citations within otherwise substantive paragraphs are all OUT OF SCOPE and must NOT cause a failure.\n\n' - 'RULES — flag the report ONLY if ANY of the following are clearly found:\n' - '1. Sections where detailed content has been replaced by ellipsis or brevity markers such as "...for brevity", ' - '"Content truncated for brevity", "omitted for brevity", "(remaining content follows the same pattern)", etc.\n' - '2. Sections that refer the reader to an external file instead of containing actual content, e.g. "This section ' - 'is stored in xxx file", "See full analysis in evidence/xxx".\n' - '3. Sections that guide the reader to view the reference source instead of writing substantive content, e.g. "See [1]", "Reference [2]".\n' - '4. Multiple reference/bibliography sections appear in the report (e.g., per-chapter reference lists), or any ' - 'variant heading such as "## References (Merged)", "## 参考文献(合并版)", "## 参考资料", etc. ' - 'Only one unified reference section at the very end is allowed.\n\n' - 'OUTPUT FORMAT:\n' - 'Respond with EXACTLY one JSON object. No markdown fences, no explanation outside the JSON.\n' - '{"pass": true} or {"pass": false, "reason": ""}\n' - 'Do NOT output anything else.'), - 'zh': - ('你是一个严格的研究报告质量审核员,你唯一的任务是判断报告是否违反了下方列出的规则。\n' - '你只能依据以下规则进行检查,不得自行发明额外标准,也不得基于规则未涉及的内容判定不通过。如果某个问题不属于下方规则的任何一条,你必须忽略它并返回 {"pass": true}。\n' - '特别说明:重复/相似内容、标题编号跳跃、章节结构顺序问题、文体风格选择、以及在有实质论述的段落中密集使用行内引注,都不在检查范围内,不得因此判定不通过。\n\n' - '规则 — 仅当明确发现以下任一问题时才判定不通过:\n' - '1. 正文被省略号或缩略标记替代,如"此处省略"、"篇幅所限不再展开"、"……以下类似"、"内容已截断"、"...for brevity"、"omitted for brevity"等。\n' - '2. 正文引导读者查看外部文件而非包含实际内容,如"该部分内容保存在xxx文件中"、"详见附件"、"See full analysis in evidence/xxx"。\n' - '3. 正文引导读者查看引用来源而没有撰写实质性内容,如"详见[1]"、"参考[2]"。\n' - '4. 报告中出现多个参考文献/引用列表章节(如各章节末尾的独立引用列表),或使用变体标题如"## 参考文献(合并版)"、"## 参考资料"、"## References (Merged)"等。' - '报告仅允许在末尾保留唯一一个统一的参考文献章节。\n\n' - '输出格式:\n' - '只返回一个JSON对象,不要使用markdown代码块,不要在JSON之外输出任何文字。\n' - '{"pass": true} 或者 {"reason": "<不得超过三句话;引用具体违反的规则编号>", "pass": false}\n' - '不要输出任何其他内容。'), + 'en': ( + 'You are a strict report quality auditor. Your ONLY job is to detect whether a research report violates any of the rules listed below.\n' + 'You MUST check ONLY against these rules — do NOT invent additional criteria or penalize anything not explicitly listed here.\n' + 'If a problem is NOT described by rules below, you MUST ignore it and return {"pass": true}. ' + 'Specifically: duplicate/repeated content, heading numbering gaps, structural ordering issues, stylistic choices, ' + 'and the density of inline citations within otherwise substantive paragraphs are all OUT OF SCOPE and must NOT cause a failure.\n\n' + 'RULES — flag the report ONLY if ANY of the following are clearly found:\n' + '1. Sections where detailed content has been replaced by ellipsis or brevity markers such as "...for brevity", ' + '"Content truncated for brevity", "omitted for brevity", "(remaining content follows the same pattern)", etc.\n' + '2. Sections that refer the reader to an external file instead of containing actual content, e.g. "This section ' + 'is stored in xxx file", "See full analysis in evidence/xxx".\n' + '3. Sections that guide the reader to view the reference source instead of writing substantive content, e.g. "See [1]", "Reference [2]".\n' + '4. Multiple reference/bibliography sections appear in the report (e.g., per-chapter reference lists), or any ' + 'variant heading such as "## References (Merged)", "## 参考文献(合并版)", "## 参考资料", etc. ' + 'Only one unified reference section at the very end is allowed.\n\n' + 'OUTPUT FORMAT:\n' + 'Respond with EXACTLY one JSON object. No markdown fences, no explanation outside the JSON.\n' + '{"pass": true} or {"pass": false, "reason": ""}\n' + 'Do NOT output anything else.' + ), + 'zh': ( + '你是一个严格的研究报告质量审核员,你唯一的任务是判断报告是否违反了下方列出的规则。\n' + '你只能依据以下规则进行检查,不得自行发明额外标准,也不得基于规则未涉及的内容判定不通过。如果某个问题不属于下方规则的任何一条,你必须忽略它并返回 {"pass": true}。\n' + '特别说明:重复/相似内容、标题编号跳跃、章节结构顺序问题、文体风格选择、以及在有实质论述的段落中密集使用行内引注,都不在检查范围内,不得因此判定不通过。\n\n' + '规则 — 仅当明确发现以下任一问题时才判定不通过:\n' + '1. 正文被省略号或缩略标记替代,如"此处省略"、"篇幅所限不再展开"、"……以下类似"、"内容已截断"、"...for brevity"、"omitted for brevity"等。\n' + '2. 正文引导读者查看外部文件而非包含实际内容,如"该部分内容保存在xxx文件中"、"详见附件"、"See full analysis in evidence/xxx"。\n' + '3. 正文引导读者查看引用来源而没有撰写实质性内容,如"详见[1]"、"参考[2]"。\n' + '4. 报告中出现多个参考文献/引用列表章节(如各章节末尾的独立引用列表),或使用变体标题如"## 参考文献(合并版)"、"## 参考资料"、"## References (Merged)"等。' + '报告仅允许在末尾保留唯一一个统一的参考文献章节。\n\n' + '输出格式:\n' + '只返回一个JSON对象,不要使用markdown代码块,不要在JSON之外输出任何文字。\n' + '{"pass": true} 或者 {"reason": "<不得超过三句话;引用具体违反的规则编号>", "pass": false}\n' + '不要输出任何其他内容。' + ), } _USER_TEMPLATES = { - 'en': - ('Please audit the following research report against the rules provided in the system instruction.\n\n' - '---BEGIN REPORT---\n{report}\n---END REPORT---'), - 'zh': ('请依据系统指令中提供的规则审核以下研究报告。\n\n' - '---报告开始---\n{report}\n---报告结束---'), + 'en': ( + 'Please audit the following research report against the rules provided in the system instruction.\n\n' + '---BEGIN REPORT---\n{report}\n---END REPORT---' + ), + 'zh': ('请依据系统指令中提供的规则审核以下研究报告。\n\n---报告开始---\n{report}\n---报告结束---'), } _MAX_REPORT_CHARS = 80000 @@ -96,25 +99,27 @@ def __init__(self, config: DictConfig): qc_cfg = getattr(qc_cfg, 'quality_check', DictConfig({})) self._model: str = str(getattr(qc_cfg, 'model', 'qwen3.5-plus')) - self._api_key: Optional[str] = getattr( - qc_cfg, 'openai_api_key', None) or getattr(config.llm, - 'openai_api_key', None) - self._base_url: Optional[str] = getattr( - qc_cfg, 'openai_base_url', None) or getattr( - config.llm, 'openai_base_url', None) + self._api_key: Optional[str] = getattr(qc_cfg, 'openai_api_key', None) or getattr( + config.llm, 'openai_api_key', None + ) + self._base_url: Optional[str] = getattr(qc_cfg, 'openai_base_url', None) or getattr( + config.llm, 'openai_base_url', None + ) self._client: Optional[OpenAILLM] = None def _build_llm_config(self) -> DictConfig: """Build lightweight llm config for quality checker.""" - return OmegaConf.create({ - 'llm': { - 'model': self._model, - 'openai_api_key': self._api_key, - 'openai_base_url': self._base_url, - }, - 'generation_config': {}, - }) + return OmegaConf.create( + { + 'llm': { + 'model': self._model, + 'openai_api_key': self._api_key, + 'openai_base_url': self._base_url, + }, + 'generation_config': {}, + } + ) def _ensure_client(self): if self._client is not None: @@ -126,23 +131,23 @@ async def check(self, content: str, lang: str) -> Optional[str]: report_text = content if len(report_text) > self._MAX_REPORT_CHARS: - report_text = report_text[:self._MAX_REPORT_CHARS] + report_text = report_text[: self._MAX_REPORT_CHARS] sys_prompt = self._SYSTEM_PROMPTS.get(lang, self._SYSTEM_PROMPTS['en']) - usr_template = self._USER_TEMPLATES.get(lang, - self._USER_TEMPLATES['en']) + usr_template = self._USER_TEMPLATES.get(lang, self._USER_TEMPLATES['en']) try: - response = self._client.generate(messages=[ - Message(role='system', content=sys_prompt), - Message( - role='user', - content=usr_template.format(report=report_text), - ), - ]) + response = self._client.generate( + messages=[ + Message(role='system', content=sys_prompt), + Message( + role='user', + content=usr_template.format(report=report_text), + ), + ] + ) raw = (response.content or '').strip() - logger.info( - f'ModelQualityChecker ({self._model}): raw response: {raw}') + logger.info(f'ModelQualityChecker ({self._model}): raw response: {raw}') verdict = json.loads(raw) if verdict.get('pass', True): @@ -150,8 +155,7 @@ async def check(self, content: str, lang: str) -> Optional[str]: return verdict.get('reason', 'placeholder_content') except json.JSONDecodeError: - logger.warning(f'ModelQualityChecker: failed to parse JSON from ' - f'model response: {raw!r}') + logger.warning(f'ModelQualityChecker: failed to parse JSON from model response: {raw!r}') return None except Exception as exc: logger.warning(f'ModelQualityChecker: model call failed: {exc}') @@ -175,6 +179,5 @@ def build_quality_checkers(config: DictConfig) -> List[ReportQualityChecker]: checkers: List[ReportQualityChecker] = [] checkers.append(ModelQualityChecker(config)) - logger.info( - f'Quality checker chain initialised with {len(checkers)} checker(s).') + logger.info(f'Quality checker chain initialised with {len(checkers)} checker(s).') return checkers diff --git a/projects/deep_research/v2/callbacks/reporter_callback.py b/projects/deep_research/v2/callbacks/reporter_callback.py index 477623a74..a3f4bac8f 100644 --- a/projects/deep_research/v2/callbacks/reporter_callback.py +++ b/projects/deep_research/v2/callbacks/reporter_callback.py @@ -1,19 +1,19 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # yapf: disable +import json import os import re import shutil from typing import Any, Dict, List, Optional, Set -import json -from callbacks.quality_checker import (ReportQualityChecker, - build_quality_checkers) +from omegaconf import DictConfig + +from callbacks.quality_checker import ReportQualityChecker, build_quality_checkers from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_MEMORY_DIR -from omegaconf import DictConfig logger = get_logger() @@ -446,7 +446,7 @@ def _format_trajectory(self, messages: List[Dict[str, Any]]) -> str: lines.append('') elif role == 'tool': - lines.append(f'{labels["tool_result"]} ({tool_name})') + lines.append(f'{labels['tool_result']} ({tool_name})') # Truncate very long tool results if content and len(content) > 20000: content = content[:20000] + '\n...(truncated)' @@ -485,7 +485,7 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): labels = self._TRAJECTORY_LABELS.get( self.lang, self._TRAJECTORY_LABELS['en']) - trajectory_str = (f'{labels["trajectory_intro"]}\n\n' + trajectory_str = (f'{labels['trajectory_intro']}\n\n' f'{trajectory_text}') if messages[insert_pos].role == 'user': diff --git a/projects/deep_research/v2/callbacks/researcher_callback.py b/projects/deep_research/v2/callbacks/researcher_callback.py index 8306a6ba7..2569638fb 100644 --- a/projects/deep_research/v2/callbacks/researcher_callback.py +++ b/projects/deep_research/v2/callbacks/researcher_callback.py @@ -5,14 +5,14 @@ import shutil from typing import List, Optional -from callbacks.quality_checker import (ReportQualityChecker, - build_quality_checkers) +from omegaconf import DictConfig, OmegaConf + +from callbacks.quality_checker import ReportQualityChecker, build_quality_checkers from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.openai_llm import OpenAI as OpenAILLM from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig, OmegaConf logger = get_logger() diff --git a/projects/deep_research/v2/callbacks/searcher_callback.py b/projects/deep_research/v2/callbacks/searcher_callback.py index 735a2d47a..093f7e0a7 100644 --- a/projects/deep_research/v2/callbacks/searcher_callback.py +++ b/projects/deep_research/v2/callbacks/searcher_callback.py @@ -1,15 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json import os import re import uuid from typing import Any, List, Optional -import json +from omegaconf import DictConfig + from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() @@ -61,26 +62,25 @@ class SearcherCallback(Callback): # Bilingual round-reminder templates keyed by language code. _ROUND_REMINDER_TEMPLATES = { - 'zh': - ('你已接近最大允许的对话轮数上限,请立刻开始收敛准备最终交付。\n' - '- 从现在开始:优先总结已有证据与进度、补齐关键缺口、减少发散探索。\n' - '- 在接下来的极少数轮次内,立刻准备并输出最终的 JSON 回复。\n' - '- 当前轮次信息:round=,max_chat_round=,剩余≈ 轮。' - ), - 'en': - ('You are approaching the maximum allowed conversation round limit. Begin converging immediately and prepare the final delivery.\n' - '- From now on: Prioritize summarizing existing evidence and progress, fill critical gaps, and reduce exploratory divergence.\n' - '- Within the very few remaining rounds, immediately prepare and output the final JSON response.\n' - '- Current round info: round=, max_chat_round=, remaining ≈ rounds.' - ), + 'zh': ( + '你已接近最大允许的对话轮数上限,请立刻开始收敛准备最终交付。\n' + '- 从现在开始:优先总结已有证据与进度、补齐关键缺口、减少发散探索。\n' + '- 在接下来的极少数轮次内,立刻准备并输出最终的 JSON 回复。\n' + '- 当前轮次信息:round=,max_chat_round=,剩余≈ 轮。' + ), + 'en': ( + 'You are approaching the maximum allowed conversation round limit. Begin converging immediately and prepare the final delivery.\n' + '- From now on: Prioritize summarizing existing evidence and progress, fill critical gaps, and reduce exploratory divergence.\n' + '- Within the very few remaining rounds, immediately prepare and output the final JSON response.\n' + '- Current round info: round=, max_chat_round=, remaining ≈ rounds.' + ), } def __init__(self, config: DictConfig): super().__init__(config) self.output_dir = getattr(config, 'output_dir', './output') self.search_task_id: Optional[str] = None - self.search_result_path = os.path.join( - self.output_dir, f'search_result_{uuid.uuid4().hex[:4]}.json') + self.search_result_path = os.path.join(self.output_dir, f'search_result_{uuid.uuid4().hex[:4]}.json') # Resolve language from config for bilingual prompt selection. self.lang = self._resolve_lang(config) self._ensure_output_dir() @@ -103,8 +103,7 @@ def _ensure_output_dir(self) -> None: try: os.makedirs(self.output_dir, exist_ok=True) except Exception as e: - logger.warning( - f'Failed to create output_dir {self.output_dir!r}: {e}') + logger.warning(f'Failed to create output_dir {self.output_dir!r}: {e}') @staticmethod def _sanitize_task_id(task_id: Any, max_len: int = 10) -> Optional[str]: @@ -138,27 +137,21 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): if not isinstance(message.content, str): continue search_task_description = json.loads(message.content) - raw_task_id = search_task_description.get( - 'task_id') or search_task_description.get('任务ID') + raw_task_id = search_task_description.get('task_id') or search_task_description.get('任务ID') safe_task_id = self._sanitize_task_id(raw_task_id) self.search_task_id = safe_task_id if safe_task_id: - self.search_result_path = os.path.join( - self.output_dir, - f'search_result_{safe_task_id}.json') + self.search_result_path = os.path.join(self.output_dir, f'search_result_{safe_task_id}.json') except json.JSONDecodeError: - logger.warning( - f'Failed to parse search task description: {message.content}' - ) + logger.warning(f'Failed to parse search task description: {message.content}') continue except Exception as e: logger.warning( - f'Unexpected error when parsing search task description: {message.content}, ' - f'with error: {e}') + f'Unexpected error when parsing search task description: {message.content}, with error: {e}' + ) continue - async def on_generate_response(self, runtime: Runtime, - messages: List[Message]): + async def on_generate_response(self, runtime: Runtime, messages: List[Message]): """ Inject a round-aware reminder into the system prompt near max rounds. @@ -188,10 +181,8 @@ async def on_generate_response(self, runtime: Runtime, custom_message = None if round_reminder_cfg is not None: enabled = bool(getattr(round_reminder_cfg, 'enabled', False)) - remind_before = getattr(round_reminder_cfg, - 'remind_before_max_round', remind_before) - remind_at_round = getattr(round_reminder_cfg, 'remind_at_round', - None) + remind_before = getattr(round_reminder_cfg, 'remind_before_max_round', remind_before) + remind_at_round = getattr(round_reminder_cfg, 'remind_at_round', None) custom_message = getattr(round_reminder_cfg, 'message', None) if not enabled: @@ -217,21 +208,18 @@ async def on_generate_response(self, runtime: Runtime, reminder_mark = '\n[ROUND_REMINDER]\n' # Avoid injecting duplicates (e.g. if resumed from history at the same round). for m in reversed(messages[-10:]): - if m.role == 'user' and isinstance( - m.content, str) and '[ROUND_REMINDER]' in m.content: + if m.role == 'user' and isinstance(m.content, str) and '[ROUND_REMINDER]' in m.content: return remaining = max_chat_round - runtime.round if not custom_message or not isinstance(custom_message, str): - custom_message = self._ROUND_REMINDER_TEMPLATES.get( - self.lang, self._ROUND_REMINDER_TEMPLATES['en']) + custom_message = self._ROUND_REMINDER_TEMPLATES.get(self.lang, self._ROUND_REMINDER_TEMPLATES['en']) injected = custom_message injected = injected.replace('', str(runtime.round)) injected = injected.replace('', str(max_chat_round)) injected = injected.replace('', str(remaining)) - messages.append( - Message(role='user', content=reminder_mark + injected + '\n')) + messages.append(Message(role='user', content=reminder_mark + injected + '\n')) async def on_task_end(self, runtime: Runtime, messages: List[Message]): """ @@ -240,12 +228,9 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): """ self._ensure_output_dir() json_path = self.search_result_path - md_path = (json_path[:-5] - + '.md') if json_path.endswith('.json') else ( - json_path.split('.')[0] + '.md') + md_path = (json_path[:-5] + '.md') if json_path.endswith('.json') else (json_path.split('.')[0] + '.md') if os.path.exists(json_path) or os.path.exists(md_path): - logger.info( - f'Search result already exists at {json_path} or {md_path}') + logger.info(f'Search result already exists at {json_path} or {md_path}') return # Find the last assistant message without tool calls @@ -265,39 +250,28 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): except (json.JSONDecodeError, TypeError): parsed_json = _parse_search_result_json(content) if parsed_json is not None: - logger.info( - 'Searcher: parsed JSON from fenced or embedded payload' - ) + logger.info('Searcher: parsed JSON from fenced or embedded payload') else: parsed_json = _parse_search_result_json(str(content)) if parsed_json is not None: try: with open(json_path, 'x', encoding='utf-8') as f: - json.dump( - parsed_json, f, ensure_ascii=False, indent=2) - logger.info( - f'Searcher: Search result saved to {json_path}') + json.dump(parsed_json, f, ensure_ascii=False, indent=2) + logger.info(f'Searcher: Search result saved to {json_path}') except FileExistsError: - logger.info( - f'Search result already exists at {json_path}') + logger.info(f'Search result already exists at {json_path}') else: - logger.warning( - 'Failed to parse search result as JSON, saving as markdown' - ) - text = content if isinstance(content, - str) else str(content) + logger.warning('Failed to parse search result as JSON, saving as markdown') + text = content if isinstance(content, str) else str(content) try: with open(md_path, 'x', encoding='utf-8') as f: f.write(text) - logger.info( - f'Searcher: Search result saved to {md_path}') + logger.info(f'Searcher: Search result saved to {md_path}') except FileExistsError: - logger.info( - f'Search result already exists at {md_path}') + logger.info(f'Search result already exists at {md_path}') except Exception as e: - logger.warning( - f'Unexpected error when saving search result: {e}') + logger.warning(f'Unexpected error when saving search result: {e}') return logger.warning('Searcher: No final search result found in messages') diff --git a/projects/deep_research/v2/eval/dr_bench_runner.py b/projects/deep_research/v2/eval/dr_bench_runner.py index 1917564bf..ce3978c6b 100644 --- a/projects/deep_research/v2/eval/dr_bench_runner.py +++ b/projects/deep_research/v2/eval/dr_bench_runner.py @@ -17,7 +17,9 @@ """ from __future__ import annotations + import argparse +import json import os import subprocess import sys @@ -28,8 +30,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple -import json - try: # Auto-load environment variables from a nearby `.env` (if present). from dotenv import find_dotenv, load_dotenv @@ -55,10 +55,7 @@ def _read_jsonl(path: str) -> List[Dict]: return items -def _append_jsonl(path: str, - obj: Dict, - *, - lock: Optional[threading.Lock] = None) -> None: +def _append_jsonl(path: str, obj: Dict, *, lock: Optional[threading.Lock] = None) -> None: os.makedirs(os.path.dirname(path) or '.', exist_ok=True) if lock is None: with open(path, 'a', encoding='utf-8') as f: @@ -267,8 +264,7 @@ def _report_is_stable( stable_since = now_s return False, sig, stable_since - return (now_s - stable_since) >= max(0.0, - stable_window_s), sig, stable_since + return (now_s - stable_since) >= max(0.0, stable_window_s), sig, stable_since def _run_one_task( @@ -324,18 +320,12 @@ def _run_one_task( # (e.g. process hung at shutdown). Force-reap to unblock # the batch runner. # - post_finish_grace_s = float( - os.getenv('DR_BENCH_POST_FINISH_GRACE_S', '180') or 180.0) - post_report_exit_grace_s = float( - os.getenv('DR_BENCH_POST_REPORT_EXIT_GRACE_S', '3600') or 3600.0) - report_stable_window_s = float( - os.getenv('DR_BENCH_REPORT_STABLE_WINDOW_S', '2') or 2.0) - poll_interval_s = float( - os.getenv('DR_BENCH_SUBPROCESS_POLL_INTERVAL_S', '0.5') or 0.5) - terminate_timeout_s = float( - os.getenv('DR_BENCH_SUBPROCESS_TERMINATE_TIMEOUT_S', '5') or 5.0) - kill_timeout_s = float( - os.getenv('DR_BENCH_SUBPROCESS_KILL_TIMEOUT_S', '2') or 2.0) + post_finish_grace_s = float(os.getenv('DR_BENCH_POST_FINISH_GRACE_S', '180') or 180.0) + post_report_exit_grace_s = float(os.getenv('DR_BENCH_POST_REPORT_EXIT_GRACE_S', '3600') or 3600.0) + report_stable_window_s = float(os.getenv('DR_BENCH_REPORT_STABLE_WINDOW_S', '2') or 2.0) + poll_interval_s = float(os.getenv('DR_BENCH_SUBPROCESS_POLL_INTERVAL_S', '0.5') or 0.5) + terminate_timeout_s = float(os.getenv('DR_BENCH_SUBPROCESS_TERMINATE_TIMEOUT_S', '5') or 5.0) + kill_timeout_s = float(os.getenv('DR_BENCH_SUBPROCESS_KILL_TIMEOUT_S', '2') or 2.0) report_seen_stable_at: Optional[float] = None report_last_sig: Optional[Tuple[float, int]] = None @@ -364,9 +354,11 @@ def _run_one_task( # --- Condition 1: .researcher_task_finished marker --- if marker_seen_at is None and os.path.exists(marker_path): marker_seen_at = now_s - if (marker_seen_at is not None and proc.poll() is None - and (now_s - marker_seen_at) >= max( - 0.0, post_finish_grace_s)): + if ( + marker_seen_at is not None + and proc.poll() is None + and (now_s - marker_seen_at) >= max(0.0, post_finish_grace_s) + ): _terminate_process( proc, terminate_timeout_s=terminate_timeout_s, @@ -377,8 +369,7 @@ def _run_one_task( # --- Condition 2: report stable for a long time (fallback) --- report_path_hint = _find_report_md(workdir) - if report_path_hint and _is_direct_final_report_path( - workdir, report_path_hint): + if report_path_hint and _is_direct_final_report_path(workdir, report_path_hint): stable, report_last_sig, report_stable_since = _report_is_stable( report_path_hint, stable_window_s=report_stable_window_s, @@ -391,10 +382,11 @@ def _run_one_task( report_seen_stable_at = now_s else: report_seen_stable_at = None - if (report_seen_stable_at is not None - and proc.poll() is None - and (now_s - report_seen_stable_at) >= max( - 0.0, post_report_exit_grace_s)): + if ( + report_seen_stable_at is not None + and proc.poll() is None + and (now_s - report_seen_stable_at) >= max(0.0, post_report_exit_grace_s) + ): _terminate_process( proc, terminate_timeout_s=terminate_timeout_s, @@ -406,8 +398,7 @@ def _run_one_task( # Drain available stdout without blocking. if select is not None: try: - r, _, _ = select.select([proc.stdout], [], [], - poll_interval_s) + r, _, _ = select.select([proc.stdout], [], [], poll_interval_s) except Exception: r = [] if r: @@ -422,8 +413,7 @@ def _run_one_task( print(f'[{task.task_id}] {line}', end='') else: with print_lock: - print( - f'[{task.task_id}] {line}', end='') + print(f'[{task.task_id}] {line}', end='') continue else: # No select available; degrade to polling only. @@ -449,9 +439,11 @@ def _run_one_task( returncode = 0 if returncode != 0: tail = ''.join(tail_lines)[-20000:] - return task.task_id, None, ( - f'ms-agent exited with code={returncode}. ' - f'log={log_path}. output tail:\n{tail}') + return ( + task.task_id, + None, + (f'ms-agent exited with code={returncode}. log={log_path}. output tail:\n{tail}'), + ) else: with open(log_path, 'w', encoding='utf-8') as logf: # Use Popen+poll so we can force-reap hung-at-exit children once @@ -470,9 +462,11 @@ def _run_one_task( # --- Condition 1: .researcher_task_finished marker --- if marker_seen_at is None and os.path.exists(marker_path): marker_seen_at = now_s - if (marker_seen_at is not None and proc2.poll() is None - and (now_s - marker_seen_at) >= max( - 0.0, post_finish_grace_s)): + if ( + marker_seen_at is not None + and proc2.poll() is None + and (now_s - marker_seen_at) >= max(0.0, post_finish_grace_s) + ): _terminate_process( proc2, terminate_timeout_s=terminate_timeout_s, @@ -483,8 +477,7 @@ def _run_one_task( # --- Condition 2: report stable for a long time (fallback) --- report_path_hint = _find_report_md(workdir) - if report_path_hint and _is_direct_final_report_path( - workdir, report_path_hint): + if report_path_hint and _is_direct_final_report_path(workdir, report_path_hint): stable, report_last_sig, report_stable_since = _report_is_stable( report_path_hint, stable_window_s=report_stable_window_s, @@ -497,10 +490,11 @@ def _run_one_task( report_seen_stable_at = now_s else: report_seen_stable_at = None - if (report_seen_stable_at is not None - and proc2.poll() is None - and (now_s - report_seen_stable_at) >= max( - 0.0, post_report_exit_grace_s)): + if ( + report_seen_stable_at is not None + and proc2.poll() is None + and (now_s - report_seen_stable_at) >= max(0.0, post_report_exit_grace_s) + ): _terminate_process( proc2, terminate_timeout_s=terminate_timeout_s, @@ -518,17 +512,23 @@ def _run_one_task( returncode = 0 if returncode != 0: tail = _tail_text_from_file(log_path, max_chars=20000) - return task.task_id, None, ( - f'ms-agent exited with code={returncode}. ' - f'log={log_path}. output tail:\n{tail}') + return ( + task.task_id, + None, + (f'ms-agent exited with code={returncode}. log={log_path}. output tail:\n{tail}'), + ) except Exception as e: return task.task_id, None, f'subprocess failed: {e}' report_path = _find_report_md(workdir) if not report_path: - return task.task_id, None, ( - f'final_report.md not found in workdir={workdir}. ' - f'log={log_path}. ms-agent output tail:\n{_tail_text_from_file(log_path, max_chars=20000)}' + return ( + task.task_id, + None, + ( + f'final_report.md not found in workdir={workdir}. ' + f'log={log_path}. ms-agent output tail:\n{_tail_text_from_file(log_path, max_chars=20000)}' + ), ) try: @@ -538,28 +538,23 @@ def _run_one_task( return task.task_id, None, f'failed to read report: {e} (path={report_path})' if not article.strip(): - return task.task_id, None, ( - f'empty report content (path={report_path}). log={log_path}. ' - f'ms-agent output tail:\n{_tail_text_from_file(log_path, max_chars=20000)}' + return ( + task.task_id, + None, + ( + f'empty report content (path={report_path}). log={log_path}. ' + f'ms-agent output tail:\n{_tail_text_from_file(log_path, max_chars=20000)}' + ), ) return task.task_id, article, None def main() -> None: - parser = argparse.ArgumentParser( - description= - 'Run ms-agent v2 on dr_bench queries and dump raw_data jsonl.') - parser.add_argument( - '--query_file', required=True, help='Path to dr_bench query.jsonl') - parser.add_argument( - '--output_jsonl', - required=True, - help='Output path for dr_bench raw_data/.jsonl') - parser.add_argument( - '--model_name', - default='ms_deepresearch', - help='Model/agent name used in output file naming') + parser = argparse.ArgumentParser(description='Run ms-agent v2 on dr_bench queries and dump raw_data jsonl.') + parser.add_argument('--query_file', required=True, help='Path to dr_bench query.jsonl') + parser.add_argument('--output_jsonl', required=True, help='Output path for dr_bench raw_data/.jsonl') + parser.add_argument('--model_name', default='ms_deepresearch', help='Model/agent name used in output file naming') parser.add_argument( '--config', default='projects/deep_research/v2/researcher.yaml', @@ -568,47 +563,31 @@ def main() -> None: parser.add_argument( '--work_root', default='eval/dr_bench/results/runs', - help= - 'Root dir to store per-task workdirs. Will create ///', + help='Root dir to store per-task workdirs. Will create ///', ) - parser.add_argument( - '--limit', - type=int, - default=0, - help='Limit number of tasks (0 means all)') - parser.add_argument( - '--workers', - type=int, - default=1, - help='Concurrency level (subprocess-based)') + parser.add_argument('--limit', type=int, default=0, help='Limit number of tasks (0 means all)') + parser.add_argument('--workers', type=int, default=1, help='Concurrency level (subprocess-based)') parser.add_argument( '--python', default=sys.executable, - help= - 'Python executable to run ms-agent (defaults to current interpreter)', + help='Python executable to run ms-agent (defaults to current interpreter)', ) - parser.add_argument( - '--trust_remote_code', - action='store_true', - help='Pass --trust_remote_code true to ms-agent') + parser.add_argument('--trust_remote_code', action='store_true', help='Pass --trust_remote_code true to ms-agent') parser.add_argument( '--ms_agent_root', default='.', - help= - 'Path to ms-agent repo root (contains ms_agent/). Defaults to current working directory.', + help='Path to ms-agent repo root (contains ms_agent/). Defaults to current working directory.', ) parser.add_argument( '--stream_subprocess_output', action='store_true', - help= - 'Stream ms-agent stdout/stderr to console (also written to /ms_agent.log).', + help='Stream ms-agent stdout/stderr to console (also written to /ms_agent.log).', ) parser.add_argument( '--extra', nargs=argparse.REMAINDER, default=[], - help= - 'Extra args passed through to ms-agent (e.g. --llm.model xxx --generation_config.stream false)', + help='Extra args passed through to ms-agent (e.g. --llm.model xxx --generation_config.stream false)', ) args = parser.parse_args() @@ -642,7 +621,7 @@ def main() -> None: tasks.append(Task(task_id=task_id, prompt=prompt)) if args.limit and args.limit > 0: - tasks = tasks[:args.limit] + tasks = tasks[: args.limit] done_ids = _load_existing_ids(output_jsonl) # Backfill: if a workdir already has a top-level final report file but the @@ -677,15 +656,12 @@ def main() -> None: print(msg) return - print( - f'Will run {len(tasks)} tasks (workers={args.workers}). Output: {output_jsonl}' - ) + print(f'Will run {len(tasks)} tasks (workers={args.workers}). Output: {output_jsonl}') os.makedirs(os.path.dirname(output_jsonl) or '.', exist_ok=True) # Ensure ms-agent is importable at runtime for subprocess (best-effort check) if not os.path.exists(os.path.join(ms_agent_root, 'ms_agent')): - raise FileNotFoundError( - f'ms_agent_root seems wrong: {ms_agent_root} (missing ms_agent/)') + raise FileNotFoundError(f'ms_agent_root seems wrong: {ms_agent_root} (missing ms_agent/)') extra_args = args.extra or [] print_lock = threading.Lock() @@ -707,13 +683,7 @@ def main() -> None: if err: print(f'[{tid}] ERROR: {err}', file=sys.stderr) continue - _append_jsonl( - output_jsonl, { - 'id': tid, - 'prompt': t.prompt, - 'article': article - }, - lock=write_lock) + _append_jsonl(output_jsonl, {'id': tid, 'prompt': t.prompt, 'article': article}, lock=write_lock) print(f'[{tid}] OK') return @@ -740,20 +710,12 @@ def main() -> None: try: tid, article, err = fut.result() except Exception as e: - print( - f'[{t.task_id}] ERROR: future failed: {e}', - file=sys.stderr) + print(f'[{t.task_id}] ERROR: future failed: {e}', file=sys.stderr) continue if err: print(f'[{tid}] ERROR: {err}', file=sys.stderr) continue - _append_jsonl( - output_jsonl, { - 'id': tid, - 'prompt': t.prompt, - 'article': article - }, - lock=write_lock) + _append_jsonl(output_jsonl, {'id': tid, 'prompt': t.prompt, 'article': article}, lock=write_lock) print(f'[{tid}] OK') diff --git a/projects/deep_research/v2/reporter.py b/projects/deep_research/v2/reporter.py index d9c6f2507..2ce320951 100644 --- a/projects/deep_research/v2/reporter.py +++ b/projects/deep_research/v2/reporter.py @@ -2,11 +2,12 @@ import os from typing import Any, AsyncGenerator, List, Union +from omegaconf import DictConfig + from ms_agent.agent.llm_agent import LLMAgent from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG -from omegaconf import DictConfig logger = get_logger() @@ -23,23 +24,19 @@ class ReporterAgent(LLMAgent): 5. Assemble final reports """ - def __init__(self, - config: DictConfig = DictConfig({}), - tag: str = DEFAULT_TAG, - trust_remote_code: bool = False, - **kwargs): + def __init__( + self, config: DictConfig = DictConfig({}), tag: str = DEFAULT_TAG, trust_remote_code: bool = False, **kwargs + ): super().__init__(config, tag, trust_remote_code, **kwargs) # Reporter-specific configuration self._reports_dir = 'reports' - if hasattr(config, 'tools') and hasattr(config.tools, - 'report_generator'): + if hasattr(config, 'tools') and hasattr(config.tools, 'report_generator'): report_cfg = config.tools.report_generator self._reports_dir = getattr(report_cfg, 'reports_dir', 'reports') async def run( - self, inputs: Union[str, List[str], List[Message], - List[List[Message]]], **kwargs + self, inputs: Union[str, List[str], List[Message], List[List[Message]]], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: # Add context about the reporter's role if isinstance(inputs, str): @@ -51,9 +48,7 @@ async def run( if os.path.exists(evidence_dir): evidence_index = os.path.join(evidence_dir, 'index.json') if os.path.exists(evidence_index): - logger.info( - f'ReporterAgent: Evidence index found at {evidence_index}' - ) + logger.info(f'ReporterAgent: Evidence index found at {evidence_index}') inputs = enhanced_input diff --git a/projects/deep_research/v2/researcher.py b/projects/deep_research/v2/researcher.py index 6af015717..c8ace978c 100644 --- a/projects/deep_research/v2/researcher.py +++ b/projects/deep_research/v2/researcher.py @@ -1,14 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, AsyncGenerator, List, Union +from omegaconf import DictConfig + from ms_agent.agent.llm_agent import LLMAgent from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG -from ms_agent.utils.stats import (append_stats, build_timing_record, - get_stats_path, monotonic, now_iso, - summarize_usage) -from omegaconf import DictConfig +from ms_agent.utils.stats import append_stats, build_timing_record, get_stats_path, monotonic, now_iso, summarize_usage logger = get_logger() @@ -18,15 +17,12 @@ class ResearcherAgent(LLMAgent): Researcher Agent that conducts deep research tasks using LLMs and various tools. """ - def __init__(self, - config: DictConfig = DictConfig({}), - tag: str = DEFAULT_TAG, - trust_remote_code: bool = False, - **kwargs): + def __init__( + self, config: DictConfig = DictConfig({}), tag: str = DEFAULT_TAG, trust_remote_code: bool = False, **kwargs + ): super().__init__(config, tag, trust_remote_code, **kwargs) - async def run_loop(self, messages: Union[List[Message], str], - **kwargs) -> AsyncGenerator[Any, Any]: + async def run_loop(self, messages: Union[List[Message], str], **kwargs) -> AsyncGenerator[Any, Any]: start_ts = now_iso() start_time = monotonic() last_messages: List[Message] = [] @@ -66,7 +62,7 @@ async def on_task_end(self, messages: List[Message]): await super().on_task_end(messages) try: from ms_agent.tools.search.websearch_tool import WebSearchTool + WebSearchTool.log_global_summarization_usage() except Exception as exc: - logger.warning( - f'Failed to log web search summarization usage: {exc}') + logger.warning(f'Failed to log web search summarization usage: {exc}') diff --git a/projects/deep_research/v2/time_handler.py b/projects/deep_research/v2/time_handler.py index a36282b5f..90514f0a8 100644 --- a/projects/deep_research/v2/time_handler.py +++ b/projects/deep_research/v2/time_handler.py @@ -2,9 +2,10 @@ from datetime import datetime from typing import Any -from ms_agent.config.config import ConfigLifecycleHandler from omegaconf import DictConfig +from ms_agent.config.config import ConfigLifecycleHandler + class TimeHandler(ConfigLifecycleHandler): """Config handler that injects current date/time and other config values into prompts""" @@ -29,8 +30,7 @@ def task_begin(self, config: DictConfig, tag: str) -> DictConfig: def traverse_and_replace(_config: Any): if isinstance(_config, DictConfig): for name, value in _config.items(): - if isinstance(value, DictConfig) or isinstance( - value, list): + if isinstance(value, DictConfig) or isinstance(value, list): traverse_and_replace(value) elif isinstance(value, str): new_value = value @@ -38,8 +38,7 @@ def traverse_and_replace(_config: Any): for var_name, var_value in time_vars.items(): placeholder = f'<{var_name}>' if placeholder in new_value: - new_value = new_value.replace( - placeholder, var_value) + new_value = new_value.replace(placeholder, var_value) setattr(_config, name, new_value) elif isinstance(_config, list): @@ -52,8 +51,7 @@ def traverse_and_replace(_config: Any): for var_name, var_value in time_vars.items(): placeholder = f'<{var_name}>' if placeholder in new_value: - new_value = new_value.replace( - placeholder, var_value) + new_value = new_value.replace(placeholder, var_value) _config[i] = new_value traverse_and_replace(config) diff --git a/projects/deep_research/v2/tools/evidence_tool.py b/projects/deep_research/v2/tools/evidence_tool.py index 1379ce511..7ef061b4c 100644 --- a/projects/deep_research/v2/tools/evidence_tool.py +++ b/projects/deep_research/v2/tools/evidence_tool.py @@ -1,11 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json import os import re import time import uuid from typing import Any, Dict, List, Optional -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.utils import file_lock @@ -161,8 +161,7 @@ def _render_analysis_card(analysis: Dict[str, Any]) -> str: if analysis.get('task_id'): lines.append(f"- **Task ID**: `{analysis['task_id']}`") if analysis.get('based_on_note_ids'): - ids_str = ', '.join(f'`{nid}`' - for nid in analysis.get('based_on_note_ids', [])) + ids_str = ', '.join(f'`{nid}`' for nid in analysis.get('based_on_note_ids', [])) lines.append(f'- **Based on Notes**: {ids_str}') if analysis.get('tags'): tags_str = ', '.join(f'`{t}`' for t in analysis['tags']) @@ -285,14 +284,15 @@ def _parse_note_from_md(content: str, note_id: str) -> Dict[str, Any]: elif header == 'Sources': sources = [] for line in body.split('\n'): - match = re.search( - r'- \[(\w+)\] (.+?)(?:\s+\(published: ([^)]+)\))?$', line) + match = re.search(r'- \[(\w+)\] (.+?)(?:\s+\(published: ([^)]+)\))?$', line) if match: - sources.append({ - 'url': match.group(2).strip(), - 'source_tier': match.group(1), - 'published_at': match.group(3) or '' - }) + sources.append( + { + 'url': match.group(2).strip(), + 'source_tier': match.group(1), + 'published_at': match.group(3) or '', + } + ) note['sources'] = sources return note @@ -321,44 +321,31 @@ def __init__(self, config, **kwargs): self.exclude_func(tool_cfg) # Configurable paths - self._evidence_dir = getattr(tool_cfg, 'evidence_dir', - 'evidence') if tool_cfg else 'evidence' - self._chunks_dir = getattr(tool_cfg, 'chunks_dir', - 'chunks') if tool_cfg else 'chunks' - self._lock_subdir = getattr(tool_cfg, 'lock_subdir', - '.locks') if tool_cfg else '.locks' + self._evidence_dir = getattr(tool_cfg, 'evidence_dir', 'evidence') if tool_cfg else 'evidence' + self._chunks_dir = getattr(tool_cfg, 'chunks_dir', 'chunks') if tool_cfg else 'chunks' + self._lock_subdir = getattr(tool_cfg, 'lock_subdir', '.locks') if tool_cfg else '.locks' # Feature flags - self._enable_chunk_storage = bool( - getattr(tool_cfg, 'enable_chunk_storage', - False)) if tool_cfg else False + self._enable_chunk_storage = bool(getattr(tool_cfg, 'enable_chunk_storage', False)) if tool_cfg else False async def connect(self) -> None: """Initialize directory structure.""" _ensure_dir(self.output_dir) _ensure_dir(os.path.join(self.output_dir, self._evidence_dir, 'notes')) - _ensure_dir( - os.path.join(self.output_dir, self._evidence_dir, 'analyses')) + _ensure_dir(os.path.join(self.output_dir, self._evidence_dir, 'analyses')) # Backward-compat: older runs may have used evidence/conclusions/ - _ensure_dir( - os.path.join(self.output_dir, self._evidence_dir, 'conclusions')) + _ensure_dir(os.path.join(self.output_dir, self._evidence_dir, 'conclusions')) _ensure_dir(os.path.join(self.output_dir, self._chunks_dir)) _ensure_dir(os.path.join(self.output_dir, self._lock_subdir)) def _paths(self) -> Dict[str, str]: return { - 'index': - os.path.join(self.output_dir, self._evidence_dir, 'index.json'), - 'notes_dir': - os.path.join(self.output_dir, self._evidence_dir, 'notes'), - 'analyses_dir': - os.path.join(self.output_dir, self._evidence_dir, 'analyses'), - 'legacy_conclusions_dir': - os.path.join(self.output_dir, self._evidence_dir, 'conclusions'), - 'chunks_dir': - os.path.join(self.output_dir, self._chunks_dir), - 'lock_dir': - os.path.join(self.output_dir, self._lock_subdir), + 'index': os.path.join(self.output_dir, self._evidence_dir, 'index.json'), + 'notes_dir': os.path.join(self.output_dir, self._evidence_dir, 'notes'), + 'analyses_dir': os.path.join(self.output_dir, self._evidence_dir, 'analyses'), + 'legacy_conclusions_dir': os.path.join(self.output_dir, self._evidence_dir, 'conclusions'), + 'chunks_dir': os.path.join(self.output_dir, self._chunks_dir), + 'lock_dir': os.path.join(self.output_dir, self._lock_subdir), } async def _get_tools_inner(self) -> Dict[str, Any]: @@ -367,107 +354,85 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='write_note', server_name=self.SERVER_NAME, - description= - ('Write a new evidence note (card) to the evidence store. ' - 'Each note represents ONE piece of evidence: a claim/observation with supporting text. ' - 'Returns the generated note_id.'), + description=( + 'Write a new evidence note (card) to the evidence store. ' + 'Each note represents ONE piece of evidence: a claim/observation with supporting text. ' + 'Returns the generated note_id.' + ), parameters={ - 'type': - 'object', + 'type': 'object', 'properties': { 'title': { - 'type': - 'string', - 'description': - ('Brief title describing this evidence (e.g., "Tesla Q3 revenue growth"). ' - 'Optional: if omitted, a title is derived from the first line of `content`.'), + 'type': 'string', + 'description': ( + 'Brief title describing this evidence (e.g., "Tesla Q3 revenue growth"). ' + 'Optional: if omitted, a title is derived from the first line of `content`.' + ), }, 'content': { - 'type': - 'string', - 'description': - ('The full evidence text for this note. ' - 'State the core finding or observation, then provide all ' - 'supporting details: specific data points, statistics, quotes, ' - 'case studies, reasoning, and any other substantive information. ' - 'Be thorough — preserve all valuable details from the source material. ' - 'Multi-paragraph allowed.'), + 'type': 'string', + 'description': ( + 'The full evidence text for this note. ' + 'State the core finding or observation, then provide all ' + 'supporting details: specific data points, statistics, quotes, ' + 'case studies, reasoning, and any other substantive information. ' + 'Be thorough — preserve all valuable details from the source material. ' + 'Multi-paragraph allowed.' + ), }, 'contradicts': { - 'type': - 'string', - 'description': - ('Optional: Evidence text that contradicts this finding. ' - 'Include if there are conflicting sources or caveats.' - ), + 'type': 'string', + 'description': ( + 'Optional: Evidence text that contradicts this finding. ' + 'Include if there are conflicting sources or caveats.' + ), }, 'sources': { 'type': 'array', - 'description': - 'List of source references for this evidence.', + 'description': 'List of source references for this evidence.', 'items': { 'type': 'object', 'properties': { - 'url': { - 'type': 'string', - 'description': 'Source URL' - }, + 'url': {'type': 'string', 'description': 'Source URL'}, 'published_at': { - 'type': - 'string', - 'description': - 'Publication date (YYYY-MM-DD)' + 'type': 'string', + 'description': 'Publication date (YYYY-MM-DD)', }, 'source_tier': { - 'type': - 'string', - 'enum': [ - 'official', 'primary', - 'secondary', 'unknown' - ], - 'description': - ('Source credibility tier (for example, Official ' - 'Documents/Papers/Standards > ' - 'Primary News/Announcements > Secondary Blogs)' - ), + 'type': 'string', + 'enum': ['official', 'primary', 'secondary', 'unknown'], + 'description': ( + 'Source credibility tier (for example, Official ' + 'Documents/Papers/Standards > ' + 'Primary News/Announcements > Secondary Blogs)' + ), }, }, 'required': ['url'], }, }, 'summary': { - 'type': - 'string', - 'description': - 'One-sentence summary of this evidence.', + 'type': 'string', + 'description': 'One-sentence summary of this evidence.', }, 'task_id': { - 'type': - 'string', - 'description': - 'The plan task this evidence relates to.', + 'type': 'string', + 'description': 'The plan task this evidence relates to.', }, 'tags': { 'type': 'array', - 'items': { - 'type': 'string' - }, + 'items': {'type': 'string'}, 'description': 'Tags for categorization.', }, 'quality_score': { - 'type': - 'integer', - 'minimum': - 0, - 'maximum': - 100, - 'description': - 'Optional: Confidence/quality score (0-100).', + 'type': 'integer', + 'minimum': 0, + 'maximum': 100, + 'description': 'Optional: Confidence/quality score (0-100).', }, }, 'required': ['content'], - 'additionalProperties': - False, + 'additionalProperties': False, }, ), Tool( @@ -479,8 +444,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'properties': { 'note_id': { 'type': 'string', - 'description': - 'The ID of the note to retrieve.', + 'description': 'The ID of the note to retrieve.', }, }, 'required': ['note_id'], @@ -490,9 +454,10 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='list_notes', server_name=self.SERVER_NAME, - description= - ('List all evidence notes, optionally filtered by task_id or tags. ' - 'Returns a summary list (not full content).'), + description=( + 'List all evidence notes, optionally filtered by task_id or tags. ' + 'Returns a summary list (not full content).' + ), parameters={ 'type': 'object', 'properties': { @@ -501,13 +466,9 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'description': 'Optional: Filter by task ID.', }, 'tags': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - 'Optional: Filter by tags (notes must have ALL specified tags).', + 'type': 'array', + 'items': {'type': 'string'}, + 'description': 'Optional: Filter by tags (notes must have ALL specified tags).', }, # 'min_quality': { # 'type': 'integer', @@ -521,8 +482,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='search_notes', server_name=self.SERVER_NAME, - description= - 'Search notes by keyword in title, claim, or summary.', + description='Search notes by keyword in title, claim, or summary.', parameters={ 'type': 'object', 'properties': { @@ -554,77 +514,60 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='write_analysis', server_name=self.SERVER_NAME, - description= - ('Write an interim **analysis** record to the evidence store. ' - 'Use this tool whenever you need to turn multiple evidence notes into reusable reasoning artifacts, e.g.: ' - '(1) synthesis / interim summaries; ' - '(2) comparisons and trade-off decisions (A vs B, pros/cons, why choose X); ' - '(3) framework building (typologies, evaluation rubrics, scoring criteria, checklists); ' - '(4) mapping & reconciliation (align competing definitions/metrics, resolve conflicts, record assumptions); ' - '(5) scenario framing and uncertainty tracking (what-if branches, key sensitivities/risks, open questions); ' - '(6) rankings/recommendations that require rationale (e.g., pick top 2–3 options and justify). ' - '(7) Structured / visual intermediate artifacts (e.g., mind-map-style hierarchical outlines, and ' - 'text-based flow/relationship diagrams—prefer Mermaid syntax when possible).' - '(8) other intermediate analysis that requires reasoning, justification and recording.' - 'This is **not** the final report; it is an intermediate analysis that should cite supporting evidence via ' - 'based_on_note_ids when possible so downstream writing can reuse it. ' - 'Returns the generated analysis_id.'), + description=( + 'Write an interim **analysis** record to the evidence store. ' + 'Use this tool whenever you need to turn multiple evidence notes into reusable reasoning artifacts, e.g.: ' + '(1) synthesis / interim summaries; ' + '(2) comparisons and trade-off decisions (A vs B, pros/cons, why choose X); ' + '(3) framework building (typologies, evaluation rubrics, scoring criteria, checklists); ' + '(4) mapping & reconciliation (align competing definitions/metrics, resolve conflicts, record assumptions); ' + '(5) scenario framing and uncertainty tracking (what-if branches, key sensitivities/risks, open questions); ' + '(6) rankings/recommendations that require rationale (e.g., pick top 2–3 options and justify). ' + '(7) Structured / visual intermediate artifacts (e.g., mind-map-style hierarchical outlines, and ' + 'text-based flow/relationship diagrams—prefer Mermaid syntax when possible).' + '(8) other intermediate analysis that requires reasoning, justification and recording.' + 'This is **not** the final report; it is an intermediate analysis that should cite supporting evidence via ' + 'based_on_note_ids when possible so downstream writing can reuse it. ' + 'Returns the generated analysis_id.' + ), parameters={ 'type': 'object', 'properties': { 'title': { - 'type': - 'string', - 'description': - 'Brief title describing this analysis (e.g., "Interim comparison: Framework A vs B").', + 'type': 'string', + 'description': 'Brief title describing this analysis (e.g., "Interim comparison: Framework A vs B").', }, 'content': { - 'type': - 'string', - 'description': - ('The analysis content in Markdown. ' - 'This should capture synthesis/comparison, constraints, assumptions, and reasoning. ' - 'Multi-paragraph allowed.'), + 'type': 'string', + 'description': ( + 'The analysis content in Markdown. ' + 'This should capture synthesis/comparison, constraints, assumptions, and reasoning. ' + 'Multi-paragraph allowed.' + ), }, 'summary': { - 'type': - 'string', - 'description': - 'Optional: One-sentence summary of this analysis.', + 'type': 'string', + 'description': 'Optional: One-sentence summary of this analysis.', }, 'task_id': { - 'type': - 'string', - 'description': - 'Optional: The plan task this analysis relates to.', + 'type': 'string', + 'description': 'Optional: The plan task this analysis relates to.', }, 'based_on_note_ids': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - 'Optional: List of note_ids this analysis is based on.', + 'type': 'array', + 'items': {'type': 'string'}, + 'description': 'Optional: List of note_ids this analysis is based on.', }, 'tags': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - 'Optional: Tags for categorization.', + 'type': 'array', + 'items': {'type': 'string'}, + 'description': 'Optional: Tags for categorization.', }, 'quality_score': { - 'type': - 'integer', - 'minimum': - 0, - 'maximum': - 100, - 'description': - 'Optional: Confidence/quality score (0-100).', + 'type': 'integer', + 'minimum': 0, + 'maximum': 100, + 'description': 'Optional: Confidence/quality score (0-100).', }, }, 'required': ['title', 'content', 'summary', 'tags'], @@ -639,16 +582,12 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'object', 'properties': { 'analysis_id': { - 'type': - 'string', - 'description': - 'The ID of the analysis to retrieve.', + 'type': 'string', + 'description': 'The ID of the analysis to retrieve.', }, 'parse_analysis': { - 'type': - 'boolean', - 'description': - 'Optional: Whether to parse stored markdown back to structured dict.', + 'type': 'boolean', + 'description': 'Optional: Whether to parse stored markdown back to structured dict.', }, }, 'required': ['analysis_id'], @@ -658,9 +597,10 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='list_analyses', server_name=self.SERVER_NAME, - description= - ('List all analyses, optionally filtered by task_id or tags. ' - 'Returns a summary list (not full content).'), + description=( + 'List all analyses, optionally filtered by task_id or tags. ' + 'Returns a summary list (not full content).' + ), parameters={ 'type': 'object', 'properties': { @@ -669,13 +609,9 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'description': 'Optional: Filter by task ID.', }, 'tags': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - 'Optional: Filter by tags (analyses must have ALL specified tags).', + 'type': 'array', + 'items': {'type': 'string'}, + 'description': 'Optional: Filter by tags (analyses must have ALL specified tags).', }, }, 'required': [], @@ -685,8 +621,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='search_analyses', server_name=self.SERVER_NAME, - description= - 'Search analyses by keyword in title, summary, or tags.', + description='Search analyses by keyword in title, summary, or tags.', parameters={ 'type': 'object', 'properties': { @@ -708,8 +643,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'properties': { 'analysis_id': { 'type': 'string', - 'description': - 'The ID of the analysis to delete.', + 'description': 'The ID of the analysis to delete.', }, }, 'required': ['analysis_id'], @@ -726,13 +660,12 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'required': [], 'additionalProperties': False, }, - ) + ), ] } return tools - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await getattr(self, tool_name)(**(tool_args or {})) def _load_index_locked(self, paths: Dict[str, str]) -> Dict[str, Any]: @@ -742,16 +675,13 @@ def _load_index_locked(self, paths: Dict[str, str]) -> Dict[str, Any]: return { 'schema_version': 2, 'updated_at': _now_iso(), - 'notes': - {}, # note_id -> {title, task_id, summary, sources, tags, quality_score, created_at} - 'analyses': - {}, # analysis_id -> {title, task_id, summary, based_on_note_ids, tags, quality_score, created_at, path} + 'notes': {}, # note_id -> {title, task_id, summary, sources, tags, quality_score, created_at} + 'analyses': {}, # analysis_id -> {title, task_id, summary, based_on_note_ids, tags, quality_score, created_at, path} } # Backward/forward compatible defaults if 'notes' not in data or not isinstance(data.get('notes'), dict): data['notes'] = {} - if 'analyses' not in data or not isinstance( - data.get('analyses'), dict): + if 'analyses' not in data or not isinstance(data.get('analyses'), dict): data['analyses'] = {} # Backward-compat: older schema used "conclusions" key. @@ -760,14 +690,12 @@ def _load_index_locked(self, paths: Dict[str, str]) -> Dict[str, Any]: data['analyses'] = legacy return data - def _save_index_locked(self, paths: Dict[str, str], - index: Dict[str, Any]) -> None: + def _save_index_locked(self, paths: Dict[str, str], index: Dict[str, Any]) -> None: """Save index.json.""" index['updated_at'] = _now_iso() _write_text(paths['index'], _json_dumps(index)) - def _add_to_index(self, index: Dict[str, Any], note: Dict[str, - Any]) -> None: + def _add_to_index(self, index: Dict[str, Any], note: Dict[str, Any]) -> None: """Add a note's metadata to the index.""" note_id = note['note_id'] index['notes'][note_id] = { @@ -780,9 +708,7 @@ def _add_to_index(self, index: Dict[str, Any], note: Dict[str, 'created_at': note.get('created_at', ''), } - def _add_analysis_to_index(self, index: Dict[str, Any], - analysis: Dict[str, Any], - analysis_path: str) -> None: + def _add_analysis_to_index(self, index: Dict[str, Any], analysis: Dict[str, Any], analysis_path: str) -> None: """Add an analysis' metadata to the index.""" aid = analysis['analysis_id'] index['analyses'][aid] = { @@ -803,16 +729,14 @@ def _remove_from_index(self, index: Dict[str, Any], note_id: str) -> bool: return True return False - def _remove_analysis_from_index(self, index: Dict[str, Any], - analysis_id: str) -> bool: + def _remove_analysis_from_index(self, index: Dict[str, Any], analysis_id: str) -> bool: """Remove an analysis from the index. Returns True if found and removed.""" if analysis_id in index.get('analyses', {}): del index['analyses'][analysis_id] return True return False - def _store_chunk(self, chunk_id: str, content: str, - metadata: Dict[str, Any]) -> str: + def _store_chunk(self, chunk_id: str, content: str, metadata: Dict[str, Any]) -> str: """ Store a text chunk. Reserved for future implementation. @@ -863,10 +787,12 @@ async def write_note( content = (content or '').strip() if not content: - return _json_dumps({ - 'status': 'error', - 'message': 'write_note requires non-empty content.', - }) + return _json_dumps( + { + 'status': 'error', + 'message': 'write_note requires non-empty content.', + } + ) if title is None or not str(title).strip(): first_line = content.split('\n', 1)[0].strip() @@ -890,8 +816,7 @@ async def write_note( if sources: # Validate source tiers for src in sources: - src['source_tier'] = _validate_source_tier( - src.get('source_tier', 'unknown')) + src['source_tier'] = _validate_source_tier(src.get('source_tier', 'unknown')) note['sources'] = sources if summary: note['summary'] = summary.strip() @@ -914,11 +839,13 @@ async def write_note( self._add_to_index(index, note) self._save_index_locked(paths, index) - return _json_dumps({ - 'status': 'ok', - 'note_id': note_id, - 'path': os.path.relpath(note_path, self.output_dir), - }) + return _json_dumps( + { + 'status': 'ok', + 'note_id': note_id, + 'path': os.path.relpath(note_path, self.output_dir), + } + ) async def write_analysis( self, @@ -947,16 +874,13 @@ async def write_analysis( if task_id: analysis['task_id'] = task_id.strip() if based_on_note_ids: - analysis['based_on_note_ids'] = [ - nid.strip() for nid in based_on_note_ids if nid.strip() - ] + analysis['based_on_note_ids'] = [nid.strip() for nid in based_on_note_ids if nid.strip()] if tags: analysis['tags'] = [t.strip() for t in tags if t.strip()] if quality_score is not None: analysis['quality_score'] = max(0, min(100, quality_score)) - analysis_path = os.path.join(paths['analyses_dir'], - f'analysis_{analysis_id}.md') + analysis_path = os.path.join(paths['analyses_dir'], f'analysis_{analysis_id}.md') analysis_content = _render_analysis_card(analysis) _write_text(analysis_path, analysis_content) @@ -965,33 +889,25 @@ async def write_analysis( self._add_analysis_to_index(index, analysis, analysis_path) self._save_index_locked(paths, index) - return _json_dumps({ - 'status': - 'ok', - 'analysis_id': - analysis_id, - 'path': - os.path.relpath(analysis_path, self.output_dir), - }) - - async def get_analysis(self, - analysis_id: str, - parse_analysis: Optional[bool] = False) -> str: + return _json_dumps( + { + 'status': 'ok', + 'analysis_id': analysis_id, + 'path': os.path.relpath(analysis_path, self.output_dir), + } + ) + + async def get_analysis(self, analysis_id: str, parse_analysis: Optional[bool] = False) -> str: """Retrieve an analysis by ID.""" paths = self._paths() - analysis_path = os.path.join(paths['analyses_dir'], - f'analysis_{analysis_id}.md') - legacy_path = os.path.join(paths['legacy_conclusions_dir'], - f'conclusion_{analysis_id}.md') + analysis_path = os.path.join(paths['analyses_dir'], f'analysis_{analysis_id}.md') + legacy_path = os.path.join(paths['legacy_conclusions_dir'], f'conclusion_{analysis_id}.md') if not os.path.exists(analysis_path) and os.path.exists(legacy_path): analysis_path = legacy_path if not os.path.exists(analysis_path): - return _json_dumps({ - 'status': 'error', - 'message': f'Analysis {analysis_id} not found.' - }) + return _json_dumps({'status': 'error', 'message': f'Analysis {analysis_id} not found.'}) with open(analysis_path, 'r', encoding='utf-8') as f: content = f.read() @@ -999,16 +915,16 @@ async def get_analysis(self, if not parse_analysis: return _json_dumps({'status': 'ok', 'raw_content': content}) analysis = _parse_analysis_from_md(content, analysis_id) - return _json_dumps({ - 'status': 'ok', - 'analysis_id': analysis_id, - 'analysis': analysis, - 'raw_content': content, - }) + return _json_dumps( + { + 'status': 'ok', + 'analysis_id': analysis_id, + 'analysis': analysis, + 'raw_content': content, + } + ) - async def list_analyses(self, - task_id: Optional[str] = None, - tags: Optional[List[str]] = None) -> str: + async def list_analyses(self, task_id: Optional[str] = None, tags: Optional[List[str]] = None) -> str: """List analyses with optional filters.""" paths = self._paths() _ensure_dir(paths['lock_dir']) @@ -1025,33 +941,28 @@ async def list_analyses(self, a_tags = set(meta.get('tags', [])) if not all(t in a_tags for t in tags): continue - results.append({ - 'analysis_id': - aid, - 'title': - meta.get('title', ''), - 'task_id': - meta.get('task_id', ''), - 'summary': - meta.get('summary', ''), - 'based_on_note_ids': - meta.get('based_on_note_ids', []), - 'tags': - meta.get('tags', []), - 'quality_score': - meta.get('quality_score'), - 'created_at': - meta.get('created_at', ''), - 'path': - meta.get('path', ''), - }) + results.append( + { + 'analysis_id': aid, + 'title': meta.get('title', ''), + 'task_id': meta.get('task_id', ''), + 'summary': meta.get('summary', ''), + 'based_on_note_ids': meta.get('based_on_note_ids', []), + 'tags': meta.get('tags', []), + 'quality_score': meta.get('quality_score'), + 'created_at': meta.get('created_at', ''), + 'path': meta.get('path', ''), + } + ) results.sort(key=lambda x: x.get('created_at', ''), reverse=True) - return _json_dumps({ - 'status': 'ok', - 'count': len(results), - 'analyses': results, - }) + return _json_dumps( + { + 'status': 'ok', + 'count': len(results), + 'analyses': results, + } + ) async def search_analyses(self, keyword: str) -> str: """Search analyses by keyword.""" @@ -1060,10 +971,7 @@ async def search_analyses(self, keyword: str) -> str: keyword_lower = keyword.lower().strip() if not keyword_lower: - return _json_dumps({ - 'status': 'error', - 'message': 'Keyword is required.' - }) + return _json_dumps({'status': 'error', 'message': 'Keyword is required.'}) with file_lock(paths['lock_dir'], 'evidence_index'): index = self._load_index_locked(paths) @@ -1071,50 +979,48 @@ async def search_analyses(self, keyword: str) -> str: analyses_meta = index.get('analyses', {}) results = [] for aid, meta in analyses_meta.items(): - searchable = ' '.join([ - meta.get('title', ''), - meta.get('summary', ''), - ]).lower() + searchable = ' '.join( + [ + meta.get('title', ''), + meta.get('summary', ''), + ] + ).lower() a_tags = meta.get('tags', []) searchable += ' ' + ' '.join(a_tags).lower() if keyword_lower in searchable: - results.append({ - 'analysis_id': aid, - 'title': meta.get('title', ''), - 'summary': meta.get('summary', ''), - 'task_id': meta.get('task_id', ''), - 'quality_score': meta.get('quality_score'), - }) + results.append( + { + 'analysis_id': aid, + 'title': meta.get('title', ''), + 'summary': meta.get('summary', ''), + 'task_id': meta.get('task_id', ''), + 'quality_score': meta.get('quality_score'), + } + ) - return _json_dumps({ - 'status': 'ok', - 'keyword': keyword, - 'count': len(results), - 'analyses': results, - }) + return _json_dumps( + { + 'status': 'ok', + 'keyword': keyword, + 'count': len(results), + 'analyses': results, + } + ) async def delete_analysis(self, analysis_id: str) -> str: """Delete an analysis by ID.""" paths = self._paths() _ensure_dir(paths['lock_dir']) - analysis_path = os.path.join(paths['analyses_dir'], - f'analysis_{analysis_id}.md') - legacy_path = os.path.join(paths['legacy_conclusions_dir'], - f'conclusion_{analysis_id}.md') + analysis_path = os.path.join(paths['analyses_dir'], f'analysis_{analysis_id}.md') + legacy_path = os.path.join(paths['legacy_conclusions_dir'], f'conclusion_{analysis_id}.md') with file_lock(paths['lock_dir'], 'evidence_index'): index = self._load_index_locked(paths) removed = self._remove_analysis_from_index(index, analysis_id) - if not removed and not os.path.exists( - analysis_path) and not os.path.exists(legacy_path): - return _json_dumps({ - 'status': - 'error', - 'message': - f'Analysis {analysis_id} not found.' - }) + if not removed and not os.path.exists(analysis_path) and not os.path.exists(legacy_path): + return _json_dumps({'status': 'error', 'message': f'Analysis {analysis_id} not found.'}) self._save_index_locked(paths, index) @@ -1125,18 +1031,13 @@ async def delete_analysis(self, analysis_id: str) -> str: return _json_dumps({'status': 'ok', 'deleted': analysis_id}) - async def get_note(self, - note_id: str, - parse_note: Optional[bool] = False) -> str: + async def get_note(self, note_id: str, parse_note: Optional[bool] = False) -> str: """Retrieve a note by ID.""" paths = self._paths() note_path = os.path.join(paths['notes_dir'], f'note_{note_id}.md') if not os.path.exists(note_path): - return _json_dumps({ - 'status': 'error', - 'message': f'Note {note_id} not found.' - }) + return _json_dumps({'status': 'error', 'message': f'Note {note_id} not found.'}) with open(note_path, 'r', encoding='utf-8') as f: content = f.read() @@ -1145,17 +1046,11 @@ async def get_note(self, return _json_dumps({'status': 'ok', 'raw_content': content}) else: note = _parse_note_from_md(content, note_id) - return _json_dumps({ - 'status': 'ok', - 'note_id': note_id, - 'note': note, - 'raw_content': content - }) - - async def list_notes(self, - task_id: Optional[str] = None, - tags: Optional[List[str]] = None, - min_quality: Optional[int] = None) -> str: + return _json_dumps({'status': 'ok', 'note_id': note_id, 'note': note, 'raw_content': content}) + + async def list_notes( + self, task_id: Optional[str] = None, tags: Optional[List[str]] = None, min_quality: Optional[int] = None + ) -> str: """List notes with optional filters. Args: @@ -1185,25 +1080,29 @@ async def list_notes(self, if score is None or score < min_quality: continue - results.append({ - 'note_id': nid, - 'title': meta.get('title', ''), - 'task_id': meta.get('task_id', ''), - 'summary': meta.get('summary', ''), - 'sources': meta.get('sources', []), - 'tags': meta.get('tags', []), - 'quality_score': meta.get('quality_score'), - 'created_at': meta.get('created_at', ''), - }) + results.append( + { + 'note_id': nid, + 'title': meta.get('title', ''), + 'task_id': meta.get('task_id', ''), + 'summary': meta.get('summary', ''), + 'sources': meta.get('sources', []), + 'tags': meta.get('tags', []), + 'quality_score': meta.get('quality_score'), + 'created_at': meta.get('created_at', ''), + } + ) # Sort by created_at descending results.sort(key=lambda x: x.get('created_at', ''), reverse=True) - return _json_dumps({ - 'status': 'ok', - 'count': len(results), - 'notes': results, - }) + return _json_dumps( + { + 'status': 'ok', + 'count': len(results), + 'notes': results, + } + ) async def search_notes(self, keyword: str) -> str: """Search notes by keyword.""" @@ -1212,10 +1111,7 @@ async def search_notes(self, keyword: str) -> str: keyword_lower = keyword.lower().strip() if not keyword_lower: - return _json_dumps({ - 'status': 'error', - 'message': 'Keyword is required.' - }) + return _json_dumps({'status': 'error', 'message': 'Keyword is required.'}) with file_lock(paths['lock_dir'], 'evidence_index'): index = self._load_index_locked(paths) @@ -1225,28 +1121,34 @@ async def search_notes(self, keyword: str) -> str: for nid, meta in notes_meta.items(): # Search in title, summary - searchable = ' '.join([ - meta.get('title', ''), - meta.get('summary', ''), - ]).lower() + searchable = ' '.join( + [ + meta.get('title', ''), + meta.get('summary', ''), + ] + ).lower() tags = meta.get('tags', []) searchable += ' ' + ' '.join(tags).lower() if keyword_lower in searchable: - results.append({ - 'note_id': nid, - 'title': meta.get('title', ''), - 'summary': meta.get('summary', ''), - 'task_id': meta.get('task_id', ''), - 'quality_score': meta.get('quality_score'), - }) + results.append( + { + 'note_id': nid, + 'title': meta.get('title', ''), + 'summary': meta.get('summary', ''), + 'task_id': meta.get('task_id', ''), + 'quality_score': meta.get('quality_score'), + } + ) - return _json_dumps({ - 'status': 'ok', - 'keyword': keyword, - 'count': len(results), - 'notes': results, - }) + return _json_dumps( + { + 'status': 'ok', + 'keyword': keyword, + 'count': len(results), + 'notes': results, + } + ) async def delete_note(self, note_id: str) -> str: """Delete a note by ID.""" @@ -1261,10 +1163,7 @@ async def delete_note(self, note_id: str) -> str: removed = self._remove_from_index(index, note_id) if not removed and not os.path.exists(note_path): - return _json_dumps({ - 'status': 'error', - 'message': f'Note {note_id} not found.' - }) + return _json_dumps({'status': 'error', 'message': f'Note {note_id} not found.'}) self._save_index_locked(paths, index) @@ -1289,11 +1188,13 @@ async def load_index(self) -> str: notes = index.get('notes', {}) analyses = index.get('analyses', {}) - return _json_dumps({ - 'status': 'ok', - 'updated_at': index.get('updated_at', ''), - 'total_notes': len(notes), - 'total_analyses': len(analyses), - 'notes': notes, - 'analyses': analyses, - }) + return _json_dumps( + { + 'status': 'ok', + 'updated_at': index.get('updated_at', ''), + 'total_notes': len(notes), + 'total_analyses': len(analyses), + 'notes': notes, + 'analyses': analyses, + } + ) diff --git a/projects/deep_research/v2/tools/report_tool.py b/projects/deep_research/v2/tools/report_tool.py index 819da108c..5d0bcfbdd 100644 --- a/projects/deep_research/v2/tools/report_tool.py +++ b/projects/deep_research/v2/tools/report_tool.py @@ -1,11 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json import os import re import time import uuid from typing import Any, Dict, List, Optional -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.utils import file_lock, render_markdown_todo @@ -46,20 +46,17 @@ def _write_text(path: str, content: str) -> None: def _coerce_chapters_argument(chapters: Any) -> tuple[List[Dict[str, Any]], Optional[str]]: """Normalize `chapters` from the model (list, JSON string, or nested strings).""" if chapters is None: - return [], ( - 'commit_outline requires `chapters` (array of chapter objects, ' - 'or a JSON string of that array).') + return [], ('commit_outline requires `chapters` (array of chapter objects, or a JSON string of that array).') raw: Any = chapters if isinstance(raw, str): try: raw = json.loads(raw.strip()) except json.JSONDecodeError as e: return [], ( - 'commit_outline `chapters` must be a JSON array of objects, ' - f'or a JSON string of that array: {e}') + f'commit_outline `chapters` must be a JSON array of objects, or a JSON string of that array: {e}' + ) if not isinstance(raw, list): - return [], ( - f'commit_outline `chapters` must be a list, got {type(chapters).__name__}.') + return [], (f'commit_outline `chapters` must be a list, got {type(chapters).__name__}.') out: List[Dict[str, Any]] = [] for i, ch in enumerate(raw): if isinstance(ch, str): @@ -67,11 +64,10 @@ def _coerce_chapters_argument(chapters: Any) -> tuple[List[Dict[str, Any]], Opti ch = json.loads(ch.strip()) except json.JSONDecodeError: return [], ( - f'commit_outline chapters[{i}] must be an object; ' - 'string entry is not valid JSON for an object.') + f'commit_outline chapters[{i}] must be an object; string entry is not valid JSON for an object.' + ) if not isinstance(ch, dict): - return [], ( - f'commit_outline chapters[{i}] must be an object, got {type(ch).__name__}.') + return [], (f'commit_outline chapters[{i}] must be an object, got {type(ch).__name__}.') out.append(ch) return out, None @@ -81,14 +77,9 @@ def _render_outline_md(outline: Dict[str, Any]) -> str: lines = [f"# {outline.get('title', 'Report Outline')}", ''] for ch in outline.get('chapters', []): - status_icon = { - 'pending': '⏳', - 'in_progress': '🔄', - 'completed': '✅' - }.get(ch.get('status', 'pending'), '⏳') + status_icon = {'pending': '⏳', 'in_progress': '🔄', 'completed': '✅'}.get(ch.get('status', 'pending'), '⏳') - lines.append( - f"## Chapter {ch['chapter_id']}: {ch['title']} {status_icon}") + lines.append(f"## Chapter {ch['chapter_id']}: {ch['title']} {status_icon}") if ch.get('goals'): lines.append('') @@ -103,8 +94,7 @@ def _render_outline_md(outline: Dict[str, Any]) -> str: if ch.get('candidate_evidence'): lines.append('') - lines.append( - f"**Related evidence:** {', '.join(ch['candidate_evidence'])}") + lines.append(f"**Related evidence:** {', '.join(ch['candidate_evidence'])}") lines.append('') @@ -116,27 +106,19 @@ def _render_outline_progress_md(outline: Dict[str, Any]) -> str: chapters = outline.get('chapters', []) total = len(chapters) completed = sum(1 for ch in chapters if ch.get('status') == 'completed') - in_progress = sum(1 for ch in chapters - if ch.get('status') == 'in_progress') + in_progress = sum(1 for ch in chapters if ch.get('status') == 'in_progress') pending = total - completed - in_progress lines = [f"# {outline.get('title', 'Report Outline')}", ''] - lines.append( - f'Progress: {completed}/{total} completed | {in_progress} in progress | {pending} pending' - ) + lines.append(f'Progress: {completed}/{total} completed | {in_progress} in progress | {pending} pending') lines.append('') lines.append('## Chapters') lines.append('') for ch in chapters: status = ch.get('status', 'pending') - status_icon = { - 'pending': '⏳', - 'in_progress': '🔄', - 'completed': '✅' - }.get(status, '⏳') - lines.append( - f"- {status_icon} Chapter {ch['chapter_id']}: {ch['title']}") + status_icon = {'pending': '⏳', 'in_progress': '🔄', 'completed': '✅'}.get(status, '⏳') + lines.append(f"- {status_icon} Chapter {ch['chapter_id']}: {ch['title']}") lines.append('') return '\n'.join(lines) @@ -172,43 +154,28 @@ def __init__(self, config, **kwargs): self.exclude_func(tool_cfg) # Configurable paths - self._reports_dir = getattr(tool_cfg, 'reports_dir', - 'reports') if tool_cfg else 'reports' - self._evidence_dir = getattr(tool_cfg, 'evidence_dir', - 'evidence') if tool_cfg else 'evidence' - self._lock_subdir = getattr(tool_cfg, 'lock_subdir', - '.locks') if tool_cfg else '.locks' + self._reports_dir = getattr(tool_cfg, 'reports_dir', 'reports') if tool_cfg else 'reports' + self._evidence_dir = getattr(tool_cfg, 'evidence_dir', 'evidence') if tool_cfg else 'evidence' + self._lock_subdir = getattr(tool_cfg, 'lock_subdir', '.locks') if tool_cfg else '.locks' async def connect(self) -> None: """Initialize directory structure.""" _ensure_dir(self.output_dir) - _ensure_dir( - os.path.join(self.output_dir, self._reports_dir, 'chapters')) + _ensure_dir(os.path.join(self.output_dir, self._reports_dir, 'chapters')) _ensure_dir(os.path.join(self.output_dir, self._lock_subdir)) def _paths(self) -> Dict[str, str]: return { - 'outline_json': - os.path.join(self.output_dir, self._reports_dir, 'outline.json'), - 'outline_md': - os.path.join(self.output_dir, self._reports_dir, 'outline.md'), - 'outline_progress_md': - os.path.join(self.output_dir, self._reports_dir, - 'outline_progress.md'), - 'chapters_dir': - os.path.join(self.output_dir, self._reports_dir, 'chapters'), - 'conflict_json': - os.path.join(self.output_dir, self._reports_dir, 'conflict.json'), - 'draft_md': - os.path.join(self.output_dir, self._reports_dir, 'draft.md'), - 'report_md': - os.path.join(self.output_dir, self._reports_dir, 'report.md'), - 'evidence_index': - os.path.join(self.output_dir, self._evidence_dir, 'index.json'), - 'evidence_notes_dir': - os.path.join(self.output_dir, self._evidence_dir, 'notes'), - 'lock_dir': - os.path.join(self.output_dir, self._lock_subdir), + 'outline_json': os.path.join(self.output_dir, self._reports_dir, 'outline.json'), + 'outline_md': os.path.join(self.output_dir, self._reports_dir, 'outline.md'), + 'outline_progress_md': os.path.join(self.output_dir, self._reports_dir, 'outline_progress.md'), + 'chapters_dir': os.path.join(self.output_dir, self._reports_dir, 'chapters'), + 'conflict_json': os.path.join(self.output_dir, self._reports_dir, 'conflict.json'), + 'draft_md': os.path.join(self.output_dir, self._reports_dir, 'draft.md'), + 'report_md': os.path.join(self.output_dir, self._reports_dir, 'report.md'), + 'evidence_index': os.path.join(self.output_dir, self._evidence_dir, 'index.json'), + 'evidence_notes_dir': os.path.join(self.output_dir, self._evidence_dir, 'notes'), + 'lock_dir': os.path.join(self.output_dir, self._lock_subdir), } def _filter_candidate_evidence( @@ -248,11 +215,11 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='commit_outline', server_name=self.SERVER_NAME, - description= - ('Generate the report outline with chapter structure. ' - 'Each chapter must be bound to relevant evidence (note_ids). ' - 'Ensures all evidence is covered by at least one chapter.' - ), + description=( + 'Generate the report outline with chapter structure. ' + 'Each chapter must be bound to relevant evidence (note_ids). ' + 'Ensures all evidence is covered by at least one chapter.' + ), parameters={ 'type': 'object', 'properties': { @@ -264,52 +231,38 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'array', 'description': 'List of chapter definitions.', 'items': { - 'type': - 'object', + 'type': 'object', 'properties': { 'title': { 'type': 'string', 'description': 'Chapter title.', }, 'goals': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - 'Main objectives of this chapter.', + 'type': 'array', + 'items': {'type': 'string'}, + 'description': 'Main objectives of this chapter.', }, 'sections_description': { - 'type': - 'string', - 'description': - ('Detailed section-by-section plan for this chapter ' - '(NOT a single-sentence summary). ' - 'Write subsections as a numbered list in markdown. ' - 'For EACH subsection include: ' - '(a) subsection title, (b) 2-5 bullet key ' - 'points / questions to answer, ' - '(c) expected output form: narrative synthesis is required; ' - 'optionally add an artifact ' - '(e.g., table/checklist) to support the narrative.' - ), + 'type': 'string', + 'description': ( + 'Detailed section-by-section plan for this chapter ' + '(NOT a single-sentence summary). ' + 'Write subsections as a numbered list in markdown. ' + 'For EACH subsection include: ' + '(a) subsection title, (b) 2-5 bullet key ' + 'points / questions to answer, ' + '(c) expected output form: narrative synthesis is required; ' + 'optionally add an artifact ' + '(e.g., table/checklist) to support the narrative.' + ), }, 'candidate_evidence': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - 'List of note_ids relevant to this chapter.', + 'type': 'array', + 'items': {'type': 'string'}, + 'description': 'List of note_ids relevant to this chapter.', }, }, - 'required': [ - 'title', 'goals', - 'sections_description', - 'candidate_evidence' - ], + 'required': ['title', 'goals', 'sections_description', 'candidate_evidence'], }, }, }, @@ -320,11 +273,11 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='prepare_chapter_bundle', server_name=self.SERVER_NAME, - description= - ('Prepare metadata and evidence content for writing a specific chapter. ' - 'Returns the chapter info with full evidence details for review. ' - 'Call this before commit_chapter to review evidence quality.' - ), + description=( + 'Prepare metadata and evidence content for writing a specific chapter. ' + 'Returns the chapter info with full evidence details for review. ' + 'Call this before commit_chapter to review evidence quality.' + ), parameters={ 'type': 'object', 'properties': { @@ -333,15 +286,12 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'description': 'The chapter number (1-based).', }, 'relevant_evidence': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - ('List of note_ids maybe used in this chapter. ' - 'The note_ids in this list will be loaded for review.' - ), + 'type': 'array', + 'items': {'type': 'string'}, + 'description': ( + 'List of note_ids maybe used in this chapter. ' + 'The note_ids in this list will be loaded for review.' + ), }, # 'need_raw_chunks': { # 'type': 'boolean', @@ -356,70 +306,57 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='commit_chapter', server_name=self.SERVER_NAME, - description= - ('Write the content of a specific chapter. ' - 'The chapter will be saved as chapter_XX.md and status updated to completed.' - ), + description=( + 'Write the content of a specific chapter. ' + 'The chapter will be saved as chapter_XX.md and status updated to completed.' + ), parameters={ - 'type': - 'object', + 'type': 'object', 'properties': { 'chapter_id': { 'type': 'integer', 'description': 'The chapter number (1-based).', }, 'reranked_evidence': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - 'List of note_ids reranked and chosen for this chapter.', + 'type': 'array', + 'items': {'type': 'string'}, + 'description': 'List of note_ids reranked and chosen for this chapter.', }, 'content': { - 'type': - 'string', - 'description': - ('The markdown content of the chapter. ' - 'The content should include citations to the resources used in this chapter.' - 'Make sure the content is based on the reranked evidence.' - ), + 'type': 'string', + 'description': ( + 'The markdown content of the chapter. ' + 'The content should include citations to the resources used in this chapter.' + 'Make sure the content is based on the reranked evidence.' + ), }, 'cited_urls': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - ('List of resource urls actually cited in this chapter.' - 'Keep the same order as cited in content.'), + 'type': 'array', + 'items': {'type': 'string'}, + 'description': ( + 'List of resource urls actually cited in this chapter.' + 'Keep the same order as cited in content.' + ), }, }, # Keep schema consistent with Python signature (reranked_evidence has no default) - 'required': [ - 'chapter_id', 'reranked_evidence', 'content', - 'cited_urls' - ], - 'additionalProperties': - False, + 'required': ['chapter_id', 'reranked_evidence', 'content', 'cited_urls'], + 'additionalProperties': False, }, ), Tool( tool_name='load_chunk', server_name=self.SERVER_NAME, - description= - ('Load raw chunk content when evidence summaries are insufficient. ' - 'Reserved for future implementation.'), + description=( + 'Load raw chunk content when evidence summaries are insufficient. ' + 'Reserved for future implementation.' + ), parameters={ 'type': 'object', 'properties': { 'chunk_ids': { 'type': 'array', - 'items': { - 'type': 'string' - }, + 'items': {'type': 'string'}, 'description': 'List of chunk IDs to load.', }, }, @@ -430,8 +367,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='commit_conflict', server_name=self.SERVER_NAME, - description= - 'Record a conflict or contradiction between evidence.', + description='Record a conflict or contradiction between evidence.', parameters={ 'type': 'object', 'properties': { @@ -440,24 +376,17 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'description': 'Description of the conflict.', }, 'evidence_ids': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - 'Note IDs involved in the conflict.', + 'type': 'array', + 'items': {'type': 'string'}, + 'description': 'Note IDs involved in the conflict.', }, 'chapter_id': { 'type': 'integer', - 'description': - 'Optional: Related chapter number.', + 'description': 'Optional: Related chapter number.', }, 'resolution': { - 'type': - 'string', - 'description': - 'Optional: How the conflict was resolved.', + 'type': 'string', + 'description': 'Optional: How the conflict was resolved.', }, }, 'required': ['description', 'evidence_ids'], @@ -467,8 +396,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='update_outline', server_name=self.SERVER_NAME, - description= - 'Update a specific chapter in the outline (title, goals, or evidence bindings).', + description='Update a specific chapter in the outline (title, goals, or evidence bindings).', parameters={ 'type': 'object', 'properties': { @@ -478,45 +406,35 @@ async def _get_tools_inner(self) -> Dict[str, Any]: }, 'updates': { 'type': 'object', - 'description': - 'Fields to update (title, goals, sections_description, candidate_evidence).', + 'description': 'Fields to update (title, goals, sections_description, candidate_evidence).', 'properties': { 'title': { 'type': 'string', 'description': 'Title of the chapter.', }, 'goals': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - 'Main objectives of this chapter.', + 'type': 'array', + 'items': {'type': 'string'}, + 'description': 'Main objectives of this chapter.', }, 'sections_description': { - 'type': - 'string', - 'description': - ('Detailed section-by-section plan for ' - 'this chapter (NOT a single-sentence summary). ' - 'Write subsections as a numbered list in markdown. ' - 'For EACH subsection include: ' - '(a) subsection title, (b) 2-5 bullet key ' - 'points / questions to answer, ' - '(c) expected output form: narrative synthesis ' - 'is required; optionally add an artifact ' - '(e.g., table/checklist) to support the narrative.' - ), + 'type': 'string', + 'description': ( + 'Detailed section-by-section plan for ' + 'this chapter (NOT a single-sentence summary). ' + 'Write subsections as a numbered list in markdown. ' + 'For EACH subsection include: ' + '(a) subsection title, (b) 2-5 bullet key ' + 'points / questions to answer, ' + '(c) expected output form: narrative synthesis ' + 'is required; optionally add an artifact ' + '(e.g., table/checklist) to support the narrative.' + ), }, 'candidate_evidence': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - 'List of note_ids relevant to this chapter.', + 'type': 'array', + 'items': {'type': 'string'}, + 'description': 'List of note_ids relevant to this chapter.', }, }, }, @@ -528,24 +446,22 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='assemble_draft', server_name=self.SERVER_NAME, - description= - ('Assemble all chapters into a draft (draft.md) with TOC and references. ' - 'Returns the draft path along with a summary of recorded conflicts. ' - 'The model should then review the draft and conflicts to produce the final report.' - ), + description=( + 'Assemble all chapters into a draft (draft.md) with TOC and references. ' + 'Returns the draft path along with a summary of recorded conflicts. ' + 'The model should then review the draft and conflicts to produce the final report.' + ), parameters={ 'type': 'object', 'properties': { 'include_toc': { 'type': 'boolean', - 'description': - 'Whether to include table of contents.', + 'description': 'Whether to include table of contents.', 'default': True, }, 'include_references': { 'type': 'boolean', - 'description': - 'Whether to include references section.', + 'description': 'Whether to include references section.', 'default': True, }, }, @@ -583,30 +499,22 @@ async def _get_tools_inner(self) -> Dict[str, Any]: } return tools - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await getattr(self, tool_name)(**(tool_args or {})) def _load_outline(self, paths: Dict[str, str]) -> Optional[Dict[str, Any]]: """Load outline.json.""" return _safe_read_json(paths['outline_json']) - def _save_outline(self, - paths: Dict[str, str], - outline: Dict[str, Any], - render: bool = True) -> None: + def _save_outline(self, paths: Dict[str, str], outline: Dict[str, Any], render: bool = True) -> None: """Save outline.json and render outline.md.""" outline['updated_at'] = _now_iso() _write_text(paths['outline_json'], _json_dumps(outline)) _write_text(paths['outline_md'], _render_outline_md(outline)) - _write_text(paths['outline_progress_md'], - _render_outline_progress_md(outline)) + _write_text(paths['outline_progress_md'], _render_outline_progress_md(outline)) if render: - render_markdown_todo( - paths['outline_progress_md'], - title='CURRENT REPORT OUTLINE', - use_pager=False) + render_markdown_todo(paths['outline_progress_md'], title='CURRENT REPORT OUTLINE', use_pager=False) def _load_evidence_index(self, paths: Dict[str, str]) -> Dict[str, Any]: """Load evidence index.""" @@ -634,11 +542,9 @@ def _load_full_evidence_index(self, paths: Dict[str, str]) -> Dict[str, Any]: data['analyses'] = legacy return data - def _load_note_content(self, paths: Dict[str, str], - note_id: str) -> Optional[Dict[str, Any]]: + def _load_note_content(self, paths: Dict[str, str], note_id: str) -> Optional[Dict[str, Any]]: """Load a single note's full content from markdown file.""" - note_path = os.path.join(paths['evidence_notes_dir'], - f'note_{note_id}.md') + note_path = os.path.join(paths['evidence_notes_dir'], f'note_{note_id}.md') if not os.path.exists(note_path): return None @@ -657,8 +563,7 @@ def _load_conflict(self, paths: Dict[str, str]) -> Dict[str, Any]: return {'updated_at': _now_iso(), 'conflicts': []} return data - def _save_conflict(self, paths: Dict[str, str], - conflict: Dict[str, Any]) -> None: + def _save_conflict(self, paths: Dict[str, str], conflict: Dict[str, Any]) -> None: """Save conflict.json.""" conflict['updated_at'] = _now_iso() _write_text(paths['conflict_json'], _json_dumps(conflict)) @@ -671,14 +576,16 @@ async def load_index(self) -> str: index = self._load_full_evidence_index(paths) notes = index.get('notes', {}) analyses = index.get('analyses', {}) - return _json_dumps({ - 'status': 'ok', - 'updated_at': index.get('updated_at', ''), - 'total_notes': len(notes), - 'total_analyses': len(analyses), - 'notes': notes, - 'analyses': analyses, - }) + return _json_dumps( + { + 'status': 'ok', + 'updated_at': index.get('updated_at', ''), + 'total_notes': len(notes), + 'total_analyses': len(analyses), + 'notes': notes, + 'analyses': analyses, + } + ) async def commit_outline( self, @@ -705,34 +612,27 @@ async def commit_outline( for idx, ch in enumerate(chapters_list, start=1): candidate_raw = ch.get('candidate_evidence', []) - kept, dropped = self._filter_candidate_evidence( - paths, candidate_raw) + kept, dropped = self._filter_candidate_evidence(paths, candidate_raw) if dropped: invalid_candidate_by_chapter[str(idx)] = dropped covered_evidence.update(kept) - outline_chapters.append({ - 'chapter_id': - idx, - 'title': - ch.get('title', f'Chapter {idx}'), - 'goals': - ch.get('goals', []), - 'sections_description': - ch.get('sections_description', ''), - 'candidate_evidence': - kept, - 'status': - 'pending', - }) + outline_chapters.append( + { + 'chapter_id': idx, + 'title': ch.get('title', f'Chapter {idx}'), + 'goals': ch.get('goals', []), + 'sections_description': ch.get('sections_description', ''), + 'candidate_evidence': kept, + 'status': 'pending', + } + ) # Check coverage uncovered = all_note_ids - covered_evidence coverage_warning = None if uncovered: - coverage_warning = ( - f'Warning: the following evidence is not covered by any chapter: {list(uncovered)}' - ) + coverage_warning = f'Warning: the following evidence is not covered by any chapter: {list(uncovered)}' outline = { 'title': title, @@ -746,8 +646,7 @@ async def commit_outline( result = { 'status': 'ok', - 'outline_path': os.path.relpath(paths['outline_json'], - self.output_dir), + 'outline_path': os.path.relpath(paths['outline_json'], self.output_dir), 'chapters_count': len(outline_chapters), 'total_evidence': len(all_note_ids), 'covered_evidence': len(covered_evidence), @@ -779,12 +678,9 @@ async def prepare_chapter_bundle( # Load outline outline = self._load_outline(paths) if outline is None: - return _json_dumps({ - 'status': - 'error', - 'message': - 'Outline not created yet. Please call commit_outline first.' - }) + return _json_dumps( + {'status': 'error', 'message': 'Outline not created yet. Please call commit_outline first.'} + ) # Find chapter chapter = None @@ -794,25 +690,17 @@ async def prepare_chapter_bundle( break if chapter is None: - return _json_dumps({ - 'status': 'error', - 'message': f'Chapter {chapter_id} not found.' - }) + return _json_dumps({'status': 'error', 'message': f'Chapter {chapter_id} not found.'}) - cand_kept, cand_dropped = self._filter_candidate_evidence( - paths, chapter.get('candidate_evidence', [])) - rel_kept, rel_dropped = self._filter_candidate_evidence( - paths, relevant_evidence or []) + cand_kept, cand_dropped = self._filter_candidate_evidence(paths, chapter.get('candidate_evidence', [])) + rel_kept, rel_dropped = self._filter_candidate_evidence(paths, relevant_evidence or []) # Load evidence content evidence_index = self._load_evidence_index(paths) notes_meta = evidence_index.get('notes', {}) _known_sorted = sorted(notes_meta.keys()) _sample = _known_sorted[:48] - _note_id_hint = ( - 'Known note ids in evidence index (sample): ' - + (', '.join(_sample) if _sample else '(none)') - ) + _note_id_hint = 'Known note ids in evidence index (sample): ' + (', '.join(_sample) if _sample else '(none)') if len(_known_sorted) > len(_sample): _note_id_hint += f' … (+{len(_known_sorted) - len(_sample)} more)' @@ -836,24 +724,18 @@ def _missing_note_entry(note_id: str, meta: Dict[str, Any]) -> Dict[str, Any]: note_data = self._load_note_content(paths, note_id) if note_data: - notes_content.append({ - 'note_id': - note_id, - 'title': - meta.get('title', note_data.get('title', '')), - 'content': - note_data.get('content', ''), - 'contradicts': - note_data.get('contradicts', ''), - 'summary': - meta.get('summary', note_data.get('summary', '')), - 'sources': - meta.get('sources', note_data.get('sources', [])), - 'quality_score': - meta.get('quality_score', note_data.get('quality_score')), - 'tags': - meta.get('tags', note_data.get('tags', [])), - }) + notes_content.append( + { + 'note_id': note_id, + 'title': meta.get('title', note_data.get('title', '')), + 'content': note_data.get('content', ''), + 'contradicts': note_data.get('contradicts', ''), + 'summary': meta.get('summary', note_data.get('summary', '')), + 'sources': meta.get('sources', note_data.get('sources', [])), + 'quality_score': meta.get('quality_score', note_data.get('quality_score')), + 'tags': meta.get('tags', note_data.get('tags', [])), + } + ) else: notes_content.append(_missing_note_entry(note_id, meta)) @@ -865,31 +747,23 @@ def _missing_note_entry(note_id: str, meta: Dict[str, Any]) -> Dict[str, Any]: note_data = self._load_note_content(paths, note_id) if note_data: - notes_content.append({ - 'note_id': - note_id, - 'title': - meta.get('title', note_data.get('title', '')), - 'content': - note_data.get('content', ''), - 'contradicts': - note_data.get('contradicts', ''), - 'summary': - meta.get('summary', note_data.get('summary', '')), - 'sources': - meta.get('sources', note_data.get('sources', [])), - 'quality_score': - meta.get('quality_score', - note_data.get('quality_score')), - 'tags': - meta.get('tags', note_data.get('tags', [])), - }) + notes_content.append( + { + 'note_id': note_id, + 'title': meta.get('title', note_data.get('title', '')), + 'content': note_data.get('content', ''), + 'contradicts': note_data.get('contradicts', ''), + 'summary': meta.get('summary', note_data.get('summary', '')), + 'sources': meta.get('sources', note_data.get('sources', [])), + 'quality_score': meta.get('quality_score', note_data.get('quality_score')), + 'tags': meta.get('tags', note_data.get('tags', [])), + } + ) else: notes_content.append(_missing_note_entry(note_id, meta)) # Build meta (only ids that resolved to on-disk notes for this bundle) - candidate_evidence = list( - dict.fromkeys(list(cand_kept) + list(rel_kept))) + candidate_evidence = list(dict.fromkeys(list(cand_kept) + list(rel_kept))) meta = { 'chapter_id': chapter_id, 'chapter_title': chapter['title'], @@ -902,8 +776,7 @@ def _missing_note_entry(note_id: str, meta: Dict[str, Any]) -> Dict[str, Any]: } # Save meta.json - meta_path = os.path.join(paths['chapters_dir'], - f'chapter_{chapter_id:02d}_meta.json') + meta_path = os.path.join(paths['chapters_dir'], f'chapter_{chapter_id:02d}_meta.json') with file_lock(paths['lock_dir'], f'chapter_{chapter_id}_meta'): _write_text(meta_path, _json_dumps(meta)) @@ -913,20 +786,13 @@ def _missing_note_entry(note_id: str, meta: Dict[str, Any]) -> Dict[str, Any]: self._save_outline(paths, outline) out_bundle: Dict[str, Any] = { - 'status': - 'ok', - 'chapter_id': - chapter_id, - 'chapter_title': - chapter['title'], - 'chapter_goals': - chapter.get('goals', []), - 'evidence_count': - len(notes_content), - 'meta_path': - os.path.relpath(meta_path, self.output_dir), - 'notes_content': - notes_content, + 'status': 'ok', + 'chapter_id': chapter_id, + 'chapter_title': chapter['title'], + 'chapter_goals': chapter.get('goals', []), + 'evidence_count': len(notes_content), + 'meta_path': os.path.relpath(meta_path, self.output_dir), + 'notes_content': notes_content, } skipped: Dict[str, List[str]] = {} if cand_dropped: @@ -955,10 +821,7 @@ async def commit_chapter( # Validate outline exists outline = self._load_outline(paths) if outline is None: - return _json_dumps({ - 'status': 'error', - 'message': 'Outline not created yet.' - }) + return _json_dumps({'status': 'error', 'message': 'Outline not created yet.'}) # Find and update chapter chapter_found = False @@ -973,14 +836,10 @@ async def commit_chapter( break if not chapter_found: - return _json_dumps({ - 'status': 'error', - 'message': f'Chapter {chapter_id} not found.' - }) + return _json_dumps({'status': 'error', 'message': f'Chapter {chapter_id} not found.'}) # Write chapter file - chapter_path = os.path.join(paths['chapters_dir'], - f'chapter_{chapter_id:02d}.md') + chapter_path = os.path.join(paths['chapters_dir'], f'chapter_{chapter_id:02d}.md') with file_lock(paths['lock_dir'], f'chapter_{chapter_id}'): _write_text(chapter_path, content) @@ -988,8 +847,7 @@ async def commit_chapter( with file_lock(paths['lock_dir'], 'report_outline'): self._save_outline(paths, outline) - meta_path = os.path.join(paths['chapters_dir'], - f'chapter_{chapter_id:02d}_meta.json') + meta_path = os.path.join(paths['chapters_dir'], f'chapter_{chapter_id:02d}_meta.json') meta = _safe_read_json(meta_path) meta = meta if isinstance(meta, dict) else {} meta['reranked_evidence'] = list(reranked_evidence or []) @@ -997,31 +855,27 @@ async def commit_chapter( with file_lock(paths['lock_dir'], f'chapter_{chapter_id}_meta'): _write_text(meta_path, _json_dumps(meta)) - return _json_dumps({ - 'status': - 'ok', - 'chapter_id': - chapter_id, - 'chapter_title': - chapter_title, - 'path': - os.path.relpath(chapter_path, self.output_dir), - 'content_length': - len(content), - 'reranked_evidence': - reranked_evidence or [], - 'cited_urls': - cited_urls or [], - }) + return _json_dumps( + { + 'status': 'ok', + 'chapter_id': chapter_id, + 'chapter_title': chapter_title, + 'path': os.path.relpath(chapter_path, self.output_dir), + 'content_length': len(content), + 'reranked_evidence': reranked_evidence or [], + 'cited_urls': cited_urls or [], + } + ) async def load_chunk(self, chunk_ids: List[str]) -> str: """Load raw chunk content. Reserved for future implementation.""" - return _json_dumps({ - 'status': 'not_implemented', - 'message': - 'Chunk storage not enabled in this version. Use evidence notes directly.', - 'chunk_ids': chunk_ids, - }) + return _json_dumps( + { + 'status': 'not_implemented', + 'message': 'Chunk storage not enabled in this version. Use evidence notes directly.', + 'chunk_ids': chunk_ids, + } + ) async def commit_conflict( self, @@ -1050,16 +904,14 @@ async def commit_conflict( conflicts['conflicts'].append(conflict_entry) self._save_conflict(paths, conflicts) - return _json_dumps({ - 'status': - 'ok', - 'conflict_id': - conflict_id, - 'total_conflicts': - len(conflicts['conflicts']), - 'conflict_path': - os.path.relpath(paths['conflict_json'], self.output_dir), - }) + return _json_dumps( + { + 'status': 'ok', + 'conflict_id': conflict_id, + 'total_conflicts': len(conflicts['conflicts']), + 'conflict_path': os.path.relpath(paths['conflict_json'], self.output_dir), + } + ) async def update_outline( self, @@ -1073,10 +925,7 @@ async def update_outline( with file_lock(paths['lock_dir'], 'report_outline'): outline = self._load_outline(paths) if outline is None: - return _json_dumps({ - 'status': 'error', - 'message': 'Outline not created yet.' - }) + return _json_dumps({'status': 'error', 'message': 'Outline not created yet.'}) chapter_found = False invalid_candidate_removed: List[str] = [] @@ -1087,23 +936,16 @@ async def update_outline( if 'goals' in updates: ch['goals'] = updates['goals'] if 'sections_description' in updates: - ch['sections_description'] = updates[ - 'sections_description'] + ch['sections_description'] = updates['sections_description'] if 'candidate_evidence' in updates: - kept, dropped = self._filter_candidate_evidence( - paths, updates['candidate_evidence']) + kept, dropped = self._filter_candidate_evidence(paths, updates['candidate_evidence']) ch['candidate_evidence'] = kept invalid_candidate_removed = dropped chapter_found = True break if not chapter_found: - return _json_dumps({ - 'status': - 'error', - 'message': - f'Chapter {chapter_id} not found.' - }) + return _json_dumps({'status': 'error', 'message': f'Chapter {chapter_id} not found.'}) self._save_outline(paths, outline) @@ -1115,8 +957,7 @@ async def update_outline( if invalid_candidate_removed: out['invalid_candidate_evidence_removed'] = invalid_candidate_removed out['invalid_candidate_evidence_note'] = ( - 'These ids were removed from candidate_evidence because no ' - 'matching evidence/notes/note_.md exists.' + 'These ids were removed from candidate_evidence because no matching evidence/notes/note_.md exists.' ) return _json_dumps(out) @@ -1130,47 +971,38 @@ async def assemble_draft( outline = self._load_outline(paths) if outline is None: - return _json_dumps({ - 'status': 'error', - 'message': 'Outline not created yet.' - }) + return _json_dumps({'status': 'error', 'message': 'Outline not created yet.'}) # Collect chapter contents chapters_content = [] missing_chapters = [] for ch in outline.get('chapters', []): - chapter_path = os.path.join(paths['chapters_dir'], - f"chapter_{ch['chapter_id']:02d}.md") + chapter_path = os.path.join(paths['chapters_dir'], f"chapter_{ch['chapter_id']:02d}.md") if os.path.exists(chapter_path): with open(chapter_path, 'r', encoding='utf-8') as f: - chapters_content.append({ - 'id': - ch['chapter_id'], - 'title': - ch['title'], - 'content': - f.read(), - 'reranked_evidence': - ch.get('reranked_evidence', []), - 'cited_urls': - ch.get('cited_urls', []), - }) + chapters_content.append( + { + 'id': ch['chapter_id'], + 'title': ch['title'], + 'content': f.read(), + 'reranked_evidence': ch.get('reranked_evidence', []), + 'cited_urls': ch.get('cited_urls', []), + } + ) else: missing_chapters.append(ch['chapter_id']) if missing_chapters: - return _json_dumps({ - 'status': - 'error', - 'message': - f'The following chapters are not completed yet: {missing_chapters}', - }) + return _json_dumps( + { + 'status': 'error', + 'message': f'The following chapters are not completed yet: {missing_chapters}', + } + ) # Build draft - draft_lines = [ - f"# {outline.get('title', 'Research Report')} (Draft)", '' - ] + draft_lines = [f"# {outline.get('title', 'Research Report')} (Draft)", ''] # Table of contents if include_toc: @@ -1178,8 +1010,7 @@ async def assemble_draft( draft_lines.append('') for ch in chapters_content: anchor = ch['title'].replace(' ', '-').lower() - draft_lines.append( - f"- [Chapter {ch['id']} {ch['title']}](#{anchor})") + draft_lines.append(f"- [Chapter {ch['id']} {ch['title']}](#{anchor})") draft_lines.append('') # Chapters @@ -1194,7 +1025,7 @@ async def assemble_draft( cited_urls = set() for ch in chapters_content: - for url in (ch.get('cited_urls') or []): + for url in ch.get('cited_urls') or []: cited_urls.add(url) all_cited = set() @@ -1226,33 +1057,31 @@ async def assemble_draft( conflicts_list = conflicts_data.get('conflicts', []) conflicts_summary = [] for c in conflicts_list: - conflicts_summary.append({ - 'id': c.get('id'), - 'description': c.get('description'), - 'chapter_id': c.get('chapter_id'), - 'resolution': c.get('resolution'), - }) - - return _json_dumps({ - 'status': - 'ok', - 'draft_path': - os.path.relpath(paths['draft_md'], self.output_dir), - 'chapters_count': - len(chapters_content), - 'content_length': - len(draft_content), - 'conflicts_count': - len(conflicts_list), - 'conflicts_summary': - conflicts_summary, - 'next_step_reminder': - ('Review the draft and conflicts, then generate the final report. ' - 'Note: the draft cannot be used as the final report; ' - 'do not replace report content with references or pointers to other content or files ' - '(e.g., "details are in chapter_2.md", "see draft.md for more details").' - ), - }) + conflicts_summary.append( + { + 'id': c.get('id'), + 'description': c.get('description'), + 'chapter_id': c.get('chapter_id'), + 'resolution': c.get('resolution'), + } + ) + + return _json_dumps( + { + 'status': 'ok', + 'draft_path': os.path.relpath(paths['draft_md'], self.output_dir), + 'chapters_count': len(chapters_content), + 'content_length': len(draft_content), + 'conflicts_count': len(conflicts_list), + 'conflicts_summary': conflicts_summary, + 'next_step_reminder': ( + 'Review the draft and conflicts, then generate the final report. ' + 'Note: the draft cannot be used as the final report; ' + 'do not replace report content with references or pointers to other content or files ' + '(e.g., "details are in chapter_2.md", "see draft.md for more details").' + ), + } + ) async def get_status(self) -> str: """Get current report generation progress.""" @@ -1262,52 +1091,40 @@ async def get_status(self) -> str: conflicts = self._load_conflict(paths) if outline is None: - return _json_dumps({ - 'status': - 'not_started', - 'outline_exists': - False, - 'chapters': [], - 'conflicts_count': - len(conflicts.get('conflicts', [])), - }) + return _json_dumps( + { + 'status': 'not_started', + 'outline_exists': False, + 'chapters': [], + 'conflicts_count': len(conflicts.get('conflicts', [])), + } + ) chapters_status = [] for ch in outline.get('chapters', []): - chapter_path = os.path.join(paths['chapters_dir'], - f"chapter_{ch['chapter_id']:02d}.md") - chapters_status.append({ - 'chapter_id': - ch['chapter_id'], - 'title': - ch['title'], - 'status': - ch.get('status', 'pending'), - 'file_exists': - os.path.exists(chapter_path), - 'candidate_evidence_count': - len(ch.get('candidate_evidence', [])), - }) - - completed = sum(1 for ch in chapters_status - if ch['status'] == 'completed') + chapter_path = os.path.join(paths['chapters_dir'], f"chapter_{ch['chapter_id']:02d}.md") + chapters_status.append( + { + 'chapter_id': ch['chapter_id'], + 'title': ch['title'], + 'status': ch.get('status', 'pending'), + 'file_exists': os.path.exists(chapter_path), + 'candidate_evidence_count': len(ch.get('candidate_evidence', [])), + } + ) + + completed = sum(1 for ch in chapters_status if ch['status'] == 'completed') total = len(chapters_status) - return _json_dumps({ - 'status': - 'in_progress' if completed < total else 'completed', - 'outline_exists': - True, - 'report_title': - outline.get('title', ''), - 'progress': - f'{completed}/{total}', - 'chapters': - chapters_status, - 'conflicts_count': - len(conflicts.get('conflicts', [])), - 'draft_exists': - os.path.exists(paths['draft_md']), - 'report_exists': - os.path.exists(paths['report_md']), - }) + return _json_dumps( + { + 'status': 'in_progress' if completed < total else 'completed', + 'outline_exists': True, + 'report_title': outline.get('title', ''), + 'progress': f'{completed}/{total}', + 'chapters': chapters_status, + 'conflicts_count': len(conflicts.get('conflicts', [])), + 'draft_exists': os.path.exists(paths['draft_md']), + 'report_exists': os.path.exists(paths['report_md']), + } + ) diff --git a/projects/fin_research/aggregator.py b/projects/fin_research/aggregator.py index 2380e58e3..10ac55f40 100644 --- a/projects/fin_research/aggregator.py +++ b/projects/fin_research/aggregator.py @@ -1,14 +1,15 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os from typing import Any, AsyncGenerator, List, Union -import json from callbacks.file_parser import extract_code_blocks +from omegaconf import DictConfig + from ms_agent.agent.llm_agent import LLMAgent from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG -from omegaconf import DictConfig logger = get_logger() @@ -18,32 +19,25 @@ class AggregatorAgent(LLMAgent): Aggregator Agent that aggregates the reports from SearchAgent and CollectorAgent. """ - def __init__(self, - config: DictConfig = DictConfig({}), - tag: str = DEFAULT_TAG, - trust_remote_code: bool = False, - **kwargs): + def __init__( + self, config: DictConfig = DictConfig({}), tag: str = DEFAULT_TAG, trust_remote_code: bool = False, **kwargs + ): super().__init__(config, tag, trust_remote_code, **kwargs) async def run( - self, inputs: Union[str, List[str], List[Message], - List[List[Message]]], **kwargs + self, inputs: Union[str, List[str], List[Message], List[List[Message]]], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: reports = {} # Dict of reports if isinstance(inputs, list): if isinstance(inputs[0], str): - refractory_inputs = [[Message(role='user', content=item)] - for item in inputs - ] # multiple parent nodes + refractory_inputs = [[Message(role='user', content=item)] for item in inputs] # multiple parent nodes elif isinstance(inputs[0], Message): refractory_inputs = [inputs] # single parent node elif len(inputs) > 1 and isinstance(inputs[0], list): refractory_inputs = inputs # multiple parent nodes else: - raise ValueError( - f"Invalid input type: List[{type(inputs[0]) if inputs else 'empty list'}]" - ) + raise ValueError(f"Invalid input type: List[{type(inputs[0]) if inputs else 'empty list'}]") elif isinstance(inputs, str): refractory_inputs = [[Message(role='user', content=inputs)]] else: @@ -74,9 +68,11 @@ async def run( with open(report_path, 'r', encoding='utf-8') as f: report = f.read() - report_type = ('**Financial Data Analysis Report**' - if 'analysis' in report_path else - '**Online Sentiment Analysis Report**') + report_type = ( + '**Financial Data Analysis Report**' + if 'analysis' in report_path + else '**Online Sentiment Analysis Report**' + ) reports[report_type] = report if report else message.content plan = {} @@ -88,14 +84,15 @@ async def run( if content: # Only load if file is not empty plan.update(json.loads(content)) except Exception as e: - logger.warning( - f'Failed to load plan.json: {e}. Using empty plan.') + logger.warning(f'Failed to load plan.json: {e}. Using empty plan.') return await super().run( - messages= - (f'The reports from the SearchAgent and AnalystAgent are as follows:\n' - f'{json.dumps(reports, ensure_ascii=False, indent=2)}\n' - f'Please integrate the reports into a comprehensive financial analysis report.\n' - f'Please review the original plan for the financial analysis task:\n' - f'{json.dumps(plan, ensure_ascii=False, indent=2)}\n'), - kwargs=kwargs) + messages=( + f'The reports from the SearchAgent and AnalystAgent are as follows:\n' + f'{json.dumps(reports, ensure_ascii=False, indent=2)}\n' + f'Please integrate the reports into a comprehensive financial analysis report.\n' + f'Please review the original plan for the financial analysis task:\n' + f'{json.dumps(plan, ensure_ascii=False, indent=2)}\n' + ), + kwargs=kwargs, + ) diff --git a/projects/fin_research/callbacks/aggregator_callback.py b/projects/fin_research/callbacks/aggregator_callback.py index b827f18a6..913a0252c 100644 --- a/projects/fin_research/callbacks/aggregator_callback.py +++ b/projects/fin_research/callbacks/aggregator_callback.py @@ -3,19 +3,19 @@ import re from typing import List +from omegaconf import DictConfig + from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.tools.filesystem_tool import FileSystemTool from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() class AggregatorCallback(Callback): - """Save output plan to local disk. - """ + """Save output plan to local disk.""" def __init__(self, config: DictConfig): super().__init__(config) @@ -37,7 +37,8 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): r'\s*\[ACT=(?:outline|partial_report|final_report)\]\s*:?\s*.*?(?:\n|\.)', '', message.content, - flags=re.MULTILINE).strip() + flags=re.MULTILINE, + ).strip() f.write(filtered_content) break logger.info(f'Aggregator report saved to {self.report_path}') diff --git a/projects/fin_research/callbacks/analyst_callback.py b/projects/fin_research/callbacks/analyst_callback.py index 31af7306b..b0952216d 100644 --- a/projects/fin_research/callbacks/analyst_callback.py +++ b/projects/fin_research/callbacks/analyst_callback.py @@ -1,34 +1,30 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os import re from pathlib import Path from typing import List -import json +from omegaconf import DictConfig + from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() class AnalystCallback(Callback): - """Save output plan to local disk. - """ + """Save output plan to local disk.""" def __init__(self, config: DictConfig): super().__init__(config) - self.report_path = self.config.get( - 'report_path', - os.path.join(self.config.output_dir, 'analysis_report.md')) + self.report_path = self.config.get('report_path', os.path.join(self.config.output_dir, 'analysis_report.md')) def _resolve_data_root(self) -> str: - code_exec_cfg = getattr( - getattr(self.config, 'tools', {}), 'code_executor', None) - impl = getattr(code_exec_cfg, 'implementation', - 'sandbox') if code_exec_cfg else 'sandbox' + code_exec_cfg = getattr(getattr(self.config, 'tools', {}), 'code_executor', None) + impl = getattr(code_exec_cfg, 'implementation', 'sandbox') if code_exec_cfg else 'sandbox' if isinstance(impl, str) and impl.lower() == 'sandbox': return '/data' @@ -40,8 +36,7 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): for message in messages: if message.role == 'system': message.content = message.content.replace('\\\n', '') - message.content = message.content.replace( - '', self._resolve_data_root()) + message.content = message.content.replace('', self._resolve_data_root()) elif message.role == 'assistant': if '[ACT=summary]' in message.content: summary_messages['collector_summary'] = message.content @@ -49,32 +44,30 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): summary_messages['collector_plan'] = message.content if os.path.exists(os.path.join(self.config.output_dir, 'plan.json')): - with open(os.path.join(self.config.output_dir, 'plan.json'), - 'r') as f: + with open(os.path.join(self.config.output_dir, 'plan.json'), 'r') as f: plan = json.load(f) if not plan: - logger.error( - 'The plan.json file is empty, please check the file.') + logger.error('The plan.json file is empty, please check the file.') user_message = Message( role='user', - content= - (f'The complete plan for the current overall financial analysis task is as follows:\n{plan}\n' - f'Please follow the plan to complete the data analysis task.\n' - f'IMPORTANT: Review the input analysis specification provided under "financial_data_dimension"' - )) + content=( + f'The complete plan for the current overall financial analysis task is as follows:\n{plan}\n' + f'Please follow the plan to complete the data analysis task.\n' + f'IMPORTANT: Review the input analysis specification provided under "financial_data_dimension"' + ), + ) else: user_message = Message( role='user', - content= - ('Please conduct data analysis in accordance with the research plan followed during the data ' - 'collection phase and the results obtained from data collection.' - )) + content=( + 'Please conduct data analysis in accordance with the research plan followed during the data ' + 'collection phase and the results obtained from data collection.' + ), + ) # Add the summary of the data collection phase to the user message (add plan if exists) if summary_messages['collector_summary']: - messages[:] = [ - message for message in messages if message.role == 'system' - ] + messages[:] = [message for message in messages if message.role == 'system'] summary_messages = ( f'The summary of the data collection phase is as follows:\n' f'{json.dumps(summary_messages, ensure_ascii=False, indent=2)}' @@ -86,15 +79,11 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): for message in messages[::-1]: if message.role == 'assistant' and not message.tool_calls: with open(self.report_path, 'w') as f: - filtered_content = re.sub( - r'\s*\[ACT=(?:code|collect|report|fix)\]\s*', '', - message.content).strip() + filtered_content = re.sub(r'\s*\[ACT=(?:code|collect|report|fix)\]\s*', '', message.content).strip() f.write(filtered_content) break user_message = Message( - role='user', - content=json.dumps({'report_path': self.report_path}, - ensure_ascii=False, - indent=2)) + role='user', content=json.dumps({'report_path': self.report_path}, ensure_ascii=False, indent=2) + ) messages.append(user_message) diff --git a/projects/fin_research/callbacks/collector_callback.py b/projects/fin_research/callbacks/collector_callback.py index caffbcb71..6aeb11bec 100644 --- a/projects/fin_research/callbacks/collector_callback.py +++ b/projects/fin_research/callbacks/collector_callback.py @@ -1,30 +1,28 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os from pathlib import Path from typing import List -import json +from omegaconf import DictConfig + from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() class CollectorCallback(Callback): - """Save output plan to local disk. - """ + """Save output plan to local disk.""" def __init__(self, config: DictConfig): super().__init__(config) def _resolve_data_root(self) -> str: - code_exec_cfg = getattr( - getattr(self.config, 'tools', {}), 'code_executor', None) - impl = getattr(code_exec_cfg, 'implementation', - 'sandbox') if code_exec_cfg else 'sandbox' + code_exec_cfg = getattr(getattr(self.config, 'tools', {}), 'code_executor', None) + impl = getattr(code_exec_cfg, 'implementation', 'sandbox') if code_exec_cfg else 'sandbox' if isinstance(impl, str) and impl.lower() == 'sandbox': return '/data' @@ -35,16 +33,13 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): for message in messages: if message.role == 'system': message.content = message.content.replace('\\\n', '') - message.content = message.content.replace( - '', self._resolve_data_root()) + message.content = message.content.replace('', self._resolve_data_root()) if os.path.exists(os.path.join(self.config.output_dir, 'plan.json')): - with open(os.path.join(self.config.output_dir, 'plan.json'), - 'r') as f: + with open(os.path.join(self.config.output_dir, 'plan.json'), 'r') as f: plan = json.load(f) if not plan: - logger.error( - 'The plan.json file is empty, please check the file.') + logger.error('The plan.json file is empty, please check the file.') if messages[-1].role == 'user': messages[-1].content = ( f'The complete plan for the current overall financial analysis task is as follows:\n{plan}\n' @@ -53,20 +48,23 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): elif messages[-1].role in ('assistant', 'tool', 'system'): user_message = Message( role='user', - content= - (f'The complete plan for the current global financial analysis task is as follows:\n{plan}\n' - f'Please follow the plan to complete the data collection task.\n' - )) + content=( + f'The complete plan for the current global financial analysis task is as follows:\n{plan}\n' + f'Please follow the plan to complete the data collection task.\n' + ), + ) messages.append(user_message) messages[:] = [ - messages[i] for i in range(len(messages)) - if (messages[i].role == 'system') or ( - i == (len(messages) - 1) and messages[i].role == 'user') + messages[i] + for i in range(len(messages)) + if (messages[i].role == 'system') or (i == (len(messages) - 1) and messages[i].role == 'user') ] else: user_message = Message( role='user', - content= - ('Please conduct data collection in accordance with the research plan ' - 'provided in orchestrator\'s output.')) + content=( + 'Please conduct data collection in accordance with the research plan ' + 'provided in orchestrator\'s output.' + ), + ) messages.append(user_message) diff --git a/projects/fin_research/callbacks/file_parser.py b/projects/fin_research/callbacks/file_parser.py index 67f1bdd69..83465be55 100644 --- a/projects/fin_research/callbacks/file_parser.py +++ b/projects/fin_research/callbacks/file_parser.py @@ -2,9 +2,7 @@ from typing import List, Optional, Tuple -def extract_code_blocks(text: str, - target_filename: Optional[str] = None - ) -> Tuple[List, str]: +def extract_code_blocks(text: str, target_filename: Optional[str] = None) -> Tuple[List, str]: """Extract code blocks from the given text. ```py:a.py diff --git a/projects/fin_research/callbacks/orchestrator_callback.py b/projects/fin_research/callbacks/orchestrator_callback.py index d9b507dce..d3456fba2 100644 --- a/projects/fin_research/callbacks/orchestrator_callback.py +++ b/projects/fin_research/callbacks/orchestrator_callback.py @@ -3,19 +3,19 @@ from typing import List from file_parser import extract_code_blocks +from omegaconf import DictConfig + from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.tools.filesystem_tool import FileSystemTool from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() class OrchestratorCallback(Callback): - """Save output plan to local disk. - """ + """Save output plan to local disk.""" def __init__(self, config: DictConfig): super().__init__(config) @@ -37,8 +37,7 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): all_files, _ = extract_code_blocks(content) results = [] for f in all_files: - result = await self.file_system.write_file(f['filename'], - f['code']) + result = await self.file_system.write_file(f['filename'], f['code']) results.append(result) r = '\n'.join(results) diff --git a/projects/fin_research/searcher.py b/projects/fin_research/searcher.py index 23703fa64..4966e3873 100644 --- a/projects/fin_research/searcher.py +++ b/projects/fin_research/searcher.py @@ -1,16 +1,16 @@ +import json import os from typing import List, Union -import json from callbacks.file_parser import extract_code_blocks +from omegaconf import DictConfig + from ms_agent.agent.code_agent import CodeAgent from ms_agent.llm import Message from ms_agent.llm.openai import OpenAIChat from ms_agent.tools.search_engine import get_web_search_tool from ms_agent.utils import get_logger -from ms_agent.workflow.deep_research.research_workflow_beta import \ - ResearchWorkflowBeta -from omegaconf import DictConfig +from ms_agent.workflow.deep_research.research_workflow_beta import ResearchWorkflowBeta logger = get_logger() @@ -20,51 +20,39 @@ class SearchAgent(CodeAgent): """Agent wrapper that delegates work to ResearchWorkflowBeta.""" - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) if isinstance(self.config, DictConfig): if hasattr(self.config, 'llm'): llm_config = self.config.llm - api_key = getattr(llm_config, 'openai_api_key', - '') or os.getenv('OPENAI_API_KEY') - base_url = getattr(llm_config, 'openai_base_url', - '') or os.getenv('OPENAI_BASE_URL') - model = getattr(llm_config, 'model', - '') or 'Qwen/Qwen3-235B-A22B-Instruct-2507' - self.chat_client = OpenAIChat( - api_key=api_key, base_url=base_url, model=model) + api_key = getattr(llm_config, 'openai_api_key', '') or os.getenv('OPENAI_API_KEY') + base_url = getattr(llm_config, 'openai_base_url', '') or os.getenv('OPENAI_BASE_URL') + model = getattr(llm_config, 'model', '') or 'Qwen/Qwen3-235B-A22B-Instruct-2507' + self.chat_client = OpenAIChat(api_key=api_key, base_url=base_url, model=model) else: - raise ValueError( - 'LLM configuration not found, SearchAgent requires OpenAI compatible API.' - ) + raise ValueError('LLM configuration not found, SearchAgent requires OpenAI compatible API.') - if hasattr(self.config, 'tools') and hasattr( - self.config.tools, 'search_engine'): + if hasattr(self.config, 'tools') and hasattr(self.config.tools, 'search_engine'): self.search_engine = get_web_search_tool( - config_file=getattr(self.config.tools.search_engine, - 'config_file', '')) + config_file=getattr(self.config.tools.search_engine, 'config_file', '') + ) else: raise ValueError('Search engine configuration not found.') self.workdir = getattr(self.config, 'output_dir', './output') self.use_ray = getattr(self.config, 'use_ray', False) - self.report_prefix = getattr(self.config, 'report_prefix', - 'sentiment_') + self.report_prefix = getattr(self.config, 'report_prefix', 'sentiment_') - async def run(self, inputs: Union[str, List[Message]], - **kwargs) -> List[Message]: + async def run(self, inputs: Union[str, List[Message]], **kwargs) -> List[Message]: workflow = ResearchWorkflowBeta( client=self.chat_client, search_engine=self.search_engine, workdir=self.workdir, use_ray=self.use_ray, enable_multimodal=False, - report_prefix=self.report_prefix) + report_prefix=self.report_prefix, + ) if inputs is None: return [Message(role='assistant', content='')] @@ -74,29 +62,24 @@ async def run(self, inputs: Union[str, List[Message]], instruction = {} for message in inputs[::-1]: if message.role == 'assistant': - instruction = json.loads( - extract_code_blocks(message.content)[0][0].get( - 'code', {})) + instruction = json.loads(extract_code_blocks(message.content)[0][0].get('code', {})) break - if not instruction and os.path.exists( - os.path.join(self.workdir, 'plan.json')): + if not instruction and os.path.exists(os.path.join(self.workdir, 'plan.json')): with open(os.path.join(self.workdir, 'plan.json'), 'r') as f: instruction = json.load(f) user_prompt = json.dumps( { - 'public_sentiment_dimension': - instruction.get('public_sentiment_dimension', {}), + 'public_sentiment_dimension': instruction.get('public_sentiment_dimension', {}), }, ensure_ascii=False, - indent=2) + indent=2, + ) elif isinstance(inputs, str): user_prompt = inputs else: - raise ValueError( - 'Invalid input type, SearchAgent requires a string or list of messages.' - ) + raise ValueError('Invalid input type, SearchAgent requires a string or list of messages.') report_path = await workflow.run( user_prompt=user_prompt, @@ -107,7 +90,5 @@ async def run(self, inputs: Union[str, List[Message]], ) result_content = report_path if report_path else 'No report generated.' - result_content = json.dumps({'report_path': report_path}, - ensure_ascii=False, - indent=2) + result_content = json.dumps({'report_path': report_path}, ensure_ascii=False, indent=2) return [Message(role='user', content=result_content)] diff --git a/projects/fin_research/time_handler.py b/projects/fin_research/time_handler.py index dd01bc8ae..93c8b92b7 100644 --- a/projects/fin_research/time_handler.py +++ b/projects/fin_research/time_handler.py @@ -2,9 +2,10 @@ from datetime import datetime from typing import Any -from ms_agent.config.config import ConfigLifecycleHandler from omegaconf import DictConfig +from ms_agent.config.config import ConfigLifecycleHandler + class TimeHandler(ConfigLifecycleHandler): """Config handler that injects current date/time into prompts""" @@ -24,8 +25,7 @@ def task_begin(self, config: DictConfig, tag: str) -> DictConfig: def traverse_and_replace(_config: Any): if isinstance(_config, DictConfig): for name, value in _config.items(): - if isinstance(value, DictConfig) or isinstance( - value, list): + if isinstance(value, DictConfig) or isinstance(value, list): traverse_and_replace(value) elif isinstance(value, str): new_value = value @@ -33,8 +33,7 @@ def traverse_and_replace(_config: Any): for var_name, var_value in time_vars.items(): placeholder = f'<{var_name}>' if placeholder in new_value: - new_value = new_value.replace( - placeholder, var_value) + new_value = new_value.replace(placeholder, var_value) setattr(_config, name, new_value) elif isinstance(_config, list): @@ -47,8 +46,7 @@ def traverse_and_replace(_config: Any): for var_name, var_value in time_vars.items(): placeholder = f'<{var_name}>' if placeholder in new_value: - new_value = new_value.replace( - placeholder, var_value) + new_value = new_value.replace(placeholder, var_value) _config[i] = new_value traverse_and_replace(config) diff --git a/projects/fin_research/tools/principle_skill.py b/projects/fin_research/tools/principle_skill.py index 19882f497..99e77ac9e 100644 --- a/projects/fin_research/tools/principle_skill.py +++ b/projects/fin_research/tools/principle_skill.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # flake8: noqa +import json import os from typing import Any, Dict, List, Optional, Tuple -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger @@ -81,8 +81,7 @@ class PrincipleSkill(ToolBase): def __init__(self, config): super().__init__(config) - tools_cfg = getattr(config, 'tools', - None) if config is not None else None + tools_cfg = getattr(config, 'tools', None) if config is not None else None self.exclude_func(getattr(tools_cfg, 'principle_skill', None)) configured_dir = None @@ -96,9 +95,7 @@ def __init__(self, config): self.principle_dir = configured_dir or default_dir # Build a mapping from normalized user inputs to on-disk filenames and display names - self._name_to_file: Dict[str, - Tuple[str, - str]] = self._build_principle_index() + self._name_to_file: Dict[str, Tuple[str, str]] = self._build_principle_index() async def connect(self): # Warn once if the directory cannot be found; still operate to allow deferred config @@ -114,66 +111,61 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='load_principles', server_name='principle_skill', - description= - (f'Load one or more analysis principles (concept + how to apply to ' - f'financial analysis) and return their curated Markdown content.\n\n' - f'This is a single-aggregator tool designed to fetch multiple principles ' - f'in one call. Provide a list of requested principles via the "principles" ' - f'parameter. The tool supports common synonyms and is case-insensitive.\n\n' - f'Examples of valid principle identifiers: "MECE", "Pyramid", "Minto", ' - f'"SWOT", "Value Chain", "Pareto", "80-20", "80/20", "Boston Matrix", "BCG".\n\n' - f'When format is "markdown" (default), the tool returns a single combined ' - f'Markdown string (optionally including section titles). When format is ' - f'"json", the tool returns a JSON object mapping principle to content.\n' - f'{PRINCIPLE_GUIDE}\n' - f'{ROUTING_GUIDE}\n'), + description=( + f'Load one or more analysis principles (concept + how to apply to ' + f'financial analysis) and return their curated Markdown content.\n\n' + f'This is a single-aggregator tool designed to fetch multiple principles ' + f'in one call. Provide a list of requested principles via the "principles" ' + f'parameter. The tool supports common synonyms and is case-insensitive.\n\n' + f'Examples of valid principle identifiers: "MECE", "Pyramid", "Minto", ' + f'"SWOT", "Value Chain", "Pareto", "80-20", "80/20", "Boston Matrix", "BCG".\n\n' + f'When format is "markdown" (default), the tool returns a single combined ' + f'Markdown string (optionally including section titles). When format is ' + f'"json", the tool returns a JSON object mapping principle to content.\n' + f'{PRINCIPLE_GUIDE}\n' + f'{ROUTING_GUIDE}\n' + ), parameters={ 'type': 'object', 'properties': { 'principles': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - ('List of principles to load. Case-insensitive; supports synonyms.\n' - 'Allowed identifiers include (non-exhaustive):\n' - '- MECE\n- Pyramid\n- Minto\n- SWOT\n- Value Chain\n' - '- Pareto\n- 80-20\n- 80/20\n- Boston Matrix\n- BCG\n' - ), + 'type': 'array', + 'items': {'type': 'string'}, + 'description': ( + 'List of principles to load. Case-insensitive; supports synonyms.\n' + 'Allowed identifiers include (non-exhaustive):\n' + '- MECE\n- Pyramid\n- Minto\n- SWOT\n- Value Chain\n' + '- Pareto\n- 80-20\n- 80/20\n- Boston Matrix\n- BCG\n' + ), }, 'format': { - 'type': - 'string', + 'type': 'string', 'enum': ['markdown', 'json'], - 'description': - ('Output format: "markdown" (combined Markdown string) or "json" ' - '(JSON object mapping principle to content). Default: "markdown".' - ), + 'description': ( + 'Output format: "markdown" (combined Markdown string) or "json" ' + '(JSON object mapping principle to content). Default: "markdown".' + ), }, 'include_titles': { - 'type': - 'boolean', - 'description': - ('When format="markdown", if true, each section is prefixed with a ' - 'Markdown heading of the canonical principle title. Default: true.' - ), + 'type': 'boolean', + 'description': ( + 'When format="markdown", if true, each section is prefixed with a ' + 'Markdown heading of the canonical principle title. Default: true.' + ), }, 'join_with': { - 'type': - 'string', - 'description': - ('When format="markdown", the delimiter used to join multiple ' - 'sections. Default: "\n\n---\n\n".'), + 'type': 'string', + 'description': ( + 'When format="markdown", the delimiter used to join multiple ' + 'sections. Default: "\n\n---\n\n".' + ), }, 'strict': { - 'type': - 'boolean', - 'description': - ('If true, unknown principles cause an error. If false, unknown ' - 'items are ignored with a note in the output. Default: false.' - ), + 'type': 'boolean', + 'description': ( + 'If true, unknown principles cause an error. If false, unknown ' + 'items are ignored with a note in the output. Default: false.' + ), }, }, 'required': ['principles'], @@ -184,8 +176,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: } return tools - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await getattr(self, tool_name)(**tool_args) async def load_principles( @@ -204,10 +195,7 @@ async def load_principles( if not principles: return json.dumps( - { - 'success': False, - 'error': 'No principles provided.' - }, + {'success': False, 'error': 'No principles provided.'}, ensure_ascii=False, indent=2, ) @@ -223,12 +211,7 @@ async def load_principles( if unknown and strict: return json.dumps( - { - 'success': - False, - 'error': - 'Unknown principles (strict mode): ' + ', '.join(unknown) - }, + {'success': False, 'error': 'Unknown principles (strict mode): ' + ', '.join(unknown)}, ensure_ascii=False, indent=2, ) @@ -241,15 +224,11 @@ async def load_principles( content = f.read().strip() loaded[canonical_title] = content except Exception as e: # noqa - loaded[ - canonical_title] = f'Failed to load {filename}: {str(e)}' + loaded[canonical_title] = f'Failed to load {filename}: {str(e)}' if not loaded: return json.dumps( - { - 'success': False, - 'error': 'Failed to load any principles.' - }, + {'success': False, 'error': 'Failed to load any principles.'}, ensure_ascii=False, indent=2, ) @@ -272,14 +251,10 @@ async def load_principles( sections.append(content) if unknown and not strict: - sections.append( - f'> Note: Unknown principles ignored: {", ".join(unknown)}') + sections.append(f'> Note: Unknown principles ignored: {", ".join(unknown)}') return json.dumps( - { - 'success': True, - 'sections': sections - }, + {'success': True, 'sections': sections}, ensure_ascii=False, indent=2, ) @@ -288,23 +263,24 @@ def _build_principle_index(self) -> Dict[str, Tuple[str, str]]: """Return mapping from normalized query → (filename, canonical title).""" entries: List[Tuple[List[str], str, str]] = [ # synonyms, filename, canonical title - (['mece', 'mutually exclusive and collectively exhaustive'], - 'MECE.md', 'MECE'), - ([ - 'pyramid', 'minto', 'minto pyramid', 'pyramid principle', - 'minto_pyramid' - ], 'Minto_Pyramid.md', 'Pyramid (Minto Pyramid)'), + (['mece', 'mutually exclusive and collectively exhaustive'], 'MECE.md', 'MECE'), + ( + ['pyramid', 'minto', 'minto pyramid', 'pyramid principle', 'minto_pyramid'], + 'Minto_Pyramid.md', + 'Pyramid (Minto Pyramid)', + ), (['swot', 'swot analysis'], 'SWOT.md', 'SWOT'), - (['value chain', 'value-chain', - 'value_chain'], 'Value_Chain.md', 'Value Chain'), - ([ - 'pareto', '80-20', '80/20', 'pareto 80-20', 'pareto_80-20', - '8020' - ], 'Pareto_80-20.md', 'Pareto (80/20 Rule)'), - ([ - 'boston matrix', 'bcg', 'boston consulting group', - 'boston_matrix', 'boston' - ], 'Boston_Matrix.md', 'Boston Matrix (BCG)'), + (['value chain', 'value-chain', 'value_chain'], 'Value_Chain.md', 'Value Chain'), + ( + ['pareto', '80-20', '80/20', 'pareto 80-20', 'pareto_80-20', '8020'], + 'Pareto_80-20.md', + 'Pareto (80/20 Rule)', + ), + ( + ['boston matrix', 'bcg', 'boston consulting group', 'boston_matrix', 'boston'], + 'Boston_Matrix.md', + 'Boston Matrix (BCG)', + ), ] index: Dict[str, Tuple[str, str]] = {} diff --git a/projects/fin_research/tools/spec_loader.py b/projects/fin_research/tools/spec_loader.py index 6ed8e1e21..a2c86bdcf 100644 --- a/projects/fin_research/tools/spec_loader.py +++ b/projects/fin_research/tools/spec_loader.py @@ -1,14 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # flake8: noqa +import json import os +from spec_constant import PRINCIPLE_ROUTING_GUIDE, PRINCIPLE_SPEC_GUIDE, WRITING_ROUTING_GUIDE, WRITING_SPEC_GUIDE from typing import Any, Dict, List, Tuple -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger -from spec_constant import (PRINCIPLE_ROUTING_GUIDE, PRINCIPLE_SPEC_GUIDE, - WRITING_ROUTING_GUIDE, WRITING_SPEC_GUIDE) logger = get_logger() @@ -36,14 +35,11 @@ class SpecLoader(ToolBase): def __init__(self, config): super().__init__(config) - tools_cfg = getattr(config, 'tools', - None) if config is not None else None - spec_cfg = getattr(tools_cfg, 'spec_loader', - None) if tools_cfg is not None else None + tools_cfg = getattr(config, 'tools', None) if config is not None else None + spec_cfg = getattr(tools_cfg, 'spec_loader', None) if tools_cfg is not None else None self.exclude_func(spec_cfg) - configured_dir = getattr(spec_cfg, 'spec_dir', - None) if spec_cfg is not None else None + configured_dir = getattr(spec_cfg, 'spec_dir', None) if spec_cfg is not None else None default_dir = os.path.join(os.getcwd(), self.SPEC_DIR) self.spec_dir = configured_dir or default_dir @@ -61,66 +57,61 @@ async def get_tools(self) -> Dict[str, Any]: Tool( tool_name='load_writing_specs', server_name='spec_loader', - description= - ('Load one or more writing-style specs (rules + examples) and return ' - 'their curated Markdown content. Use this when you are unsure about ' - 'how to structure or phrase a financial report in an analyst-like style.\n\n' - 'Supported spec identifiers (case-insensitive, synonyms allowed):\n' - '- structure → section depth / headings\n' - '- methods → how much to expose MECE/SWOT/etc.\n' - '- bullets → bullets vs paragraphs\n' - '- focus → task focus and relevance\n' - '- tone → analyst-style voice\n' - '- density → length and information density control\n\n' - 'Provide a list of requested writing specs via the "writing_specs" parameter.\n\n' - f'{WRITING_SPEC_GUIDE}\n' - f'{WRITING_ROUTING_GUIDE}\n'), + description=( + 'Load one or more writing-style specs (rules + examples) and return ' + 'their curated Markdown content. Use this when you are unsure about ' + 'how to structure or phrase a financial report in an analyst-like style.\n\n' + 'Supported spec identifiers (case-insensitive, synonyms allowed):\n' + '- structure → section depth / headings\n' + '- methods → how much to expose MECE/SWOT/etc.\n' + '- bullets → bullets vs paragraphs\n' + '- focus → task focus and relevance\n' + '- tone → analyst-style voice\n' + '- density → length and information density control\n\n' + 'Provide a list of requested writing specs via the "writing_specs" parameter.\n\n' + f'{WRITING_SPEC_GUIDE}\n' + f'{WRITING_ROUTING_GUIDE}\n' + ), parameters={ 'type': 'object', 'properties': { 'writing_specs': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - ('List of writing specs to load. Case-insensitive; supports synonyms.\n' - 'Allowed identifiers include (non-exhaustive):\n' - '- structure\n- methods\n- bullets\n- focus\n- tone\n- density\n' - ), + 'type': 'array', + 'items': {'type': 'string'}, + 'description': ( + 'List of writing specs to load. Case-insensitive; supports synonyms.\n' + 'Allowed identifiers include (non-exhaustive):\n' + '- structure\n- methods\n- bullets\n- focus\n- tone\n- density\n' + ), }, 'format': { - 'type': - 'string', + 'type': 'string', 'enum': ['markdown', 'json'], - 'description': - ('Output format: "markdown" (combined Markdown string) or "json" ' - '(JSON object mapping spec to content). Default: "markdown".' - ), + 'description': ( + 'Output format: "markdown" (combined Markdown string) or "json" ' + '(JSON object mapping spec to content). Default: "markdown".' + ), }, 'include_titles': { - 'type': - 'boolean', - 'description': - ('When format="markdown", if true, each section is prefixed with a ' - 'Markdown heading of the canonical spec title. Default: false.' - ), + 'type': 'boolean', + 'description': ( + 'When format="markdown", if true, each section is prefixed with a ' + 'Markdown heading of the canonical spec title. Default: false.' + ), }, 'join_with': { - 'type': - 'string', - 'description': - ('When format="markdown", the delimiter used to join multiple ' - 'sections. Default: "\n\n---\n\n".'), + 'type': 'string', + 'description': ( + 'When format="markdown", the delimiter used to join multiple ' + 'sections. Default: "\n\n---\n\n".' + ), }, 'strict': { - 'type': - 'boolean', - 'description': - ('If true, unknown specs cause an error. If false, unknown items are ' - 'ignored with a note in the output. Default: false.' - ), + 'type': 'boolean', + 'description': ( + 'If true, unknown specs cause an error. If false, unknown items are ' + 'ignored with a note in the output. Default: false.' + ), }, }, 'required': ['writing_specs'], @@ -130,94 +121,83 @@ async def get_tools(self) -> Dict[str, Any]: Tool( tool_name='load_principle_specs', server_name='spec_loader', - description= - (f'Load one or more analysis principles (concept + how to apply to ' - f'financial analysis) and return their curated Markdown content.\n\n' - f'This is a single-aggregator tool designed to fetch multiple principles ' - f'in one call. Provide a list of requested principles via the "principles" ' - f'parameter. The tool supports common synonyms and is case-insensitive.\n\n' - f'Examples of valid principle identifiers: "MECE", "Pyramid", "Minto", ' - f'"SWOT", "Value Chain", "Pareto", "80-20", "80/20", "Boston Matrix", "BCG".\n\n' - f'When format is "markdown" (default), the tool returns a single combined ' - f'Markdown string (optionally including section titles). When format is ' - f'"json", the tool returns a JSON object mapping principle to content.\n' - f'{PRINCIPLE_SPEC_GUIDE}\n' - f'{PRINCIPLE_ROUTING_GUIDE}\n'), + description=( + f'Load one or more analysis principles (concept + how to apply to ' + f'financial analysis) and return their curated Markdown content.\n\n' + f'This is a single-aggregator tool designed to fetch multiple principles ' + f'in one call. Provide a list of requested principles via the "principles" ' + f'parameter. The tool supports common synonyms and is case-insensitive.\n\n' + f'Examples of valid principle identifiers: "MECE", "Pyramid", "Minto", ' + f'"SWOT", "Value Chain", "Pareto", "80-20", "80/20", "Boston Matrix", "BCG".\n\n' + f'When format is "markdown" (default), the tool returns a single combined ' + f'Markdown string (optionally including section titles). When format is ' + f'"json", the tool returns a JSON object mapping principle to content.\n' + f'{PRINCIPLE_SPEC_GUIDE}\n' + f'{PRINCIPLE_ROUTING_GUIDE}\n' + ), parameters={ 'type': 'object', 'properties': { 'principles': { - 'type': - 'array', - 'items': { - 'type': 'string' - }, - 'description': - ('List of principles to load. Case-insensitive; supports synonyms.\n' - 'Allowed identifiers include (non-exhaustive):\n' - '- MECE\n- Pyramid\n- Minto\n- SWOT\n- Value Chain\n' - '- Pareto\n- 80-20\n- 80/20\n- Boston Matrix\n- BCG\n' - ), + 'type': 'array', + 'items': {'type': 'string'}, + 'description': ( + 'List of principles to load. Case-insensitive; supports synonyms.\n' + 'Allowed identifiers include (non-exhaustive):\n' + '- MECE\n- Pyramid\n- Minto\n- SWOT\n- Value Chain\n' + '- Pareto\n- 80-20\n- 80/20\n- Boston Matrix\n- BCG\n' + ), }, 'format': { - 'type': - 'string', + 'type': 'string', 'enum': ['markdown', 'json'], - 'description': - ('Output format: "markdown" (combined Markdown string) or "json" ' - '(JSON object mapping principle to content). Default: "markdown".' - ), + 'description': ( + 'Output format: "markdown" (combined Markdown string) or "json" ' + '(JSON object mapping principle to content). Default: "markdown".' + ), }, 'include_titles': { - 'type': - 'boolean', - 'description': - ('When format="markdown", if true, each section is prefixed with a ' - 'Markdown heading of the canonical principle title. Default: false.' - ), + 'type': 'boolean', + 'description': ( + 'When format="markdown", if true, each section is prefixed with a ' + 'Markdown heading of the canonical principle title. Default: false.' + ), }, 'join_with': { - 'type': - 'string', - 'description': - ('When format="markdown", the delimiter used to join multiple ' - 'sections. Default: "\n\n---\n\n".'), + 'type': 'string', + 'description': ( + 'When format="markdown", the delimiter used to join multiple ' + 'sections. Default: "\n\n---\n\n".' + ), }, 'strict': { - 'type': - 'boolean', - 'description': - ('If true, unknown principles cause an error. If false, unknown ' - 'items are ignored with a note in the output. Default: false.' - ), + 'type': 'boolean', + 'description': ( + 'If true, unknown principles cause an error. If false, unknown ' + 'items are ignored with a note in the output. Default: false.' + ), }, }, 'required': ['principles'], 'additionalProperties': False, }, - ) + ), ] } if hasattr(self, 'exclude_functions') and self.exclude_functions: - tools['spec_loader'] = [ - t for t in tools['spec_loader'] - if t['tool_name'] not in self.exclude_functions - ] + tools['spec_loader'] = [t for t in tools['spec_loader'] if t['tool_name'] not in self.exclude_functions] return tools - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await getattr(self, tool_name)(**tool_args) - async def load_writing_specs(self, writing_specs: List[str], - **kwargs) -> str: + async def load_writing_specs(self, writing_specs: List[str], **kwargs) -> str: writing_spec_map = self._build_writing_spec_index() return await self.load_specs(writing_spec_map, writing_specs, **kwargs) - async def load_principle_specs(self, principles: List[str], - **kwargs) -> str: + async def load_principle_specs(self, principles: List[str], **kwargs) -> str: principle_map = self._build_principle_spec_index() return await self.load_specs(principle_map, principles, **kwargs) @@ -238,10 +218,7 @@ async def load_specs( if not specs: return json.dumps( - { - 'success': False, - 'error': 'No specs provided.' - }, + {'success': False, 'error': 'No specs provided.'}, ensure_ascii=False, indent=2, ) @@ -259,8 +236,7 @@ async def load_specs( return json.dumps( { 'success': False, - 'error': - 'Unknown specs (strict mode): ' + ', '.join(unknown), + 'error': 'Unknown specs (strict mode): ' + ', '.join(unknown), }, ensure_ascii=False, indent=2, @@ -274,15 +250,11 @@ async def load_specs( content = f.read().strip() loaded[canonical_title] = content except Exception as e: # noqa - loaded[ - canonical_title] = f'Failed to load {filename}: {str(e)}' + loaded[canonical_title] = f'Failed to load {filename}: {str(e)}' if not loaded: return json.dumps( - { - 'success': False, - 'error': 'Failed to load any specs.' - }, + {'success': False, 'error': 'Failed to load any specs.'}, ensure_ascii=False, indent=2, ) @@ -305,16 +277,9 @@ async def load_specs( sections.append(content) if unknown and not strict: - sections.append( - f'> Note: Unknown specs ignored: {", ".join(unknown)}') + sections.append(f'> Note: Unknown specs ignored: {", ".join(unknown)}') - return json.dumps( - { - 'success': True, - 'sections': join_with.join(sections) - }, - ensure_ascii=False, - indent=2) + return json.dumps({'success': True, 'sections': join_with.join(sections)}, ensure_ascii=False, indent=2) def _build_writing_spec_index(self) -> Dict[str, Tuple[str, str]]: """Return writing spec mapping from normalized query → (filename, canonical title).""" @@ -326,18 +291,12 @@ def _build_writing_spec_index(self) -> Dict[str, Tuple[str, str]]: 'Structure & Layering', ), ( - [ - 'methods', 'methodology', 'framework exposure', - 'methodology exposure' - ], + ['methods', 'methodology', 'framework exposure', 'methodology exposure'], 'writing_specs/Methodology_Exposure.md', 'Methodology Exposure', ), ( - [ - 'bullets', 'bullet', 'bullets & paragraphs', - 'paragraph rhythm' - ], + ['bullets', 'bullet', 'bullets & paragraphs', 'paragraph rhythm'], 'writing_specs/Bullets_Paragraph_Rhythm.md', 'Bullets & Paragraph Rhythm', ), @@ -368,23 +327,24 @@ def _build_principle_spec_index(self) -> Dict[str, Tuple[str, str]]: """Return principle spec mapping from normalized query → (filename, canonical title).""" entries: List[Tuple[List[str], str, str]] = [ # synonyms, filename, canonical title - (['mece', 'mutually exclusive and collectively exhaustive'], - 'principle_specs/MECE.md', 'MECE'), - ([ - 'pyramid', 'minto', 'minto pyramid', 'pyramid principle', - 'minto_pyramid' - ], 'principle_specs/Minto_Pyramid.md', 'Pyramid (Minto Pyramid)'), + (['mece', 'mutually exclusive and collectively exhaustive'], 'principle_specs/MECE.md', 'MECE'), + ( + ['pyramid', 'minto', 'minto pyramid', 'pyramid principle', 'minto_pyramid'], + 'principle_specs/Minto_Pyramid.md', + 'Pyramid (Minto Pyramid)', + ), (['swot', 'swot analysis'], 'principle_specs/SWOT.md', 'SWOT'), - (['value chain', 'value-chain', - 'value_chain'], 'principle_specs/Value_Chain.md', 'Value Chain'), - ([ - 'pareto', '80-20', '80/20', 'pareto 80-20', 'pareto_80-20', - '8020' - ], 'principle_specs/Pareto_80-20.md', 'Pareto (80/20 Rule)'), - ([ - 'boston matrix', 'bcg', 'boston consulting group', - 'boston_matrix', 'boston' - ], 'principle_specs/Boston_Matrix.md', 'Boston Matrix (BCG)'), + (['value chain', 'value-chain', 'value_chain'], 'principle_specs/Value_Chain.md', 'Value Chain'), + ( + ['pareto', '80-20', '80/20', 'pareto 80-20', 'pareto_80-20', '8020'], + 'principle_specs/Pareto_80-20.md', + 'Pareto (80/20 Rule)', + ), + ( + ['boston matrix', 'bcg', 'boston consulting group', 'boston_matrix', 'boston'], + 'principle_specs/Boston_Matrix.md', + 'Boston Matrix (BCG)', + ), ] index: Dict[str, Tuple[str, str]] = {} diff --git a/projects/singularity_cinema/compose_video/agent.py b/projects/singularity_cinema/compose_video/agent.py index 09995c629..c2d5eae4f 100644 --- a/projects/singularity_cinema/compose_video/agent.py +++ b/projects/singularity_cinema/compose_video/agent.py @@ -1,33 +1,28 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 +import json import math import os import shutil from copy import deepcopy -import json import moviepy as mp from moviepy import AudioClip +from omegaconf import DictConfig +from PIL import Image + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() class ComposeVideo(CodeAgent): - - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') - self.background_effect = getattr(self.config, 'background_effect', - None) + self.background_effect = getattr(self.config, 'background_effect', None) self.bg_path = os.path.join(self.work_dir, 'background.png') # Determine render directory based on engine @@ -45,9 +40,17 @@ def __init__(self, self.preset = getattr(self.config.video, 'preset', 'ultrafast') self.fps = getattr(self.config.video, 'fps', 24) - def compose_final_video(self, background_path, foreground_paths, - audio_paths, subtitle_paths, illustration_paths, - video_paths, segments, output_path): + def compose_final_video( + self, + background_path, + foreground_paths, + audio_paths, + subtitle_paths, + illustration_paths, + video_paths, + segments, + output_path, + ): segment_durations = [] logger.info('Composing the final video.') @@ -59,8 +62,7 @@ def compose_final_video(self, background_path, foreground_paths, segment = segments[i] is_video_frame = 'video' in segment use_video_soundtrack = self.config.use_video_soundtrack and is_video_frame - if audio_path and os.path.exists( - audio_path) and not use_video_soundtrack: + if audio_path and os.path.exists(audio_path) and not use_video_soundtrack: try: audio_clip = mp.AudioFileClip(audio_path) # Use actual audio duration + small pause, no minimum enforcement @@ -71,15 +73,10 @@ def compose_final_video(self, background_path, foreground_paths, else: actual_duration = None if not use_video_soundtrack: - raise ValueError( - f'File {audio_path} does not exist, run again to generate it.' - ) + raise ValueError(f'File {audio_path} does not exist, run again to generate it.') - if i < len(foreground_paths - ) and foreground_paths[i] and os.path.exists( - foreground_paths[i]): - animation_clip = mp.VideoFileClip( - foreground_paths[i], has_mask=True) + if i < len(foreground_paths) and foreground_paths[i] and os.path.exists(foreground_paths[i]): + animation_clip = mp.VideoFileClip(foreground_paths[i], has_mask=True) animation_duration = animation_clip.duration animation_clip.close() @@ -91,20 +88,16 @@ def compose_final_video(self, background_path, foreground_paths, logger.info('Step1: Compose video for each segment.') segment_videos = [] - for i, (duration, - segment) in enumerate(zip(segment_durations, segments)): + for i, (duration, segment) in enumerate(zip(segment_durations, segments)): if duration is not None: - logger.info( - f'Processing {i + 1} segment - {duration:.1f} seconds.') + logger.info(f'Processing {i + 1} segment - {duration:.1f} seconds.') else: - logger.info( - f'Processing {i + 1} segment - use video soundtrack.') + logger.info(f'Processing {i + 1} segment - use video soundtrack.') current_video_clips = [] # Check if this segment uses generated video instead of illustration - use_generated_video = 'video' in segment and video_paths[ - i] and os.path.exists(video_paths[i]) + use_generated_video = 'video' in segment and video_paths[i] and os.path.exists(video_paths[i]) if use_generated_video: # Use generated video as base layer @@ -125,8 +118,7 @@ def compose_final_video(self, background_path, foreground_paths, video_available_w, video_available_h = 1920, 1080 video_scale_w = video_available_w / video_original_w video_scale_h = video_available_h / video_original_h - video_scale = max(video_scale_w, - video_scale_h) # Cover mode + video_scale = max(video_scale_w, video_scale_h) # Cover mode video_new_w = int(video_original_w * video_scale) video_new_h = int(video_original_h * video_scale) @@ -136,16 +128,13 @@ def compose_final_video(self, background_path, foreground_paths, video_new_h += 1 if video_new_w > 0 and video_new_h > 0: - video_clip = video_clip.resized( - (video_new_w, video_new_h)) + video_clip = video_clip.resized((video_new_w, video_new_h)) video_clip = video_clip.with_position('center') # Extract and preserve video audio before adjusting duration video_audio = None if video_clip.audio is not None: - logger.info( - f'Extracting audio from generated video {i + 1}' - ) + logger.info(f'Extracting audio from generated video {i + 1}') video_audio = video_clip.audio segment_video_audios.append(video_audio) @@ -155,40 +144,32 @@ def compose_final_video(self, background_path, foreground_paths, assert duration is not None and duration > 0 # Adjust video duration to match segment duration if video_clip.duration < duration: - logger.info( - f'Video {i + 1} is shorter than segment, extending to {duration:.1f}s' - ) - video_clip = video_clip.with_duration( - duration) + logger.info(f'Video {i + 1} is shorter than segment, extending to {duration:.1f}s') + video_clip = video_clip.with_duration(duration) elif video_clip.duration > duration: - logger.info( - f'Video {i + 1} is longer than segment, trimming to {duration:.1f}s' - ) - video_clip = video_clip.subclipped( - 0, duration) + logger.info(f'Video {i + 1} is longer than segment, trimming to {duration:.1f}s') + video_clip = video_clip.subclipped(0, duration) current_video_clips.append(video_clip) else: - logger.error( - f'Invalid scaled video dimensions: {video_new_w}x{video_new_h}' - ) + logger.error(f'Invalid scaled video dimensions: {video_new_w}x{video_new_h}') video_clip.close() use_generated_video = False except Exception as e: - logger.error( - f'Failed to process video for segment {i + 1}: {e}') + logger.error(f'Failed to process video for segment {i + 1}: {e}') use_generated_video = False segment_video_audios.append(None) else: segment_video_audios.append(None) # Add illustration as base layer (if not using generated video) - if not use_generated_video and i < len( - illustration_paths - ) and illustration_paths[i] and os.path.exists( - illustration_paths[i]): - illustration_clip = mp.ImageClip( - illustration_paths[i], duration=duration) + if ( + not use_generated_video + and i < len(illustration_paths) + and illustration_paths[i] + and os.path.exists(illustration_paths[i]) + ): + illustration_clip = mp.ImageClip(illustration_paths[i], duration=duration) bg_original_w, bg_original_h = illustration_clip.size # Validate image dimensions @@ -214,12 +195,10 @@ def compose_final_video(self, background_path, foreground_paths, # Ensure dimensions are positive if bg_new_w <= 0 or bg_new_h <= 0: - logger.error( - f'Invalid scaled dimensions: {bg_new_w}x{bg_new_h}') + logger.error(f'Invalid scaled dimensions: {bg_new_w}x{bg_new_h}') continue - illustration_clip = illustration_clip.resized( - (bg_new_w, bg_new_h)) + illustration_clip = illustration_clip.resized((bg_new_w, bg_new_h)) exit_duration = 1.0 start_animation_time = max(duration - exit_duration, 0) @@ -240,13 +219,11 @@ def make_ken_burns(t): progress = t / kb_duration if kb_duration > 0 else 0 progress = min(1.0, progress) # Cubic easing for smooth acceleration/deceleration - eased_progress = progress * progress * ( - 3.0 - 2.0 * progress) + eased_progress = progress * progress * (3.0 - 2.0 * progress) if eased_progress > 1.0: eased_progress = 1.0 # Calculate current zoom level - current_zoom = zoom_start + ( - zoom_end - zoom_start) * eased_progress + current_zoom = zoom_start + (zoom_end - zoom_start) * eased_progress # Calculate new dimensions with validation zoom_w = int(kb_base_w * current_zoom) zoom_h = int(kb_base_h * current_zoom) @@ -257,26 +234,20 @@ def make_ken_burns(t): return zoom_w, zoom_h # Apply the zoom effect with resizing over time - illustration_clip = illustration_clip.resized( - make_ken_burns) + illustration_clip = illustration_clip.resized(make_ken_burns) # Keep image centered and stable throughout the animation - illustration_clip = illustration_clip.with_position( - 'center') + illustration_clip = illustration_clip.with_position('center') elif self.background_effect == 'slide': # TODO legacy code, untested # Default slide left animation - def illustration_pos_factory(idx, start_x, end_x, bg_h, - start_animation_time, - exit_duration): - + def illustration_pos_factory(idx, start_x, end_x, bg_h, start_animation_time, exit_duration): def illustration_pos(t): y = (1080 - bg_h) // 2 if t < start_animation_time: x = start_x elif t < start_animation_time + exit_duration: - progress = ( - t - start_animation_time) / exit_duration + progress = (t - start_animation_time) / exit_duration progress = min(max(progress, 0), 1) x = start_x + (end_x - start_x) * progress else: @@ -286,22 +257,19 @@ def illustration_pos(t): return illustration_pos illustration_clip = illustration_clip.with_position( - illustration_pos_factory(i, (1920 - bg_new_w) // 2, - -bg_new_w, bg_new_h, - start_animation_time, - exit_duration)) + illustration_pos_factory( + i, (1920 - bg_new_w) // 2, -bg_new_w, bg_new_h, start_animation_time, exit_duration + ) + ) current_video_clips.append(illustration_clip) # Add foreground animation layer - if i < len(foreground_paths - ) and foreground_paths[i] and os.path.exists( - foreground_paths[i]): + if i < len(foreground_paths) and foreground_paths[i] and os.path.exists(foreground_paths[i]): fg_clip = mp.VideoFileClip(foreground_paths[i], has_mask=True) original_w, original_h = fg_clip.size - available_w, available_h = ( - 1250, 700) if self.config.use_subtitle else (1450, 800) + available_w, available_h = (1250, 700) if self.config.use_subtitle else (1450, 800) scale_w = available_w / original_w scale_h = available_h / original_h scale = min(scale_w, scale_h, 1.0) @@ -313,9 +281,7 @@ def illustration_pos(t): if new_w > 0 and new_h > 0: fg_clip = fg_clip.resized((new_w, new_h)) else: - logger.error( - f'Invalid scaled foreground dimensions: {new_w}x{new_h}' - ) + logger.error(f'Invalid scaled foreground dimensions: {new_w}x{new_h}') fg_clip.close() continue @@ -327,8 +293,7 @@ def illustration_pos(t): fg_clip = fg_clip.with_duration(duration) current_video_clips.append(fg_clip) if self.config.use_subtitle: - if duration is not None and i < len( - subtitle_paths) and subtitle_paths[i]: + if duration is not None and i < len(subtitle_paths) and subtitle_paths[i]: segment_subs = subtitle_paths[i] num_subs = len(segment_subs) sub_duration = duration / num_subs @@ -345,17 +310,13 @@ def illustration_pos(t): ) continue - subtitle_clip = mp.ImageClip( - sub_path, duration=sub_duration) + subtitle_clip = mp.ImageClip(sub_path, duration=sub_duration) subtitle_y = 900 - subtitle_clip = subtitle_clip.with_position( - ('center', subtitle_y)) - subtitle_clip = subtitle_clip.with_start( - k * sub_duration) + subtitle_clip = subtitle_clip.with_position(('center', subtitle_y)) + subtitle_clip = subtitle_clip.with_start(k * sub_duration) current_video_clips.append(subtitle_clip) except Exception as e: - logger.error( - f'Failed to load subtitle {sub_path}: {e}') + logger.error(f'Failed to load subtitle {sub_path}: {e}') # Add background as top layer (transparent PNG with decorative elements) if background_path and os.path.exists(background_path): @@ -364,26 +325,24 @@ def illustration_pos(t): current_video_clips.append(bg_clip) if current_video_clips: - segment_video = mp.CompositeVideoClip( - current_video_clips, size=(1920, 1080)) + segment_video = mp.CompositeVideoClip(current_video_clips, size=(1920, 1080)) segment_videos.append(segment_video) logger.info('Step2: Combine all video segments.') - final_video = mp.concatenate_videoclips( - segment_videos, method='compose') + final_video = mp.concatenate_videoclips(segment_videos, method='compose') logger.info('Step3: Compose audios.') if audio_paths: valid_audio_clips = [] - for i, (audio_path, duration, segment) in enumerate( - zip(audio_paths, segment_durations, segments)): + for i, (audio_path, duration, segment) in enumerate(zip(audio_paths, segment_durations, segments)): segment_audio = None # Check if this segment has generated video audio - if i < len(segment_video_audios) and segment_video_audios[ - i] is not None and self.config.use_video_soundtrack: - logger.info( - f'Using audio from generated video for segment {i + 1}' - ) + if ( + i < len(segment_video_audios) + and segment_video_audios[i] is not None + and self.config.use_video_soundtrack + ): + logger.info(f'Using audio from generated video for segment {i + 1}') segment_audio = segment_video_audios[i] elif audio_path and os.path.exists(audio_path): # Use TTS audio if no video audio available @@ -393,14 +352,9 @@ def illustration_pos(t): if audio_clip.duration > duration: audio_clip = audio_clip.subclipped(0, duration) elif audio_clip.duration < duration: - - silence = AudioClip( - lambda t: [0, 0], - duration=duration - - audio_clip.duration).with_fps(44100) + silence = AudioClip(lambda t: [0, 0], duration=duration - audio_clip.duration).with_fps(44100) # silence = silence.set_channels(2) - audio_clip = mp.concatenate_audioclips( - [audio_clip, silence]) + audio_clip = mp.concatenate_audioclips([audio_clip, silence]) segment_audio = audio_clip if segment_audio is not None: @@ -408,18 +362,12 @@ def illustration_pos(t): if valid_audio_clips: final_audio = mp.concatenate_audioclips(valid_audio_clips) - logger.info( - f'Audio composing done: {final_audio.duration:.1f} seconds.' - ) + logger.info(f'Audio composing done: {final_audio.duration:.1f} seconds.') if final_audio.duration > final_video.duration: - final_audio = final_audio.subclipped( - 0, final_video.duration) + final_audio = final_audio.subclipped(0, final_video.duration) elif final_audio.duration < final_video.duration: - silence = AudioClip( - lambda t: [0, 0], - duration=final_video.duration - final_audio.duration) - final_audio = mp.concatenate_audioclips( - [final_audio, silence]) + silence = AudioClip(lambda t: [0, 0], duration=final_video.duration - final_audio.duration) + final_audio = mp.concatenate_audioclips([final_audio, silence]) final_video = final_video.with_audio(final_audio) @@ -427,43 +375,31 @@ def illustration_pos(t): if os.path.exists(self.config.bg_audio_path): bg_music_path = self.config.bg_audio_path else: - bg_music_path = os.path.join(self.config.local_dir, - self.config.bg_audio_path) + bg_music_path = os.path.join(self.config.local_dir, self.config.bg_audio_path) else: bg_music_path = '' - if os.path.exists( - bg_music_path) and not self.config.use_video_soundtrack: + if os.path.exists(bg_music_path) and not self.config.use_video_soundtrack: bg_music = mp.AudioFileClip(bg_music_path) if bg_music.duration < final_video.duration: - repeat_times = int( - final_video.duration / bg_music.duration) + 1 - bg_music = mp.concatenate_audioclips([bg_music] - * repeat_times) + repeat_times = int(final_video.duration / bg_music.duration) + 1 + bg_music = mp.concatenate_audioclips([bg_music] * repeat_times) bg_music = bg_music.subclipped(0, final_video.duration) elif bg_music.duration > final_video.duration: bg_music = bg_music.subclipped(0, final_video.duration) - bg_music = bg_music.with_volume_scaled( - self.config.bg_audio_volume) + bg_music = bg_music.with_volume_scaled(self.config.bg_audio_volume) if final_video.audio: - tts_audio = final_video.audio.with_duration( - final_video.duration).with_volume_scaled(1.0) + tts_audio = final_video.audio.with_duration(final_video.duration).with_volume_scaled(1.0) bg_audio = bg_music.with_duration(final_video.duration) - mixed_audio = mp.CompositeAudioClip( - [tts_audio, - bg_audio]).with_duration(final_video.duration) + mixed_audio = mp.CompositeAudioClip([tts_audio, bg_audio]).with_duration(final_video.duration) else: - mixed_audio = bg_music.with_duration( - final_video.duration).with_volume_scaled(0.3) + mixed_audio = bg_music.with_duration(final_video.duration).with_volume_scaled(0.3) final_video = final_video.with_audio(mixed_audio) assert final_video is not None logger.info('Rendering final video...') - logger.info( - f'Total video duration: {final_video.duration:.1f} seconds') + logger.info(f'Total video duration: {final_video.duration:.1f} seconds') logger.info(f'Video resolution: {final_video.size}') - logger.info( - f"Audio status: {'Has audio' if final_video.audio else 'No audio'}" - ) + logger.info(f"Audio status: {'Has audio' if final_video.audio else 'No audio'}") logger.info(f'final_video type: {type(final_video)}') logger.info(f'final_video attributes: {dir(final_video)}') @@ -480,7 +416,8 @@ def illustration_pos(t): audio_bitrate='192k', audio_fps=44100, preset=self.preset, - write_logfile=False) + write_logfile=False, + ) logger.info(f'file saved: {output_path}') @@ -506,19 +443,14 @@ async def execute_code(self, messages, **kwargs): illustration_paths = [] video_paths = [] for i, segment in enumerate(segments): - illustration_paths.append( - os.path.join(self.images_dir, f'illustration_{i + 1}.png')) - foreground_paths.append( - os.path.join(self.render_dir, f'scene_{i + 1}', - f'Scene{i+1}.mov')) - audio_paths.append( - os.path.join(self.tts_dir, f'segment_{i + 1}.mp3')) + illustration_paths.append(os.path.join(self.images_dir, f'illustration_{i + 1}.png')) + foreground_paths.append(os.path.join(self.render_dir, f'scene_{i + 1}', f'Scene{i + 1}.mov')) + audio_paths.append(os.path.join(self.tts_dir, f'segment_{i + 1}.mp3')) segment_subtitles = [] j = 0 while True: - sub_path = os.path.join(self.subtitle_dir, - f'bilingual_subtitle_{i + 1}_{j}.png') + sub_path = os.path.join(self.subtitle_dir, f'bilingual_subtitle_{i + 1}_{j}.png') if os.path.exists(sub_path): segment_subtitles.append(sub_path) j += 1 @@ -526,8 +458,7 @@ async def execute_code(self, messages, **kwargs): break subtitle_paths.append(segment_subtitles) - video_paths.append( - os.path.join(self.videos_dir, f'video_{i + 1}.mp4')) + video_paths.append(os.path.join(self.videos_dir, f'video_{i + 1}.mp4')) self.compose_final_video( background_path=self.bg_path, @@ -537,5 +468,6 @@ async def execute_code(self, messages, **kwargs): illustration_paths=illustration_paths, video_paths=video_paths, segments=segments, - output_path=final_video_path) + output_path=final_video_path, + ) return messages diff --git a/projects/singularity_cinema/create_background/agent.py b/projects/singularity_cinema/create_background/agent.py index 44e510b13..3cc233ad7 100644 --- a/projects/singularity_cinema/create_background/agent.py +++ b/projects/singularity_cinema/create_background/agent.py @@ -3,23 +3,19 @@ import textwrap import matplotlib.font_manager as fm +from omegaconf import DictConfig +from PIL import Image, ImageDraw, ImageFont + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image, ImageDraw, ImageFont logger = get_logger() class CreateBackground(CodeAgent): - - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.bg_path = os.path.join(self.work_dir, 'background.png') @@ -30,10 +26,10 @@ def __init__(self, def get_font(self, size): candidates = list(self.fonts) import subprocess + for name in candidates: try: - font_path = subprocess.check_output( - ['fc-match', '-f', '%{file}\n', name], text=True).strip() + font_path = subprocess.check_output(['fc-match', '-f', '%{file}\n', name], text=True).strip() font = ImageFont.truetype(font_path, size) font.getmask('中') @@ -64,7 +60,7 @@ async def execute_code(self, messages, **kwargs): 'padding': 50, 'line_width': 8, 'subtitle_offset': 40, - 'line_position_offset': 140 + 'line_position_offset': 140, } # Create image with transparent background (RGBA mode) @@ -78,27 +74,18 @@ async def execute_code(self, messages, **kwargs): y_position = config['padding'] for line in title_lines: bbox = draw.textbbox((0, 0), line, font=title_font) - draw.text((config['padding'], y_position), - line, - font=title_font, - fill=slogan_subtitle_color) + draw.text((config['padding'], y_position), line, font=title_font, fill=slogan_subtitle_color) y_position += (bbox[3] - bbox[1]) + config['line_spacing'] subtitle_lines = self.slogan y_position = config['padding'] for i, line in enumerate(subtitle_lines): bbox = draw.textbbox((0, 0), line, font=subtitle_font) - x_offset = width - bbox[2] - (config['padding'] + 30) + ( - i * config['subtitle_offset']) - draw.text((x_offset, y_position), - line, - font=subtitle_font, - fill=slogan_subtitle_color) + x_offset = width - bbox[2] - (config['padding'] + 30) + (i * config['subtitle_offset']) + draw.text((x_offset, y_position), line, font=subtitle_font, fill=slogan_subtitle_color) y_position += bbox[3] - bbox[1] + 5 line_y = height - config['padding'] - config['line_position_offset'] if self.config.use_subtitle: - draw.line([(0, line_y), (width, line_y)], - fill=slogan_subtitle_color, - width=config['line_width']) + draw.line([(0, line_y), (width, line_y)], fill=slogan_subtitle_color, width=config['line_width']) image.save(self.bg_path) return messages diff --git a/projects/singularity_cinema/generate_animation/agent.py b/projects/singularity_cinema/generate_animation/agent.py index 26837050b..e273b2e84 100644 --- a/projects/singularity_cinema/generate_animation/agent.py +++ b/projects/singularity_cinema/generate_animation/agent.py @@ -3,17 +3,13 @@ import os import sys -from ms_agent.agent import CodeAgent from omegaconf import DictConfig +from ms_agent.agent import CodeAgent -class GenerateAnimation(CodeAgent): - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): +class GenerateAnimation(CodeAgent): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) async def execute_code(self, messages, **kwargs): @@ -21,15 +17,15 @@ async def execute_code(self, messages, **kwargs): sys.path.insert(0, os.path.dirname(__file__)) if engine == 'manim': from generate_manim_code import GenerateManimCode + sys.path.pop(0) - agent = GenerateManimCode(self.config, self.tag, - self.trust_remote_code, **kwargs) + agent = GenerateManimCode(self.config, self.tag, self.trust_remote_code, **kwargs) return await agent.execute_code(messages, **kwargs) elif engine == 'remotion': from generate_remotion_code import GenerateRemotionCode + sys.path.pop(0) - agent = GenerateRemotionCode(self.config, self.tag, - self.trust_remote_code, **kwargs) + agent = GenerateRemotionCode(self.config, self.tag, self.trust_remote_code, **kwargs) return await agent.execute_code(messages, **kwargs) else: raise ValueError(f'Unknown animation engine: {engine}') diff --git a/projects/singularity_cinema/generate_animation/generate_manim_code.py b/projects/singularity_cinema/generate_animation/generate_manim_code.py index 990bab72d..43a6f768b 100644 --- a/projects/singularity_cinema/generate_animation/generate_manim_code.py +++ b/projects/singularity_cinema/generate_animation/generate_manim_code.py @@ -1,25 +1,21 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Union -import json +from omegaconf import DictConfig +from PIL import Image + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() class GenerateManimCode(CodeAgent): - - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 'llm_num_parallel', 10) @@ -27,8 +23,7 @@ def __init__(self, self.manim_code_dir = os.path.join(self.work_dir, 'manim_code') os.makedirs(self.manim_code_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], - **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) with open(os.path.join(self.work_dir, 'audio_info.txt'), 'r') as f: @@ -45,8 +40,7 @@ async def execute_code(self, messages: Union[str, List[Message]], with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: futures = { - executor.submit(self._generate_manim_code_static, seg, dur, - idx, self.config, self.images_dir): idx + executor.submit(self._generate_manim_code_static, seg, dur, idx, self.config, self.images_dir): idx for seg, dur, idx in tasks } for future in as_completed(futures): @@ -54,20 +48,16 @@ async def execute_code(self, messages: Union[str, List[Message]], manim_code[idx] = future.result() for i, code in enumerate(manim_code): - manim_file = os.path.join(self.manim_code_dir, - f'segment_{i + 1}.py') + manim_file = os.path.join(self.manim_code_dir, f'segment_{i + 1}.py') with open(manim_file, 'w') as f: f.write(code) return messages @staticmethod - def _generate_manim_code_static(segment, audio_duration, i, config, - image_dir): + def _generate_manim_code_static(segment, audio_duration, i, config, image_dir): """Static method for multiprocessing""" llm = LLM.from_config(config) - return GenerateManimCode._generate_manim_impl(llm, segment, - audio_duration, i, - image_dir, config) + return GenerateManimCode._generate_manim_impl(llm, segment, audio_duration, i, image_dir, config) @staticmethod def get_image_size(filename): @@ -82,8 +72,7 @@ def get_all_images_info(segment, i, image_dir): # Now check for files corresponding to these descriptions for idx, desc in enumerate(descriptions): - foreground_image = os.path.join( - image_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') + foreground_image = os.path.join(image_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') if os.path.exists(foreground_image): size = GenerateManimCode.get_image_size(foreground_image) @@ -94,8 +83,7 @@ def get_all_images_info(segment, i, image_dir): } all_images_info.append(image_info) - image_info_file = os.path.join( - os.path.dirname(image_dir), 'image_info.txt') + image_info_file = os.path.join(os.path.dirname(image_dir), 'image_info.txt') if os.path.exists(image_info_file): with open(image_info_file, 'r') as f: for line in f.readlines(): @@ -107,13 +95,11 @@ def get_all_images_info(segment, i, image_dir): return all_images_info @staticmethod - def _generate_manim_impl(llm, segment, audio_duration, i, image_dir, - config): + def _generate_manim_impl(llm, segment, audio_duration, i, image_dir, config): class_name = f'Scene{i + 1}' content = segment['content'] manim_requirement = segment['manim'] - images_info = GenerateManimCode.get_all_images_info( - segment, i, image_dir) + images_info = GenerateManimCode.get_all_images_info(segment, i, image_dir) if images_info: images_info = json.dumps(images_info, indent=4, ensure_ascii=False) else: @@ -177,8 +163,7 @@ def _generate_manim_impl(llm, segment, audio_duration, i, image_dir, """ logger.info(f'正在生成 manim 代码:{content}') - _response_message = llm.generate( - [Message(role='user', content=prompt)], temperature=0.3) + _response_message = llm.generate([Message(role='user', content=prompt)], temperature=0.3) response = _response_message.content if '```python' in response: manim_code = response.split('```python')[1].split('```')[0] diff --git a/projects/singularity_cinema/generate_animation/generate_remotion_code.py b/projects/singularity_cinema/generate_animation/generate_remotion_code.py index 1007a60b1..11ea68acd 100644 --- a/projects/singularity_cinema/generate_animation/generate_remotion_code.py +++ b/projects/singularity_cinema/generate_animation/generate_remotion_code.py @@ -1,27 +1,23 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import glob +import json import os import re from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Union -import json +from omegaconf import DictConfig +from PIL import Image + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() class GenerateRemotionCode(CodeAgent): - - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 'llm_num_parallel', 10) @@ -29,8 +25,7 @@ def __init__(self, self.remotion_code_dir = os.path.join(self.work_dir, 'remotion_code') os.makedirs(self.remotion_code_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], - **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) with open(os.path.join(self.work_dir, 'audio_info.txt'), 'r') as f: @@ -43,8 +38,7 @@ async def execute_code(self, messages: Union[str, List[Message]], animation_requirement = segment.get('remotion') if animation_requirement is not None: # Check if file already exists - remotion_file = os.path.join(self.remotion_code_dir, - f'Segment{i + 1}.tsx') + remotion_file = os.path.join(self.remotion_code_dir, f'Segment{i + 1}.tsx') if os.path.exists(remotion_file): continue tasks.append((segment, audio_info['audio_duration'], i)) @@ -53,16 +47,14 @@ async def execute_code(self, messages: Union[str, List[Message]], # Load existing files for skipped segments for i in range(len(segments)): - remotion_file = os.path.join(self.remotion_code_dir, - f'Segment{i + 1}.tsx') + remotion_file = os.path.join(self.remotion_code_dir, f'Segment{i + 1}.tsx') if os.path.exists(remotion_file): with open(remotion_file, 'r', encoding='utf-8') as f: remotion_code[i] = f.read() with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: futures = { - executor.submit(self._generate_remotion_code_static, seg, dur, - idx, self.config, self.images_dir): idx + executor.submit(self._generate_remotion_code_static, seg, dur, idx, self.config, self.images_dir): idx for seg, dur, idx in tasks } for future in as_completed(futures): @@ -70,19 +62,16 @@ async def execute_code(self, messages: Union[str, List[Message]], remotion_code[idx] = future.result() for i, code in enumerate(remotion_code): - remotion_file = os.path.join(self.remotion_code_dir, - f'Segment{i + 1}.tsx') + remotion_file = os.path.join(self.remotion_code_dir, f'Segment{i + 1}.tsx') with open(remotion_file, 'w', encoding='utf-8') as f: f.write(code) return messages @staticmethod - def _generate_remotion_code_static(segment, audio_duration, i, config, - image_dir): + def _generate_remotion_code_static(segment, audio_duration, i, config, image_dir): """Static method for multiprocessing""" llm = LLM.from_config(config) - return GenerateRemotionCode._generate_remotion_impl( - llm, segment, audio_duration, i, image_dir, config) + return GenerateRemotionCode._generate_remotion_impl(llm, segment, audio_duration, i, image_dir, config) @staticmethod def get_image_size(filename): @@ -95,23 +84,17 @@ def get_all_images_info(segment, i, image_dir): foreground = segment.get('foreground', []) for idx, _req in enumerate(foreground): - foreground_image = os.path.join( - image_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') + foreground_image = os.path.join(image_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') if os.path.exists(foreground_image): size = GenerateRemotionCode.get_image_size(foreground_image) image_info = { - 'filename': - os.path.join('images', os.path.basename( - foreground_image)), # Use basename for Remotion - 'size': - size, - 'description': - _req, + 'filename': os.path.join('images', os.path.basename(foreground_image)), # Use basename for Remotion + 'size': size, + 'description': _req, } all_images_info.append(image_info) - image_info_file = os.path.join( - os.path.dirname(image_dir), 'image_info.txt') + image_info_file = os.path.join(os.path.dirname(image_dir), 'image_info.txt') if os.path.exists(image_info_file): with open(image_info_file, 'r') as f: for line in f.readlines(): @@ -123,13 +106,11 @@ def get_all_images_info(segment, i, image_dir): return all_images_info @staticmethod - def _generate_remotion_impl(llm, segment, audio_duration, i, image_dir, - config): + def _generate_remotion_impl(llm, segment, audio_duration, i, image_dir, config): component_name = f'Segment{i + 1}' content = segment['content'] animation_requirement = segment['remotion'] - images_info = GenerateRemotionCode.get_all_images_info( - segment, i, image_dir) + images_info = GenerateRemotionCode.get_all_images_info(segment, i, image_dir) # Inject image info with code snippets. images_info_str = '' @@ -210,14 +191,11 @@ def _generate_remotion_impl(llm, segment, audio_duration, i, image_dir, """ logger.info(f'正在生成 remotion 代码:{content}') - _response_message = llm.generate( - [Message(role='user', content=prompt)], temperature=0.3) + _response_message = llm.generate([Message(role='user', content=prompt)], temperature=0.3) response = _response_message.content # Robust code extraction using regex - code_match = re.search( - r'```(?:typescript|tsx|js|javascript)?\s*(.*?)```', response, - re.DOTALL) + code_match = re.search(r'```(?:typescript|tsx|js|javascript)?\s*(.*?)```', response, re.DOTALL) if code_match: code = code_match.group(1) else: diff --git a/projects/singularity_cinema/generate_audio/agent.py b/projects/singularity_cinema/generate_audio/agent.py index 499779c73..670454a60 100644 --- a/projects/singularity_cinema/generate_audio/agent.py +++ b/projects/singularity_cinema/generate_audio/agent.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio +import json import os import shutil from copy import deepcopy @@ -7,34 +8,28 @@ from typing import List import edge_tts -import json import numpy as np from moviepy import AudioClip, AudioFileClip +from omegaconf import DictConfig + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM from ms_agent.llm.openai_llm import OpenAI from ms_agent.tools.audio_generator import AudioGenerator from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() @dataclass class Pattern: - name: str pattern: str tags: List[str] = field(default_factory=list) class GenerateAudio(CodeAgent): - - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.llm: OpenAI = LLM.from_config(self.config) @@ -58,17 +53,18 @@ async def execute_code(self, messages, **kwargs): assert len(audio_durations) == len(audio_paths) audio_info = [] for audio_path, audio_duration in zip(audio_paths, audio_durations): - audio_info.append({ - 'audio_path': audio_path, - 'audio_duration': audio_duration, - }) + audio_info.append( + { + 'audio_path': audio_path, + 'audio_duration': audio_duration, + } + ) with open(os.path.join(self.work_dir, 'audio_info.txt'), 'w') as f: f.write(json.dumps(audio_info, indent=4, ensure_ascii=False)) return messages @staticmethod async def create_silent_audio(output_path, duration=5.0): - def make_frame(t): return np.array([0.0, 0.0]) @@ -85,8 +81,7 @@ async def audio_generate(self, text, output_file, speaker='male'): os.makedirs(output_dir, exist_ok=True) _config = deepcopy(self.config) _config.tools.audio_generator = _config.audio_generator - _temp_file = await AudioGenerator(self.config).generate_audio( - text, speaker=voice, rate=rate, pitch=pitch) + _temp_file = await AudioGenerator(self.config).generate_audio(text, speaker=voice, rate=rate, pitch=pitch) shutil.move(_temp_file, output_file) @staticmethod diff --git a/projects/singularity_cinema/generate_illustration_prompts/agent.py b/projects/singularity_cinema/generate_illustration_prompts/agent.py index 5529dc67a..13bf1c78d 100644 --- a/projects/singularity_cinema/generate_illustration_prompts/agent.py +++ b/projects/singularity_cinema/generate_illustration_prompts/agent.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os import re import time @@ -6,25 +7,23 @@ from dataclasses import dataclass, field from typing import List, Optional, Union -import json +from omegaconf import DictConfig + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() @dataclass class Pattern: - name: str pattern: str tags: List[str] = field(default_factory=list) class GenerateIllustrationPrompts(CodeAgent): - # Background prompt generator (t2i) system = """你是一名提示词工程师,负责为短视频生成一张背景图。 @@ -44,20 +43,14 @@ class GenerateIllustrationPrompts(CodeAgent): - 不要留白:使用适当的背景填充图像,尽量不要使用白色背景 """ - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 'llm_num_parallel', 10) - self.illustration_prompts_dir = os.path.join(self.work_dir, - 'illustration_prompts') + self.illustration_prompts_dir = os.path.join(self.work_dir, 'illustration_prompts') os.makedirs(self.illustration_prompts_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], - **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) logger.info('Generating illustration prompts.') @@ -66,9 +59,9 @@ async def execute_code(self, messages: Union[str, List[Message]], with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: futures = { - executor.submit(self._generate_illustration_prompts_static, i, - segment, self.config, - self.illustration_prompts_dir): i + executor.submit( + self._generate_illustration_prompts_static, i, segment, self.config, self.illustration_prompts_dir + ): i for i, segment in tasks } for future in as_completed(futures): @@ -76,16 +69,14 @@ async def execute_code(self, messages: Union[str, List[Message]], return messages @staticmethod - def _generate_illustration_prompts_static(i, segment, config, - illustration_prompts_dir): + def _generate_illustration_prompts_static(i, segment, config, illustration_prompts_dir): """Static method for multiprocessing""" llm = LLM.from_config(config) max_retries = 10 if config.background == 'image': for attempt in range(max_retries): try: - GenerateIllustrationPrompts._generate_illustration_impl( - llm, i, segment, illustration_prompts_dir) + GenerateIllustrationPrompts._generate_illustration_impl(llm, i, segment, illustration_prompts_dir) break except Exception: time.sleep(2) @@ -93,26 +84,20 @@ def _generate_illustration_prompts_static(i, segment, config, if config.foreground == 'image': for attempt in range(max_retries): try: - GenerateIllustrationPrompts._generate_foreground_impl( - llm, i, segment, illustration_prompts_dir) + GenerateIllustrationPrompts._generate_foreground_impl(llm, i, segment, illustration_prompts_dir) break except Exception: time.sleep(2) @staticmethod def _generate_illustration_impl(llm, i, segment, illustration_prompts_dir): - if os.path.exists( - os.path.join(illustration_prompts_dir, f'segment_{i+1}.txt')): + if os.path.exists(os.path.join(illustration_prompts_dir, f'segment_{i + 1}.txt')): return background_concept = segment.get('background') - logger.info( - f'Generating background prompt from plan: {background_concept}') + logger.info(f'Generating background prompt from plan: {background_concept}') - with open( - os.path.join( - os.path.dirname(illustration_prompts_dir), 'topic.txt'), - 'r') as f: + with open(os.path.join(os.path.dirname(illustration_prompts_dir), 'topic.txt'), 'r') as f: topic = f.read() query = ( f'User original topic: {topic}\n' @@ -126,47 +111,35 @@ def _generate_illustration_impl(llm, i, segment, illustration_prompts_dir): response = llm.generate(inputs).content.strip() # Strip thinking tags - response = re.sub( - r'.*?', '', response, flags=re.DOTALL).strip() + response = re.sub(r'.*?', '', response, flags=re.DOTALL).strip() - with open( - os.path.join(illustration_prompts_dir, f'segment_{i + 1}.txt'), - 'w') as f: + with open(os.path.join(illustration_prompts_dir, f'segment_{i + 1}.txt'), 'w') as f: f.write(response) @staticmethod def _generate_foreground_impl(llm, i, segment, illustration_prompts_dir): foreground_assets = segment.get('foreground') for idx, asset_desc in enumerate(foreground_assets): - file_path = os.path.join(illustration_prompts_dir, - f'segment_{i+1}_foreground_{idx+1}.txt') + file_path = os.path.join(illustration_prompts_dir, f'segment_{i + 1}_foreground_{idx + 1}.txt') if os.path.exists(file_path): continue - logger.info( - f'Generating foreground_{idx} prompt from plan: {asset_desc}') + logger.info(f'Generating foreground_{idx} prompt from plan: {asset_desc}') - with open( - os.path.join( - os.path.dirname(illustration_prompts_dir), - 'topic.txt'), 'r') as f: + with open(os.path.join(os.path.dirname(illustration_prompts_dir), 'topic.txt'), 'r') as f: topic = f.read() - query = (f'User original topic: {topic}\n' - f'Design a single foreground asset: {asset_desc}\n') + query = f'User original topic: {topic}\nDesign a single foreground asset: {asset_desc}\n' inputs = [ - Message( - role='system', - content=GenerateIllustrationPrompts.system_foreground), + Message(role='system', content=GenerateIllustrationPrompts.system_foreground), Message(role='user', content=query), ] response = llm.generate(inputs).content.strip() # Strip thinking tags - response = re.sub( - r'.*?', '', response, flags=re.DOTALL).strip() + response = re.sub(r'.*?', '', response, flags=re.DOTALL).strip() with open(file_path, 'w') as f: f.write(response) diff --git a/projects/singularity_cinema/generate_images/agent.py b/projects/singularity_cinema/generate_images/agent.py index e21eb325d..42a3b8d03 100644 --- a/projects/singularity_cinema/generate_images/agent.py +++ b/projects/singularity_cinema/generate_images/agent.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio +import json import os import re import shutil @@ -9,55 +10,42 @@ from typing import List, Union import aiohttp -import json import numpy as np +from omegaconf import DictConfig +from PIL import Image + from ms_agent.agent import CodeAgent from ms_agent.llm import Message from ms_agent.tools.image_generator import ImageGenerator from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() class GenerateImages(CodeAgent): - - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 't2i_num_parallel', 1) self.fusion = self.fade - self.illustration_prompts_dir = os.path.join(self.work_dir, - 'illustration_prompts') + self.illustration_prompts_dir = os.path.join(self.work_dir, 'illustration_prompts') self.images_dir = os.path.join(self.work_dir, 'images') os.makedirs(self.images_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], - **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) illustration_prompts = [] for i in range(len(segments)): - illustration_path = os.path.join(self.illustration_prompts_dir, - f'segment_{i+1}.txt') - if self.config.background == 'image' and os.path.exists( - illustration_path): + illustration_path = os.path.join(self.illustration_prompts_dir, f'segment_{i + 1}.txt') + if self.config.background == 'image' and os.path.exists(illustration_path): with open(illustration_path, 'r') as f: illustration_prompts.append(f.read()) else: illustration_prompts.append(None) logger.info('Generating images.') - tasks = [ - (i, segment, prompt) - for i, (segment, - prompt) in enumerate(zip(segments, illustration_prompts)) - ] + tasks = [(i, segment, prompt) for i, (segment, prompt) in enumerate(zip(segments, illustration_prompts))] # Use ThreadPoolExecutor for parallel execution with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: @@ -67,16 +55,13 @@ async def execute_code(self, messages: Union[str, List[Message]], final_prompt = prompt_text if final_prompt: # Remove thinking tags if present - final_prompt = re.sub( - r'.*?', - '', - final_prompt, - flags=re.DOTALL).strip() + final_prompt = re.sub(r'.*?', '', final_prompt, flags=re.DOTALL).strip() futures.append( - executor.submit(self._process_single_illustration_static, - i, segment, final_prompt, self.config, - self.images_dir)) + executor.submit( + self._process_single_illustration_static, i, segment, final_prompt, self.config, self.images_dir + ) + ) # Wait for all tasks to complete for future in futures: future.result() @@ -84,30 +69,27 @@ async def execute_code(self, messages: Union[str, List[Message]], return messages @staticmethod - def _process_single_illustration_static(i, segment, prompt, config, - images_dir): + def _process_single_illustration_static(i, segment, prompt, config, images_dir): """Static method for thread pool execution""" # Create new event loop for this thread loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete( - GenerateImages._process_single_illustration_impl( - i, segment, prompt, config, images_dir)) + GenerateImages._process_single_illustration_impl(i, segment, prompt, config, images_dir) + ) loop.run_until_complete( - GenerateImages._process_foreground_illustration_impl( - i, segment, config, images_dir)) + GenerateImages._process_foreground_illustration_impl(i, segment, config, images_dir) + ) finally: loop.close() @staticmethod - async def _process_single_illustration_impl(i, segment, prompt, config, - images_dir): + async def _process_single_illustration_impl(i, segment, prompt, config, images_dir): """Implementation of single illustration processing""" if config.background != 'image': # Generate a 2000x2000 solid color image - logger.info( - f'Generating solid color background for segment {i + 1}.') + logger.info(f'Generating solid color background for segment {i + 1}.') output_path = os.path.join(images_dir, f'illustration_{i + 1}.png') if not os.path.exists(output_path): # Create a 2000x2000 image with the color defined in config.background @@ -115,8 +97,7 @@ async def _process_single_illustration_impl(i, segment, prompt, config, img.save(output_path) else: logger.info(f'Generating image for: {prompt}.') - img_path = os.path.join(images_dir, - f'illustration_{i + 1}_origin.png') + img_path = os.path.join(images_dir, f'illustration_{i + 1}_origin.png') output_path = os.path.join(images_dir, f'illustration_{i + 1}.png') if os.path.exists(output_path): return @@ -133,16 +114,13 @@ async def _process_single_illustration_impl(i, segment, prompt, config, elif hasattr(_config.image_generator, 'size'): kwargs['size'] = _config.image_generator.size - logger.info( - f'Generating image. Prompt: {prompt[:50]}... kwargs: {kwargs}') + logger.info(f'Generating image. Prompt: {prompt[:50]}... kwargs: {kwargs}') _temp_file = await image_generator.generate_image(prompt, **kwargs) # Check directly if the return is a valid file path if not _temp_file or not os.path.exists(_temp_file): - logger.error( - f'Background image generation failed for segment {i + 1}. Result: {_temp_file}' - ) + logger.error(f'Background image generation failed for segment {i + 1}. Result: {_temp_file}') return shutil.move(_temp_file, img_path) @@ -154,27 +132,22 @@ async def _process_single_illustration_impl(i, segment, prompt, config, pass @staticmethod - async def _process_foreground_illustration_impl(i, segment, config, - images_dir): + async def _process_foreground_illustration_impl(i, segment, config, images_dir): """Implementation of foreground illustration processing""" if config.foreground != 'image': return logger.info(f'Generating foreground image for: segment {i}.') work_dir = getattr(config, 'output_dir', 'output') - illustration_prompts_dir = os.path.join(work_dir, - 'illustration_prompts') + illustration_prompts_dir = os.path.join(work_dir, 'illustration_prompts') foreground_assets = segment.get('foreground', []) for idx, _req in enumerate(foreground_assets): - foreground_image = os.path.join( - images_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') + foreground_image = os.path.join(images_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') if os.path.exists(foreground_image): continue - foreground_prompt_path = os.path.join( - illustration_prompts_dir, - f'segment_{i+1}_foreground_{idx+1}.txt') + foreground_prompt_path = os.path.join(illustration_prompts_dir, f'segment_{i + 1}_foreground_{idx + 1}.txt') assert os.path.exists(foreground_prompt_path) @@ -182,9 +155,7 @@ async def _process_foreground_illustration_impl(i, segment, config, prompt_text = f.read() # Clean Prompt from Thinking process - prompt = re.sub( - r'.*?', '', prompt_text, - flags=re.DOTALL).strip() + prompt = re.sub(r'.*?', '', prompt_text, flags=re.DOTALL).strip() _config = deepcopy(config) _config.tools.image_generator = _config.image_generator @@ -205,18 +176,12 @@ async def _process_foreground_illustration_impl(i, segment, config, os.remove(_temp_file) @staticmethod - def fade(input_image, - output_image, - segment, - fade_factor=0.3, - brightness_boost=80, - opacity=1.0): + def fade(input_image, output_image, segment, fade_factor=0.3, brightness_boost=80, opacity=1.0): # Support both 'manim' and 'remotion' keys for animation detection has_animation = segment.get('manim') or segment.get('remotion') img = Image.open(input_image).convert('RGBA') if has_animation: - logger.info( - 'Applying fade effect to background image (Animation present)') + logger.info('Applying fade effect to background image (Animation present)') arr = np.array(img, dtype=np.float32) arr[..., :3] = arr[..., :3] * fade_factor + brightness_boost arr[..., :3] = np.clip(arr[..., :3], 0, 255) diff --git a/projects/singularity_cinema/generate_script/agent.py b/projects/singularity_cinema/generate_script/agent.py index fc0a726cc..694f626e2 100644 --- a/projects/singularity_cinema/generate_script/agent.py +++ b/projects/singularity_cinema/generate_script/agent.py @@ -3,21 +3,17 @@ from copy import deepcopy from typing import List +from omegaconf import DictConfig + from ms_agent import LLMAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() class GenerateScript(LLMAgent): - - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') extra = getattr(self.config, 'extra_requirement', '') diff --git a/projects/singularity_cinema/generate_subtitle/agent.py b/projects/singularity_cinema/generate_subtitle/agent.py index efa873851..64904af99 100644 --- a/projects/singularity_cinema/generate_subtitle/agent.py +++ b/projects/singularity_cinema/generate_subtitle/agent.py @@ -1,16 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os import re from typing import List -import json import matplotlib.font_manager as fm +from omegaconf import DictConfig +from PIL import Image, ImageDraw, ImageFont + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image, ImageDraw, ImageFont logger = get_logger() @@ -38,7 +39,7 @@ def _chunk_tokens(tokens: List[str], max_len: int) -> List[str]: if len(t) > max_len: # If a single token exceeds max_len, split it for i in range(0, len(t), max_len): - sub = t[i:i + max_len] + sub = t[i : i + max_len] if cur: chunks.append(cur.strip()) cur = '' @@ -51,8 +52,7 @@ def _chunk_tokens(tokens: List[str], max_len: int) -> List[str]: continue # If t is punctuation and can be merged with previous chunk (allowing slight overflow) - if _is_punct(t) and cur and len(cur) + len( - t) <= max_len + PUNCTUATION_OVERFLOW_ALLOWANCE: + if _is_punct(t) and cur and len(cur) + len(t) <= max_len + PUNCTUATION_OVERFLOW_ALLOWANCE: cur = (cur + t).strip() + ' ' continue @@ -81,17 +81,11 @@ def _clean_chunks(chunks: List[str], max_len: int) -> List[str]: class GenerateSubtitle(CodeAgent): - - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.llm: OpenAI = LLM.from_config(self.config) - self.subtitle_translate = getattr(self.config, 'subtitle_translate', - None) + self.subtitle_translate = getattr(self.config, 'subtitle_translate', None) self.subtitle_dir = os.path.join(self.work_dir, 'subtitles') os.makedirs(self.subtitle_dir, exist_ok=True) self.fonts = self.config.fonts @@ -108,20 +102,15 @@ async def execute_code(self, messages, **kwargs): for j, chunk_text in enumerate(text_chunks): subtitle = None if self.subtitle_translate: - subtitle = await self.translate_text( - chunk_text, self.subtitle_translate) + subtitle = await self.translate_text(chunk_text, self.subtitle_translate) - output_file = os.path.join( - self.subtitle_dir, f'bilingual_subtitle_{i + 1}_{j}.png') + output_file = os.path.join(self.subtitle_dir, f'bilingual_subtitle_{i + 1}_{j}.png') if os.path.exists(output_file): continue self.create_bilingual_subtitle_image( - source=chunk_text, - target=subtitle, - output_file=output_file, - width=1720, - height=180) + source=chunk_text, target=subtitle, output_file=output_file, width=1720, height=180 + ) return messages def split_text_to_chunks(self, text, max_len: int = 30): @@ -133,7 +122,6 @@ def split_text_to_chunks(self, text, max_len: int = 30): return _clean_chunks(chunks, max_len) async def translate_text(self, text, to_lang): - prompt = f"""You are a professional translation expert specializing in accurately and fluently translating text into {to_lang}. ## Skills @@ -147,7 +135,7 @@ async def translate_text(self, text, to_lang): - Output only the translation result without any explanations. Now translate: -""" # noqa +""" # noqa messages = [ Message(role='system', content=prompt), Message(role='user', content=text), @@ -201,14 +189,16 @@ def smart_wrap_text(self, text, max_lines=2, chars_per_line=50): return lines if lines else [text] - def create_subtitle_image(self, - text, - width=1720, - height=120, - font_size=28, - text_color='black', - bg_color='rgba(0,0,0,0)', - chars_per_line=50): + def create_subtitle_image( + self, + text, + width=1720, + height=120, + font_size=28, + text_color='black', + bg_color='rgba(0,0,0,0)', + chars_per_line=50, + ): font = self.get_font(font_size) min_font_size = 18 max_height = 500 @@ -217,15 +207,13 @@ def create_subtitle_image(self, while font_size >= min_font_size: if font_size != original_font_size: font = self.get_font(font_size) - lines = self.smart_wrap_text( - text, max_lines=2, chars_per_line=chars_per_line) + lines = self.smart_wrap_text(text, max_lines=2, chars_per_line=chars_per_line) line_height = font_size + 8 total_text_height = len(lines) * line_height all_lines_fit = True for line in lines: - bbox = ImageDraw.Draw(Image.new('RGB', (1, 1))).textbbox( - (0, 0), line, font=font) + bbox = ImageDraw.Draw(Image.new('RGB', (1, 1))).textbbox((0, 0), line, font=font) line_width = bbox[2] - bbox[0] if line_width > width * 0.95: all_lines_fit = False @@ -257,28 +245,18 @@ def create_subtitle_image(self, draw.text((x, y), line, fill=text_color, font=font) return img, actual_height - def create_bilingual_subtitle_image(self, - source, - output_file, - target='', - width=1720, - height=180): + def create_bilingual_subtitle_image(self, source, output_file, target='', width=1720, height=180): main_font_size = 32 target_font_size = 22 main_target_gap = 6 pattern = r'^[a-zA-Z0-9\s.,!?;:\'"()-]+$' chars_per_line = 50 if not bool(re.match(pattern, source)) else 100 if target: - target_chars_per_line = 50 if not bool(re.match(pattern, - target)) else 100 + target_chars_per_line = 50 if not bool(re.match(pattern, target)) else 100 main_img, main_height = self.create_subtitle_image( - source, - width, - height, - main_font_size, - 'black', - chars_per_line=chars_per_line) + source, width, height, main_font_size, 'black', chars_per_line=chars_per_line + ) if target and target.strip(): target_chars_per_line = 100 @@ -288,13 +266,12 @@ def create_bilingual_subtitle_image(self, height, target_font_size, '#404040', # Darker gray for better visibility - chars_per_line=target_chars_per_line) + chars_per_line=target_chars_per_line, + ) total_height = main_height + target_height + main_target_gap - combined_img = Image.new('RGBA', (width, total_height), - (0, 0, 0, 0)) + combined_img = Image.new('RGBA', (width, total_height), (0, 0, 0, 0)) combined_img.paste(main_img, (0, 0), main_img) - combined_img.paste(target_img, (0, main_height + main_target_gap), - target_img) + combined_img.paste(target_img, (0, main_height + main_target_gap), target_img) final_img = combined_img final_height = total_height else: diff --git a/projects/singularity_cinema/generate_video/agent.py b/projects/singularity_cinema/generate_video/agent.py index aaecd7d5f..f58e34c3c 100644 --- a/projects/singularity_cinema/generate_video/agent.py +++ b/projects/singularity_cinema/generate_video/agent.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio +import json import os import shutil from concurrent.futures import ThreadPoolExecutor @@ -7,23 +8,18 @@ from typing import List, Union import aiohttp -import json +from omegaconf import DictConfig + from ms_agent.agent import CodeAgent from ms_agent.llm import Message from ms_agent.tools.video_generator import VideoGenerator from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() class GenerateVideo(CodeAgent): - - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 't2v_num_parallel', 1) @@ -31,30 +27,24 @@ def __init__(self, self.videos_dir = os.path.join(self.work_dir, 'videos') os.makedirs(self.videos_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], - **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) video_prompts = [] for i in range(len(segments)): if 'video' in segments[i]: - with open( - os.path.join(self.video_prompts_dir, - f'segment_{i + 1}.txt'), 'r') as f: + with open(os.path.join(self.video_prompts_dir, f'segment_{i + 1}.txt'), 'r') as f: video_prompts.append(f.read()) else: video_prompts.append(None) logger.info('Generating videos.') - tasks = [(i, segment, prompt) - for i, (segment, - prompt) in enumerate(zip(segments, video_prompts))] + tasks = [(i, segment, prompt) for i, (segment, prompt) in enumerate(zip(segments, video_prompts))] # Use ThreadPoolExecutor for parallel execution with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: futures = [ - executor.submit(self._process_single_video_static, i, segment, - prompt, self.config, self.videos_dir) + executor.submit(self._process_single_video_static, i, segment, prompt, self.config, self.videos_dir) for i, segment, prompt in tasks ] # Wait for all tasks to complete @@ -70,25 +60,19 @@ def _process_single_video_static(i, segment, prompt, config, videos_dir): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - loop.run_until_complete( - GenerateVideo._process_single_video_impl( - i, segment, prompt, config, videos_dir)) + loop.run_until_complete(GenerateVideo._process_single_video_impl(i, segment, prompt, config, videos_dir)) finally: loop.close() @staticmethod - async def _process_single_video_impl(i, segment, prompt, config, - videos_dir): + async def _process_single_video_impl(i, segment, prompt, config, videos_dir): if prompt is None: - logger.info( - f'Skipping video generation for segment {i + 1} (no video prompt).' - ) + logger.info(f'Skipping video generation for segment {i + 1} (no video prompt).') return output_path = os.path.join(videos_dir, f'video_{i + 1}.mp4') if os.path.exists(output_path): - logger.info( - f'Video already exists for segment {i + 1}: {output_path}') + logger.info(f'Video already exists for segment {i + 1}: {output_path}') return logger.info(f'Generating video for segment {i + 1}: {prompt}') @@ -107,6 +91,5 @@ async def _process_single_video_impl(i, segment, prompt, config, _config.tools.video_generator = _config.video_generator video_generator = VideoGenerator(_config) - _temp_file = await video_generator.generate_video( - prompt, seconds=fit_duration) + _temp_file = await video_generator.generate_video(prompt, seconds=fit_duration) shutil.move(_temp_file, output_path) diff --git a/projects/singularity_cinema/generate_video_prompts/agent.py b/projects/singularity_cinema/generate_video_prompts/agent.py index 2cd091f3b..6a66e562d 100644 --- a/projects/singularity_cinema/generate_video_prompts/agent.py +++ b/projects/singularity_cinema/generate_video_prompts/agent.py @@ -1,20 +1,20 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Union -import json +from omegaconf import DictConfig + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() class GenerateVideoPrompts(CodeAgent): - - system = (""" + system = """ You are an expert in creating scene descriptions for video generation. Based on given knowledge points or storyboard scripts, generate detailed English descriptions for creating text-to-video content that align with specified themes and styles. @@ -33,21 +33,16 @@ class GenerateVideoPrompts(CodeAgent): - Output approximately 200 words in English. - Return ONLY the prompt description. Do not include style keywords unless requested, and do not add explanations or markers. - """) + """ - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 'llm_num_parallel', 10) self.video_prompts_dir = os.path.join(self.work_dir, 'video_prompts') os.makedirs(self.video_prompts_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], - **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: if not self.config.use_text2video: return messages with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: @@ -60,27 +55,30 @@ async def execute_code(self, messages: Union[str, List[Message]], with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: futures = { - executor.submit(self._generate_video_prompts_static, i, - segment, self.config, topic, self.system, - self.video_prompts_dir): i - for i, segment in tasks if 'video' in segment + executor.submit( + self._generate_video_prompts_static, + i, + segment, + self.config, + topic, + self.system, + self.video_prompts_dir, + ): i + for i, segment in tasks + if 'video' in segment } for future in as_completed(futures): future.result() return messages @staticmethod - def _generate_video_prompts_static(i, segment, config, topic, system, - video_prompts_dir): + def _generate_video_prompts_static(i, segment, config, topic, system, video_prompts_dir): llm = LLM.from_config(config) - GenerateVideoPrompts._generate_video_prompt_impl( - llm, i, segment, topic, system, video_prompts_dir, config) + GenerateVideoPrompts._generate_video_prompt_impl(llm, i, segment, topic, system, video_prompts_dir, config) @staticmethod - def _generate_video_prompt_impl(llm, i, segment, topic, system, - video_prompts_dir, config): - if os.path.exists( - os.path.join(video_prompts_dir, f'segment_{i+1}.txt')): + def _generate_video_prompt_impl(llm, i, segment, topic, system, video_prompts_dir, config): + if os.path.exists(os.path.join(video_prompts_dir, f'segment_{i + 1}.txt')): return work_dir = os.path.dirname(video_prompts_dir) @@ -95,10 +93,12 @@ def _generate_video_prompt_impl(llm, i, segment, topic, system, break video = segment['video'] - query = (f'The user original request is: {topic}, ' - f'illustration based on: {segment["content"]}, ' - f'Video duration: {fit_duration}, ' - f'Requirements from the storyboard designer: {video}') + query = ( + f'The user original request is: {topic}, ' + f'illustration based on: {segment["content"]}, ' + f'Video duration: {fit_duration}, ' + f'Requirements from the storyboard designer: {video}' + ) logger.info(f'Generating video prompt for : {segment["content"]}.') inputs = [ Message(role='system', content=system), @@ -107,7 +107,5 @@ def _generate_video_prompt_impl(llm, i, segment, topic, system, _response_message = llm.generate(inputs) response = _response_message.content prompt = response.strip() - with open( - os.path.join(video_prompts_dir, f'segment_{i + 1}.txt'), - 'w') as f: + with open(os.path.join(video_prompts_dir, f'segment_{i + 1}.txt'), 'w') as f: f.write(prompt) diff --git a/projects/singularity_cinema/parse_images/agent.py b/projects/singularity_cinema/parse_images/agent.py index 138803dc0..8a5e324c7 100644 --- a/projects/singularity_cinema/parse_images/agent.py +++ b/projects/singularity_cinema/parse_images/agent.py @@ -1,45 +1,39 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 import hashlib +import json import os import re from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from urllib.request import urlretrieve -import json +from omegaconf import DictConfig +from PIL import Image + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() class ParseImages(CodeAgent): - - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') _config = deepcopy(config) delattr(_config, 'llm') _config.llm = DictConfig({}) for key, value in _config.mllm.items(): - key = key[len('mllm_'):] + key = key[len('mllm_') :] setattr(_config.llm, key, value) _config.generation_config = DictConfig({'temperature': 0.3}) if 'extra_body' in config.generation_config: _config.generation_config.extra_body = config.generation_config.extra_body self.mllm: OpenAI = LLM.from_config(_config) - logger.info( - f"Using MLLM for image parsing: {getattr(self.mllm, 'model', None)}" - ) + logger.info(f"Using MLLM for image parsing: {getattr(self.mllm, 'model', None)}") self.image_dir = os.path.join(self.work_dir, 'images') os.makedirs(self.image_dir, exist_ok=True) @@ -130,22 +124,18 @@ def get_image_description(self, filename): image_data = image_file.read() base64_image = base64.b64encode(image_data).decode('utf-8') - _content = [{ - 'type': - 'text', - 'text': - ('Describe this image in under 50 words. Be objective and accurate. For charts/graphs, ' - 'analyze axis labels and data to explain what the chart shows and its purpose, ' - 'not just the chart type. Provide enough detail to distinguish it from other images.' - 'Return only the requested image description. Do not add any other content.' - ) - }, { - 'type': 'image_url', - 'image_url': { - 'url': f'data:image/png;base64,{base64_image}', - 'detail': 'high' - } - }] + _content = [ + { + 'type': 'text', + 'text': ( + 'Describe this image in under 50 words. Be objective and accurate. For charts/graphs, ' + 'analyze axis labels and data to explain what the chart shows and its purpose, ' + 'not just the chart type. Provide enough detail to distinguish it from other images.' + 'Return only the requested image description. Do not add any other content.' + ), + }, + {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}', 'detail': 'high'}}, + ] messages = [ Message(role='user', content=_content), diff --git a/projects/singularity_cinema/render_animation/agent.py b/projects/singularity_cinema/render_animation/agent.py index 1b3007c66..1340608a6 100644 --- a/projects/singularity_cinema/render_animation/agent.py +++ b/projects/singularity_cinema/render_animation/agent.py @@ -3,17 +3,13 @@ import os import sys -from ms_agent.agent import CodeAgent from omegaconf import DictConfig +from ms_agent.agent import CodeAgent -class RenderAnimation(CodeAgent): - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): +class RenderAnimation(CodeAgent): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) async def execute_code(self, messages, **kwargs): @@ -21,15 +17,15 @@ async def execute_code(self, messages, **kwargs): sys.path.insert(0, os.path.dirname(__file__)) if engine == 'manim': from render_manim import RenderManim + sys.path.pop(0) - agent = RenderManim(self.config, self.tag, self.trust_remote_code, - **kwargs) + agent = RenderManim(self.config, self.tag, self.trust_remote_code, **kwargs) return await agent.execute_code(messages, **kwargs) elif engine == 'remotion': from render_remotion import RenderRemotion + sys.path.pop(0) - agent = RenderRemotion(self.config, self.tag, - self.trust_remote_code, **kwargs) + agent = RenderRemotion(self.config, self.tag, self.trust_remote_code, **kwargs) return await agent.execute_code(messages, **kwargs) else: raise ValueError(f'Unknown animation engine: {engine}') diff --git a/projects/singularity_cinema/render_animation/render_manim.py b/projects/singularity_cinema/render_animation/render_manim.py index 19b9abc27..59686e226 100644 --- a/projects/singularity_cinema/render_animation/render_manim.py +++ b/projects/singularity_cinema/render_animation/render_manim.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 +import json import os import re import shutil @@ -9,48 +10,41 @@ from os import getcwd from typing import List, Union -import json from moviepy import VideoFileClip +from omegaconf import DictConfig +from PIL import Image + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() class RenderManim(CodeAgent): - window_size = (1250, 700) - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) if not self.config.use_subtitle: self.window_size = (1450, 800) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 'llm_num_parallel', 10) self.manim_render_timeout = getattr( - self.config, 'animation_render_timeout', - getattr(self.config, 'manim_render_timeout', 300)) + self.config, 'animation_render_timeout', getattr(self.config, 'manim_render_timeout', 300) + ) self.render_dir = os.path.join(self.work_dir, 'manim_render') self.code_fix_round = getattr(self.config, 'code_fix_round', 5) self.mllm_check_round = getattr(self.config, 'mllm_fix_round', 1) os.makedirs(self.render_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], - **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) manim_code_dir = os.path.join(self.work_dir, 'manim_code') manim_code = [] for i in range(len(segments)): - with open(os.path.join(manim_code_dir, f'segment_{i+1}.py'), - 'r') as f: + with open(os.path.join(manim_code_dir, f'segment_{i + 1}.py'), 'r') as f: manim_code.append(f.read()) with open(os.path.join(self.work_dir, 'audio_info.txt'), 'r') as f: audio_infos = json.load(f) @@ -59,17 +53,25 @@ async def execute_code(self, messages: Union[str, List[Message]], tasks = [ (i, segment, code, audio_info['audio_duration']) - for i, (segment, code, audio_info - ) in enumerate(zip(segments, manim_code, audio_infos)) + for i, (segment, code, audio_info) in enumerate(zip(segments, manim_code, audio_infos)) ] with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: futures = { - executor.submit(self._render_manim_scene_static, i, segment, - code, duration, self.config, self.work_dir, - self.render_dir, self.window_size, - self.manim_render_timeout, self.code_fix_round, - self.mllm_check_round): i + executor.submit( + self._render_manim_scene_static, + i, + segment, + code, + duration, + self.config, + self.work_dir, + self.render_dir, + self.window_size, + self.manim_render_timeout, + self.code_fix_round, + self.mllm_check_round, + ): i for i, segment, code, duration in tasks } for future in as_completed(futures): @@ -78,23 +80,54 @@ async def execute_code(self, messages: Union[str, List[Message]], return messages @staticmethod - def _render_manim_scene_static(i, segment, code, audio_duration, config, - work_dir, render_dir, window_size, - manim_render_timeout, code_fix_round, - mllm_check_round): + def _render_manim_scene_static( + i, + segment, + code, + audio_duration, + config, + work_dir, + render_dir, + window_size, + manim_render_timeout, + code_fix_round, + mllm_check_round, + ): """Static method for multiprocessing""" llm = LLM.from_config(config) - return RenderManim._render_manim_impl(llm, i, segment, code, - audio_duration, work_dir, - render_dir, window_size, - manim_render_timeout, config, - code_fix_round, mllm_check_round) + return RenderManim._render_manim_impl( + llm, + i, + segment, + code, + audio_duration, + work_dir, + render_dir, + window_size, + manim_render_timeout, + config, + code_fix_round, + mllm_check_round, + ) @staticmethod - def _render_manim_impl(llm, i, segment, code, audio_duration, work_dir, - render_dir, window_size, manim_render_timeout, - config, code_fix_round, mllm_check_round): - scene_name = f'Scene{i+1}' # sometimes actual_scene_name cannot find matched class, so do not change this name + def _render_manim_impl( + llm, + i, + segment, + code, + audio_duration, + work_dir, + render_dir, + window_size, + manim_render_timeout, + config, + code_fix_round, + mllm_check_round, + ): + scene_name = ( + f'Scene{i + 1}' # sometimes actual_scene_name cannot find matched class, so do not change this name + ) logger.info(f'Rendering manim code for: scene_{i + 1}') output_dir = os.path.join(render_dir, f'scene_{i + 1}') os.makedirs(output_dir, exist_ok=True) @@ -126,10 +159,16 @@ def _render_manim_impl(llm, i, segment, code, audio_duration, work_dir, env['LC_ALL'] = 'zh_CN.UTF-8' window_size_str = ','.join([str(x) for x in window_size]) cmd = [ - 'manim', 'render', '-ql', '--transparent', '--format=mov', - f'--resolution={window_size_str}', '--disable_caching', - f'--media_dir={os.path.dirname(code_file)}', code_file, - actual_scene_name + 'manim', + 'render', + '-ql', + '--transparent', + '--format=mov', + f'--resolution={window_size_str}', + '--disable_caching', + f'--media_dir={os.path.dirname(code_file)}', + code_file, + actual_scene_name, ] try: @@ -141,15 +180,14 @@ def _render_manim_impl(llm, i, segment, code, audio_duration, work_dir, text=True, encoding='utf-8', errors='ignore', - env=env) + env=env, + ) # Wait for process to complete with timeout - stdout, stderr = process.communicate( - timeout=manim_render_timeout) + stdout, stderr = process.communicate(timeout=manim_render_timeout) # Create result object compatible with original logic class Result: - def __init__(self, returncode, stdout, stderr): self.returncode = returncode self.stdout = stdout @@ -158,41 +196,64 @@ def __init__(self, returncode, stdout, stderr): result = Result(process.returncode, stdout, stderr) output_text = (result.stdout or '') + (result.stderr or '') except subprocess.TimeoutExpired as e: - output_text = (e.stdout.decode('utf-8', errors='ignore') - if e.stdout else '') + ( - e.stderr.decode('utf-8', errors='ignore') - if e.stderr else '') # noqa + output_text = (e.stdout.decode('utf-8', errors='ignore') if e.stdout else '') + ( + e.stderr.decode('utf-8', errors='ignore') if e.stderr else '' + ) # noqa logger.error( f'Manim rendering timed out after {manim_render_timeout} ' - f'seconds for {actual_scene_name}, output: {output_text}') + f'seconds for {actual_scene_name}, output: {output_text}' + ) logger.info('Trying to fix manim code.') code, fix_history = RenderManim._fix_manim_code_impl( - llm, output_text, fix_history, code, manim_requirement, - class_name, content, audio_duration, segment, i, work_dir) + llm, + output_text, + fix_history, + code, + manim_requirement, + class_name, + content, + audio_duration, + segment, + i, + work_dir, + ) continue if result.returncode != 0: - logger.warning( - f'Manim command exited with code {result.returncode}') + logger.warning(f'Manim command exited with code {result.returncode}') logger.warning(f'Output: {output_text}') real_error_indicators = [ - 'SyntaxError', 'NameError', 'ImportError', - 'AttributeError', 'TypeError', 'ValueError', - 'ModuleNotFoundError', 'Traceback', 'Error:', - 'Failed to render', 'unexpected keyword argument', - 'got an unexpected', 'invalid syntax' + 'SyntaxError', + 'NameError', + 'ImportError', + 'AttributeError', + 'TypeError', + 'ValueError', + 'ModuleNotFoundError', + 'Traceback', + 'Error:', + 'Failed to render', + 'unexpected keyword argument', + 'got an unexpected', + 'invalid syntax', ] - if any([ - error_indicator in output_text - for error_indicator in real_error_indicators - ]): + if any([error_indicator in output_text for error_indicator in real_error_indicators]): logger.info('Trying to fix manim code.') code, fix_history = RenderManim._fix_manim_code_impl( - llm, output_text, fix_history, code, manim_requirement, - class_name, content, audio_duration, segment, i, - work_dir) + llm, + output_text, + fix_history, + code, + manim_requirement, + class_name, + content, + audio_duration, + segment, + i, + work_dir, + ) continue for root, dirs, files in os.walk(output_dir): @@ -200,31 +261,37 @@ def __init__(self, returncode, stdout, stderr): if file == f'{actual_scene_name}.mov': found_file = os.path.join(root, file) if not RenderManim.verify_and_fix_mov_file(found_file): - fixed_path = RenderManim.convert_mov_to_compatible( - found_file) + fixed_path = RenderManim.convert_mov_to_compatible(found_file) if fixed_path: found_file = fixed_path shutil.copy2(found_file, output_path) - scaled_path = RenderManim.scale_video_to_fit( - output_path, target_size=window_size) + scaled_path = RenderManim.scale_video_to_fit(output_path, target_size=window_size) if scaled_path and scaled_path != output_path: shutil.rmtree(output_path, ignore_errors=True) shutil.copy2(scaled_path, output_path) final_file_path = output_path if not final_file_path: - logger.error( - f'Manim file: {class_name} not found, trying to fix manim code.' - ) + logger.error(f'Manim file: {class_name} not found, trying to fix manim code.') code, fix_history = RenderManim._fix_manim_code_impl( - llm, output_text, fix_history, code, manim_requirement, - class_name, content, audio_duration, segment, i, work_dir) + llm, + output_text, + fix_history, + code, + manim_requirement, + class_name, + content, + audio_duration, + segment, + i, + work_dir, + ) else: if cur_check_round >= mllm_max_check_round: break output_text = RenderManim.check_manim_quality( - final_file_path, work_dir, i, config, segment, - cur_check_round) + final_file_path, work_dir, i, config, segment, cur_check_round + ) cur_check_round += 1 if output_text: try: @@ -233,18 +300,26 @@ def __init__(self, returncode, stdout, stderr): except OSError: pass logger.info( - f'Trying to fix manim code of segment {i+1}, because model checking not passed: \n{output_text}' + f'Trying to fix manim code of segment {i + 1}, because model checking not passed: \n{output_text}' ) code, fix_history = RenderManim._fix_manim_code_impl( - llm, output_text, fix_history, code, manim_requirement, - class_name, content, audio_duration, segment, i, - work_dir) + llm, + output_text, + fix_history, + code, + manim_requirement, + class_name, + content, + audio_duration, + segment, + i, + work_dir, + ) continue else: break if final_file_path: - RenderManim._extract_preview_frames_static(final_file_path, i, - work_dir, 'final') + RenderManim._extract_preview_frames_static(final_file_path, i, work_dir, 'final') manim_code_dir = os.path.join(work_dir, 'manim_code') manim_file = os.path.join(manim_code_dir, f'segment_{i + 1}.py') with open(manim_file, 'w') as f: @@ -253,14 +328,13 @@ def __init__(self, returncode, stdout, stderr): raise FileNotFoundError(final_file_path) @staticmethod - def check_manim_quality(final_file_path, work_dir, i, config, segment, - cur_check_round): + def check_manim_quality(final_file_path, work_dir, i, config, segment, cur_check_round): _mm_config = deepcopy(config) delattr(_mm_config, 'llm') _mm_config.llm = DictConfig({}) _mm_config.generation_config = DictConfig({'temperature': 0.3}) for key, value in _mm_config.mllm.items(): - key = key[len('mllm_'):] + key = key[len('mllm_') :] setattr(_mm_config.llm, key, value) test_system = """**Role Definition** @@ -314,44 +388,34 @@ def check_manim_quality(final_file_path, work_dir, i, config, segment, The right component is squeezed to the edge. Fix suggestion: Reduce the width of the four left components, move the right component further right... ``` -"""# noqa +""" # noqa - test_images = RenderManim._extract_preview_frames_static( - final_file_path, i, work_dir, cur_check_round) + test_images = RenderManim._extract_preview_frames_static(final_file_path, i, work_dir, cur_check_round) llm = LLM.from_config(_mm_config) - logger.info( - f"Using mllm model for manim quality check: {getattr(llm, 'model', None)}" - ) + logger.info(f"Using mllm model for manim quality check: {getattr(llm, 'model', None)}") - frame_names = [ - 'the middle frame of the animation', - 'the last frame of the animation' - ] + frame_names = ['the middle frame of the animation', 'the last frame of the animation'] content = segment['content'] manim_requirement = segment['manim'] all_issues = [] - for idx, (image_path, - frame_name) in enumerate(zip(test_images, frame_names)): + for idx, (image_path, frame_name) in enumerate(zip(test_images, frame_names)): with open(image_path, 'rb') as image_file: image_data = image_file.read() base64_image = base64.b64encode(image_data).decode('utf-8') - _content = [{ - 'type': - 'text', - 'text': - (f'The checked frame is: {frame_name} of this animation\n' - f'The content of this animation: {content}\n' - f'The manim animation requirement: {manim_requirement}\n' - f'You must carefully check the animation layout issues.') - }, { - 'type': 'image_url', - 'image_url': { - 'url': f'data:image/png;base64,{base64_image}', - 'detail': 'high' - } - }] + _content = [ + { + 'type': 'text', + 'text': ( + f'The checked frame is: {frame_name} of this animation\n' + f'The content of this animation: {content}\n' + f'The manim animation requirement: {manim_requirement}\n' + f'You must carefully check the animation layout issues.' + ), + }, + {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}', 'detail': 'high'}}, + ] messages = [ Message(role='system', content=test_system), @@ -366,8 +430,7 @@ def check_manim_quality(final_file_path, work_dir, i, config, segment, issues.append(issue) issues = '\n'.join(issues).strip() if issues: - issues = (f'The checked frame is: {frame_name}\n' - f'Problems found: {issues}\n') + issues = f'The checked frame is: {frame_name}\nProblems found: {issues}\n' pattern = r'(.*?)' desc = [] @@ -375,17 +438,14 @@ def check_manim_quality(final_file_path, work_dir, i, config, segment, desc.append(_desc) desc = '\n'.join(desc).strip() if issues and desc: - issues = (f'{issues}' - f'The detail description of this frame: {desc}\n') + issues = f'{issues}The detail description of this frame: {desc}\n' all_issues.append(issues) all_issues = '\n\n'.join(all_issues).strip() return all_issues @staticmethod - def _extract_preview_frames_static(video_path, segment_id, work_dir, - cur_check_round): - + def _extract_preview_frames_static(video_path, segment_id, work_dir, cur_check_round): test_dir = os.path.join(work_dir, 'manim_test') os.makedirs(test_dir, exist_ok=True) video = VideoFileClip(video_path) @@ -395,10 +455,7 @@ def _extract_preview_frames_static(video_path, segment_id, work_dir, preview_paths = [] for frame_idx, timestamp in timestamps.items(): - output_path = os.path.join( - test_dir, - f'segment_{segment_id + 1}_round{cur_check_round}_{frame_idx}.png' - ) + output_path = os.path.join(test_dir, f'segment_{segment_id + 1}_round{cur_check_round}_{frame_idx}.png') video.save_frame(output_path, t=timestamp) preview_paths.append(output_path) video.close() @@ -414,8 +471,7 @@ def get_all_images_info(segment, i, image_dir): all_images_info = [] foreground = segment.get('foreground', []) for idx, _req in enumerate(foreground): - foreground_image = os.path.join( - image_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') + foreground_image = os.path.join(image_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') size = RenderManim.get_image_size(foreground_image) image_info = { 'filename': foreground_image, @@ -424,8 +480,7 @@ def get_all_images_info(segment, i, image_dir): } all_images_info.append(image_info) - image_info_file = os.path.join( - os.path.dirname(image_dir), 'image_info.txt') + image_info_file = os.path.join(os.path.dirname(image_dir), 'image_info.txt') if os.path.exists(image_info_file): with open(image_info_file, 'r') as f: for line in f.readlines(): @@ -437,9 +492,19 @@ def get_all_images_info(segment, i, image_dir): return all_images_info @staticmethod - def _fix_manim_code_impl(llm, error_log, fix_history, manim_code, - manim_requirement, class_name, content, - audio_duration, segment, i, work_dir): + def _fix_manim_code_impl( + llm, + error_log, + fix_history, + manim_code, + manim_requirement, + class_name, + content, + audio_duration, + segment, + i, + work_dir, + ): image_dir = os.path.join(work_dir, 'images') images_info = RenderManim.get_all_images_info(segment, i, image_dir) @@ -455,7 +520,7 @@ def _fix_manim_code_impl(llm, error_log, fix_history, manim_code, * Scale the images. Do not use the original size, carefully rescale the images to match the requirements below: * The image size on the canvas depend on its importance, important image occupies more spaces * Use 1/4 space of the canvas for each image -""" # noqa +""" # noqa else: image_prompt = '' @@ -521,7 +586,7 @@ def _fix_manim_code_impl(llm, error_log, fix_history, manim_code, - **don't remove any image or its effects when making modifications** Please precisely fix the detected issues while maintaining the richness and creativity of the animation. -""" # noqa +""" # noqa inputs = [Message(role='user', content=fix_request)] _response_message = llm.generate(inputs) response = _response_message.content @@ -534,7 +599,8 @@ def _fix_manim_code_impl(llm, error_log, fix_history, manim_code, fix_history = ( f'You have a fix history which generates the code which is given to you:\n\n{fix_request}\n\n' f'If last error is same with latest error, **You probably find a wrong root cause**, ' - f'Check carefully and fix it again.**') + f'Check carefully and fix it again.**' + ) return manim_code, fix_history @staticmethod @@ -556,7 +622,8 @@ def convert_mov_to_compatible(mov_path): fps=24, verbose=False, logger=None, - ffmpeg_params=['-pix_fmt', 'yuva420p']) + ffmpeg_params=['-pix_fmt', 'yuva420p'], + ) clip.close() if RenderManim.verify_and_fix_mov_file(fixed_path): @@ -592,7 +659,8 @@ def scale_video_to_fit(video_path, target_size=None): audio_codec='aac' if scaled_clip.audio else None, fps=24, verbose=False, - logger=None) + logger=None, + ) clip.close() scaled_clip.close() diff --git a/projects/singularity_cinema/render_animation/render_remotion.py b/projects/singularity_cinema/render_animation/render_remotion.py index e58e2927e..cb3c75f7a 100644 --- a/projects/singularity_cinema/render_animation/render_remotion.py +++ b/projects/singularity_cinema/render_animation/render_remotion.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os import re import shutil @@ -8,42 +9,34 @@ from collections import defaultdict from typing import List, Optional, Tuple, Union -import json from moviepy import VideoFileClip +from omegaconf import DictConfig + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() class RenderRemotion(CodeAgent): - - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 'llm_num_parallel', 5) # When enabled, render compositions one-by-one and attempt a fix immediately on failure. # This reduces wasted work when one broken Segment TSX causes global bundler failure. - self.render_immediate_fix = getattr(self.config, - 'render_immediate_fix', True) + self.render_immediate_fix = getattr(self.config, 'render_immediate_fix', True) self.render_dir = os.path.join(self.work_dir, 'remotion_render') - self.remotion_project_dir = os.path.join(self.work_dir, - 'remotion_project') + self.remotion_project_dir = os.path.join(self.work_dir, 'remotion_project') self.remotion_code_dir = os.path.join(self.work_dir, 'remotion_code') self.images_dir = os.path.join(self.work_dir, 'images') self.code_fix_round = getattr(self.config, 'code_fix_round', 3) # Default to 1 to ensure visual quality check runs at least once unless explicitly disabled (-1) self.mllm_check_round = getattr(self.config, 'mllm_fix_round', 1) # Maximum times to attempt automatic visual fixes per segment - self.max_visual_fix_rounds = getattr(self.config, - 'max_visual_fix_rounds', 2) + self.max_visual_fix_rounds = getattr(self.config, 'max_visual_fix_rounds', 2) # Track per-segment visual failure counts self.visual_fail_counts = defaultdict(int) # Track scale per segment for edge clipping retry @@ -51,8 +44,7 @@ def __init__(self, os.makedirs(self.render_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], - **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) with open(os.path.join(self.work_dir, 'audio_info.txt'), 'r') as f: @@ -65,51 +57,37 @@ async def execute_code(self, messages: Union[str, List[Message]], self._ensure_browser(self.remotion_project_dir) logger.info('Installing dependencies...') - subprocess.run( - 'npm install', - cwd=self.remotion_project_dir, - shell=True, - check=True) + subprocess.run('npm install', cwd=self.remotion_project_dir, shell=True, check=True) segment_status = { - i: os.path.exists( - os.path.join(self.render_dir, f'scene_{i+1}', - f'Scene{i+1}.mov')) + i: os.path.exists(os.path.join(self.render_dir, f'scene_{i + 1}', f'Scene{i + 1}.mov')) for i in range(len(segments)) } for round_idx in range(self.code_fix_round + 1): # Identify segments needing render (all initially, then only failed ones) - segments_to_render = [ - i for i, status in segment_status.items() if status is not True - ] + segments_to_render = [i for i, status in segment_status.items() if status is not True] if not segments_to_render: logger.info('All segments rendered successfully.') break - logger.info( - f'Round {round_idx + 1}: Rendering {len(segments_to_render)} segments...' - ) + logger.info(f'Round {round_idx + 1}: Rendering {len(segments_to_render)} segments...') results = {} def _read_current_code(seg_i: int) -> str: - code_path = os.path.join(self.remotion_code_dir, - f'Segment{seg_i+1}.tsx') + code_path = os.path.join(self.remotion_code_dir, f'Segment{seg_i + 1}.tsx') if os.path.exists(code_path): with open(code_path, 'r', encoding='utf-8') as f: return f.read() - project_code_path = os.path.join(self.remotion_project_dir, - 'src', - f'Segment{seg_i+1}.tsx') + project_code_path = os.path.join(self.remotion_project_dir, 'src', f'Segment{seg_i + 1}.tsx') if os.path.exists(project_code_path): with open(project_code_path, 'r', encoding='utf-8') as f: return f.read() return '' - def _extract_error_segment_indices( - log_text: Optional[str]) -> List[int]: + def _extract_error_segment_indices(log_text: Optional[str]) -> List[int]: if not log_text: return [] # esbuild/webpack error lines usually include: ...\src\Segment15.tsx:... @@ -144,9 +122,7 @@ def _extract_error_segment_indices( if not success and error_log and 'EDGE_CLIPPING' in error_log: new_scale = 0.8 self.segment_scales[i] = new_scale - logger.info( - f'Edge clipping detected for segment {i+1}, reducing scale to {new_scale}' - ) + logger.info(f'Edge clipping detected for segment {i + 1}, reducing scale to {new_scale}') # Update Root.tsx with new scale self._update_root_tsx_for_segment(i) segment_status[i] = False # Force retry @@ -158,14 +134,12 @@ def _extract_error_segment_indices( # If bundler fails globally, error_log points to the culprit file. culprit_indices = _extract_error_segment_indices(error_log) to_fix = culprit_indices if culprit_indices else [i] - to_fix = sorted( - {idx - for idx in to_fix if 0 <= idx < len(segments)}) + to_fix = sorted({idx for idx in to_fix if 0 <= idx < len(segments)}) # If the error points to OTHER segments, it means the current segment failed due to global breakage. # Pause and fix the root cause first. logger.info( - f'Immediate fix triggered by failure on segment {i+1}. Fix targets: {[x+1 for x in to_fix]}' + f'Immediate fix triggered by failure on segment {i + 1}. Fix targets: {[x + 1 for x in to_fix]}' ) # Apply fixes @@ -173,49 +147,40 @@ def _extract_error_segment_indices( err_text = error_log or 'Unknown error' current_code = _read_current_code(fix_i) _, fixed_code = self._fix_code_static( - fix_i, err_text, current_code, self.config, - self.remotion_project_dir) + fix_i, err_text, current_code, self.config, self.remotion_project_dir + ) if fixed_code: self._update_segment_code(fix_i, fixed_code) # If we fixed a different segment, we should probably reset its status too if fix_i != i: - segment_status[ - fix_i] = False # Force re-render of the culprit later if it was skipped + segment_status[fix_i] = False # Force re-render of the culprit later if it was skipped return messages def _update_segment_code(self, i, code): # Update in remotion_code_dir (source of truth) - src_file = os.path.join(self.remotion_code_dir, f'Segment{i+1}.tsx') + src_file = os.path.join(self.remotion_code_dir, f'Segment{i + 1}.tsx') with open(src_file, 'w', encoding='utf-8') as f: f.write(code) # Update in remotion_project_dir (execution env) - dst_file = os.path.join(self.remotion_project_dir, 'src', - f'Segment{i+1}.tsx') + dst_file = os.path.join(self.remotion_project_dir, 'src', f'Segment{i + 1}.tsx') with open(dst_file, 'w', encoding='utf-8') as f: f.write(code) def _setup_remotion_project(self, segments, audio_infos): # 1. Create project structure - os.makedirs( - os.path.join(self.remotion_project_dir, 'src'), exist_ok=True) - os.makedirs( - os.path.join(self.remotion_project_dir, 'public', 'images'), - exist_ok=True) + os.makedirs(os.path.join(self.remotion_project_dir, 'src'), exist_ok=True) + os.makedirs(os.path.join(self.remotion_project_dir, 'public', 'images'), exist_ok=True) # Some generated TSX may import assets via relative paths like `./images/foo.png`. # Keep a mirrored copy under `src/images` to avoid bundler module resolution failures. - os.makedirs( - os.path.join(self.remotion_project_dir, 'src', 'images'), - exist_ok=True) + os.makedirs(os.path.join(self.remotion_project_dir, 'src', 'images'), exist_ok=True) if os.path.exists(self.images_dir): for file in os.listdir(self.images_dir): src = os.path.join(self.images_dir, file) - dst_public = os.path.join(self.remotion_project_dir, 'public', - 'images', file) - dst_src = os.path.join(self.remotion_project_dir, 'src', - 'images', file) + dst_public = os.path.join(self.remotion_project_dir, 'public', 'images', file) + dst_src = os.path.join(self.remotion_project_dir, 'src', 'images', file) for dst in (dst_public, dst_src): shutil.copy(src, dst) @@ -231,24 +196,18 @@ def _setup_remotion_project(self, segments, audio_infos): # Extract filename from absolute path filename = os.path.basename(original_path) # Copy to public/images and src/images - dst_public = os.path.join(self.remotion_project_dir, - 'public', 'images', filename) - dst_src = os.path.join(self.remotion_project_dir, 'src', - 'images', filename) + dst_public = os.path.join(self.remotion_project_dir, 'public', 'images', filename) + dst_src = os.path.join(self.remotion_project_dir, 'src', 'images', filename) shutil.copy(original_path, dst_public) shutil.copy(original_path, dst_src) # Store mapping for path replacement user_image_mapping[original_path] = f'images/{filename}' - logger.info( - f'Copied user image: {original_path} -> images/{filename}' - ) + logger.info(f'Copied user image: {original_path} -> images/{filename}') # 3. Copy generated code and replace absolute paths for i in range(len(segments)): - src_file = os.path.join(self.remotion_code_dir, - f'Segment{i+1}.tsx') - dst_file = os.path.join(self.remotion_project_dir, 'src', - f'Segment{i+1}.tsx') + src_file = os.path.join(self.remotion_code_dir, f'Segment{i + 1}.tsx') + dst_file = os.path.join(self.remotion_project_dir, 'src', f'Segment{i + 1}.tsx') if os.path.exists(src_file): # Read file content with open(src_file, 'r', encoding='utf-8') as f: @@ -264,13 +223,13 @@ def _setup_remotion_project(self, segments, audio_infos): else: with open(dst_file, 'w') as f: f.write( - f"import React from 'react';\nexport const Segment{i+1} = () =>
Missing Segment
;" + f"import React from 'react';\nexport const Segment{i + 1} = () =>
Missing Segment
;" ) else: # Create a dummy file if missing to prevent build failure with open(dst_file, 'w') as f: f.write( - f"import React from 'react';\nexport const Segment{i+1} = () =>
Missing Segment
;" + f"import React from 'react';\nexport const Segment{i + 1} = () =>
Missing Segment
;" ) # 4. Create package.json with locked versions @@ -285,18 +244,14 @@ def _setup_remotion_project(self, segments, audio_infos): '@remotion/bundler': '^4.0.0', '@remotion/renderer': '^4.0.0', '@remotion/shapes': '^4.0.0', - '@remotion/media-utils': '^4.0.0' - } + '@remotion/media-utils': '^4.0.0', + }, } - with open( - os.path.join(self.remotion_project_dir, 'package.json'), - 'w') as f: + with open(os.path.join(self.remotion_project_dir, 'package.json'), 'w') as f: json.dump(package_json, f, indent=2) # 5. Create src/index.ts - with open( - os.path.join(self.remotion_project_dir, 'src', 'index.ts'), - 'w') as f: + with open(os.path.join(self.remotion_project_dir, 'src', 'index.ts'), 'w') as f: f.write("import { registerRoot } from 'remotion';\n") f.write("import { RemotionRoot } from './Root';\n") f.write('registerRoot(RemotionRoot);\n') @@ -305,6 +260,7 @@ def _setup_remotion_project(self, segments, audio_infos): self._generate_root_tsx(segments, audio_infos) # 7. Create tsconfig.json + def _generate_root_tsx(self, segments, audio_infos): """Generate Root.tsx with dynamic scale support""" fps = self.config.video.fps @@ -318,13 +274,14 @@ def _generate_root_tsx(self, segments, audio_infos): root_content = "import React from 'react';\n" root_content += "import { Composition } from 'remotion';\n" for i in range(len(segments)): - root_content += f"import * as Segment{i+1}_NS from './Segment{i+1}';\n" + root_content += f"import * as Segment{i + 1}_NS from './Segment{i + 1}';\n" root_content += '\nexport const RemotionRoot: React.FC = () => {\n' for i in range(len(segments)): root_content += ( - f' const Segment{i+1} = Segment{i+1}_NS.default || ' - f'Segment{i+1}_NS.Segment{i+1} || (() => null);\n') + f' const Segment{i + 1} = Segment{i + 1}_NS.default || ' + f'Segment{i + 1}_NS.Segment{i + 1} || (() => null);\n' + ) root_content += ' return (\n' root_content += ' <>\n' @@ -334,8 +291,8 @@ def _generate_root_tsx(self, segments, audio_infos): # Get scale from tracking dict or use default scale = self.segment_scales.get(i, 0.9) root_content += ' Tuple[int, bool, Optional[str]]: + i, segment, duration, config, work_dir, render_dir, remotion_project_dir, mllm_check_round=0, scale=0.9 + ) -> Tuple[int, bool, Optional[str]]: """Static method for multiprocessing""" - composition_id = f'Segment{i+1}' - output_dir_scene = os.path.join(render_dir, f'scene_{i+1}') + composition_id = f'Segment{i + 1}' + output_dir_scene = os.path.join(render_dir, f'scene_{i + 1}') os.makedirs(output_dir_scene, exist_ok=True) - output_path = os.path.abspath( - os.path.join(output_dir_scene, f'Scene{i+1}.mov')) + output_path = os.path.abspath(os.path.join(output_dir_scene, f'Scene{i + 1}.mov')) logger.info(f'Rendering {composition_id} to {output_path}') # Determine remotion command if os.name == 'nt': - remotion_cmd = os.path.abspath( - os.path.join(remotion_project_dir, 'node_modules', '.bin', - 'remotion.cmd')) + remotion_cmd = os.path.abspath(os.path.join(remotion_project_dir, 'node_modules', '.bin', 'remotion.cmd')) else: - remotion_cmd = os.path.abspath( - os.path.join(remotion_project_dir, 'node_modules', '.bin', - 'remotion')) + remotion_cmd = os.path.abspath(os.path.join(remotion_project_dir, 'node_modules', '.bin', 'remotion')) if not os.path.exists(remotion_cmd): remotion_cmd = 'npx remotion' @@ -594,24 +515,21 @@ def _render_remotion_scene_static( '--prores-profile=4444', '--pixel-format=yuva444p10le', '--image-format=png', - '--every-nth-frame=1' # Render every frame for smooth animation + '--every-nth-frame=1', # Render every frame for smooth animation ] # Try to find browser executable (Local > System) browser_executable = None - remotion_cache_dir = os.path.join(remotion_project_dir, 'node_modules', - '.remotion') + remotion_cache_dir = os.path.join(remotion_project_dir, 'node_modules', '.remotion') # 1. Check Local Cache if os.path.exists(remotion_cache_dir): for root, _, files in os.walk(remotion_cache_dir): if 'chrome-headless-shell.exe' in files: - browser_executable = os.path.abspath( - os.path.join(root, 'chrome-headless-shell.exe')) + browser_executable = os.path.abspath(os.path.join(root, 'chrome-headless-shell.exe')) break elif 'chrome-headless-shell' in files: - browser_executable = os.path.abspath( - os.path.join(root, 'chrome-headless-shell')) + browser_executable = os.path.abspath(os.path.join(root, 'chrome-headless-shell')) break # 2. Check System Chrome if not found locally @@ -621,22 +539,20 @@ def _render_remotion_scene_static( if not browser_executable: # shutil is imported at module level browser_executable = ( - shutil.which('chrome') or shutil.which('google-chrome') + shutil.which('chrome') + or shutil.which('google-chrome') or shutil.which('chromium') - or shutil.which('chromium-browser')) + or shutil.which('chromium-browser') + ) if not browser_executable and os.name == 'nt': possible_paths = [ r'C:\Program Files\Google\Chrome\Application\chrome.exe', r'C:\Program Files (x86)\Google\Chrome\Application\chrome.exe', - os.path.expandvars( - r'%LOCALAPPDATA%\Google\Chrome\Application\chrome.exe' - ), + os.path.expandvars(r'%LOCALAPPDATA%\Google\Chrome\Application\chrome.exe'), r'C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe', r'C:\Program Files\Microsoft\Edge\Application\msedge.exe', - os.path.expandvars( - r'%ProgramFiles(x86)%\Microsoft\Edge\Application\msedge.exe' - ), + os.path.expandvars(r'%ProgramFiles(x86)%\Microsoft\Edge\Application\msedge.exe'), ] for p in possible_paths: if os.path.exists(p): @@ -647,10 +563,9 @@ def _render_remotion_scene_static( logger.info(f'Using browser executable: {browser_executable}') base_cmd.extend(['--browser-executable', browser_executable]) # Add stability flags - base_cmd.extend([ - '--chromium-options', - 'no-sandbox,disable-setuid-sandbox,disable-gpu,disable-dev-shm-usage' - ]) + base_cmd.extend( + ['--chromium-options', 'no-sandbox,disable-setuid-sandbox,disable-gpu,disable-dev-shm-usage'] + ) if os.name == 'nt' and 'remotion.cmd' in remotion_cmd: cmd = [remotion_cmd] + base_cmd @@ -666,8 +581,7 @@ def _render_remotion_scene_static( # But subprocess.run with shell=True and a list is tricky. # It's safer to join the command into a string for shell=True on Windows. if os.name == 'nt': - cmd_str = ' '.join( - [f'"{arg}"' if ' ' in arg else arg for arg in cmd]) + cmd_str = ' '.join([f'"{arg}"' if ' ' in arg else arg for arg in cmd]) result = subprocess.run( cmd_str, cwd=remotion_project_dir, @@ -675,7 +589,8 @@ def _render_remotion_scene_static( capture_output=True, text=True, encoding='utf-8', - errors='ignore') + errors='ignore', + ) else: result = subprocess.run( cmd, @@ -684,7 +599,8 @@ def _render_remotion_scene_static( capture_output=True, text=True, encoding='utf-8', - errors='ignore') + errors='ignore', + ) else: result = subprocess.run( cmd, @@ -693,15 +609,13 @@ def _render_remotion_scene_static( capture_output=True, text=True, encoding='utf-8', - errors='ignore') + errors='ignore', + ) if result.returncode != 0: # Capture output was set to True to allow smart error detection. - log_content = (result.stderr or '') + '\n' + ( - result.stdout or '') - logger.warning( - f'Rendering failed for {composition_id}. Log (except): {log_content[:500]}...' - ) + log_content = (result.stderr or '') + '\n' + (result.stdout or '') + logger.warning(f'Rendering failed for {composition_id}. Log (except): {log_content[:500]}...') return i, False, log_content else: logger.info(f'Rendered {composition_id} successfully.') @@ -718,8 +632,8 @@ def _check_edge_clipping(frame_path, threshold=10): Returns True if clipping detected (colored pixels at edges). """ try: - from PIL import Image import numpy as np + from PIL import Image img = Image.open(frame_path).convert('RGB') pixels = np.array(img) @@ -731,8 +645,7 @@ def _check_edge_clipping(frame_path, threshold=10): left_edge = pixels[:, 0, :] right_edge = pixels[:, width - 1, :] - edges = np.concatenate( - [top_edge, bottom_edge, left_edge, right_edge]) + edges = np.concatenate([top_edge, bottom_edge, left_edge, right_edge]) # Check if pixels are near black (0,0,0) or white (255,255,255) near_black = np.all(edges < threshold, axis=1) @@ -749,7 +662,6 @@ def _check_edge_clipping(frame_path, threshold=10): @staticmethod def _extract_preview_frames_static(video_path, segment_id, work_dir): - test_dir = os.path.join(work_dir, 'remotion_test') os.makedirs(test_dir, exist_ok=True) video = VideoFileClip(video_path) @@ -759,39 +671,31 @@ def _extract_preview_frames_static(video_path, segment_id, work_dir): preview_paths = [] for frame_idx, timestamp in timestamps.items(): - output_path = os.path.join( - test_dir, f'segment_{segment_id + 1}_{frame_idx}.png') + output_path = os.path.join(test_dir, f'segment_{segment_id + 1}_{frame_idx}.png') video.save_frame(output_path, t=timestamp) preview_paths.append(output_path) video.close() return preview_paths @staticmethod - def _fix_code_static(i, - error_log, - code, - config, - remotion_project_dir=None): + def _fix_code_static(i, error_log, code, config, remotion_project_dir=None): """Static method for multiprocessing fix""" if not code: return i, '' # 3. Use LLM to fix remaining issues. llm = LLM.from_config(config) - logger.info(f'Fixing code for segment {i+1} with LLM...') - return i, RenderRemotion._fix_code_impl(llm, error_log, code, - remotion_project_dir) + logger.info(f'Fixing code for segment {i + 1} with LLM...') + return i, RenderRemotion._fix_code_impl(llm, error_log, code, remotion_project_dir) @staticmethod def _fix_code_impl(llm, error_log, code, remotion_project_dir=None): available_images_info = '' if remotion_project_dir: - images_path = os.path.join(remotion_project_dir, 'public', - 'images') + images_path = os.path.join(remotion_project_dir, 'public', 'images') if os.path.exists(images_path): files = sorted(os.listdir(images_path)) - available_images_info = '\nAvailable images in public/images/:\n' + '\n'.join( - [f'- {f}' for f in files]) + available_images_info = '\nAvailable images in public/images/:\n' + '\n'.join([f'- {f}' for f in files]) if 'VISUAL CHECK FAILED' in error_log: fix_prompt = f""" @@ -858,9 +762,7 @@ def _fix_code_impl(llm, error_log, code, remotion_project_dir=None): response = _response_message.content # Robust code extraction using regex - code_match = re.search( - r'```(?:typescript|tsx|js|javascript)?\s*(.*?)```', response, - re.DOTALL) + code_match = re.search(r'```(?:typescript|tsx|js|javascript)?\s*(.*?)```', response, re.DOTALL) if code_match: code = code_match.group(1) else: diff --git a/projects/singularity_cinema/segment/agent.py b/projects/singularity_cinema/segment/agent.py index 35df82260..a72804a25 100644 --- a/projects/singularity_cinema/segment/agent.py +++ b/projects/singularity_cinema/segment/agent.py @@ -1,18 +1,18 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os from copy import deepcopy -import json +from omegaconf import DictConfig + from ms_agent.agent import LLMAgent from ms_agent.llm import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() class Segment(LLMAgent): - system = """你是一名动画分镜设计师。现在有一个短视频场景需要进行分镜设计。分镜需要满足以下条件: - 每个分镜包含: @@ -60,7 +60,7 @@ class Segment(LLMAgent): ... ] ``` -""" # noqa +""" # noqa video_prompt = """- 你可以使用文生视频功能来渲染某些镜头,这可以增强短视频的整体趣味性和可读性 * 当使用文生视频渲染某些镜头时,返回的结构应该只包含三个字段:index、content和video。不要包含其他字段如{animation_engine}、background等。换句话说,文生视频镜头不应该包含动画引擎或背景图片 @@ -69,13 +69,9 @@ class Segment(LLMAgent): * **生成具有强动态效果的视频,而不是只有镜头移动的静态场景。你需要在视频中讲好你的故事** * video字段包含你对文生视频生成的要求。注意生成的视频如何与前后镜头协调 * 如果你使用多个文生视频镜头,注意保持角色、建筑、动物等的ID一致性 - * 需要叙述摄像机和镜头信息,集中于讲述故事、推进情节和深化主题""" # noqa + * 需要叙述摄像机和镜头信息,集中于讲述故事、推进情节和深化主题""" # noqa - def __init__(self, - config: DictConfig, - tag: str, - trust_remote_code: bool = False, - **kwargs): + def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): _config = deepcopy(config) _config.tools = DictConfig({}) super().__init__(_config, tag, trust_remote_code, **kwargs) @@ -89,8 +85,7 @@ async def create_messages(self, messages): video_prompt = self.video_prompt if self.config.use_text2video else '' video_prompt = video_prompt.format(animation_engine=self.engine) - system = system.format( - video_prompt=video_prompt, animation_engine=self.engine) + system = system.format(video_prompt=video_prompt, animation_engine=self.engine) return [ Message(role='system', content=system), @@ -110,10 +105,7 @@ async def run(self, messages, **kwargs): if self.config.background != 'image': image_prompt = f'\n\n背景图片无需生成,是纯色:{self.config.background}\n\n' - query = (f'原始主题:\n\n{topic}\n\n' - f'原始脚本:\n\n{script}\n\n' - f'{image_prompt}' - f'请完成你的动画分镜设计:\n') + query = f'原始主题:\n\n{topic}\n\n原始脚本:\n\n{script}\n\n{image_prompt}请完成你的动画分镜设计:\n' messages = await super().run(query, **kwargs) response = messages[-1].content if '```json' in response: @@ -147,9 +139,10 @@ async def run(self, messages, **kwargs): return messages async def add_images(self, segments, topic, script, **kwargs): - - video_prompt = ('注意:不需要修改包含video字段的镜头。这些镜头是文生视频镜头,它不需要背景、动画或前景图片。' - '只需在返回值中保留并返回这些镜头的index即可。') + video_prompt = ( + '注意:不需要修改包含video字段的镜头。这些镜头是文生视频镜头,它不需要背景、动画或前景图片。' + '只需在返回值中保留并返回这些镜头的index即可。' + ) if not self.config.use_text2video: video_prompt = '' @@ -208,14 +201,13 @@ async def add_images(self, segments, topic, script, **kwargs): ] 现在开始: -""" # noqa +""" # noqa # Format the system prompt with the actual engine name animation_engine = self.engine animation_engine_cap = animation_engine.capitalize() system = system.format( - video_prompt=video_prompt, - animation_engine=animation_engine, - animation_engine_cap=animation_engine_cap) + video_prompt=video_prompt, animation_engine=animation_engine, animation_engine_cap=animation_engine_cap + ) new_image_info = '未提供图片。' name_mapping = {} @@ -223,9 +215,7 @@ async def add_images(self, segments, topic, script, **kwargs): with open(os.path.join(self.work_dir, 'image_info.txt'), 'r') as f: image_info = f.readlines() - image_info = [ - image.strip() for image in image_info if image.strip() - ] + image_info = [image.strip() for image in image_info if image.strip()] image_list = [] for i, info in enumerate(image_info): info = json.loads(info) @@ -242,7 +232,8 @@ async def add_images(self, segments, topic, script, **kwargs): f'原始脚本:\n\n{script}\n\n' f'原始分镜:\n\n{json.dumps(segments, ensure_ascii=False, indent=4)}\n\n' f'用户提供的图片:\n\n{new_image_info}\n\n' - f'请完成你的图片设计:\n') + f'请完成你的图片设计:\n' + ) messages = [ Message(role='system', content=system), Message(role='user', content=query), diff --git a/setup.py b/setup.py index 1bded7e04..a79918b67 100644 --- a/setup.py +++ b/setup.py @@ -2,9 +2,10 @@ # !/usr/bin/env python import os import shutil +from typing import List + from setuptools import find_packages, setup from setuptools.command.build_py import build_py as _build_py -from typing import List def readme(): @@ -41,6 +42,7 @@ def parse_requirements(fname='requirements.txt', with_version=True): import re import sys from os.path import exists + require_fpath = fname def parse_line(line): @@ -70,8 +72,7 @@ def parse_line(line): if ';' in rest: # Handle platform specific dependencies # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies - version, platform_deps = map(str.strip, - rest.split(';')) + version, platform_deps = map(str.strip, rest.split(';')) info['platform_deps'] = platform_deps else: version = rest # NOQA @@ -85,8 +86,7 @@ def parse_require_file(fpath): if line.startswith('http'): print('skip http requirements %s' % line) continue - if line and not line.startswith('#') and not line.startswith( - '--'): + if line and not line.startswith('#') and not line.startswith('--'): for info in parse_line(line): yield info elif line and line.startswith('--find-links'): @@ -121,7 +121,6 @@ def gen_packages_items(): class build_py(_build_py): - def run(self): super().run() @@ -147,8 +146,7 @@ def _build_and_copy_webui(self): webui_src = os.path.join(repo_root, 'webui') if not os.path.isdir(webui_src): - print( - 'Warning: webui directory not found, skipping webui packaging') + print('Warning: webui directory not found, skipping webui packaging') return frontend_src = os.path.join(webui_src, 'frontend') @@ -156,17 +154,11 @@ def _build_and_copy_webui(self): # Check if npm is available try: - subprocess.run(['npm', '--version'], - capture_output=True, - check=True, - timeout=5) + subprocess.run(['npm', '--version'], capture_output=True, check=True, timeout=5) npm_available = True - except (subprocess.CalledProcessError, FileNotFoundError, - subprocess.TimeoutExpired): + except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired): npm_available = False - print( - 'Warning: npm not found, cannot build frontend. WebUI may not work properly.' - ) + print('Warning: npm not found, cannot build frontend. WebUI may not work properly.') # Build frontend if npm is available if npm_available and os.path.isdir(frontend_src): @@ -177,24 +169,16 @@ def _build_and_copy_webui(self): if not os.path.exists(node_modules): print('Installing frontend dependencies...') try: - subprocess.run(['npm', 'install'], - cwd=frontend_src, - check=True, - timeout=300) - except (subprocess.CalledProcessError, - subprocess.TimeoutExpired) as e: + subprocess.run(['npm', 'install'], cwd=frontend_src, check=True, timeout=300) + except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: print(f'Warning: npm install failed: {e}') return # Build frontend try: - subprocess.run(['npm', 'run', 'build'], - cwd=frontend_src, - check=True, - timeout=300) + subprocess.run(['npm', 'run', 'build'], cwd=frontend_src, check=True, timeout=300) print('Frontend built successfully') - except (subprocess.CalledProcessError, - subprocess.TimeoutExpired) as e: + except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: print(f'Warning: npm build failed: {e}') return @@ -219,23 +203,17 @@ def _build_and_copy_webui(self): shutil.copytree(frontend_dist_src, frontend_dst) print(f'Copied frontend dist to {frontend_dst}') else: - print( - 'Warning: frontend dist not found, WebUI may not work in production mode' - ) + print('Warning: frontend dist not found, WebUI may not work in production mode') if __name__ == '__main__': - print( - 'Usage: `python setup.py sdist bdist_wheel` or `pip install .[framework]` from source code' - ) + print('Usage: `python setup.py sdist bdist_wheel` or `pip install .[framework]` from source code') - install_requires, deps_link = parse_requirements( - 'requirements/framework.txt') + install_requires, deps_link = parse_requirements('requirements/framework.txt') extra_requires = {} all_requires = [] - extra_requires['research'], _ = parse_requirements( - 'requirements/research.txt') + extra_requires['research'], _ = parse_requirements('requirements/research.txt') extra_requires['code'], _ = parse_requirements('requirements/code.txt') extra_requires['webui'], _ = parse_requirements('requirements/webui.txt') all_requires.extend(install_requires) @@ -247,8 +225,7 @@ def _build_and_copy_webui(self): setup( name='ms-agent', version=get_version(), - description= - 'MS-Agent: Lightweight Framework for Empowering Agents with Autonomous Exploration', + description='MS-Agent: Lightweight Framework for Empowering Agents with Autonomous Exploration', long_description=readme(), long_description_content_type='text/markdown', author='The ModelScope teams', @@ -280,8 +257,7 @@ def _build_and_copy_webui(self): license='Apache License 2.0', install_requires=install_requires, extras_require=extra_requires, - entry_points={ - 'console_scripts': ['ms-agent=ms_agent.cli.cli:run_cmd'] - }, + entry_points={'console_scripts': ['ms-agent=ms_agent.cli.cli:run_cmd']}, dependency_links=deps_link, - zip_safe=False) + zip_safe=False, + ) diff --git a/shell-grep-glob-workspace-policy.md b/shell-grep-glob-workspace-policy.md deleted file mode 100644 index ac4e3f912..000000000 --- a/shell-grep-glob-workspace-policy.md +++ /dev/null @@ -1,225 +0,0 @@ -# Shell / Grep / Glob 与策略内核架构方案 - -本文档描述在 modelscope-agent 中为 **Shell**、**Grep**、**Glob** 提供统一的安全、权限、沙箱与产物管理的设计,以及与 **`feat/agent-tool-overhaul`** 分支中 **TaskManager**(后台 Agent、预留 Shell)的兼容方式。 - ---- - -## 1. 目标与边界 - -### 目标 - -- 在「同一工作区、同一沙箱视图」下,为 **Shell / Grep / Glob** 提供统一的: - - **安全**(命令与路径约束) - - **权限**(只读 / 写工作区 / 网络等分级) - - **沙箱**(本地子进程 vs Docker enclave 等与现有 `CodeExecutionTool` 对齐) - - **产物管理**(大 stdout/stderr 落盘、预览、配额) -- **默认 `allow_list`(允许根路径)包含 `output_dir`**(及其规范化的绝对路径),可配置追加其它根。 - -### 边界 - -- **不替代** `FileSystemTool` 的精确编辑与读缓存等语义;Shell 面向构建、包管理、复杂管道。 -- **Grep / Glob** 作为**只读发现面**的独立工具,减少对裸 shell 的依赖;复杂 `find -exec` 等仍可由受控 Shell 在更高权限模式下完成(若产品允许)。 - ---- - -## 2. 分层架构 - -``` -┌─────────────────────────────────────────────────────────────┐ -│ Tool Facade 层 │ -│ ShellTool │ GrepTool │ GlobTool (独立 JSON Schema) │ -└────────────┬───────────────────────────────┬────────────────┘ - │ │ -┌────────────▼───────────────────────────────▼────────────────┐ -│ WorkspacePolicyKernel(策略内核,纯逻辑、可单测) │ -│ - roots: 默认含 canonical(output_dir),可配置追加 │ -│ - allow_list / deny_list 合并与优先级 │ -│ - resolve_path(rel|abs) → 必须在 allow_roots 下 │ -│ - classify(op): read | search | mutate | exec | network_hint │ -└────────────┬────────────────────────────────────────────────┘ - │ -┌────────────▼────────────────────────────────────────────────┐ -│ SandboxRuntime(执行面,可替换实现) │ -│ - LocalProcessRuntime(asyncio subprocess,cwd=workspace) │ -│ - EnclaveRuntime(现有 ms_enclave / CodeExecutionTool 路径) │ -│ - 会话级 sandbox_id / working_dir 与挂载点一致 │ -└────────────┬────────────────────────────────────────────────┘ - │ -┌────────────▼────────────────────────────────────────────────┐ -│ ArtifactManager(产物管理) │ -│ - 超阈值 stdout/stderr → 落盘 + preview + 相对路径引用 │ -│ - 按 task_id / tool_call_id 分目录 │ -│ - TTL / 总配额(建议:output_dir/.ms_agent_artifacts/) │ -└─────────────────────────────────────────────────────────────┘ -``` - -**原则**:Grep/Glob 的**主路径**不是「拼一条 shell 给模型」;内部可调用 `rg` 或文件系统 walk,但必须经过 **PolicyKernel** 与 **SandboxRuntime**,输出经 **ArtifactManager**。 - ---- - -## 3. WorkspacePolicyKernel(共享策略内核) - -### 3.1 默认 allow_list(允许根集合) - -- 初始化:`allow_roots = { canonical_abs(output_dir) }`。 -- 配置可追加,例如:`tools.code_executor.extra_allow_roots` 或 `tools.workspace_policy.allow`(列表),合并去重。 -- Shell / Grep / Glob 涉及的 **`path`、`cwd`、搜索根目录** 均先执行 `resolve_under_allow_roots()`;失败则**拒绝**并返回结构化错误(不静默改路径到其它目录)。 - -### 3.2 权限与操作分类(建议) - -| 类别 | 示例 | Shell | Grep | Glob | -|------|------|-------|------|------| -| read | 读取工作区内文件 | 受模式 + 策略约束 | ✓ | ✓ | -| search | 内容/文件名发现 | 可引导至 Grep/Glob | ✓ | ✓ | -| mutate | rm、chmod、git 写入等 | 需 `workspace_write` | — | — | -| network | curl、pip 等 | 需显式 **network** 能力位 | — | — | - -Shell 在 **`read_only`** 模式下:仅允许白名单类命令(如 `git status`/`diff`/`log`、只读参数的 `rg` 等),并对重定向、写入工作区外等行为做拒绝或降级(可用前缀表 + 危险模式黑名单,必要时辅以轻量解析)。 - -### 3.3 Shell 安全补充 - -- **固定 cwd**:默认 `workspace_root`(与 `output_dir` 或沙箱内挂载点一致)。 -- **环境变量**:最小集或白名单继承;避免将宿主敏感变量原样传入。 -- **命令预处理**:与现有 `CodeExecutionTool.shell_executor` 思路一致——含 `| && ; > <` 等时使用 `sh -lc` 与安全 quoting;另加**命令长度上限**、**可配置的危险构造限制**(如嵌套命令替换,按产品分级)。 -- **(暂时不做)** 与 `FileSystemTool` 的「写前必读 / staleness」策略对齐:对会修改工作区文件的 Shell 子类共享元数据(若产品需要强一致)。 - ---- - -## 4. SandboxRuntime(共享沙箱) - -- **会话级**:每个 Agent 运行周期内一个 `SandboxSession`(或复用现有 `sandbox_id`)。 -- **Shell / Grep / Glob** 共用同一 **`working_dir` / 挂载视图** 与同一 **`SandboxRuntime` 实现**(本地 `asyncio` 子进程 vs Docker enclave,由 `implementation: sandbox | python_env` 等与现有一致)。 -- **Grep**:在 enclave 内调用 `rg` 或使用宿主 `ripgrep` 库(由部署二选一);**Glob**:在策略解析后的根上做目录遍历或 `pathspec`,避免默认可执行任意 `find -exec`。 - ---- - -## 5. ArtifactManager(产物管理) - -- **阈值**:例如 stdout+stderr 合计超过 N KB 则 spill 至 - `{output_dir}/.ms_agent_artifacts/{tool_name}/{task_or_call_id}.txt`(路径可配置)。 -- **返回**:JSON 中包含 `preview`(首尾若干字符/行)、`artifact_path`(相对 `output_dir`)、`truncated: true`。 -- **与 TaskManager 配合**:后台任务完成时,`TaskManager.complete(task_id, result)` 的 `result` 宜为「短摘要 + artifact 路径」,避免通知与下一轮上下文被撑爆。 - ---- - -## 6. GrepTool / GlobTool(独立工具、共享内核) - -- **输入**:结构化字段(如 pattern、path、glob、head_limit、offset、output_mode),不把「整条 shell」作为唯一 API。 -- **实现**:内部调用 `SandboxRuntime.exec_rg(...)` 或在策略内核限定根上的 glob 遍历;**禁止**由用户可控字符串直接拼接未校验的 shell。 -- **共享**:同一 `WorkspacePolicyKernel` + `SandboxRuntime` + `ArtifactManager`(由 `ToolManager` 或执行类工具在初始化时注入)。 -- **注册**:在 `ToolManager` 中作为独立 `ToolBase`(可一个 server 多个 tool,或两个 server);与 `file_system` 解耦,保持 `file_system` 精简。 - ---- - -## 7. 与 `feat/agent-tool-overhaul` 的 Task 体系兼容 - -### 7.1 分支中的现状(摘要) - -- **`TaskManager`**(`ms_agent/utils/task_manager.py`):进程级后台任务注册表;`BackgroundTask` 中 **`task_type` 注释已包含 `'agent' | 'shell'`**。 -- **`AgentTool`**:`run_in_background` 时 `register(task_type='agent', proc=mp.Process, ...)`,watcher 在子进程结束后调用 `complete` / `fail`;`LLMAgent` 通过 `set_task_manager` 注入同一 `TaskManager`,每轮 `drain_notifications()` 将完成事件注入对话。 - -### 7.2 Shell 后台(与 Agent 对称) - -**建议接口** - -- **同步**:`shell_executor(command, timeout)` → 行为与现网接近,但走 PolicyKernel + ArtifactManager。 -- **后台**:增加 `run_in_background: bool`(或等价命名), **`__call_id`**(与 `AgentTool` 注入一致,便于对账与「推后台」扩展)。 - -**后台行为** - -1. `task_id = task_manager.register(task_type='shell', tool_name='shell_executor', description=command[:200], proc=...)` -2. `proc` 可为 **`asyncio.create_subprocess_*` 返回的 `Process`**(与 Agent 的 `mp.Process` 不同,需在 **`TaskManager.kill` / `kill_all` 中扩展**:对 `asyncio.subprocess.Process` 调用 `kill()` / `terminate()`,并处理已结束进程)。 -3. `asyncio.create_task(watcher)`:等待结束 → `ArtifactManager.maybe_spill` → `await task_manager.complete(task_id, result_str)`(失败则 `fail`)。 - -**立即返回 JSON**(与 Agent 后台对齐,便于统一文档与客户端): - -```json -{ - "status": "async_launched", - "task_id": "", - "tool_name": "shell_executor" -} -``` - -### 7.3 LLMAgent 接线 - -- 与 overhaul 一致:构造 `TaskManager()`,遍历 `extra_tools`,若实现 **`set_task_manager(self.task_manager)`** 则注入。 -- **`LocalCodeExecutionTool` / 未来的 `SecureShellTool`** 实现 `set_task_manager`,与 `AgentTool` 共享同一 `TaskManager` 实例。 - -### 7.4 长同步 Shell → Escape 到后台 - -- 与 `AgentTool._run_sync_escapable` 类似:同步 Shell 带 `sync_timeout_s`,超时或显式信号后取消当前子进程并改为 `register(task_type='shell', ...)` 后台重跑或仅保留已产出部分(产品二选一)。 -- 若存在 **TaskControlTool** 类机制,可复用「`__call_id` + escape 事件」模式,Shell 侧维护 `call_id → Process` 映射以支持 **kill / escape**。 - -### 7.5 兼容对照表 - -| 能力 | overhaul 行为 | 本方案落点 | -|------|----------------|------------| -| 后台 Agent | `register(task_type='agent', proc=Process)` | 不变 | -| 预留 Shell | `task_type` 含 `'shell'` | `shell_executor(run_in_background=true)` 走同一 register / complete / fail | -| 回合内通知 | `drain_notifications()` | Shell 完成同样入队 | -| Kill / 清理 | `kill` / `kill_all` | 扩展支持 asyncio 子进程;watcher `finally` 释放资源 | - ---- - -## 8. 配置示例(OmegaConf / YAML 意向) - -```yaml -tools: - workspace_policy: - allow_roots: [] # 追加;默认已含 output_dir - deny_globs: ['**/.git/**'] - code_executor: - implementation: python_env # or sandbox - shell: - default_mode: workspace_write # read_only | workspace_write - max_output_kb: 256 - wall_time_s: 900 - grep: - default_head_limit: 250 - glob: - max_files: 100 -``` - ---- - -## 9. 实施顺序建议 - -1. 抽出 **`WorkspacePolicyKernel`** + 单元测试(路径解析、默认 `output_dir`、追加 allow)。 -2. 实现 **`ArtifactManager`**,接到现有 `shell_executor` 返回(先本地工具、后接沙箱)。 -3. 将 **`TaskManager`**(overhaul)合入主线并 **扩展 `kill` 支持 `asyncio.subprocess.Process`**。 -4. **`LocalCodeExecutionTool.set_task_manager` + `run_in_background` 的 `shell_executor`**。 -5. 新增 **GrepTool / GlobTool** façade,共享上述内核与运行时。 -6. 更新文档与系统提示:默认 **发现用 Grep/Glob,构建用 Shell,改文件用 file_system**。 - ---- - -## 10. 设计取舍小结 - -- **Shell**:强约束的通用执行面 + 后台,与 **TaskManager** 统一生命周期与通知。 -- **Grep / Glob**:独立 Schema、只读、易截断,与 Shell **共享策略与沙箱**,避免把一切搜索都绑在一条 shell 字符串上。 -- **默认 allow_roots 含 `output_dir`**:与现有 Agent 工作区模型一致,减少越权访问宿主路径的风险。 - ---- - -## 修订记录 - -| 日期 | 说明 | -|------|------| -| 2026-04-13 | 初版:根据设计与 `feat/agent-tool-overhaul` 中 TaskManager / AgentTool 后台模型整理成文。 | -| 2026-04-13 | 实现落地:见下文「实现映射」。 | - -## 11. 实现映射(代码位置) - -| 组件 | 路径 | -|------|------| -| WorkspacePolicyKernel | `ms_agent/utils/workspace_policy.py` | -| ArtifactManager | `ms_agent/utils/artifact_manager.py` | -| TaskManager | `ms_agent/utils/task_manager.py` | -| Shell 策略 / 产物 / 后台 | `ms_agent/tools/code/local_code_executor.py`(`set_task_manager`、`shell_executor`) | -| Grep / Glob | `ms_agent/tools/filesystem_tool.py` 中 `grep` / `glob` 工具(与 `read_file` / `edit_file` / `write_file` 同属 `file_system` server;用 `tools.file_system.include` / `exclude` 控制)。可选键:`grep_timeout_s`、`grep_head_limit`、`glob_max_files`;`include` 短名 `read` / `edit` / `write` 分别等价 `read_file` / `edit_file` / `write_file`。 | -| `__call_id` 注入 shell | `ms_agent/tools/tool_manager.py` | -| TaskManager 与通知 | `ms_agent/agent/llm_agent.py`(`prepare_tools` / `cleanup_tools` / `_append_task_notifications`) | -| 单测 | `tests/utils/test_workspace_policy.py` | - -**未在本阶段实现**:文档 §7.4 长同步 Shell escape 到后台;Docker `CodeExecutionTool` 侧 shell 与策略对齐(仍沿用原沙箱实现)。 diff --git a/webui/backend/agent_runner.py b/webui/backend/agent_runner.py index 1c122ff27..37d3ec3b7 100644 --- a/webui/backend/agent_runner.py +++ b/webui/backend/agent_runner.py @@ -3,6 +3,7 @@ Agent runner for MS-Agent Web UI Manages the execution of ms-agent through subprocess with log streaming. """ + import asyncio import os import re @@ -18,16 +19,18 @@ class AgentRunner: """Runs ms-agent as a subprocess with output streaming""" - def __init__(self, - session_id: str, - project: Dict[str, Any], - config_manager, - on_output: Callable[[Dict[str, Any]], None] = None, - on_log: Callable[[Dict[str, Any]], None] = None, - on_progress: Callable[[Dict[str, Any]], None] = None, - on_complete: Callable[[Dict[str, Any]], None] = None, - on_error: Callable[[Dict[str, Any]], None] = None, - workflow_type: str = 'standard'): + def __init__( + self, + session_id: str, + project: Dict[str, Any], + config_manager, + on_output: Callable[[Dict[str, Any]], None] = None, + on_log: Callable[[Dict[str, Any]], None] = None, + on_progress: Callable[[Dict[str, Any]], None] = None, + on_complete: Callable[[Dict[str, Any]], None] = None, + on_error: Callable[[Dict[str, Any]], None] = None, + workflow_type: str = 'standard', + ): self.session_id = session_id self.project = project self.config_manager = config_manager @@ -53,8 +56,7 @@ def __init__(self, self._current_tool_args = None # Current tool arguments self._current_tool_result = None # Current tool result self._tool_call_json_buffer = '' # Buffer for collecting multi-line JSON tool call info - self._is_chat_mode = project.get( - 'id') == '__chat__' # Simple chat mode flag + self._is_chat_mode = project.get('id') == '__chat__' # Simple chat mode flag self._chat_response_buffer = '' # Buffer for chat mode responses async def start(self, query: str): @@ -73,11 +75,13 @@ async def start(self, query: str): # Log the command if self.on_log: - self.on_log({ - 'level': 'info', - 'message': f'Starting agent: {" ".join(cmd[:5])}...', - 'timestamp': datetime.now().isoformat() - }) + self.on_log( + { + 'level': 'info', + 'message': f'Starting agent: {" ".join(cmd[:5])}...', + 'timestamp': datetime.now().isoformat(), + } + ) # Start subprocess self.process = await asyncio.create_subprocess_exec( @@ -87,7 +91,8 @@ async def start(self, query: str): stdin=asyncio.subprocess.PIPE, env=env, cwd=self.project['path'], - start_new_session=True) + start_new_session=True, + ) print(f'[Runner] Process started with PID: {self.process.pid}') @@ -97,6 +102,7 @@ async def start(self, query: str): except Exception as e: print(f'[Runner] ERROR: {e}') import traceback + traceback.print_exc() if self.on_error: self.on_error({'message': str(e), 'type': 'startup_error'}) @@ -142,35 +148,25 @@ async def send_input(self, text: str): if not self.process: print('[Runner] ERROR: Process is None, cannot send input') if self.on_error: - self.on_error({ - 'message': - 'Agent process is not running. Please start a new conversation.', - 'type': 'input_error' - }) + self.on_error( + {'message': 'Agent process is not running. Please start a new conversation.', 'type': 'input_error'} + ) return # Check if process has exited if self.process.returncode is not None: - print( - f'[Runner] ERROR: Process has exited with code {self.process.returncode}, cannot send input' - ) + print(f'[Runner] ERROR: Process has exited with code {self.process.returncode}, cannot send input') if self.on_error: - self.on_error({ - 'message': - 'Agent process has terminated. Please start a new conversation.', - 'type': 'input_error' - }) + self.on_error( + {'message': 'Agent process has terminated. Please start a new conversation.', 'type': 'input_error'} + ) return # Check if stdin is available if not self.process.stdin: print('[Runner] ERROR: Process stdin is None, cannot send input') if self.on_error: - self.on_error({ - 'message': - 'Cannot send input: process stdin is not available.', - 'type': 'input_error' - }) + self.on_error({'message': 'Cannot send input: process stdin is not available.', 'type': 'input_error'}) return print(f'[Runner] Sending input to agent: {text[:100]}...') @@ -188,11 +184,12 @@ async def send_input(self, text: str): except (BrokenPipeError, RuntimeError, OSError) as e: print(f'[Runner] ERROR: Failed to send input: {e}') if self.on_error: - self.on_error({ - 'message': - f'Failed to send input: Process may have terminated. Error: {str(e)}', - 'type': 'input_error' - }) + self.on_error( + { + 'message': f'Failed to send input: Process may have terminated. Error: {str(e)}', + 'type': 'input_error', + } + ) # Mark process as not running self.is_running = False self._waiting_for_input = False @@ -208,8 +205,7 @@ def _build_command(self, query: str) -> list: workflow_type = getattr(self, '_workflow_type', 'standard') if workflow_type == 'simple' and project_type == 'workflow': # For code_genesis with simple workflow, use simple_workflow.yaml - simple_config_file = os.path.join(project_path, - 'simple_workflow.yaml') + simple_config_file = os.path.join(project_path, 'simple_workflow.yaml') if os.path.exists(simple_config_file): config_file = simple_config_file @@ -221,10 +217,7 @@ def _build_command(self, query: str) -> list: if project_type == 'workflow' or project_type == 'agent': # Use ms-agent CLI command (installed via entry point) - cmd = [ - 'ms-agent', 'run', '--config', config_file, - '--trust_remote_code', 'true' - ] + cmd = ['ms-agent', 'run', '--config', config_file, '--trust_remote_code', 'true'] if query: cmd.extend(['--query', query]) @@ -234,126 +227,85 @@ def _build_command(self, query: str) -> list: # Add LLM config from user settings llm_config = self.config_manager.get_llm_config() - temperature_enabled = bool( - llm_config.get('temperature_enabled', False)) + temperature_enabled = bool(llm_config.get('temperature_enabled', False)) if llm_config.get('api_key'): provider = llm_config.get('provider', 'modelscope') if provider == 'modelscope': - cmd.extend( - ['--llm.modelscope_api_key', llm_config['api_key']]) + cmd.extend(['--llm.modelscope_api_key', llm_config['api_key']]) # Set llm.service to modelscope to ensure the correct service is used cmd.extend(['--llm.service', 'modelscope']) # Pass base_url if set by user if llm_config.get('base_url'): - cmd.extend([ - '--llm.modelscope_base_url', llm_config['base_url'] - ]) + cmd.extend(['--llm.modelscope_base_url', llm_config['base_url']]) # Pass model if set by user if llm_config.get('model'): cmd.extend(['--llm.model', llm_config['model']]) # Pass temperature if set by user (in generation_config) - if temperature_enabled and llm_config.get( - 'temperature') is not None: - cmd.extend([ - '--generation_config.temperature', - str(llm_config['temperature']) - ]) + if temperature_enabled and llm_config.get('temperature') is not None: + cmd.extend(['--generation_config.temperature', str(llm_config['temperature'])]) # Pass max_tokens if set by user (in generation_config) if llm_config.get('max_tokens'): - cmd.extend([ - '--generation_config.max_tokens', - str(llm_config['max_tokens']) - ]) + cmd.extend(['--generation_config.max_tokens', str(llm_config['max_tokens'])]) elif provider == 'openai': cmd.extend(['--llm.openai_api_key', llm_config['api_key']]) # Set llm.service to openai to ensure the correct service is used cmd.extend(['--llm.service', 'openai']) # Pass base_url if set by user if llm_config.get('base_url'): - cmd.extend( - ['--llm.openai_base_url', llm_config['base_url']]) + cmd.extend(['--llm.openai_base_url', llm_config['base_url']]) # Pass model if set by user if llm_config.get('model'): cmd.extend(['--llm.model', llm_config['model']]) # Pass temperature if set by user (in generation_config) - if temperature_enabled and llm_config.get( - 'temperature') is not None: - cmd.extend([ - '--generation_config.temperature', - str(llm_config['temperature']) - ]) + if temperature_enabled and llm_config.get('temperature') is not None: + cmd.extend(['--generation_config.temperature', str(llm_config['temperature'])]) # Pass max_tokens if set by user (in generation_config) if llm_config.get('max_tokens'): - cmd.extend([ - '--generation_config.max_tokens', - str(llm_config['max_tokens']) - ]) + cmd.extend(['--generation_config.max_tokens', str(llm_config['max_tokens'])]) # Add edit_file_config from user settings (skip for chat mode) if self.project.get('id') != '__chat__': edit_file_config = self.config_manager.get_edit_file_config() if edit_file_config.get('api_key'): # If API key is provided, pass edit_file_config - cmd.extend([ - '--tools.file_system.edit_file_config.api_key', - edit_file_config['api_key'] - ]) + cmd.extend(['--tools.file_system.edit_file_config.api_key', edit_file_config['api_key']]) if edit_file_config.get('base_url'): - cmd.extend([ - '--tools.file_system.edit_file_config.base_url', - edit_file_config['base_url'] - ]) + cmd.extend(['--tools.file_system.edit_file_config.base_url', edit_file_config['base_url']]) if edit_file_config.get('diff_model'): - cmd.extend([ - '--tools.file_system.edit_file_config.diff_model', - edit_file_config['diff_model'] - ]) + cmd.extend(['--tools.file_system.edit_file_config.diff_model', edit_file_config['diff_model']]) else: # If no API key, exclude edit_file from tools # Read the current include list from config file and remove edit_file try: with open(config_file, 'r', encoding='utf-8') as f: config_data = yaml.safe_load(f) - if config_data and 'tools' in config_data and 'file_system' in config_data[ - 'tools']: - include_list = config_data['tools'][ - 'file_system'].get('include', []) - if isinstance( - include_list, - list) and 'edit_file' in include_list: + if config_data and 'tools' in config_data and 'file_system' in config_data['tools']: + include_list = config_data['tools']['file_system'].get('include', []) + if isinstance(include_list, list) and 'edit_file' in include_list: # Remove edit_file from the list - filtered_include = [ - tool for tool in include_list - if tool != 'edit_file' - ] + filtered_include = [tool for tool in include_list if tool != 'edit_file'] # Pass the filtered list as comma-separated string - cmd.extend([ - '--tools.file_system.include', - ','.join(filtered_include) - ]) + cmd.extend(['--tools.file_system.include', ','.join(filtered_include)]) except Exception as e: - print( - f'[Runner] Warning: Could not read config file to exclude edit_file: {e}' - ) + print(f'[Runner] Warning: Could not read config file to exclude edit_file: {e}') # Fallback: explicitly exclude edit_file - cmd.extend( - ['--tools.file_system.exclude', 'edit_file']) + cmd.extend(['--tools.file_system.exclude', 'edit_file']) # Add EdgeOne Pages API token and project name from user settings - edgeone_pages_config = self.config_manager.get_edgeone_pages_config( - ) + edgeone_pages_config = self.config_manager.get_edgeone_pages_config() if edgeone_pages_config.get('api_token'): # If API token is provided, pass it to the MCP server config - cmd.extend([ - '--tools.edgeone-pages-mcp.env.EDGEONE_PAGES_API_TOKEN', - edgeone_pages_config['api_token'] - ]) + cmd.extend( + ['--tools.edgeone-pages-mcp.env.EDGEONE_PAGES_API_TOKEN', edgeone_pages_config['api_token']] + ) if edgeone_pages_config.get('project_name'): # If project name is provided, pass it to the MCP server config - cmd.extend([ - '--tools.edgeone-pages-mcp.env.EDGEONE_PAGES_PROJECT_NAME', - edgeone_pages_config['project_name'] - ]) + cmd.extend( + [ + '--tools.edgeone-pages-mcp.env.EDGEONE_PAGES_PROJECT_NAME', + edgeone_pages_config['project_name'], + ] + ) elif project_type == 'script': # Run the script directly @@ -386,9 +338,7 @@ async def _read_output(self): # Check if process has exited if self.process.returncode is not None and not process_exited: process_exited = True - print( - f'[Runner] Process exited with code: {self.process.returncode}' - ) + print(f'[Runner] Process exited with code: {self.process.returncode}') # Continue reading remaining output even after process exits # This ensures we don't miss any URLs or important messages if not self.process.stdout: @@ -405,8 +355,7 @@ async def _read_output(self): try: # Use shorter timeout after process exits to read remaining data faster timeout = 0.1 if process_exited else 1.0 - line = await asyncio.wait_for( - self.process.stdout.readline(), timeout=timeout) + line = await asyncio.wait_for(self.process.stdout.readline(), timeout=timeout) except asyncio.TimeoutError: # Timeout - check if we're waiting for input if self._waiting_for_input: @@ -417,14 +366,14 @@ async def _read_output(self): self._flush_chat_response() # Send waiting_input message to enable frontend input if self.on_output and not self._waiting_input_sent: - self.on_output({ - 'type': 'waiting_input', - 'content': '', - 'role': 'system', - 'metadata': { - 'waiting': True + self.on_output( + { + 'type': 'waiting_input', + 'content': '', + 'role': 'system', + 'metadata': {'waiting': True}, } - }) + ) self._waiting_input_sent = True # Process is still alive, continue waiting continue @@ -463,26 +412,21 @@ async def _read_output(self): self._flush_chat_response() # Send waiting_input message to enable frontend input if self.on_output and not self._waiting_input_sent: - self.on_output({ - 'type': 'waiting_input', - 'content': '', - 'role': 'system', - 'metadata': { - 'waiting': True + self.on_output( + { + 'type': 'waiting_input', + 'content': '', + 'role': 'system', + 'metadata': {'waiting': True}, } - }) + ) self._waiting_input_sent = True - print( - '[Runner] Agent is waiting for user input, keeping process alive...' - ) + print('[Runner] Agent is waiting for user input, keeping process alive...') # Keep process alive and wait for input - await asyncio.sleep( - 0.5) # Small delay to avoid busy waiting + await asyncio.sleep(0.5) # Small delay to avoid busy waiting continue else: - print( - '[Runner] Process exited while waiting for input' - ) + print('[Runner] Process exited while waiting for input') # Process exited, but continue reading any remaining output # Don't break yet - there might be more data in stdout buffer process_exited = True @@ -493,13 +437,13 @@ async def _read_output(self): # Reset empty line count when we get actual data empty_line_count = 0 text = line.decode('utf-8', errors='replace').rstrip() - print(f'[Runner] Output: {text[:200]}' - if len(text) > 200 else f'[Runner] Output: {text}') + print(f'[Runner] Output: {text[:200]}' if len(text) > 200 else f'[Runner] Output: {text}') try: await self._process_line(text) except Exception as e: print(f'[Runner] ERROR processing line: {e}') import traceback + traceback.print_exc() # Wait for process to complete and handle completion @@ -516,50 +460,46 @@ async def _read_output(self): self._flush_chat_response() # Flush any accumulated assistant output before handling completion - if self._collecting_assistant_output and self._accumulated_output.strip( - ): - cleaned = re.sub(r'\[INFO:ms_agent\]\s*', '', - self._accumulated_output.strip()) + if self._collecting_assistant_output and self._accumulated_output.strip(): + cleaned = re.sub(r'\[INFO:ms_agent\]\s*', '', self._accumulated_output.strip()) cleaned = re.sub(r'\[([^\]]+)\]\s*', '', cleaned, count=1) - print( - f'[Runner] Flushing accumulated output on process exit: {cleaned[:200]}...' - ) + print(f'[Runner] Flushing accumulated output on process exit: {cleaned[:200]}...') if cleaned and self.on_output: - self.on_output({ - 'type': 'agent_output', - 'content': cleaned, - 'role': 'assistant', - 'metadata': { - 'agent': self._current_step or 'agent' + self.on_output( + { + 'type': 'agent_output', + 'content': cleaned, + 'role': 'assistant', + 'metadata': {'agent': self._current_step or 'agent'}, } - }) + ) self._accumulated_output = '' self._collecting_assistant_output = False # If stop was requested, do not report as completion/error if self._stop_requested: if self.on_log: - self.on_log({ - 'level': 'info', - 'message': 'Agent stopped by user', - 'timestamp': datetime.now().isoformat() - }) + self.on_log( + { + 'level': 'info', + 'message': 'Agent stopped by user', + 'timestamp': datetime.now().isoformat(), + } + ) return # Complete current step if any before handling exit if self._current_step and self.on_output: - self.on_output({ - 'type': 'step_complete', - 'content': self._current_step, - 'role': 'assistant', - 'metadata': { - 'step': self._current_step, - 'status': 'completed' + self.on_output( + { + 'type': 'step_complete', + 'content': self._current_step, + 'role': 'assistant', + 'metadata': {'step': self._current_step, 'status': 'completed'}, } - }) + ) # If Refine step completes successfully, it should be waiting for input - if return_code == 0 and self._current_step.lower( - ) == 'refine': + if return_code == 0 and self._current_step.lower() == 'refine': self._waiting_for_input = True self._current_step = None @@ -570,57 +510,48 @@ async def _read_output(self): if return_code == 0: # Send waiting_input message if not already sent if self.on_output and not self._waiting_input_sent: - self.on_output({ - 'type': - 'waiting_input', - 'content': - ('✅ Initial refinement completed. ' - 'You can now provide additional feedback or modifications.' - ), - 'role': - 'system', - 'metadata': { - 'waiting': True + self.on_output( + { + 'type': 'waiting_input', + 'content': ( + '✅ Initial refinement completed. ' + 'You can now provide additional feedback or modifications.' + ), + 'role': 'system', + 'metadata': {'waiting': True}, } - }) + ) self._waiting_input_sent = True if self.on_complete: - self.on_complete({ - 'status': - 'success', - 'message': - 'Agent completed successfully' - }) + self.on_complete({'status': 'success', 'message': 'Agent completed successfully'}) else: if self.on_error: - self.on_error({ - 'message': - ('Agent process terminated while waiting for input. ' - f'Exit code: {return_code}'), - 'type': - 'process_exit_error', - 'code': - return_code - }) + self.on_error( + { + 'message': ( + f'Agent process terminated while waiting for input. Exit code: {return_code}' + ), + 'type': 'process_exit_error', + 'code': return_code, + } + ) elif return_code == 0: if self.on_complete: - self.on_complete({ - 'status': - 'success', - 'message': - 'Agent completed successfully' - }) + self.on_complete({'status': 'success', 'message': 'Agent completed successfully'}) else: if self.on_error: - self.on_error({ - 'message': f'Agent exited with code {return_code}', - 'type': 'exit_error', - 'code': return_code - }) + self.on_error( + { + 'message': f'Agent exited with code {return_code}', + 'type': 'exit_error', + 'code': return_code, + } + ) except Exception as e: print(f'[Runner] Read error: {e}') import traceback + traceback.print_exc() if not self._stop_requested and self.on_error: self.on_error({'message': str(e), 'type': 'read_error'}) @@ -659,8 +590,7 @@ async def _process_chat_line(self, line: str): else: self._tool_call_json_buffer = cleaned # Check if we have a complete JSON object - if cleaned == '}' and self._tool_call_json_buffer.strip( - ).startswith('{'): + if cleaned == '}' and self._tool_call_json_buffer.strip().startswith('{'): self._flush_tool_call() return @@ -668,14 +598,9 @@ async def _process_chat_line(self, line: str): if 'execute tool call' in line: if self.on_output: is_error = 'error' in line.lower() - self.on_output({ - 'type': 'tool_result', - 'content': cleaned, - 'role': 'assistant', - 'metadata': { - 'is_error': is_error - } - }) + self.on_output( + {'type': 'tool_result', 'content': cleaned, 'role': 'assistant', 'metadata': {'is_error': is_error}} + ) return # Detect [assistant]: marker - start collecting @@ -704,23 +629,25 @@ async def _process_chat_line(self, line: str): def _flush_tool_call(self): """Send tool call information to frontend""" - if self._is_chat_mode and self._tool_call_json_buffer.strip( - ) and self.on_output: + if self._is_chat_mode and self._tool_call_json_buffer.strip() and self.on_output: try: import json + tool_data = json.loads(self._tool_call_json_buffer) tool_name = tool_data.get('tool_name', 'unknown') print(f'[Runner] Tool call: {tool_name}') - self.on_output({ - 'type': 'tool_call', - 'content': '', - 'role': 'assistant', - 'metadata': { - 'tool_name': tool_name, - 'arguments': tool_data.get('arguments', {}), - 'id': tool_data.get('id', '') + self.on_output( + { + 'type': 'tool_call', + 'content': '', + 'role': 'assistant', + 'metadata': { + 'tool_name': tool_name, + 'arguments': tool_data.get('arguments', {}), + 'id': tool_data.get('id', ''), + }, } - }) + ) except json.JSONDecodeError: print('[Runner] Failed to parse tool call JSON') self._tool_call_json_buffer = '' @@ -728,17 +655,11 @@ def _flush_tool_call(self): def _flush_chat_response(self): """Send final chat response with done=True""" - if self._is_chat_mode and self._chat_response_buffer.strip( - ) and self.on_output: - print( - f'[Runner] Chat complete: {len(self._chat_response_buffer)} chars' + if self._is_chat_mode and self._chat_response_buffer.strip() and self.on_output: + print(f'[Runner] Chat complete: {len(self._chat_response_buffer)} chars') + self.on_output( + {'type': 'stream', 'content': self._chat_response_buffer.strip(), 'role': 'assistant', 'done': True} ) - self.on_output({ - 'type': 'stream', - 'content': self._chat_response_buffer.strip(), - 'role': 'assistant', - 'done': True - }) self._chat_response_buffer = '' # Don't reset _collecting_assistant_output here - more content may come # It will be reset when we see [tool_calling]: or [user]: or process exits @@ -759,6 +680,7 @@ async def _process_line(self, line: str): if '[INFO:ms_agent]' in line: # Check if there's an agent name tag [xxx] after [INFO:ms_agent] import re + if not re.search(r'\[INFO:ms_agent\]\s*\[([^\]]+)\]', line): return @@ -766,14 +688,13 @@ async def _process_line(self, line: str): if self.on_log: log_level = self._detect_log_level(line) cleaned_message = self._clean_log_prefix(line) - await self.on_log({ - 'level': - log_level, - 'message': - cleaned_message if cleaned_message else line, - 'timestamp': - datetime.now().isoformat() - }) + await self.on_log( + { + 'level': log_level, + 'message': cleaned_message if cleaned_message else line, + 'timestamp': datetime.now().isoformat(), + } + ) # Parse for special patterns (use original line for pattern matching) await self._detect_patterns(line) @@ -822,25 +743,23 @@ def _scan_and_send_output_files(self, programmer_step=None): file_path = line.split(':')[0].strip() generated_files.append(file_path) - print( - f'[Runner] Found {len(generated_files)} files in tasks.txt: {generated_files}' - ) + print(f'[Runner] Found {len(generated_files)} files in tasks.txt: {generated_files}') # Send all files in one batch if generated_files and self.on_output: - self.on_output({ - 'type': 'file_output', - 'content': generated_files, # Send as array - 'role': 'assistant', - 'metadata': { - 'files': generated_files, - 'source': 'tasks.txt' + self.on_output( + { + 'type': 'file_output', + 'content': generated_files, # Send as array + 'role': 'assistant', + 'metadata': {'files': generated_files, 'source': 'tasks.txt'}, } - }) + ) except Exception as e: print(f'[Runner] Error reading tasks.txt: {e}') import traceback + traceback.print_exc() async def _detect_patterns(self, line: str): @@ -852,40 +771,34 @@ async def _detect_patterns(self, line: str): url_match = re.search(r'"url":\s*"(https?://[^"]+)"', line) # Pattern 2: Direct URL like "https://mcp.edgeone.site/share/..." if not url_match: - url_match = re.search(r'(https?://mcp\.edgeone\.site/[^\s]+)', - line) + url_match = re.search(r'(https?://mcp\.edgeone\.site/[^\s]+)', line) # Pattern 3: EdgeOne Pages URL like "https://...edgeone.cool?..." # BUT skip if this is a curl command line (testing command, not actual deployment URL) if not url_match and 'curl -s' not in line and 'curl ' not in line: - url_match = re.search(r'(https?://[^\s]*edgeone\.cool[^\s]*)', - line) + url_match = re.search(r'(https?://[^\s]*edgeone\.cool[^\s]*)', line) # Pattern 4: Also check for edgeone.site URLs in any format (fallback) # BUT skip if this is a curl command line if not url_match and 'curl -s' not in line and 'curl ' not in line: - url_match = re.search(r'(https?://[^\s]*edgeone\.site[^\s]*)', - line) + url_match = re.search(r'(https?://[^\s]*edgeone\.site[^\s]*)', line) if url_match: deployment_url = url_match.group(1) # Clean up escaped characters in URL (e.g., \& -> &) deployment_url = deployment_url.replace('\\&', '&') - print( - f'[Runner] Detected deployment URL (early): {deployment_url} from line: {line[:100]}' - ) + print(f'[Runner] Detected deployment URL (early): {deployment_url} from line: {line[:100]}') if self.on_output: - self.on_output({ - 'type': 'deployment_url', - 'content': deployment_url, - 'role': 'assistant', - 'metadata': { - 'url': deployment_url + self.on_output( + { + 'type': 'deployment_url', + 'content': deployment_url, + 'role': 'assistant', + 'metadata': {'url': deployment_url}, } - }) + ) # Continue processing - don't return yet, other patterns might also match # Detect OpenAI API errors and other API errors # Check for OpenAI error patterns - if 'openai.' in line.lower() and ('error' in line.lower() - or 'Error' in line): + if 'openai.' in line.lower() and ('error' in line.lower() or 'Error' in line): error_message = line.strip() # Try to extract error details from the line # Pattern: openai.NotFoundError: Error code: 404 - {'error': {'message': '...', ...}} @@ -893,12 +806,11 @@ async def _detect_patterns(self, line: str): if json_match: try: import json + error_data = json.loads(json_match.group(0)) - if 'error' in error_data and 'message' in error_data[ - 'error']: + if 'error' in error_data and 'message' in error_data['error']: error_msg = error_data['error']['message'] - error_type = error_data['error'].get( - 'type', 'API Error') + error_type = error_data['error'].get('type', 'API Error') error_message = f'**{error_type}**: {error_msg}' except Exception: pass @@ -908,14 +820,14 @@ async def _detect_patterns(self, line: str): self.on_error({'message': error_message, 'type': 'api_error'}) # Also send as output message so it appears in the conversation if self.on_output: - self.on_output({ - 'type': 'error', - 'content': error_message, - 'role': 'system', - 'metadata': { - 'error_type': 'api_error' + self.on_output( + { + 'type': 'error', + 'content': error_message, + 'role': 'system', + 'metadata': {'error_type': 'api_error'}, } - }) + ) return # Detect other error patterns @@ -932,78 +844,72 @@ async def _detect_patterns(self, line: str): if json_match: try: import json + error_data = json.loads(json_match.group(0)) - if 'error' in error_data and 'message' in error_data[ - 'error']: + if 'error' in error_data and 'message' in error_data['error']: error_msg = error_data['error']['message'] - error_type = error_data['error'].get( - 'type', 'API Error') + error_type = error_data['error'].get('type', 'API Error') error_message = f'**{error_type}**: {error_msg}' except Exception: pass print(f'[Runner] Detected API error: {error_message}') if self.on_error: - self.on_error({ - 'message': - error_message, - 'type': - 'api_error', - 'code': - error_match.group(1) if error_match.groups() else None - }) + self.on_error( + { + 'message': error_message, + 'type': 'api_error', + 'code': error_match.group(1) if error_match.groups() else None, + } + ) # Also send as output message so it appears in the conversation if self.on_output: - self.on_output({ - 'type': 'error', - 'content': error_message, - 'role': 'system', - 'metadata': { - 'error_type': 'api_error' + self.on_output( + { + 'type': 'error', + 'content': error_message, + 'role': 'system', + 'metadata': {'error_type': 'api_error'}, } - }) + ) return # Detect workflow step beginning: "[tag] Agent tag task beginning." - begin_match = re.search( - r'\[([^\]]+)\]\s*Agent\s+\S+\s+task\s+beginning', line) + begin_match = re.search(r'\[([^\]]+)\]\s*Agent\s+\S+\s+task\s+beginning', line) if begin_match: step_name = begin_match.group(1) # Skip sub-steps and programmer agents (handled separately) - if (('-r' in step_name and '-' in step_name.split('-r')[-1]) - or step_name.startswith('programmer-')): + if ('-r' in step_name and '-' in step_name.split('-r')[-1]) or step_name.startswith('programmer-'): return print(f'[Runner] Step beginning: {step_name}') # Flush previous step if exists if self._current_step and self._accumulated_output.strip(): - cleaned = re.sub(r'\[INFO:ms_agent\]\s*', '', - self._accumulated_output.strip()) + cleaned = re.sub(r'\[INFO:ms_agent\]\s*', '', self._accumulated_output.strip()) cleaned = re.sub(r'\[([^\]]+)\]\s*', '', cleaned, count=1) if cleaned and self.on_output: - self.on_output({ - 'type': 'agent_output', - 'content': cleaned, - 'role': 'assistant', - 'metadata': { - 'agent': self._current_step + self.on_output( + { + 'type': 'agent_output', + 'content': cleaned, + 'role': 'assistant', + 'metadata': {'agent': self._current_step}, } - }) + ) self._accumulated_output = '' self._collecting_assistant_output = False if self._current_step and self.on_output: - self.on_output({ - 'type': 'step_complete', - 'content': self._current_step, - 'role': 'assistant', - 'metadata': { - 'step': self._current_step, - 'status': 'completed' + self.on_output( + { + 'type': 'step_complete', + 'content': self._current_step, + 'role': 'assistant', + 'metadata': {'step': self._current_step, 'status': 'completed'}, } - }) + ) # Start new step self._current_step = step_name @@ -1011,29 +917,35 @@ async def _detect_patterns(self, line: str): self._workflow_steps.append(step_name) step_status = { - s: ('completed' if i < self._workflow_steps.index(step_name) - else 'running' if s == step_name else 'pending') + s: ( + 'completed' + if i < self._workflow_steps.index(step_name) + else 'running' + if s == step_name + else 'pending' + ) for i, s in enumerate(self._workflow_steps) } if self.on_progress: - self.on_progress({ - 'type': 'workflow', - 'current_step': step_name, - 'steps': self._workflow_steps.copy(), - 'step_status': step_status - }) + self.on_progress( + { + 'type': 'workflow', + 'current_step': step_name, + 'steps': self._workflow_steps.copy(), + 'step_status': step_status, + } + ) if self.on_output: - self.on_output({ - 'type': 'step_start', - 'content': step_name, - 'role': 'assistant', - 'metadata': { - 'step': step_name, - 'status': 'running' + self.on_output( + { + 'type': 'step_start', + 'content': step_name, + 'role': 'assistant', + 'metadata': {'step': step_name, 'status': 'running'}, } - }) + ) # If Refine step is starting, scan tasks.txt for all generated files # This ensures files are detected after Coding phase completes @@ -1048,40 +960,35 @@ async def _detect_patterns(self, line: str): programmer_agent = f'programmer-{programmer_match.group(1)}' # If this is FIRST programmer agent, trigger coding step start - if not self._current_step or not self._current_step.startswith( - 'programmer-'): - print( - f'[Runner] First programmer agent detected: {programmer_agent} - starting coding step' - ) + if not self._current_step or not self._current_step.startswith('programmer-'): + print(f'[Runner] First programmer agent detected: {programmer_agent} - starting coding step') # Flush previous step's output if self._current_step and self._accumulated_output.strip(): - cleaned = re.sub(r'\[INFO:ms_agent\]\s*', '', - self._accumulated_output.strip()) + cleaned = re.sub(r'\[INFO:ms_agent\]\s*', '', self._accumulated_output.strip()) cleaned = re.sub(r'\[([^\]]+)\]\s*', '', cleaned, count=1) if cleaned and self.on_output: - self.on_output({ - 'type': 'agent_output', - 'content': cleaned, - 'role': 'assistant', - 'metadata': { - 'agent': self._current_step + self.on_output( + { + 'type': 'agent_output', + 'content': cleaned, + 'role': 'assistant', + 'metadata': {'agent': self._current_step}, } - }) + ) self._accumulated_output = '' self._collecting_assistant_output = False # Mark previous step complete if self._current_step and self.on_output: - self.on_output({ - 'type': 'step_complete', - 'content': self._current_step, - 'role': 'assistant', - 'metadata': { - 'step': self._current_step, - 'status': 'completed' + self.on_output( + { + 'type': 'step_complete', + 'content': self._current_step, + 'role': 'assistant', + 'metadata': {'step': self._current_step, 'status': 'completed'}, } - }) + ) # Start coding step self._current_step = programmer_agent @@ -1089,29 +996,35 @@ async def _detect_patterns(self, line: str): self._workflow_steps.append('coding') step_status = { - s: ('completed' if i < self._workflow_steps.index('coding') - else 'running' if s == 'coding' else 'pending') + s: ( + 'completed' + if i < self._workflow_steps.index('coding') + else 'running' + if s == 'coding' + else 'pending' + ) for i, s in enumerate(self._workflow_steps) } if self.on_progress: - self.on_progress({ - 'type': 'workflow', - 'current_step': 'coding', - 'steps': self._workflow_steps.copy(), - 'step_status': step_status - }) + self.on_progress( + { + 'type': 'workflow', + 'current_step': 'coding', + 'steps': self._workflow_steps.copy(), + 'step_status': step_status, + } + ) if self.on_output: - self.on_output({ - 'type': 'step_start', - 'content': 'coding', - 'role': 'assistant', - 'metadata': { - 'step': 'coding', - 'status': 'running' + self.on_output( + { + 'type': 'step_start', + 'content': 'coding', + 'role': 'assistant', + 'metadata': {'step': 'coding', 'status': 'running'}, } - }) + ) # Update current programmer agent elif programmer_agent != self._current_step: @@ -1119,23 +1032,21 @@ async def _detect_patterns(self, line: str): # Helper to flush accumulated assistant output def flush_accumulated_output(): - print(f'[Runner] flush_accumulated_output called: ' - f'collecting={self._collecting_assistant_output}, ' - f'buffer_len={len(self._accumulated_output)}') + print( + f'[Runner] flush_accumulated_output called: ' + f'collecting={self._collecting_assistant_output}, ' + f'buffer_len={len(self._accumulated_output)}' + ) print( f'[Runner] Buffer content: {self._accumulated_output[:200]}...' - if len(self._accumulated_output) > 200 else - f'[Runner] Buffer content: {self._accumulated_output}') - if self._collecting_assistant_output and self._accumulated_output.strip( - ): + if len(self._accumulated_output) > 200 + else f'[Runner] Buffer content: {self._accumulated_output}' + ) + if self._collecting_assistant_output and self._accumulated_output.strip(): # Clean log prefixes - cleaned_content = re.sub(r'\[INFO:ms_agent\]\s*', '', - self._accumulated_output.strip()) - cleaned_content = re.sub( - r'\[([^\]]+)\]\s*', '', cleaned_content, count=1) - print( - f'[Runner] Flushing assistant output: {cleaned_content[:100]}...' - ) + cleaned_content = re.sub(r'\[INFO:ms_agent\]\s*', '', self._accumulated_output.strip()) + cleaned_content = re.sub(r'\[([^\]]+)\]\s*', '', cleaned_content, count=1) + print(f'[Runner] Flushing assistant output: {cleaned_content[:100]}...') # Map agent name for display agent_name = self._current_step or 'agent' @@ -1144,30 +1055,30 @@ def flush_accumulated_output(): display_agent = 'coding' if cleaned_content and self.on_output: - self.on_output({ - 'type': 'agent_output', - 'content': cleaned_content, - 'role': 'assistant', - 'metadata': { - 'agent': display_agent + self.on_output( + { + 'type': 'agent_output', + 'content': cleaned_content, + 'role': 'assistant', + 'metadata': {'agent': display_agent}, } - }) + ) self._accumulated_output = '' self._collecting_assistant_output = False else: - print(f'[Runner] flush_accumulated_output skipped: ' - f'collecting={self._collecting_assistant_output}, ' - f'has_content={bool(self._accumulated_output.strip())}') + print( + f'[Runner] flush_accumulated_output skipped: ' + f'collecting={self._collecting_assistant_output}, ' + f'has_content={bool(self._accumulated_output.strip())}' + ) # Detect workflow step finished: "[tag] Agent tag task finished." - end_match = re.search(r'\[([^\]]+)\]\s*Agent\s+\S+\s+task\s+finished', - line) + end_match = re.search(r'\[([^\]]+)\]\s*Agent\s+\S+\s+task\s+finished', line) if end_match: step_name = end_match.group(1) # Skip install (handled by programmer detection) and sub-steps - if step_name == 'install' or ('-r' in step_name and '-' - in step_name.split('-r')[-1]): + if step_name == 'install' or ('-r' in step_name and '-' in step_name.split('-r')[-1]): return # Skip flush for refine (already flushed during collection) @@ -1200,28 +1111,30 @@ def flush_accumulated_output(): # Build step status dict - all steps up to current are completed step_status = {} for s in self._workflow_steps: - step_status[s] = 'completed' if self._workflow_steps.index( - s) <= self._workflow_steps.index(step_name) else 'pending' + step_status[s] = ( + 'completed' if self._workflow_steps.index(s) <= self._workflow_steps.index(step_name) else 'pending' + ) if self.on_progress: - self.on_progress({ - 'type': 'workflow', - 'current_step': step_name, - 'steps': self._workflow_steps.copy(), - 'step_status': step_status - }) + self.on_progress( + { + 'type': 'workflow', + 'current_step': step_name, + 'steps': self._workflow_steps.copy(), + 'step_status': step_status, + } + ) # Send step complete message if self.on_output: - self.on_output({ - 'type': 'step_complete', - 'content': step_name, - 'role': 'assistant', - 'metadata': { - 'step': step_name, - 'status': 'completed' + self.on_output( + { + 'type': 'step_complete', + 'content': step_name, + 'role': 'assistant', + 'metadata': {'step': step_name, 'status': 'completed'}, } - }) + ) # Clear current step since it's completed self._current_step = None @@ -1230,8 +1143,7 @@ def flush_accumulated_output(): # Clean log prefixes from line # Detect assistant output: "[tag] [assistant]:" if '[assistant]:' in line: - in_coding = self._current_step and self._current_step.startswith( - 'programmer-') + in_coding = self._current_step and self._current_step.startswith('programmer-') if not in_coding: # Start collecting (don't send first line immediately) @@ -1252,27 +1164,24 @@ def flush_accumulated_output(): # Continue collecting assistant output elif self._collecting_assistant_output: # Skip if in coding phase - if self._current_step and self._current_step.startswith( - 'programmer-'): + if self._current_step and self._current_step.startswith('programmer-'): self._collecting_assistant_output = False self._accumulated_output = '' # Don't return - continue processing else: # Check if new pattern starts - if '[tool_calling]:' in line or ('[assistant]:' in line - and 'Agent' not in line): + if '[tool_calling]:' in line or ('[assistant]:' in line and 'Agent' not in line): if self._accumulated_output.strip(): - cleaned = self._clean_log_prefix( - self._accumulated_output.strip()) + cleaned = self._clean_log_prefix(self._accumulated_output.strip()) if cleaned and self.on_output: - self.on_output({ - 'type': 'agent_output', - 'content': cleaned, - 'role': 'assistant', - 'metadata': { - 'agent': self._current_step or 'agent' + self.on_output( + { + 'type': 'agent_output', + 'content': cleaned, + 'role': 'assistant', + 'metadata': {'agent': self._current_step or 'agent'}, } - }) + ) self._accumulated_output = '' self._collecting_assistant_output = False else: @@ -1282,46 +1191,37 @@ def flush_accumulated_output(): if cleaned_line: self._accumulated_output += cleaned_line + '\n' # Check for EdgeOne deployment URL in this line - url_match = re.search( - r'(https?://[^\s]*edgeone\.cool[^\s]*)', - cleaned_line) + url_match = re.search(r'(https?://[^\s]*edgeone\.cool[^\s]*)', cleaned_line) if url_match: deployment_url = url_match.group(1) # Clean up escaped characters in URL (e.g., \& -> &) - deployment_url = deployment_url.replace( - '\\&', '&') - print( - f'[Runner] Detected deployment URL in assistant: {deployment_url}' - ) + deployment_url = deployment_url.replace('\\&', '&') + print(f'[Runner] Detected deployment URL in assistant: {deployment_url}') if self.on_output: - self.on_output({ - 'type': 'deployment_url', - 'content': deployment_url, - 'role': 'assistant', - 'metadata': { - 'url': deployment_url + self.on_output( + { + 'type': 'deployment_url', + 'content': deployment_url, + 'role': 'assistant', + 'metadata': {'url': deployment_url}, } - }) + ) # Check for waiting for input pattern - if ('Waiting for user feedback' in line - or 'Waiting for user input from stdin' - in line): + if 'Waiting for user feedback' in line or 'Waiting for user input from stdin' in line: print('[Runner] Agent waiting for user input') self._waiting_for_input = True if self.on_output and not self._waiting_input_sent: - self.on_output({ - 'type': - 'waiting_input', - 'content': - ('✅ Initial refinement completed. ' - 'You can now provide additional feedback or modifications.' - ), - 'role': - 'system', - 'metadata': { - 'waiting': True + self.on_output( + { + 'type': 'waiting_input', + 'content': ( + '✅ Initial refinement completed. ' + 'You can now provide additional feedback or modifications.' + ), + 'role': 'system', + 'metadata': {'waiting': True}, } - }) + ) self._waiting_input_sent = True return @@ -1340,8 +1240,7 @@ def flush_accumulated_output(): self._tool_call_json_buffer = json_part elif json_part: # Try to extract tool name directly if it's not JSON format - tool_match = re.search(r'([\w\-]+(?:---[\w\-]+)?)', - json_part) + tool_match = re.search(r'([\w\-]+(?:---[\w\-]+)?)', json_part) if tool_match: self._current_tool_name = tool_match.group(1) return @@ -1351,8 +1250,7 @@ def flush_accumulated_output(): # Extract agent name from line if available (for better matching) agent_name_from_line = None if '[INFO:ms_agent]' in line: - agent_match = re.search(r'\[INFO:ms_agent\]\s*\[([^\]]+)\]', - line) + agent_match = re.search(r'\[INFO:ms_agent\]\s*\[([^\]]+)\]', line) if agent_match: agent_name_from_line = agent_match.group(1) @@ -1380,49 +1278,43 @@ def flush_accumulated_output(): self._tool_call_json_buffer += cleaned_line # Only try to parse when buffer contains tool_name and ends with } - if (self._tool_call_json_buffer - and '"tool_name"' in self._tool_call_json_buffer - and self._tool_call_json_buffer.strip().endswith('}')): + if ( + self._tool_call_json_buffer + and '"tool_name"' in self._tool_call_json_buffer + and self._tool_call_json_buffer.strip().endswith('}') + ): try: import json + tool_info = json.loads(self._tool_call_json_buffer) print('[Runner] Parsed tool JSON successfully') - tool_name = tool_info.get('tool_name') or tool_info.get( - 'name', 'unknown') + tool_name = tool_info.get('tool_name') or tool_info.get('name', 'unknown') tool_args = tool_info.get('arguments', {}) print(f'[Runner] Extracted tool_name: {tool_name}') if tool_name and tool_name != 'unknown': self._current_tool_name = tool_name self._current_tool_args = tool_args agent_name = agent_name_from_line or self._current_step or 'agent' - print( - f'[Runner] Sending tool call: {tool_name}, agent: {agent_name}' - ) + print(f'[Runner] Sending tool call: {tool_name}, agent: {agent_name}') if self.on_output: - self.on_output({ - 'type': 'tool_call', - 'content': f'调用工具: {tool_name}', - 'role': 'assistant', - 'metadata': { - 'tool_name': tool_name, - 'tool_args': tool_args, - 'agent': agent_name + self.on_output( + { + 'type': 'tool_call', + 'content': f'调用工具: {tool_name}', + 'role': 'assistant', + 'metadata': {'tool_name': tool_name, 'tool_args': tool_args, 'agent': agent_name}, } - }) + ) # Clear buffer but KEEP collecting - there may be more tool calls self._tool_call_json_buffer = '' # Don't return or stop collecting - next line might be another tool call JSON else: - print( - f'[Runner] WARNING: Invalid tool_name: {tool_name}' - ) + print(f'[Runner] WARNING: Invalid tool_name: {tool_name}') except json.JSONDecodeError as e: # JSON not complete yet, keep collecting # Only log if we have tool_name - helps debug parsing issues if '"tool_name"' in self._tool_call_json_buffer: - print( - f'[Runner] JSON incomplete, continuing... (error: {str(e)[:50]})' - ) + print(f'[Runner] JSON incomplete, continuing... (error: {str(e)[:50]})') except Exception as e: print(f'[Runner] Error parsing tool JSON: {e}') @@ -1430,23 +1322,18 @@ def flush_accumulated_output(): if '[assistant]:' in line or 'Agent' in line and 'task' in line or '[tool_result]:' in line: # If we have partial data, try to send it if self._tool_call_json_buffer: - tool_name_match = re.search(r'"tool_name"\s*:\s*"([^"]+)"', - self._tool_call_json_buffer) + tool_name_match = re.search(r'"tool_name"\s*:\s*"([^"]+)"', self._tool_call_json_buffer) if tool_name_match: tool_name = tool_name_match.group(1) # Try to extract arguments - handle nested JSON objects - args_start = self._tool_call_json_buffer.find( - '"arguments"') + args_start = self._tool_call_json_buffer.find('"arguments"') tool_args = {} if args_start != -1: - brace_start = self._tool_call_json_buffer.find( - '{', args_start) + brace_start = self._tool_call_json_buffer.find('{', args_start) if brace_start != -1: brace_count = 0 brace_end = brace_start - for i in range( - brace_start, - len(self._tool_call_json_buffer)): + for i in range(brace_start, len(self._tool_call_json_buffer)): if self._tool_call_json_buffer[i] == '{': brace_count += 1 elif self._tool_call_json_buffer[i] == '}': @@ -1456,8 +1343,7 @@ def flush_accumulated_output(): break if brace_end > brace_start: - args_str = self._tool_call_json_buffer[ - brace_start:brace_end] + args_str = self._tool_call_json_buffer[brace_start:brace_end] try: tool_args = json.loads(args_str) except Exception: @@ -1470,16 +1356,14 @@ def flush_accumulated_output(): f'{tool_name}, agent: {agent_name}, args: {tool_args}' ) if self.on_output: - self.on_output({ - 'type': 'tool_call', - 'content': f'调用工具: {tool_name}', - 'role': 'assistant', - 'metadata': { - 'tool_name': tool_name, - 'tool_args': tool_args, - 'agent': agent_name + self.on_output( + { + 'type': 'tool_call', + 'content': f'调用工具: {tool_name}', + 'role': 'assistant', + 'metadata': {'tool_name': tool_name, 'tool_args': tool_args, 'agent': agent_name}, } - }) + ) self._collecting_tool_call = False self._tool_call_json_buffer = '' return @@ -1495,16 +1379,18 @@ def flush_accumulated_output(): self._current_tool_result = result_content # Send tool result immediately if we have tool name if self._current_tool_name and self.on_output: - self.on_output({ - 'type': 'tool_result', - 'content': f'工具 {self._current_tool_name} 执行完成', - 'role': 'assistant', - 'metadata': { - 'tool_name': self._current_tool_name, - 'tool_result': result_content, - 'agent': self._current_step or 'agent' + self.on_output( + { + 'type': 'tool_result', + 'content': f'工具 {self._current_tool_name} 执行完成', + 'role': 'assistant', + 'metadata': { + 'tool_name': self._current_tool_name, + 'tool_result': result_content, + 'agent': self._current_step or 'agent', + }, } - }) + ) # Reset tool info self._current_tool_name = None self._current_tool_result = None @@ -1522,57 +1408,51 @@ def flush_accumulated_output(): # Check for EdgeOne deployment URL in tool result # Pattern 1: JSON format with edgeone.cool or edgeone.site - url_match = re.search( - r'"url":\s*"(https?://[^"]+edgeone\.(cool|site)[^"]+)"', - line) + url_match = re.search(r'"url":\s*"(https?://[^"]+edgeone\.(cool|site)[^"]+)"', line) # Pattern 2: Direct URL with edgeone.cool or edgeone.site if not url_match: - url_match = re.search( - r'(https?://[^\s]*edgeone\.(cool|site)[^\s]*)', line) + url_match = re.search(r'(https?://[^\s]*edgeone\.(cool|site)[^\s]*)', line) if url_match: deployment_url = url_match.group(1) # Clean up escaped characters in URL (e.g., \& -> &) deployment_url = deployment_url.replace('\\&', '&') - print( - f'[Runner] Detected deployment URL in tool result: {deployment_url}' - ) + print(f'[Runner] Detected deployment URL in tool result: {deployment_url}') if self.on_output: - self.on_output({ - 'type': 'deployment_url', - 'content': deployment_url, - 'role': 'assistant', - 'metadata': { - 'url': deployment_url + self.on_output( + { + 'type': 'deployment_url', + 'content': deployment_url, + 'role': 'assistant', + 'metadata': {'url': deployment_url}, } - }) + ) # After deployment success, prompt user for further input self._waiting_for_input = True if not self._waiting_input_sent: - self.on_output({ - 'type': 'waiting_input', - 'content': - 'You can now provide additional feedback or visit the deployed site.', - 'role': 'system', - 'metadata': { - 'waiting': True, - 'deployment_complete': True + self.on_output( + { + 'type': 'waiting_input', + 'content': 'You can now provide additional feedback or visit the deployed site.', + 'role': 'system', + 'metadata': {'waiting': True, 'deployment_complete': True}, } - }) + ) self._waiting_input_sent = True # Send result if we have tool name and accumulated enough content - if self._current_tool_name and len( - self._current_tool_result) > 100 and self.on_output: - self.on_output({ - 'type': 'tool_result', - 'content': f'工具 {self._current_tool_name} 执行完成', - 'role': 'assistant', - 'metadata': { - 'tool_name': self._current_tool_name, - 'tool_result': self._current_tool_result, - 'agent': self._current_step or 'agent' + if self._current_tool_name and len(self._current_tool_result) > 100 and self.on_output: + self.on_output( + { + 'type': 'tool_result', + 'content': f'工具 {self._current_tool_name} 执行完成', + 'role': 'assistant', + 'metadata': { + 'tool_name': self._current_tool_name, + 'tool_result': self._current_tool_result, + 'agent': self._current_step or 'agent', + }, } - }) + ) # Reset self._current_tool_name = None self._current_tool_result = None @@ -1580,68 +1460,52 @@ def flush_accumulated_output(): elif '[assistant]:' in line or '[tool_calling]:' in line or 'Agent' in line and 'task' in line: # Hit a new pattern, send accumulated result if self._current_tool_name and self._current_tool_result and self.on_output: - self.on_output({ - 'type': 'tool_result', - 'content': f'工具 {self._current_tool_name} 执行完成', - 'role': 'assistant', - 'metadata': { - 'tool_name': self._current_tool_name, - 'tool_result': self._current_tool_result, - 'agent': self._current_step or 'agent' + self.on_output( + { + 'type': 'tool_result', + 'content': f'工具 {self._current_tool_name} 执行完成', + 'role': 'assistant', + 'metadata': { + 'tool_name': self._current_tool_name, + 'tool_result': self._current_tool_result, + 'agent': self._current_step or 'agent', + }, } - }) + ) self._current_tool_name = None self._current_tool_result = None self._collecting_tool_result = False return # Detect file writing - file_match = re.search(r'writing file:?\s*["\']?([^\s"\']+)["\']?', - line.lower()) + file_match = re.search(r'writing file:?\s*["\']?([^\s"\']+)["\']?', line.lower()) if not file_match: - file_match = re.search( - r'creating file:?\s*["\']?([^\s"\']+)["\']?', line.lower()) + file_match = re.search(r'creating file:?\s*["\']?([^\s"\']+)["\']?', line.lower()) if file_match and self.on_progress: filename = file_match.group(1) - self.on_progress({ - 'type': 'file', - 'file': filename, - 'status': 'writing' - }) + self.on_progress({'type': 'file', 'file': filename, 'status': 'writing'}) return # Detect file written/created/saved - multiple patterns - file_keywords = [ - 'file created', 'file written', 'file saved', 'saved to:', - 'wrote to', 'generated:', 'output:' - ] + file_keywords = ['file created', 'file written', 'file saved', 'saved to:', 'wrote to', 'generated:', 'output:'] if any(keyword in line.lower() for keyword in file_keywords): # Try to extract filename with extension # More strict pattern: must have a proper filename with extension, not just numbers - file_match = re.search( - r'["\']?([a-zA-Z0-9_\-][^\s"\'\/\[\]]*\.[a-zA-Z0-9]+)["\']?', - line) + file_match = re.search(r'["\']?([a-zA-Z0-9_\-][^\s"\'\/\[\]]*\.[a-zA-Z0-9]+)["\']?', line) if file_match and self.on_progress: filename = file_match.group(1) # Validate filename: must not be just numbers or version numbers like "0.0" - if filename and not re.match(r'^\d+\.\d+$', - filename) and len(filename) > 2: + if filename and not re.match(r'^\d+\.\d+$', filename) and len(filename) > 2: # Strip 'programmer-' prefix from filename if filename.startswith('programmer-'): - filename = filename[len('programmer-'):] + filename = filename[len('programmer-') :] print(f'[Runner] Detected file output: {filename}') # Only send progress update (file_output will be sent from tasks.txt) - self.on_progress({ - 'type': 'file', - 'file': filename, - 'status': 'completed' - }) + self.on_progress({'type': 'file', 'file': filename, 'status': 'completed'}) return # Detect output file paths (e.g., "output/user_story.txt" standalone) - output_path_match = re.search( - r'(?:^|\s)((?:output|projects)/[^\s]+\.[a-zA-Z0-9]+)(?:\s|$)', - line) + output_path_match = re.search(r'(?:^|\s)((?:output|projects)/[^\s]+\.[a-zA-Z0-9]+)(?:\s|$)', line) if output_path_match and self.on_progress: filename = output_path_match.group(1) # Strip 'programmer-' prefix from basename only (not from path) @@ -1649,17 +1513,13 @@ def flush_accumulated_output(): if '/' in filename: parts = filename.rsplit('/', 1) if len(parts) == 2 and parts[1].startswith('programmer-'): - parts[1] = parts[1][len('programmer-'):] + parts[1] = parts[1][len('programmer-') :] filename = '/'.join(parts) elif filename.startswith('programmer-'): - filename = filename[len('programmer-'):] + filename = filename[len('programmer-') :] print(f'[Runner] Detected output path: {filename}') # Only send progress update (file_output will be sent from tasks.txt) - self.on_progress({ - 'type': 'file', - 'file': filename, - 'status': 'completed' - }) + self.on_progress({'type': 'file', 'file': filename, 'status': 'completed'}) return # Deployment URL detection moved to the beginning of _detect_patterns @@ -1669,22 +1529,23 @@ def flush_accumulated_output(): # Pattern: "✅ Initial refinement completed. You can now provide..." # Also detect: "Agent completed initial refinement. Waiting for user feedback." # Also detect: "Waiting for user input from stdin..." - if ('Initial refinement completed' in line - or 'provide additional feedback' in line - or 'Waiting for user feedback' in line - or 'Agent completed initial refinement' in line - or 'Waiting for user input from stdin' in line): + if ( + 'Initial refinement completed' in line + or 'provide additional feedback' in line + or 'Waiting for user feedback' in line + or 'Agent completed initial refinement' in line + or 'Waiting for user input from stdin' in line + ): print('[Runner] Agent waiting for user input') self._waiting_for_input = True # Mark that agent is waiting for input if self.on_output and not self._waiting_input_sent: - self.on_output({ - 'type': 'waiting_input', - 'content': - '✅ Initial refinement completed. You can now provide additional feedback or modifications.', - 'role': 'system', - 'metadata': { - 'waiting': True + self.on_output( + { + 'type': 'waiting_input', + 'content': '✅ Initial refinement completed. You can now provide additional feedback or modifications.', + 'role': 'system', + 'metadata': {'waiting': True}, } - }) + ) self._waiting_input_sent = True return diff --git a/webui/backend/api.py b/webui/backend/api.py index 126ac4c02..d0664ad9b 100644 --- a/webui/backend/api.py +++ b/webui/backend/api.py @@ -2,6 +2,7 @@ """ API endpoints for the MS-Agent Web UI """ + import mimetypes import os from pathlib import Path @@ -10,6 +11,7 @@ from fastapi import APIRouter, HTTPException, Query from fastapi.responses import FileResponse from pydantic import BaseModel, Field + # Import shared instances from shared import config_manager, project_discovery, session_manager @@ -17,8 +19,7 @@ def get_backend_root() -> Path: - return Path(__file__).resolve().parents[ - 1] # equal to dirname(dirname(__file__)) + return Path(__file__).resolve().parents[1] # equal to dirname(dirname(__file__)) def get_session_root(session_id: str) -> Path: @@ -46,8 +47,7 @@ class ProjectInfo(BaseModel): class SessionCreate(BaseModel): project_id: Optional[str] = None # Optional for chat mode query: Optional[str] = None - workflow_type: Optional[ - str] = 'standard' # 'standard' or 'simple' for code_genesis + workflow_type: Optional[str] = 'standard' # 'standard' or 'simple' for code_genesis session_type: Optional[str] = 'project' # 'project' or 'chat' @@ -99,14 +99,10 @@ class DeepResearchSearchConfig(BaseModel): class DeepResearchConfig(BaseModel): - researcher: DeepResearchAgentConfig = Field( - default_factory=DeepResearchAgentConfig) - searcher: DeepResearchAgentConfig = Field( - default_factory=DeepResearchAgentConfig) - reporter: DeepResearchAgentConfig = Field( - default_factory=DeepResearchAgentConfig) - search: DeepResearchSearchConfig = Field( - default_factory=DeepResearchSearchConfig) + researcher: DeepResearchAgentConfig = Field(default_factory=DeepResearchAgentConfig) + searcher: DeepResearchAgentConfig = Field(default_factory=DeepResearchAgentConfig) + reporter: DeepResearchAgentConfig = Field(default_factory=DeepResearchAgentConfig) + search: DeepResearchSearchConfig = Field(default_factory=DeepResearchSearchConfig) class MCPServer(BaseModel): @@ -129,9 +125,7 @@ class GlobalConfig(BaseModel): @router.get('/projects', response_model=List[ProjectInfo]) async def list_projects(): """List all available projects""" - print( - f'project_discovery.discover_projects(): {project_discovery.discover_projects()}' - ) + print(f'project_discovery.discover_projects(): {project_discovery.discover_projects()}') return project_discovery.discover_projects() @@ -154,8 +148,7 @@ async def get_project_readme(project_id: str): @router.get('/projects/{project_id}/workflow') -async def get_project_workflow(project_id: str, - session_id: Optional[str] = None): +async def get_project_workflow(project_id: str, session_id: Optional[str] = None): """Get the workflow configuration for a project If session_id is provided, returns the workflow based on the session's workflow_type. @@ -188,12 +181,12 @@ async def get_project_workflow(project_id: str, try: import yaml + with open(workflow_file, 'r', encoding='utf-8') as f: workflow_data = yaml.safe_load(f) return {'workflow': workflow_data, 'workflow_type': workflow_type} except Exception as e: - raise HTTPException( - status_code=500, detail=f'Error reading workflow file: {str(e)}') + raise HTTPException(status_code=500, detail=f'Error reading workflow file: {str(e)}') # Session Endpoints @@ -204,10 +197,8 @@ async def create_session(session_data: SessionCreate): if session_data.session_type == 'chat': # Create chat session without requiring a project session = session_manager.create_session( - project_id='__chat__', - project_name='Chat Assistant', - workflow_type='standard', - session_type='chat') + project_id='__chat__', project_name='Chat Assistant', workflow_type='standard', session_type='chat' + ) return session # For project mode, validate project exists @@ -219,15 +210,14 @@ async def create_session(session_data: SessionCreate): workflow_type = session_data.workflow_type or 'standard' if project.get('supports_workflow_switch'): if workflow_type not in ['standard', 'simple']: - raise HTTPException( - status_code=400, - detail="workflow_type must be 'standard' or 'simple'") + raise HTTPException(status_code=400, detail="workflow_type must be 'standard' or 'simple'") session = session_manager.create_session( project_id=session_data.project_id, project_name=project['name'], workflow_type=workflow_type, - session_type='project') + session_type='project', + ) return session @@ -265,8 +255,7 @@ async def get_session_messages(session_id: str): @router.get('/sessions/{session_id}/dr_events') -async def get_session_dr_events(session_id: str, - after_id: Optional[int] = Query(None, ge=0)): +async def get_session_dr_events(session_id: str, after_id: Optional[int] = Query(None, ge=0)): """Get deep research event history for a session.""" events = session_manager.list_dr_events(session_id, after_id) if events is None: @@ -369,8 +358,7 @@ async def update_deep_research_config(config: DeepResearchConfig): @router.post('/config/mcp/servers') async def add_mcp_server(server: MCPServer): """Add a new MCP server""" - config_manager.add_mcp_server(server.name, - server.model_dump(exclude={'name'})) + config_manager.add_mcp_server(server.name, server.model_dump(exclude={'name'})) return {'status': 'added'} @@ -392,38 +380,14 @@ async def list_available_models(): { 'provider': 'modelscope', 'model': 'Qwen/Qwen3-235B-A22B-Instruct-2507', - 'display_name': 'Qwen3-235B (Recommended)' - }, - { - 'provider': 'modelscope', - 'model': 'Qwen/Qwen2.5-72B-Instruct', - 'display_name': 'Qwen2.5-72B' - }, - { - 'provider': 'modelscope', - 'model': 'Qwen/Qwen2.5-32B-Instruct', - 'display_name': 'Qwen2.5-32B' - }, - { - 'provider': 'modelscope', - 'model': 'deepseek-ai/DeepSeek-V3', - 'display_name': 'DeepSeek-V3' - }, - { - 'provider': 'openai', - 'model': 'gpt-4o', - 'display_name': 'GPT-4o' - }, - { - 'provider': 'openai', - 'model': 'gpt-4o-mini', - 'display_name': 'GPT-4o Mini' - }, - { - 'provider': 'anthropic', - 'model': 'claude-3-5-sonnet-20241022', - 'display_name': 'Claude 3.5 Sonnet' + 'display_name': 'Qwen3-235B (Recommended)', }, + {'provider': 'modelscope', 'model': 'Qwen/Qwen2.5-72B-Instruct', 'display_name': 'Qwen2.5-72B'}, + {'provider': 'modelscope', 'model': 'Qwen/Qwen2.5-32B-Instruct', 'display_name': 'Qwen2.5-32B'}, + {'provider': 'modelscope', 'model': 'deepseek-ai/DeepSeek-V3', 'display_name': 'DeepSeek-V3'}, + {'provider': 'openai', 'model': 'gpt-4o', 'display_name': 'GPT-4o'}, + {'provider': 'openai', 'model': 'gpt-4o-mini', 'display_name': 'GPT-4o Mini'}, + {'provider': 'anthropic', 'model': 'claude-3-5-sonnet-20241022', 'display_name': 'Claude 3.5 Sonnet'}, ] } @@ -437,26 +401,22 @@ class FileReadRequest(BaseModel): @router.get('/files/list') async def list_output_files( - output_dir: Optional[str] = Query(default='output'), - session_id: Optional[str] = Query(default=None), - root_dir: Optional[str] = Query(default=None), + output_dir: Optional[str] = Query(default='output'), + session_id: Optional[str] = Query(default=None), + root_dir: Optional[str] = Query(default=None), ): """List all files under root_dir as a tree structure. root_dir: optional. If not provided, defaults to ms-agent/output. Also supports 'projects' or 'projects/xxx' etc. """ # Excluded folders - exclude_dirs = { - 'node_modules', '__pycache__', '.git', '.venv', 'venv', 'dist', 'build' - } + exclude_dirs = {'node_modules', '__pycache__', '.git', '.venv', 'venv', 'dist', 'build'} # Base directories (same way as read_file_content) - base_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) projects_dir = os.path.join(base_dir, 'projects') if session_id: - session_root = get_session_root(session_id) resolved_root = (session_root / '').resolve() @@ -522,14 +482,15 @@ def build_tree(dir_path: str) -> dict: # Return RELATIVE path to resolved_root (better for frontend + read API) rel_path = os.path.relpath(full_path, resolved_root) - result['files'].append({ - 'name': item, - 'path': rel_path, # <-- relative path - 'abs_path': - full_path, # optional: if you still want absolute for debugging - 'size': os.path.getsize(full_path), - 'modified': os.path.getmtime(full_path) - }) + result['files'].append( + { + 'name': item, + 'path': rel_path, # <-- relative path + 'abs_path': full_path, # optional: if you still want absolute for debugging + 'size': os.path.getsize(full_path), + 'modified': os.path.getmtime(full_path), + } + ) result['files'].sort(key=lambda x: x['modified'], reverse=True) return result @@ -540,12 +501,10 @@ def build_tree(dir_path: str) -> dict: def get_allowed_roots(): - base_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) output_dir = os.path.join(base_dir, 'output') projects_dir = os.path.join(base_dir, 'projects') - return base_dir, os.path.normpath(output_dir), os.path.normpath( - projects_dir) + return base_dir, os.path.normpath(output_dir), os.path.normpath(projects_dir) def resolve_root_dir(root_dir: Optional[str]) -> str: @@ -576,8 +535,7 @@ def resolve_root_dir(root_dir: Optional[str]) -> str: cand1 = os.path.join(output_dir, rd) cand2 = os.path.join(projects_dir, rd) # choose existing one if possible, otherwise default to cand1 - resolved = cand1 if os.path.exists(cand1) else ( - cand2 if os.path.exists(cand2) else cand1) + resolved = cand1 if os.path.exists(cand1) else (cand2 if os.path.exists(cand2) else cand1) resolved = os.path.normpath(os.path.abspath(resolved)) @@ -603,14 +561,11 @@ def resolve_file_path(root_dir_abs: str, file_path: str) -> str: elif file_path.startswith('projects/'): # Special case: if path starts with 'projects/', resolve from base_dir # This handles: projects/code_genesis/output/config.js - base_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - full_path = os.path.normpath( - os.path.abspath(os.path.join(base_dir, file_path))) + base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + full_path = os.path.normpath(os.path.abspath(os.path.join(base_dir, file_path))) else: # Try multiple locations - base_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) candidates = [ # First try with root_dir_abs (for session-based access) @@ -625,8 +580,7 @@ def resolve_file_path(root_dir_abs: str, file_path: str) -> str: for project_name in os.listdir(projects_dir): project_path = os.path.join(projects_dir, project_name) if os.path.isdir(project_path): - candidates.append( - os.path.join(project_path, 'output', file_path)) + candidates.append(os.path.join(project_path, 'output', file_path)) except (OSError, PermissionError): pass @@ -660,12 +614,10 @@ async def read_file_content(request: FileReadRequest): full_path = resolve_file_path(root_abs, request.path) if not os.path.exists(full_path): - raise HTTPException( - status_code=404, detail=f'File not found: {full_path}') + raise HTTPException(status_code=404, detail=f'File not found: {full_path}') if not os.path.isfile(full_path): - raise HTTPException( - status_code=400, detail=f'Path {full_path} is not a file') + raise HTTPException(status_code=400, detail=f'Path {full_path} is not a file') # limit 1MB file_size = os.path.getsize(full_path) if file_size > 1024 * 1024: @@ -706,19 +658,17 @@ async def read_file_content(request: FileReadRequest): 'root_dir': root_abs, 'filename': os.path.basename(full_path), 'language': language, - 'size': file_size + 'size': file_size, } except UnicodeDecodeError: raise HTTPException(status_code=400, detail='File is not a text file') except Exception as e: - raise HTTPException( - status_code=500, detail=f'Error reading file: {str(e)}') + raise HTTPException(status_code=500, detail=f'Error reading file: {str(e)}') def resolve_and_check_path(file_path: str) -> str: """Resolve file path, trying multiple locations""" - base_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) if os.path.isabs(file_path): full_path = file_path @@ -748,9 +698,7 @@ def resolve_and_check_path(file_path: str) -> str: project_path = os.path.join(projects_dir, project_name) if os.path.isdir(project_path): # Try project/output/filename - candidates.append( - os.path.join(project_path, 'output', - file_path)) + candidates.append(os.path.join(project_path, 'output', file_path)) except (OSError, PermissionError): pass @@ -764,8 +712,7 @@ def resolve_and_check_path(file_path: str) -> str: if not full_path: # If not found, use the first candidate for error message - full_path = os.path.normpath( - candidates[0] if candidates else file_path) + full_path = os.path.normpath(candidates[0] if candidates else file_path) full_path = os.path.normpath(full_path) @@ -775,18 +722,15 @@ def resolve_and_check_path(file_path: str) -> str: # TODO: Security check: ensure `full_path` is within configured allowed roots. if not os.path.exists(full_path): - raise HTTPException( - status_code=404, detail=f'File not found: {full_path}') + raise HTTPException(status_code=404, detail=f'File not found: {full_path}') if not os.path.isfile(full_path): - raise HTTPException( - status_code=400, detail=f'Path {full_path} is not a file') + raise HTTPException(status_code=400, detail=f'Path {full_path} is not a file') return full_path @router.get('/files/stream') -async def stream_file(path: str, - session_id: Optional[str] = Query(default=None)): +async def stream_file(path: str, session_id: Optional[str] = Query(default=None)): if session_id: session_root = get_session_root(session_id) root_abs = str(session_root.resolve()) @@ -800,8 +744,5 @@ async def stream_file(path: str, full_path, media_type=media_type, filename=os.path.basename(full_path), - headers={ - 'Content-Disposition': - f'inline; filename="{os.path.basename(full_path)}"' - }, + headers={'Content-Disposition': f'inline; filename="{os.path.basename(full_path)}"'}, ) diff --git a/webui/backend/config_manager.py b/webui/backend/config_manager.py index eddf94915..9d1511af8 100644 --- a/webui/backend/config_manager.py +++ b/webui/backend/config_manager.py @@ -3,12 +3,12 @@ Configuration management for MS-Agent Web UI Handles global settings, LLM configuration, and MCP server configuration. """ + +import json import os from threading import Lock from typing import Any, Dict, Optional -import json - class ConfigManager: """Manages global configuration for the Web UI""" @@ -21,46 +21,23 @@ class ConfigManager: 'base_url': 'https://api-inference.modelscope.cn/v1/', 'temperature': None, 'temperature_enabled': False, - 'max_tokens': None + 'max_tokens': None, }, 'deep_research': { - 'researcher': { - 'model': '', - 'api_key': '', - 'base_url': '' - }, - 'searcher': { - 'model': '', - 'api_key': '', - 'base_url': '' - }, - 'reporter': { - 'model': '', - 'api_key': '', - 'base_url': '' - }, - 'search': { - 'summarizer_model': '', - 'summarizer_api_key': '', - 'summarizer_base_url': '' - } - }, - 'edit_file_config': { - 'api_key': '', - 'base_url': 'https://api.morphllm.com/v1', - 'diff_model': 'morph-v3-fast' - }, - 'edgeone_pages': { - 'api_token': '', - 'project_name': '' + 'researcher': {'model': '', 'api_key': '', 'base_url': ''}, + 'searcher': {'model': '', 'api_key': '', 'base_url': ''}, + 'reporter': {'model': '', 'api_key': '', 'base_url': ''}, + 'search': {'summarizer_model': '', 'summarizer_api_key': '', 'summarizer_base_url': ''}, }, + 'edit_file_config': {'api_key': '', 'base_url': 'https://api.morphllm.com/v1', 'diff_model': 'morph-v3-fast'}, + 'edgeone_pages': {'api_token': '', 'project_name': ''}, 'search_keys': { 'exa_api_key': '', 'serpapi_api_key': '', }, 'mcp_servers': {}, 'theme': 'dark', - 'output_dir': './output' + 'output_dir': './output', } def __init__(self, config_dir: str): @@ -108,10 +85,7 @@ def _save_config(self): """Save configuration to file""" with self._lock: # Save main config (without mcp_servers) - config_to_save = { - k: v - for k, v in self._config.items() if k != 'mcp_servers' - } + config_to_save = {k: v for k, v in self._config.items() if k != 'mcp_servers'} with open(self.config_file, 'w', encoding='utf-8') as f: json.dump(config_to_save, f, indent=2) @@ -158,8 +132,7 @@ def update_mcp_config(self, mcp_config: Dict[str, Any]): def get_edit_file_config(self) -> Dict[str, Any]: """Get edit_file_config configuration""" config = self._load_config() - return config.get('edit_file_config', - self.DEFAULT_CONFIG['edit_file_config']) + return config.get('edit_file_config', self.DEFAULT_CONFIG['edit_file_config']) def update_edit_file_config(self, edit_file_config: Dict[str, Any]): """Update edit_file_config configuration""" @@ -170,11 +143,9 @@ def update_edit_file_config(self, edit_file_config: Dict[str, Any]): def get_edgeone_pages_config(self) -> Dict[str, Any]: """Get EdgeOne Pages configuration""" config = self._load_config() - return config.get('edgeone_pages', - self.DEFAULT_CONFIG['edgeone_pages']) + return config.get('edgeone_pages', self.DEFAULT_CONFIG['edgeone_pages']) - def update_edgeone_pages_config(self, edgeone_pages_config: Dict[str, - Any]): + def update_edgeone_pages_config(self, edgeone_pages_config: Dict[str, Any]): """Update EdgeOne Pages configuration""" self._load_config() self._config['edgeone_pages'] = edgeone_pages_config @@ -194,11 +165,9 @@ def update_search_keys(self, search_keys: Dict[str, Any]): def get_deep_research_config(self) -> Dict[str, Any]: """Get deep research configuration""" config = self._load_config() - return config.get('deep_research', - self.DEFAULT_CONFIG['deep_research']) + return config.get('deep_research', self.DEFAULT_CONFIG['deep_research']) - def update_deep_research_config(self, deep_research_config: Dict[str, - Any]): + def update_deep_research_config(self, deep_research_config: Dict[str, Any]): """Update deep research configuration""" self._load_config() self._config['deep_research'] = deep_research_config diff --git a/webui/backend/deep_research_eventizer.py b/webui/backend/deep_research_eventizer.py index 84949674a..261f3b513 100644 --- a/webui/backend/deep_research_eventizer.py +++ b/webui/backend/deep_research_eventizer.py @@ -1,6 +1,6 @@ +import json from typing import Any, Callable, Dict, List, Optional -import json from ms_agent.llm.utils import Message, ToolCall @@ -16,7 +16,6 @@ def _stringify_content(content: Any) -> str: class HistoryEventizer: - def __init__( self, emit: Callable[[Dict[str, Any]], None], @@ -52,8 +51,7 @@ def reset(self) -> None: self._tool_call_args = {} self._tool_call_names = {} - def _wrap_event(self, event_type: str, - payload: Dict[str, Any]) -> Dict[str, Any]: + def _wrap_event(self, event_type: str, payload: Dict[str, Any]) -> Dict[str, Any]: event: Dict[str, Any] = {'type': event_type, 'payload': payload} if self._session_id: event['session_id'] = self._session_id @@ -67,7 +65,7 @@ def _emit_event(self, event_type: str, payload: Dict[str, Any]) -> None: def _should_reset(self, messages: List[Message]) -> bool: if len(messages) < len(self._prev_messages): return True - for idx, msg in enumerate(messages[:len(self._prev_messages)]): + for idx, msg in enumerate(messages[: len(self._prev_messages)]): if msg.role != self._prev_messages[idx].role: return True return False @@ -87,8 +85,11 @@ def _ensure_message_id(self, idx: int, message: Message) -> str: def _is_subagent_tool(self, tool_name: str) -> bool: if not tool_name: return False - return tool_name.startswith('agent_tools---') or tool_name.endswith( - 'searcher_tool') or tool_name.endswith('reporter_tool') + return ( + tool_name.startswith('agent_tools---') + or tool_name.endswith('searcher_tool') + or tool_name.endswith('reporter_tool') + ) def _extract_tool_name(self, call: ToolCall) -> str: if not isinstance(call, dict): @@ -125,8 +126,7 @@ def _parse_tool_args(self, raw: Any) -> Dict[str, Any]: return {'request': raw} return {} - def _build_subagent_title(self, tool_name: str, - tool_args: Dict[str, Any]) -> str: + def _build_subagent_title(self, tool_name: str, tool_args: Dict[str, Any]) -> str: if 'searcher' in tool_name: base = 'Searcher' elif 'reporter' in tool_name: @@ -140,21 +140,18 @@ def _build_subagent_title(self, tool_name: str, try: parsed = json.loads(request) if isinstance(parsed, dict): - summary = parsed.get('task_id') or parsed.get( - '调研目标') or parsed.get('目标') + summary = parsed.get('task_id') or parsed.get('调研目标') or parsed.get('目标') except Exception: summary = None if summary is None: summary = request.strip().splitlines()[0][:80] return f'{base}: {summary}' if summary else base - def _record_tool_call(self, call_id: str, tool_name: str, - tool_args: Dict[str, Any]) -> tuple[bool, bool]: + def _record_tool_call(self, call_id: str, tool_name: str, tool_args: Dict[str, Any]) -> tuple[bool, bool]: is_new = call_id not in self._seen_tool_calls prev_args = self._tool_call_args.get(call_id) prev_name = self._tool_call_names.get(call_id) - if (not is_new and prev_args == tool_args - and (not tool_name or tool_name == prev_name)): + if not is_new and prev_args == tool_args and (not tool_name or tool_name == prev_name): return False, False self._seen_tool_calls.add(call_id) self._tool_call_args[call_id] = tool_args @@ -162,12 +159,10 @@ def _record_tool_call(self, call_id: str, tool_name: str, self._tool_call_names[call_id] = tool_name return True, is_new - def _maybe_emit_todos(self, tool_name: str, result_text: str, - call_id: Optional[str]) -> None: + def _maybe_emit_todos(self, tool_name: str, result_text: str, call_id: Optional[str]) -> None: if not tool_name: return - if not ('todo_list---todo_write' in tool_name - or 'todo_list---todo_read' in tool_name): + if not ('todo_list---todo_write' in tool_name or 'todo_list---todo_read' in tool_name): return try: parsed = json.loads(result_text) @@ -184,8 +179,7 @@ def _maybe_emit_todos(self, tool_name: str, result_text: str, payload['call_id'] = call_id self._emit_event('dr.state', payload) - def _emit_assistant_delta(self, message_id: str, delta: str, - full: str) -> None: + def _emit_assistant_delta(self, message_id: str, delta: str, full: str) -> None: payload = { 'message_id': message_id, 'delta': delta, @@ -193,8 +187,7 @@ def _emit_assistant_delta(self, message_id: str, delta: str, } self._emit_event('dr.chat.message.delta', payload) - def _emit_subagent_delta(self, message_id: str, delta: str, - full: str) -> None: + def _emit_subagent_delta(self, message_id: str, delta: str, full: str) -> None: payload = { 'card_id': self._card_id, 'message_id': message_id, @@ -203,8 +196,7 @@ def _emit_subagent_delta(self, message_id: str, delta: str, } self._emit_event('dr.subagent.message.delta', payload) - def _emit_subagent_message(self, message_id: str, role: str, - content: str) -> None: + def _emit_subagent_message(self, message_id: str, role: str, content: str) -> None: payload = { 'card_id': self._card_id, 'message_id': message_id, @@ -221,8 +213,7 @@ def _emit_assistant_completed(self, message_id: str, content: str) -> None: } self._emit_event('dr.chat.message.completed', payload) - def _emit_chat_message(self, message_id: str, role: str, content: str, - name: Optional[str]) -> None: + def _emit_chat_message(self, message_id: str, role: str, content: str, name: Optional[str]) -> None: payload = { 'message_id': message_id, 'role': role, @@ -232,19 +223,15 @@ def _emit_chat_message(self, message_id: str, role: str, content: str, payload['name'] = name self._emit_event('dr.chat.message', payload) - def _process_tool_calls(self, message_id: str, - tool_calls: List[ToolCall]) -> None: + def _process_tool_calls(self, message_id: str, tool_calls: List[ToolCall]) -> None: for idx, call in enumerate(tool_calls or []): call_id = call.get('id') or f'{message_id}-call-{idx}' tool_name = self._extract_tool_name(call) - tool_args = self._parse_tool_args( - self._extract_tool_args_raw(call)) - should_emit, is_new = self._record_tool_call( - call_id, tool_name, tool_args) + tool_args = self._parse_tool_args(self._extract_tool_args_raw(call)) + should_emit, is_new = self._record_tool_call(call_id, tool_name, tool_args) if not should_emit: continue - category = 'subagent' if self._is_subagent_tool( - tool_name) else 'normal' + category = 'subagent' if self._is_subagent_tool(tool_name) else 'normal' payload = { 'call_id': call_id, 'source_message_id': message_id, @@ -289,10 +276,13 @@ def _process_tool_result(self, message: Message) -> None: self._emit_event('dr.tool.result', payload) if call_id in self._subagent_call_ids: summary = result_text.strip().splitlines()[0][:160] - self._emit_event('dr.subagent.card.completed', { - 'card_id': call_id, - 'summary': summary, - }) + self._emit_event( + 'dr.subagent.card.completed', + { + 'card_id': call_id, + 'summary': summary, + }, + ) self._maybe_emit_todos(tool_name, result_text, call_id) def process(self, messages: List[Message]) -> None: @@ -312,16 +302,14 @@ def process(self, messages: List[Message]) -> None: prev_content = self._assistant_contents.get(message_id, '') if content and content != prev_content: if content.startswith(prev_content): - delta = content[len(prev_content):] + delta = content[len(prev_content) :] else: delta = content if delta: - self._emit_assistant_delta(message_id, delta, - content) + self._emit_assistant_delta(message_id, delta, content) self._assistant_contents[message_id] = content if message.tool_calls: - self._process_tool_calls(message_id, - message.tool_calls) + self._process_tool_calls(message_id, message.tool_calls) elif role == 'tool': self._process_tool_result(message) else: @@ -330,31 +318,25 @@ def process(self, messages: List[Message]) -> None: if idx >= prev_len: content = _stringify_content(message.content) if content: - self._emit_chat_message( - message_id, role, content, - getattr(message, 'name', None)) + self._emit_chat_message(message_id, role, content, getattr(message, 'name', None)) else: if role == 'assistant': content = _stringify_content(message.content) prev_content = self._assistant_contents.get(message_id, '') if content and content != prev_content: if content.startswith(prev_content): - delta = content[len(prev_content):] + delta = content[len(prev_content) :] else: delta = content if delta: - self._emit_subagent_delta(message_id, delta, - content) + self._emit_subagent_delta(message_id, delta, content) self._assistant_contents[message_id] = content if message.tool_calls: for idx, call in enumerate(message.tool_calls or []): - call_id = call.get( - 'id') or f'{message_id}-call-{idx}' + call_id = call.get('id') or f'{message_id}-call-{idx}' tool_name = self._extract_tool_name(call) - tool_args = self._parse_tool_args( - self._extract_tool_args_raw(call)) - should_emit, is_new = self._record_tool_call( - call_id, tool_name, tool_args) + tool_args = self._parse_tool_args(self._extract_tool_args_raw(call)) + should_emit, is_new = self._record_tool_call(call_id, tool_name, tool_args) if not should_emit: continue payload = { @@ -375,8 +357,7 @@ def process(self, messages: List[Message]) -> None: if role != 'tool' and idx >= prev_len: content = _stringify_content(message.content) if content: - self._emit_subagent_message( - message_id, role, content) + self._emit_subagent_message(message_id, role, content) if role == 'tool' and message.tool_call_id: call_id = message.tool_call_id if call_id in self._seen_tool_results: @@ -392,11 +373,8 @@ def process(self, messages: List[Message]) -> None: tool_args = self._tool_call_args.get(call_id) if tool_args is not None: payload['tool'] = { - 'name': - (tool_name - or self._tool_call_names.get(call_id, '')), - 'arguments': - tool_args, + 'name': (tool_name or self._tool_call_names.get(call_id, '')), + 'arguments': tool_args, } self._emit_event('dr.subagent.tool.result', payload) diff --git a/webui/backend/deep_research_worker.py b/webui/backend/deep_research_worker.py index e8a3480a3..5ca62420b 100644 --- a/webui/backend/deep_research_worker.py +++ b/webui/backend/deep_research_worker.py @@ -1,5 +1,6 @@ import argparse import asyncio +import json import os import signal import sys @@ -7,11 +8,11 @@ from pathlib import Path from typing import Any, Dict, Optional -import json from deep_research_eventizer import HistoryEventizer # noqa: E402 +from omegaconf import OmegaConf + from ms_agent.agent.loader import AgentLoader from ms_agent.tools.agent_tool import AgentTool -from omegaconf import OmegaConf BACKEND_DIR = Path(__file__).resolve().parent if str(BACKEND_DIR) not in sys.path: @@ -21,7 +22,6 @@ class NullWriter: - def write(self, _: str) -> int: return 0 @@ -30,7 +30,6 @@ def flush(self) -> None: class NDJSONEmitter: - def __init__(self, stream) -> None: self._stream = stream @@ -71,21 +70,16 @@ def _normalize_agent_override(raw: Optional[Dict[str, Any]]) -> Dict[str, str]: } -def _resolve_agent_llm_config(role: str, llm_config: Dict[str, Any], - dr_config: Dict[str, Any]) -> Dict[str, str]: +def _resolve_agent_llm_config(role: str, llm_config: Dict[str, Any], dr_config: Dict[str, Any]) -> Dict[str, str]: overrides = _normalize_agent_override((dr_config or {}).get(role)) return { - 'model': - overrides.get('model') or str(llm_config.get('model') or ''), - 'api_key': - overrides.get('api_key') or str(llm_config.get('api_key') or ''), - 'base_url': - overrides.get('base_url') or str(llm_config.get('base_url') or ''), + 'model': overrides.get('model') or str(llm_config.get('model') or ''), + 'api_key': overrides.get('api_key') or str(llm_config.get('api_key') or ''), + 'base_url': overrides.get('base_url') or str(llm_config.get('base_url') or ''), } -def _normalize_search_override( - raw: Optional[Dict[str, Any]]) -> Dict[str, str]: +def _normalize_search_override(raw: Optional[Dict[str, Any]]) -> Dict[str, str]: raw = raw or {} return { 'summarizer_model': str(raw.get('summarizer_model') or ''), @@ -95,8 +89,8 @@ def _normalize_search_override( def _build_config_override( - llm_config: Dict[str, Any], output_dir: str, - dr_config: Dict[str, Any]) -> Optional[Dict[str, Any]]: + llm_config: Dict[str, Any], output_dir: str, dr_config: Dict[str, Any] +) -> Optional[Dict[str, Any]]: override: Dict[str, Any] = {} if output_dir: override['output_dir'] = output_dir @@ -132,8 +126,7 @@ def _build_config_override( return override or None -async def _watch_artifacts(output_dir: str, emitter: NDJSONEmitter, - session_id: str) -> None: +async def _watch_artifacts(output_dir: str, emitter: NDJSONEmitter, session_id: str) -> None: last_snapshot: Dict[str, tuple[int, float]] = {} output_path = Path(output_dir) ignore_dirs = {'.locks', '__pycache__'} @@ -155,25 +148,23 @@ async def _watch_artifacts(output_dir: str, emitter: NDJSONEmitter, except OSError: continue snapshot[rel_path] = (stat.st_size, stat.st_mtime) - files.append({ - 'path': rel_path, - 'relative_path': rel_path, - 'size': stat.st_size, - 'modified': stat.st_mtime, - }) + files.append( + { + 'path': rel_path, + 'relative_path': rel_path, + 'size': stat.st_size, + 'modified': stat.st_mtime, + } + ) if snapshot != last_snapshot: - emitter.emit({ - 'type': 'dr.artifact.updated', - 'payload': { - 'files': - sorted( - files, - key=lambda x: x.get('modified', 0), - reverse=True) - }, - 'session_id': session_id, - }) + emitter.emit( + { + 'type': 'dr.artifact.updated', + 'payload': {'files': sorted(files, key=lambda x: x.get('modified', 0), reverse=True)}, + 'session_id': session_id, + } + ) last_snapshot = snapshot await asyncio.sleep(1.0) @@ -181,16 +172,14 @@ async def _watch_artifacts(output_dir: str, emitter: NDJSONEmitter, async def run_worker(args: argparse.Namespace) -> None: emitter = NDJSONEmitter(sys.__stdout__) - main_eventizer = HistoryEventizer( - emitter.emit, channel='main', session_id=args.session_id) + main_eventizer = HistoryEventizer(emitter.emit, channel='main', session_id=args.session_id) subagent_eventizers: Dict[str, HistoryEventizer] = {} loop = asyncio.get_running_loop() subagent_queue: asyncio.Queue = asyncio.Queue() def chunk_callback(*, event_type: str, data: Dict[str, Any]) -> None: - loop.call_soon_threadsafe(subagent_queue.put_nowait, - (event_type, data)) + loop.call_soon_threadsafe(subagent_queue.put_nowait, (event_type, data)) async def consume_subagent_events(): while True: @@ -215,10 +204,8 @@ async def consume_subagent_events(): llm_config = _load_llm_config() dr_config = _load_deep_research_config() - config_override = _build_config_override(llm_config, args.output_dir, - dr_config) - config_override = OmegaConf.create( - config_override) if config_override else None + config_override = _build_config_override(llm_config, args.output_dir, dr_config) + config_override = OmegaConf.create(config_override) if config_override else None agent = AgentLoader.build( config_dir_or_id=args.config, @@ -246,13 +233,10 @@ async def prepare_tools_with_callback(): tool_name = str(spec.tool_name or '') if 'searcher' in tool_name: - resolved = _resolve_agent_llm_config( - 'searcher', llm_config, dr_config) - search_override = _normalize_search_override( - (dr_config or {}).get('search')) + resolved = _resolve_agent_llm_config('searcher', llm_config, dr_config) + search_override = _normalize_search_override((dr_config or {}).get('search')) elif 'reporter' in tool_name: - resolved = _resolve_agent_llm_config( - 'reporter', llm_config, dr_config) + resolved = _resolve_agent_llm_config('reporter', llm_config, dr_config) search_override = {} else: resolved = {} @@ -273,16 +257,11 @@ async def prepare_tools_with_callback(): tools_cfg = dict(updated.get('tools') or {}) web_cfg = dict(tools_cfg.get('web_search') or {}) if search_override.get('summarizer_model'): - web_cfg['summarizer_model'] = search_override[ - 'summarizer_model'] + web_cfg['summarizer_model'] = search_override['summarizer_model'] if search_override.get('summarizer_api_key'): - web_cfg[ - 'summarizer_api_key'] = search_override[ - 'summarizer_api_key'] + web_cfg['summarizer_api_key'] = search_override['summarizer_api_key'] if search_override.get('summarizer_base_url'): - web_cfg[ - 'summarizer_base_url'] = search_override[ - 'summarizer_base_url'] + web_cfg['summarizer_base_url'] = search_override['summarizer_base_url'] if web_cfg: tools_cfg['web_search'] = web_cfg updated['tools'] = tools_cfg @@ -297,8 +276,7 @@ async def prepare_tools_with_callback(): agent.prepare_tools = prepare_tools_with_callback - artifact_task = asyncio.create_task( - _watch_artifacts(args.output_dir, emitter, args.session_id)) + artifact_task = asyncio.create_task(_watch_artifacts(args.output_dir, emitter, args.session_id)) subagent_task = asyncio.create_task(consume_subagent_events()) had_error = False @@ -312,40 +290,48 @@ async def prepare_tools_with_callback(): main_eventizer.process(result) except Exception as exc: had_error = True - emitter.emit({ - 'type': 'dr.worker.error', - 'payload': { - 'error': str(exc), - 'traceback': traceback.format_exc(), - }, - 'session_id': args.session_id, - }) - emitter.emit({ - 'type': 'error', - 'message': str(exc), - }) + emitter.emit( + { + 'type': 'dr.worker.error', + 'payload': { + 'error': str(exc), + 'traceback': traceback.format_exc(), + }, + 'session_id': args.session_id, + } + ) + emitter.emit( + { + 'type': 'error', + 'message': str(exc), + } + ) raise finally: main_eventizer.finalize() - emitter.emit({ - 'type': 'dr.worker.exited', - 'payload': { - 'status': 'completed' - }, - 'session_id': args.session_id, - }) + emitter.emit( + { + 'type': 'dr.worker.exited', + 'payload': {'status': 'completed'}, + 'session_id': args.session_id, + } + ) if STOP_REQUESTED: - emitter.emit({ - 'type': 'status', - 'status': 'stopped', - }) + emitter.emit( + { + 'type': 'status', + 'status': 'stopped', + } + ) elif not had_error: - emitter.emit({ - 'type': 'complete', - 'result': { - 'status': 'success', - }, - }) + emitter.emit( + { + 'type': 'complete', + 'result': { + 'status': 'success', + }, + } + ) subagent_queue.put_nowait((None, None)) artifact_task.cancel() subagent_task.cancel() diff --git a/webui/backend/deep_research_worker_manager.py b/webui/backend/deep_research_worker_manager.py index a5eb29d05..76c4b9d84 100644 --- a/webui/backend/deep_research_worker_manager.py +++ b/webui/backend/deep_research_worker_manager.py @@ -1,4 +1,5 @@ import asyncio +import json import os import signal import sys @@ -6,13 +7,9 @@ from pathlib import Path from typing import Any, Awaitable, Callable, Dict, Optional -import json - class DeepResearchWorkerManager: - - def __init__(self, send_event: Callable[[str, Dict[str, Any]], - Awaitable[None]]): + def __init__(self, send_event: Callable[[str, Dict[str, Any]], Awaitable[None]]): self._send_event = send_event self._processes: Dict[str, asyncio.subprocess.Process] = {} self._stdout_tasks: Dict[str, asyncio.Task] = {} @@ -26,18 +23,18 @@ def _get_worker_path(self) -> Path: return Path(__file__).resolve().parent / 'deep_research_worker.py' def _build_env( - self, env_vars: Optional[Dict[str, str]], - llm_config: Optional[Dict[str, Any]], - deep_research_config: Optional[Dict[str, Any]]) -> Dict[str, str]: + self, + env_vars: Optional[Dict[str, str]], + llm_config: Optional[Dict[str, Any]], + deep_research_config: Optional[Dict[str, Any]], + ) -> Dict[str, str]: env = os.environ.copy() if env_vars: env.update({k: v for k, v in env_vars.items() if v}) if llm_config: - env['MS_AGENT_LLM_CONFIG'] = json.dumps( - llm_config, ensure_ascii=False) + env['MS_AGENT_LLM_CONFIG'] = json.dumps(llm_config, ensure_ascii=False) if deep_research_config: - env['MS_AGENT_DEEP_RESEARCH_CONFIG'] = json.dumps( - deep_research_config, ensure_ascii=False) + env['MS_AGENT_DEEP_RESEARCH_CONFIG'] = json.dumps(deep_research_config, ensure_ascii=False) api_key = (llm_config or {}).get('api_key') base_url = (llm_config or {}).get('base_url') @@ -49,20 +46,20 @@ def _build_env( repo_root = str(self._get_repo_root()) existing_path = env.get('PYTHONPATH', '') if repo_root not in existing_path.split(os.pathsep): - env['PYTHONPATH'] = repo_root + ( - os.pathsep + existing_path if existing_path else '') + env['PYTHONPATH'] = repo_root + (os.pathsep + existing_path if existing_path else '') return env async def start( - self, - session_id: str, - *, - query: str, - config_path: str, - output_dir: str, - env_vars: Optional[Dict[str, str]] = None, - llm_config: Optional[Dict[str, Any]] = None, - deep_research_config: Optional[Dict[str, Any]] = None) -> None: + self, + session_id: str, + *, + query: str, + config_path: str, + output_dir: str, + env_vars: Optional[Dict[str, str]] = None, + llm_config: Optional[Dict[str, Any]] = None, + deep_research_config: Optional[Dict[str, Any]] = None, + ) -> None: if session_id in self._processes: await self.stop(session_id) @@ -95,17 +92,17 @@ async def start( ) self._processes[session_id] = process - self._stdout_tasks[session_id] = asyncio.create_task( - self._read_stdout(session_id, process)) - self._stderr_tasks[session_id] = asyncio.create_task( - self._read_stderr(session_id, process)) + self._stdout_tasks[session_id] = asyncio.create_task(self._read_stdout(session_id, process)) + self._stderr_tasks[session_id] = asyncio.create_task(self._read_stderr(session_id, process)) await self._send_event( - session_id, { + session_id, + { 'type': 'log', 'level': 'info', 'message': f'Deep research worker started (pid={process.pid})', 'timestamp': datetime.now().isoformat(), - }) + }, + ) async def stop(self, session_id: str) -> None: process = self._processes.get(session_id) @@ -135,8 +132,7 @@ async def stop(self, session_id: str) -> None: finally: self._cleanup(session_id) - async def _read_stdout(self, session_id: str, - process: asyncio.subprocess.Process) -> None: + async def _read_stdout(self, session_id: str, process: asyncio.subprocess.Process) -> None: if not process.stdout: return while True: @@ -162,20 +158,16 @@ async def _read_stdout(self, session_id: str, return_code = None if return_code not in (None, 0) and session_id not in self._stopping: await self._send_event( - session_id, { - 'type': - 'error', - 'message': - f'Deep research worker exited with code {return_code}', - }) - await self._send_event(session_id, { - 'type': 'status', - 'status': 'error' - }) + session_id, + { + 'type': 'error', + 'message': f'Deep research worker exited with code {return_code}', + }, + ) + await self._send_event(session_id, {'type': 'status', 'status': 'error'}) self._cleanup(session_id) - async def _read_stderr(self, session_id: str, - process: asyncio.subprocess.Process) -> None: + async def _read_stderr(self, session_id: str, process: asyncio.subprocess.Process) -> None: if not process.stderr: return while True: @@ -188,12 +180,14 @@ async def _read_stderr(self, session_id: str, sys.stderr.write(text) sys.stderr.flush() await self._send_event( - session_id, { + session_id, + { 'type': 'log', 'level': 'error', 'message': f'[deep_research_worker] {text.strip()}', 'timestamp': datetime.now().isoformat(), - }) + }, + ) except Exception: pass diff --git a/webui/backend/main.py b/webui/backend/main.py index 2afcebbe3..2f2d17291 100644 --- a/webui/backend/main.py +++ b/webui/backend/main.py @@ -3,6 +3,7 @@ MS-Agent Web UI Backend Server Provides REST API and WebSocket endpoints for the ms-agent framework. """ + import os import sys @@ -15,15 +16,11 @@ from websocket_handler import router as ws_router # Add ms-agent to path -MS_AGENT_PATH = os.path.abspath( - os.path.join(os.path.dirname(__file__), '..', '..', 'ms-agent')) +MS_AGENT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'ms-agent')) if MS_AGENT_PATH not in sys.path: sys.path.insert(0, MS_AGENT_PATH) -app = FastAPI( - title='MS-Agent Web UI', - description='Web interface for the MS-Agent framework', - version='1.0.0') +app = FastAPI(title='MS-Agent Web UI', description='Web interface for the MS-Agent framework', version='1.0.0') # CORS configuration app.add_middleware( @@ -41,10 +38,7 @@ # Serve static files in production STATIC_DIR = os.path.join(os.path.dirname(__file__), '..', 'frontend', 'dist') if os.path.exists(STATIC_DIR): - app.mount( - '/assets', - StaticFiles(directory=os.path.join(STATIC_DIR, 'assets')), - name='assets') + app.mount('/assets', StaticFiles(directory=os.path.join(STATIC_DIR, 'assets')), name='assets') @app.get('/{full_path:path}') async def serve_spa(full_path: str): @@ -64,19 +58,19 @@ async def health_check(): def main(): """Start the server""" import argparse + parser = argparse.ArgumentParser(description='MS-Agent Web UI Server') parser.add_argument('--host', default='0.0.0.0', help='Host to bind') parser.add_argument('--port', type=int, default=7860, help='Port to bind') - parser.add_argument( - '--reload', action='store_true', help='Enable auto-reload') + parser.add_argument('--reload', action='store_true', help='Enable auto-reload') args = parser.parse_args() - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(' MS-Agent Web UI Server') - print(f"{'='*60}") + print(f"{'=' * 60}") print(f' Server running at: http://{args.host}:{args.port}') print(f' API documentation: http://{args.host}:{args.port}/docs') - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") uvicorn.run('main:app', host=args.host, port=args.port, reload=args.reload) diff --git a/webui/backend/project_discovery.py b/webui/backend/project_discovery.py index e30bf639d..43f9b9e0c 100644 --- a/webui/backend/project_discovery.py +++ b/webui/backend/project_discovery.py @@ -3,6 +3,7 @@ Project discovery module for MS-Agent Web UI Discovers and manages available projects from the ms-agent/projects directory. """ + import os import re from typing import Any, Dict, List, Optional @@ -18,8 +19,7 @@ def __init__(self, projects_dir: str): self.projects_dir = projects_dir self._projects_cache: Optional[List[Dict[str, Any]]] = None - def discover_projects(self, - force_refresh: bool = False) -> List[Dict[str, Any]]: + def discover_projects(self, force_refresh: bool = False) -> List[Dict[str, Any]]: """Discover all available projects""" if self._projects_cache is not None and not force_refresh: return self._projects_cache @@ -32,8 +32,7 @@ def discover_projects(self, for item in os.listdir(self.projects_dir): item_path = os.path.join(self.projects_dir, item) # Only show projects in the whitelist - if os.path.isdir(item_path) and not item.startswith( - '.') and item in self.VISIBLE_PROJECTS: + if os.path.isdir(item_path) and not item.startswith('.') and item in self.VISIBLE_PROJECTS: project_info = self._analyze_project(item, item_path) if project_info: projects.append(project_info) @@ -53,24 +52,24 @@ def _build_virtual_projects(self) -> List[Dict[str, Any]]: researcher_yaml = os.path.join(v2_root, 'researcher.yaml') if os.path.exists(researcher_yaml): readme_path = os.path.join(v2_root, 'README.md') - description = self._extract_description( - readme_path) if os.path.exists(readme_path) else '' - projects.append({ - 'id': 'deep_research_v2', - 'name': 'deep_research_v2', - 'display_name': 'Deep Research', - 'description': description, - 'type': 'agent', - 'path': v2_root, - 'has_readme': os.path.exists(readme_path), - 'config_file': researcher_yaml, - 'supports_workflow_switch': False - }) + description = self._extract_description(readme_path) if os.path.exists(readme_path) else '' + projects.append( + { + 'id': 'deep_research_v2', + 'name': 'deep_research_v2', + 'display_name': 'Deep Research', + 'description': description, + 'type': 'agent', + 'path': v2_root, + 'has_readme': os.path.exists(readme_path), + 'config_file': researcher_yaml, + 'supports_workflow_switch': False, + } + ) return projects - def _analyze_project(self, name: str, - path: str) -> Optional[Dict[str, Any]]: + def _analyze_project(self, name: str, path: str) -> Optional[Dict[str, Any]]: """Analyze a project directory and extract its information""" # Check for workflow.yaml or agent.yaml workflow_file = os.path.join(path, 'workflow.yaml') @@ -95,16 +94,14 @@ def _analyze_project(self, name: str, # Check if project supports workflow switching (e.g., code_genesis) supports_workflow_switch = False - if project_type == 'workflow' and name == 'code_genesis' and os.path.exists( - simple_workflow_file): + if project_type == 'workflow' and name == 'code_genesis' and os.path.exists(simple_workflow_file): supports_workflow_switch = True # Generate display name from directory name display_name = self._format_display_name(name) # Extract description from README if available - description = self._extract_description(readme_file) if os.path.exists( - readme_file) else '' + description = self._extract_description(readme_file) if os.path.exists(readme_file) else '' return { 'id': name, @@ -115,7 +112,7 @@ def _analyze_project(self, name: str, 'path': path, 'has_readme': os.path.exists(readme_file), 'config_file': config_file, - 'supports_workflow_switch': supports_workflow_switch + 'supports_workflow_switch': supports_workflow_switch, } def _format_display_name(self, name: str) -> str: @@ -141,8 +138,7 @@ def _extract_description(self, readme_path: str) -> str: stripped = line.strip() # Skip headers and empty lines at the beginning if not in_description: - if stripped and not stripped.startswith( - '#') and not stripped.startswith('['): + if stripped and not stripped.startswith('#') and not stripped.startswith('['): in_description = True description_lines.append(stripped) else: @@ -188,6 +184,7 @@ def get_project_config(self, project_id: str) -> Optional[Dict[str, Any]]: try: import yaml + with open(project['config_file'], 'r', encoding='utf-8') as f: return yaml.safe_load(f) except Exception: diff --git a/webui/backend/session_manager.py b/webui/backend/session_manager.py index 1ee20587b..5852c386a 100644 --- a/webui/backend/session_manager.py +++ b/webui/backend/session_manager.py @@ -3,6 +3,7 @@ Session management for MS-Agent Web UI Handles session lifecycle and message history. """ + import uuid from datetime import datetime from threading import Lock @@ -19,11 +20,9 @@ def __init__(self): self._dr_event_counters: Dict[str, int] = {} self._lock = Lock() - def create_session(self, - project_id: str, - project_name: str, - workflow_type: str = 'standard', - session_type: str = 'project') -> Dict[str, Any]: + def create_session( + self, project_id: str, project_name: str, workflow_type: str = 'standard', session_type: str = 'project' + ) -> Dict[str, Any]: """Create a new session""" session_id = str(uuid.uuid4()) session = { @@ -36,7 +35,7 @@ def create_session(self, 'file_progress': None, 'current_step': None, 'workflow_type': workflow_type, # 'standard' or 'simple' - 'session_type': session_type # 'project' or 'chat' + 'session_type': session_type, # 'project' or 'chat' } with self._lock: @@ -78,12 +77,9 @@ def list_sessions(self) -> List[Dict[str, Any]]: """List all sessions""" return list(self._sessions.values()) - def add_message(self, - session_id: str, - role: str, - content: str, - message_type: str = 'text', - metadata: Dict[str, Any] = None) -> bool: + def add_message( + self, session_id: str, role: str, content: str, message_type: str = 'text', metadata: Dict[str, Any] = None + ) -> bool: """Add a message to a session""" if session_id not in self._sessions: return False @@ -94,7 +90,7 @@ def add_message(self, 'content': content, 'type': message_type, # text, tool_call, tool_result, error, log 'timestamp': datetime.now().isoformat(), - 'metadata': metadata or {} + 'metadata': metadata or {}, } with self._lock: @@ -110,8 +106,7 @@ def get_messages(self, session_id: str) -> Optional[List[Dict[str, Any]]]: return None return self._messages.get(session_id, []) - def add_dr_event(self, session_id: str, - event: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def add_dr_event(self, session_id: str, event: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Add a deep research event for replay.""" if session_id not in self._sessions: return None @@ -123,19 +118,14 @@ def add_dr_event(self, session_id: str, self._dr_events.setdefault(session_id, []).append(stored) return stored - def list_dr_events( - self, - session_id: str, - after_id: Optional[int] = None) -> Optional[List[Dict[str, Any]]]: + def list_dr_events(self, session_id: str, after_id: Optional[int] = None) -> Optional[List[Dict[str, Any]]]: """List deep research events for a session.""" if session_id not in self._sessions: return None events = self._dr_events.get(session_id, []) if after_id is None: return list(events) - return [ - event for event in events if event.get('event_id', 0) > after_id - ] + return [event for event in events if event.get('event_id', 0) > after_id] def update_last_message(self, session_id: str, content: str) -> bool: """Update the content of the last message (for streaming)""" @@ -146,8 +136,7 @@ def update_last_message(self, session_id: str, content: str) -> bool: self._messages[session_id][-1]['content'] = content return True - def set_workflow_progress(self, session_id: str, - progress: Dict[str, Any]) -> bool: + def set_workflow_progress(self, session_id: str, progress: Dict[str, Any]) -> bool: """Set workflow progress for a session""" if session_id not in self._sessions: return False @@ -156,8 +145,7 @@ def set_workflow_progress(self, session_id: str, self._sessions[session_id]['workflow_progress'] = progress return True - def set_file_progress(self, session_id: str, progress: Dict[str, - Any]) -> bool: + def set_file_progress(self, session_id: str, progress: Dict[str, Any]) -> bool: """Set file writing progress for a session""" if session_id not in self._sessions: return False diff --git a/webui/backend/shared.py b/webui/backend/shared.py index 1d3ce9754..c34b09944 100644 --- a/webui/backend/shared.py +++ b/webui/backend/shared.py @@ -3,6 +3,7 @@ Shared instances for backend modules. Ensures api.py and websocket_handler.py use the same manager instances. """ + import os from config_manager import ConfigManager @@ -10,8 +11,7 @@ from session_manager import SessionManager # Initialize paths -BASE_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) PROJECTS_DIR = os.path.join(BASE_DIR, 'projects') # Use ~/.ms_agent/ for configuration storage (privacy-sensitive data) CONFIG_DIR = os.path.expanduser('~/.ms_agent') diff --git a/webui/backend/websocket_handler.py b/webui/backend/websocket_handler.py index 7b2830b17..fd5915d37 100644 --- a/webui/backend/websocket_handler.py +++ b/webui/backend/websocket_handler.py @@ -3,16 +3,18 @@ WebSocket handler for real-time communication Handles agent execution, log streaming, and progress updates. """ + import asyncio +import json import os from datetime import datetime from pathlib import Path from typing import Any, Dict, Set -import json from agent_runner import AgentRunner from deep_research_worker_manager import DeepResearchWorkerManager from fastapi import APIRouter, WebSocket, WebSocketDisconnect + # Import shared instances from shared import config_manager, project_discovery, session_manager @@ -75,8 +77,7 @@ async def broadcast_log(self, log_entry: Dict[str, Any]): agent_tasks: Dict[str, asyncio.Task] = {} -async def _update_deep_research_status(session_id: str, - event: Dict[str, Any]) -> None: +async def _update_deep_research_status(session_id: str, event: Dict[str, Any]) -> None: event_type = event.get('type') if event_type == 'status': status = event.get('status') @@ -98,8 +99,7 @@ async def _update_deep_research_status(session_id: str, session_manager.update_session(session_id, {'status': 'error'}) -async def _send_deep_research_event(session_id: str, event: Dict[str, - Any]) -> None: +async def _send_deep_research_event(session_id: str, event: Dict[str, Any]) -> None: event_type = str(event.get('type') or '') stored_event = event if event_type.startswith('dr.'): @@ -127,8 +127,7 @@ async def websocket_session(websocket: WebSocket, session_id: str): print(f'[WS] Client disconnected from session: {session_id}') connection_manager.disconnect(websocket, session_id) session = session_manager.get_session(session_id) - is_deep_research = bool( - session and session.get('project_id') == 'deep_research_v2') + is_deep_research = bool(session and session.get('project_id') == 'deep_research_v2') # Stop agent if running if session_id in agent_runners: await agent_runners[session_id].stop() @@ -153,8 +152,7 @@ async def websocket_logs(websocket: WebSocket): connection_manager.disconnect(websocket) -async def handle_session_message(session_id: str, data: Dict[str, Any], - websocket: WebSocket): +async def handle_session_message(session_id: str, data: Dict[str, Any], websocket: WebSocket): """Handle incoming WebSocket messages""" action = data.get('action') @@ -168,18 +166,14 @@ async def handle_session_message(session_id: str, data: Dict[str, Any], await send_status(session_id, websocket) -async def start_agent(session_id: str, data: Dict[str, Any], - websocket: WebSocket): +async def start_agent(session_id: str, data: Dict[str, Any], websocket: WebSocket): """Start an agent for a session""" print(f'[Agent] Starting agent for session: {session_id}') session = session_manager.get_session(session_id) if not session: print(f'[Agent] ERROR: Session not found: {session_id}') - await websocket.send_json({ - 'type': 'error', - 'message': 'Session not found' - }) + await websocket.send_json({'type': 'error', 'message': 'Session not found'}) return session_type = session.get('session_type', 'project') @@ -188,6 +182,7 @@ async def start_agent(session_id: str, data: Dict[str, Any], if session_type == 'chat': # Create a virtual project for chat mode using the default agent.yaml import ms_agent + # Get ms_agent package installation path # Use __path__ which is always available for packages and gives real filesystem paths if hasattr(ms_agent, '__path__') and ms_agent.__path__: @@ -195,8 +190,7 @@ async def start_agent(session_id: str, data: Dict[str, Any], elif ms_agent.__file__ is not None: ms_agent_package_path = Path(ms_agent.__file__).parent else: - raise RuntimeError('Cannot determine ms_agent package path. ' - 'Please ensure ms_agent is properly installed.') + raise RuntimeError('Cannot determine ms_agent package path. Please ensure ms_agent is properly installed.') chat_config_path = ms_agent_package_path / 'agent' / 'agent.yaml' project = { @@ -208,17 +202,14 @@ async def start_agent(session_id: str, data: Dict[str, Any], 'path': str(ms_agent_package_path / 'agent'), 'config_file': str(chat_config_path), 'has_readme': False, - 'supports_workflow_switch': False + 'supports_workflow_switch': False, } else: # For project mode, get the project project = project_discovery.get_project(session['project_id']) if not project: print(f"[Agent] ERROR: Project not found: {session['project_id']}") - await websocket.send_json({ - 'type': 'error', - 'message': 'Project not found' - }) + await websocket.send_json({'type': 'error', 'message': 'Project not found'}) return # Clean up output directory for code_genesis before starting @@ -227,30 +218,32 @@ async def start_agent(session_id: str, data: Dict[str, Any], if os.path.exists(output_dir): try: import shutil + shutil.rmtree(output_dir) print(f'[Agent] Cleaned up output directory: {output_dir}') await connection_manager.send_to_session( - session_id, { + session_id, + { 'type': 'log', 'level': 'info', 'message': 'Cleaned up previous output directory', - 'timestamp': datetime.now().isoformat() - }) - except Exception as e: - print( - f'[Agent] WARNING: Failed to clean output directory: {e}' + 'timestamp': datetime.now().isoformat(), + }, ) + except Exception as e: + print(f'[Agent] WARNING: Failed to clean output directory: {e}') # Don't fail if cleanup fails, just log it # Get workflow_type from session (default to 'standard') workflow_type = session.get('workflow_type', 'standard') - print(f"[Agent] Project: {project['id']}, type: {project['type']}, " - f"config: {project['config_file']}, workflow_type: {workflow_type}") + print( + f"[Agent] Project: {project['id']}, type: {project['type']}, " + f"config: {project['config_file']}, workflow_type: {workflow_type}" + ) query = data.get('query', '') - print(f'[Agent] Query: {query[:100]}...' - if len(query) > 100 else f'[Agent] Query: {query}') + print(f'[Agent] Query: {query[:100]}...' if len(query) > 100 else f'[Agent] Query: {query}') # Add user message to session (but don't broadcast - frontend already has it) session_manager.add_message(session_id, 'user', query, 'text') @@ -269,16 +262,11 @@ async def start_agent(session_id: str, data: Dict[str, Any], deep_research_config=config_manager.get_deep_research_config(), ) session_manager.update_session(session_id, {'status': 'running'}) - await connection_manager.send_to_session(session_id, { - 'type': 'status', - 'status': 'running' - }) + await connection_manager.send_to_session(session_id, {'type': 'status', 'status': 'running'}) except Exception as e: await connection_manager.send_to_session( - session_id, { - 'type': 'error', - 'message': f'Worker 启动失败: {str(e)}' - }) + session_id, {'type': 'error', 'message': f'Worker 启动失败: {str(e)}'} + ) session_manager.update_session(session_id, {'status': 'error'}) return @@ -287,25 +275,19 @@ async def start_agent(session_id: str, data: Dict[str, Any], session_id=session_id, project=project, config_manager=config_manager, - on_output=lambda msg: asyncio.create_task( - on_agent_output(session_id, msg)), + on_output=lambda msg: asyncio.create_task(on_agent_output(session_id, msg)), on_log=lambda log: asyncio.create_task(on_agent_log(session_id, log)), - on_progress=lambda prog: asyncio.create_task( - on_agent_progress(session_id, prog)), - on_complete=lambda result: asyncio.create_task( - on_agent_complete(session_id, result)), - on_error=lambda err: asyncio.create_task( - on_agent_error(session_id, err)), - workflow_type=workflow_type) + on_progress=lambda prog: asyncio.create_task(on_agent_progress(session_id, prog)), + on_complete=lambda result: asyncio.create_task(on_agent_complete(session_id, result)), + on_error=lambda err: asyncio.create_task(on_agent_error(session_id, err)), + workflow_type=workflow_type, + ) agent_runners[session_id] = runner session_manager.update_session(session_id, {'status': 'running'}) # Notify session started - await connection_manager.send_to_session(session_id, { - 'type': 'status', - 'status': 'running' - }) + await connection_manager.send_to_session(session_id, {'type': 'status', 'status': 'running'}) # Start agent in background so the WS loop can still receive stop/input messages task = asyncio.create_task(runner.start(query)) @@ -328,10 +310,7 @@ async def stop_agent(session_id: str): del agent_tasks[session_id] session_manager.update_session(session_id, {'status': 'stopped'}) - await connection_manager.send_to_session(session_id, { - 'type': 'status', - 'status': 'stopped' - }) + await connection_manager.send_to_session(session_id, {'type': 'status', 'status': 'stopped'}) async def send_input(session_id: str, data: Dict[str, Any]): @@ -339,13 +318,15 @@ async def send_input(session_id: str, data: Dict[str, Any]): if session_id not in agent_runners: print(f'[WS] ERROR: Agent runner not found for session: {session_id}') await connection_manager.send_to_session( - session_id, { - 'type': - 'error', - 'message': - ('Agent is not running. The workflow may have completed. ' - 'Please start a new conversation or restart the agent.') - }) + session_id, + { + 'type': 'error', + 'message': ( + 'Agent is not running. The workflow may have completed. ' + 'Please start a new conversation or restart the agent.' + ), + }, + ) return input_text = data.get('input', '') @@ -354,26 +335,21 @@ async def send_input(session_id: str, data: Dict[str, Any]): # Check if process is still alive runner = agent_runners[session_id] if runner.process and runner.process.returncode is not None: - print( - f'[WS] ERROR: Process has exited with code {runner.process.returncode}' - ) + print(f'[WS] ERROR: Process has exited with code {runner.process.returncode}') await connection_manager.send_to_session( - session_id, { - 'type': - 'error', - 'message': - 'Agent process has terminated. The workflow completed. Please start a new conversation to continue.' - }) + session_id, + { + 'type': 'error', + 'message': 'Agent process has terminated. The workflow completed. Please start a new conversation to continue.', + }, + ) # Clean up the runner del agent_runners[session_id] return # Update session status to running session_manager.update_session(session_id, {'status': 'running'}) - await connection_manager.send_to_session(session_id, { - 'type': 'status', - 'status': 'running' - }) + await connection_manager.send_to_session(session_id, {'type': 'status', 'status': 'running'}) # Add user message to session session_manager.add_message(session_id, 'user', input_text, 'text') @@ -384,26 +360,18 @@ async def send_input(session_id: str, data: Dict[str, Any]): except Exception as e: print(f'[WS] ERROR: Failed to send input: {e}') await connection_manager.send_to_session( - session_id, { - 'type': - 'error', - 'message': - f'Failed to send input: {str(e)}. The process may have terminated.' - }) + session_id, + {'type': 'error', 'message': f'Failed to send input: {str(e)}. The process may have terminated.'}, + ) async def send_status(session_id: str, websocket: WebSocket): """Send current status to a client""" session = session_manager.get_session(session_id) if session: - await websocket.send_json({ - 'type': - 'status', - 'session': - session, - 'messages': - session_manager.get_messages(session_id) - }) + await websocket.send_json( + {'type': 'status', 'session': session, 'messages': session_manager.get_messages(session_id)} + ) async def on_agent_output(session_id: str, message: Dict[str, Any]): @@ -415,32 +383,27 @@ async def on_agent_output(session_id: str, message: Dict[str, Any]): if msg_type == 'stream': # Streaming update await connection_manager.send_to_session( - session_id, { - 'type': 'stream', - 'content': content, - 'done': message.get('done', False) - }) + session_id, {'type': 'stream', 'content': content, 'done': message.get('done', False)} + ) if message.get('done'): session_manager.add_message(session_id, role, content, 'text') else: - session_manager.add_message(session_id, role, content, msg_type, - message.get('metadata')) + session_manager.add_message(session_id, role, content, msg_type, message.get('metadata')) await connection_manager.send_to_session( - session_id, { + session_id, + { 'type': 'message', 'role': role, 'content': content, 'message_type': msg_type, - 'metadata': message.get('metadata') - }) + 'metadata': message.get('metadata'), + }, + ) async def on_agent_log(session_id: str, log: Dict[str, Any]): """Handle agent log""" - await connection_manager.send_to_session(session_id, { - 'type': 'log', - **log - }) + await connection_manager.send_to_session(session_id, {'type': 'log', **log}) await connection_manager.broadcast_log({'session_id': session_id, **log}) @@ -450,15 +413,11 @@ async def on_agent_progress(session_id: str, progress: Dict[str, Any]): if progress_type == 'workflow': session_manager.set_workflow_progress(session_id, progress) - session_manager.set_current_step(session_id, - progress.get('current_step')) + session_manager.set_current_step(session_id, progress.get('current_step')) elif progress_type == 'file': session_manager.set_file_progress(session_id, progress) - await connection_manager.send_to_session(session_id, { - 'type': 'progress', - **progress - }) + await connection_manager.send_to_session(session_id, {'type': 'progress', **progress}) async def on_agent_complete(session_id: str, result: Dict[str, Any]): @@ -471,17 +430,13 @@ async def on_agent_complete(session_id: str, result: Dict[str, Any]): agent_tasks[session_id].cancel() del agent_tasks[session_id] - await connection_manager.send_to_session(session_id, { - 'type': 'complete', - 'result': result - }) + await connection_manager.send_to_session(session_id, {'type': 'complete', 'result': result}) async def on_agent_error(session_id: str, error: Dict[str, Any]): """Handle agent error""" session_manager.update_session(session_id, {'status': 'error'}) - session_manager.add_message(session_id, 'system', - error.get('message', 'Unknown error'), 'error') + session_manager.add_message(session_id, 'system', error.get('message', 'Unknown error'), 'error') if session_id in agent_runners: del agent_runners[session_id] @@ -489,7 +444,4 @@ async def on_agent_error(session_id: str, error: Dict[str, Any]): agent_tasks[session_id].cancel() del agent_tasks[session_id] - await connection_manager.send_to_session(session_id, { - 'type': 'error', - **error - }) + await connection_manager.send_to_session(session_id, {'type': 'error', **error}) From 2d966115d2b4746c1e472333893d172d52f21de0 Mon Sep 17 00:00:00 2001 From: suluyan Date: Tue, 28 Apr 2026 15:22:28 +0800 Subject: [PATCH 39/40] Revert "fix lint" This reverts commit 67668e9c9a8af3f299bf606d3a34cff461bcfe79. --- .gitignore | 8 +- .pre-commit-config.yaml | 26 +- .../nanobot_integration/test_mcp_tools.py | 40 +- ms-agent-skills/scripts/check_ms_agent.py | 69 +- ms_agent/agent/base.py | 44 +- ms_agent/agent/code_agent.py | 15 +- ms_agent/agent/llm_agent.py | 272 ++-- ms_agent/agent/loader.py | 60 +- ms_agent/agent/runtime.py | 1 + ms_agent/app/doc_research.py | 2 +- ms_agent/app/fin_research.py | 6 +- ms_agent/callbacks/base.py | 10 +- ms_agent/callbacks/input_callback.py | 3 +- ms_agent/capabilities/__init__.py | 9 +- ms_agent/capabilities/async_task.py | 17 +- ms_agent/capabilities/mcp_server.py | 41 +- ms_agent/capabilities/registry.py | 12 +- .../capabilities/wrappers/agent_delegate.py | 101 +- .../capabilities/wrappers/deep_research.py | 152 ++- ms_agent/capabilities/wrappers/filesystem.py | 96 +- .../capabilities/wrappers/lsp_code_server.py | 83 +- ms_agent/capabilities/wrappers/web_search.py | 68 +- ms_agent/cli/app.py | 37 +- ms_agent/cli/cli.py | 7 +- ms_agent/cli/run.py | 110 +- ms_agent/cli/ui.py | 78 +- ms_agent/config/config.py | 89 +- ms_agent/config/env.py | 4 +- ms_agent/llm/anthropic_llm.py | 164 +-- ms_agent/llm/dashscope_llm.py | 17 +- ms_agent/llm/deepseek_llm.py | 51 +- ms_agent/llm/llm.py | 20 +- ms_agent/llm/modelscope_llm.py | 16 +- ms_agent/llm/openai.py | 158 ++- ms_agent/llm/openai_llm.py | 395 +++--- ms_agent/llm/utils.py | 26 +- ms_agent/memory/base.py | 6 +- ms_agent/memory/condenser/code_condenser.py | 36 +- .../memory/condenser/context_compressor.py | 35 +- ms_agent/memory/condenser/refine_condenser.py | 60 +- ms_agent/memory/default_memory.py | 320 +++-- ms_agent/memory/diversity.py | 27 +- ms_agent/memory/memory_manager.py | 13 +- ms_agent/memory/utils.py | 4 +- ms_agent/prompting/file_resolver.py | 29 +- ms_agent/rag/base.py | 6 +- ms_agent/rag/extraction.py | 4 +- ms_agent/rag/extraction_manager.py | 67 +- ms_agent/rag/llama_index_rag.py | 157 +-- ms_agent/rag/schema.py | 1 - ms_agent/retriever/hybrid_retriever.py | 76 +- ms_agent/sandbox/sandbox.py | 91 +- ms_agent/skill/auto_skills.py | 8 +- ms_agent/skill/container.py | 418 ++++--- ms_agent/skill/loader.py | 22 +- ms_agent/skill/schema.py | 130 +- ms_agent/skill/spec.py | 14 +- ms_agent/tools/agent_tool.py | 350 +++--- ms_agent/tools/audio_generator/audio_gen.py | 21 +- ms_agent/tools/audio_generator/edge_tts.py | 15 +- ms_agent/tools/base.py | 15 +- ms_agent/tools/code/code_executor.py | 490 +++++--- ms_agent/tools/code/local_code_executor.py | 462 ++++--- ms_agent/tools/code/sandbox_manager.py | 43 +- ms_agent/tools/code_server/lsp_code_server.py | 408 +++--- ms_agent/tools/docling/chunker.py | 40 +- ms_agent/tools/docling/doc_loader.py | 17 +- ms_agent/tools/docling/doc_postprocess.py | 4 +- ms_agent/tools/docling/patches.py | 72 +- ms_agent/tools/fetch_playwright_fallback.py | 24 +- ms_agent/tools/filesystem_tool.py | 253 ++-- ms_agent/tools/findata/__init__.py | 3 +- ms_agent/tools/findata/akshare_source.py | 296 +++-- ms_agent/tools/findata/baostock_source.py | 173 ++- ms_agent/tools/findata/data_source_base.py | 25 +- ms_agent/tools/findata/findata_fetcher.py | 672 ++++++---- ms_agent/tools/findata/hybrid_source.py | 83 +- .../tools/image_generator/ds_image_gen.py | 33 +- .../tools/image_generator/google_image_gen.py | 7 +- ms_agent/tools/image_generator/image_gen.py | 32 +- .../tools/image_generator/ms_image_gen.py | 49 +- ms_agent/tools/jina_reader.py | 75 +- ms_agent/tools/mcp_client.py | 128 +- ms_agent/tools/mineru/pdf_parser.py | 33 +- ms_agent/tools/search/arxiv/__init__.py | 3 +- ms_agent/tools/search/arxiv/schema.py | 110 +- ms_agent/tools/search/arxiv/search.py | 109 +- ms_agent/tools/search/content_optimizer.py | 108 +- ms_agent/tools/search/exa/schema.py | 53 +- ms_agent/tools/search/exa/search.py | 103 +- ms_agent/tools/search/localsearch_tool.py | 76 +- ms_agent/tools/search/search_base.py | 37 +- ms_agent/tools/search/search_request.py | 2 +- ms_agent/tools/search/serpapi/__init__.py | 3 +- ms_agent/tools/search/serpapi/schema.py | 31 +- ms_agent/tools/search/serpapi/search.py | 50 +- ms_agent/tools/search/sirchmunk_search.py | 195 +-- ms_agent/tools/search/tavily/fetcher.py | 4 +- ms_agent/tools/search/tavily/http.py | 4 +- ms_agent/tools/search/tavily/schema.py | 32 +- ms_agent/tools/search/tavily/search.py | 59 +- ms_agent/tools/search/web_search_spill.py | 133 +- ms_agent/tools/search/websearch_tool.py | 755 ++++++----- ms_agent/tools/search_engine.py | 37 +- ms_agent/tools/task_control_tool.py | 26 +- ms_agent/tools/todolist_tool.py | 167 ++- ms_agent/tools/tool_manager.py | 98 +- .../tools/video_generator/ds_video_gen.py | 46 +- ms_agent/tools/video_generator/video_gen.py | 30 +- ms_agent/utils/__init__.py | 3 +- ms_agent/utils/artifact_manager.py | 46 +- ms_agent/utils/constants.py | 8 +- ms_agent/utils/llm_utils.py | 32 +- ms_agent/utils/logger.py | 6 +- ms_agent/utils/parser_utils.py | 205 +-- ms_agent/utils/patcher.py | 7 +- ms_agent/utils/push_to_hub.py | 193 +-- ms_agent/utils/rate_limiter.py | 83 +- ms_agent/utils/snapshot.py | 49 +- ms_agent/utils/stats.py | 32 +- ms_agent/utils/stream_writer.py | 77 +- ms_agent/utils/task_manager.py | 4 +- ms_agent/utils/thread_util.py | 37 +- ms_agent/utils/tokenizer_util.py | 8 +- ms_agent/utils/utils.py | 109 +- ms_agent/utils/workspace_policy.py | 30 +- ms_agent/workflow/base.py | 17 +- ms_agent/workflow/chain_workflow.py | 10 +- ms_agent/workflow/dag_workflow.py | 18 +- ms_agent/workflow/deep_research/__init__.py | 12 +- ms_agent/workflow/deep_research/principle.py | 39 +- .../workflow/deep_research/research_utils.py | 17 +- .../deep_research/research_workflow.py | 4 +- .../deep_research/research_workflow_beta.py | 33 +- ms_agent/workflow/loader.py | 25 +- .../code_genesis/tools/build_sandbox_image.py | 97 -- .../code_genesis/tools/build_sandbox_image.sh | 63 +- projects/code_genesis/workflow/api_search.py | 35 +- projects/code_genesis/workflow/architect.py | 1 + projects/code_genesis/workflow/coding.py | 222 ++-- projects/code_genesis/workflow/file_design.py | 16 +- projects/code_genesis/workflow/file_order.py | 14 +- projects/code_genesis/workflow/install.py | 4 +- projects/code_genesis/workflow/refine.py | 55 +- projects/code_genesis/workflow/user_story.py | 1 + projects/deep_research/run.py | 58 +- .../v2/callbacks/quality_checker.py | 143 ++- .../v2/callbacks/reporter_callback.py | 12 +- .../v2/callbacks/researcher_callback.py | 6 +- .../v2/callbacks/searcher_callback.py | 104 +- .../deep_research/v2/eval/dr_bench_runner.py | 190 +-- projects/deep_research/v2/reporter.py | 21 +- projects/deep_research/v2/researcher.py | 22 +- projects/deep_research/v2/time_handler.py | 12 +- .../deep_research/v2/tools/evidence_tool.py | 667 +++++----- .../deep_research/v2/tools/report_tool.py | 799 +++++++----- projects/fin_research/aggregator.py | 51 +- .../callbacks/aggregator_callback.py | 9 +- .../callbacks/analyst_callback.py | 61 +- .../callbacks/collector_callback.py | 46 +- .../fin_research/callbacks/file_parser.py | 4 +- .../callbacks/orchestrator_callback.py | 9 +- projects/fin_research/searcher.py | 67 +- projects/fin_research/time_handler.py | 12 +- .../fin_research/tools/principle_skill.py | 162 ++- projects/fin_research/tools/spec_loader.py | 278 +++-- .../singularity_cinema/compose_video/agent.py | 258 ++-- .../create_background/agent.py | 35 +- .../generate_animation/agent.py | 18 +- .../generate_animation/generate_manim_code.py | 45 +- .../generate_remotion_code.py | 64 +- .../generate_audio/agent.py | 27 +- .../generate_illustration_prompts/agent.py | 73 +- .../generate_images/agent.py | 99 +- .../generate_script/agent.py | 10 +- .../generate_subtitle/agent.py | 89 +- .../generate_video/agent.py | 43 +- .../generate_video_prompts/agent.py | 60 +- .../singularity_cinema/parse_images/agent.py | 48 +- .../render_animation/agent.py | 18 +- .../render_animation/render_manim.py | 336 ++--- .../render_animation/render_remotion.py | 318 +++-- projects/singularity_cinema/segment/agent.py | 45 +- setup.py | 66 +- shell-grep-glob-workspace-policy.md | 225 ++++ webui/backend/agent_runner.py | 1107 ++++++++++------- webui/backend/api.py | 175 ++- webui/backend/config_manager.py | 63 +- webui/backend/deep_research_eventizer.py | 104 +- webui/backend/deep_research_worker.py | 162 +-- webui/backend/deep_research_worker_manager.py | 84 +- webui/backend/main.py | 24 +- webui/backend/project_discovery.py | 49 +- webui/backend/session_manager.py | 40 +- webui/backend/shared.py | 4 +- webui/backend/websocket_handler.py | 194 +-- 196 files changed, 10843 insertions(+), 7310 deletions(-) delete mode 100644 projects/code_genesis/tools/build_sandbox_image.py create mode 100644 shell-grep-glob-workspace-policy.md diff --git a/.gitignore b/.gitignore index 2cef2d174..30dfa8d1f 100644 --- a/.gitignore +++ b/.gitignore @@ -32,13 +32,7 @@ wheels/ /temp **/tmp/ .env* -.claude* -# Local Colima/Lima state when using CODE_GENESIS_COLIMA_IN_REPO=1 -.colima/ -.xdg-cache/ -.xdg-config/ -.xdg-data/ -scripts/colima_proxy.local.env +.claude-trace/ /apps/agentfabric/tmp/ MANIFEST diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4aaa131d7..00657312b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,21 +2,23 @@ repos: - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.6 - hooks: - - id: ruff-format - exclude: ^(thirdparty/|examples/|projects/fin_research/examples/|projects/agent_skills/|tests/) - - id: ruff - args: [--fix, --select, I] - exclude: ^(thirdparty/|examples/|projects/fin_research/examples/|projects/agent_skills/|tests/) - - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 + - repo: https://github.com/pycqa/flake8.git + rev: 4.0.0 hooks: - id: flake8 exclude: ^(thirdparty/|examples/|tests/|projects/agent_skills/|projects/fin_research/examples/|ms_agent/utils/prompts\.py) - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + - repo: https://github.com/PyCQA/isort.git + rev: 4.3.21 + hooks: + - id: isort + exclude: ^(examples/|projects/fin_research/examples/|projects/agent_skills/|tests/) + - repo: https://github.com/pre-commit/mirrors-yapf.git + rev: v0.30.0 + hooks: + - id: yapf + exclude: ^(thirdparty/|examples/|projects/fin_research/examples/|projects/agent_skills/|tests/) + - repo: https://github.com/pre-commit/pre-commit-hooks.git + rev: v3.1.0 hooks: - id: trailing-whitespace exclude: ^(thirdparty/|tests/|projects/fin_research/examples/|projects/agent_skills/) diff --git a/examples/nanobot_integration/test_mcp_tools.py b/examples/nanobot_integration/test_mcp_tools.py index c59313b30..65c9cad7d 100644 --- a/examples/nanobot_integration/test_mcp_tools.py +++ b/examples/nanobot_integration/test_mcp_tools.py @@ -37,9 +37,9 @@ async def connect(stack): async def list_tools(session): """List all available MCP tools.""" result = await session.list_tools() - print(f'\n{'=' * 60}') + print(f'\n{"=" * 60}') print(f' Available MCP Tools ({len(result.tools)} total)') - print(f'{'=' * 60}\n') + print(f'{"=" * 60}\n') for tool in result.tools: desc = (tool.description or '')[:70] print(f' {tool.name:35s} {desc}') @@ -72,9 +72,9 @@ async def call_tool(session, name: str, args: dict): async def test_filesystem(session): """Test filesystem tools: write a file, then replace contents.""" - print(f'\n{'=' * 60}') + print(f'\n{"=" * 60}') print(' TEST: Filesystem Tools') - print(f'{'=' * 60}') + print(f'{"=" * 60}') with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: f.write('def hello():\n print("Hello, World!")\n\nhello()\n') @@ -117,9 +117,9 @@ async def test_filesystem(session): async def test_deep_research_async(session): """Test async deep research tools (submit/check/get without actually running).""" - print(f'\n{'=' * 60}') + print(f'\n{"=" * 60}') print(' TEST: Deep Research Async Tools') - print(f'{'=' * 60}') + print(f'{"=" * 60}') # submit_research_task should return immediately with a task_id # (it will fail to find the config in CI, but the response format is testable) @@ -128,7 +128,7 @@ async def test_deep_research_async(session): }) if 'error' in result: - print(f'\n submit_research_task returned error (expected if no config): {result['error']}') + print(f'\n submit_research_task returned error (expected if no config): {result["error"]}') print(' Testing check/get with a fake task_id...') result = await call_tool(session, 'check_research_progress', { @@ -151,21 +151,21 @@ async def test_deep_research_async(session): progress = await call_tool(session, 'check_research_progress', { 'task_id': task_id, }) - print(f' Progress check: status={progress['status']}') + print(f' Progress check: status={progress["status"]}') report = await call_tool(session, 'get_research_report', { 'task_id': task_id, }) - print(f' Report check: status={report['status']}') + print(f' Report check: status={report["status"]}') print('\n DEEP RESEARCH ASYNC TESTS: PASSED') async def test_web_search(session): """Test web_search tool with arxiv (no API key required).""" - print(f'\n{'=' * 60}') + print(f'\n{"=" * 60}') print(' TEST: Web Search') - print(f'{'=' * 60}') + print(f'{"=" * 60}') result = await call_tool(session, 'web_search', { 'query': 'large language model agent framework', @@ -174,24 +174,24 @@ async def test_web_search(session): }) if 'error' in result: - print(f'\n web_search returned error: {result['error']}') + print(f'\n web_search returned error: {result["error"]}') print(' (This may happen if arxiv is unreachable)') else: - assert result['status'] == 'ok', f'Unexpected status: {result['status']}' + assert result['status'] == 'ok', f'Unexpected status: {result["status"]}' assert result['engine'] == 'arxiv' - print(f'\n Returned {result['count']} results:') + print(f'\n Returned {result["count"]} results:') for i, r in enumerate(result.get('results', []), 1): - print(f' {i}. {r.get('title', 'No title')[:60]}') - print(f' {r.get('url', '')}') + print(f' {i}. {r.get("title", "No title")[:60]}') + print(f' {r.get("url", "")}') print('\n WEB SEARCH TESTS: PASSED') async def test_agent_delegate(session): """Test agent delegate tools (async pattern only, to avoid blocking).""" - print(f'\n{'=' * 60}') + print(f'\n{"=" * 60}') print(' TEST: Agent Delegate (Async)') - print(f'{'=' * 60}') + print(f'{"=" * 60}') # Test check/get/cancel with unknown task_id (safe, no LLM needed) result = await call_tool(session, 'check_agent_task', { @@ -251,9 +251,9 @@ async def main(): if args.test in ('ad', 'all'): await test_agent_delegate(session) - print(f'\n{'=' * 60}') + print(f'\n{"=" * 60}') print(' ALL TESTS PASSED') - print(f'{'=' * 60}\n') + print(f'{"=" * 60}\n') if __name__ == '__main__': diff --git a/ms-agent-skills/scripts/check_ms_agent.py b/ms-agent-skills/scripts/check_ms_agent.py index 93aea705c..64e4668fa 100644 --- a/ms-agent-skills/scripts/check_ms_agent.py +++ b/ms-agent-skills/scripts/check_ms_agent.py @@ -1,14 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json import subprocess import sys +import json + def check_import() -> dict: """Check that ms_agent is importable.""" try: import ms_agent # noqa: F401 - version = getattr(ms_agent, '__version__', 'unknown') return {'importable': True, 'version': version} except ImportError as e: @@ -23,28 +23,29 @@ def check_capabilities() -> dict: """ try: from ms_agent.capabilities import create_registry - registry = create_registry() caps = registry.list_all() return { - 'registry_ok': True, - 'count': len(caps), - 'capabilities': [ - { - 'name': c.name, - 'granularity': c.granularity, - 'summary': c.summary, - 'tags': c.tags, - } - for c in caps - ], + 'registry_ok': + True, + 'count': + len(caps), + 'capabilities': [{ + 'name': c.name, + 'granularity': c.granularity, + 'summary': c.summary, + 'tags': c.tags, + } for c in caps], } except ImportError: # ms_agent may not be on sys.path (e.g. dev mode without pip install). # Fall back to subprocess check which uses ``-m`` resolution. try: result = subprocess.run( - [sys.executable, '-m', 'ms_agent.capabilities.mcp_server', '--check'], + [ + sys.executable, '-m', 'ms_agent.capabilities.mcp_server', + '--check' + ], capture_output=True, text=True, timeout=30, @@ -55,11 +56,15 @@ def check_capabilities() -> dict: 'registry_ok': True, 'count': len(data.get('capabilities', [])), 'capabilities': data.get('capabilities', []), - 'note': 'verified via subprocess (package not on sys.path)', + 'note': + 'verified via subprocess (package not on sys.path)', } except Exception: pass - return {'registry_ok': False, 'error': 'ms_agent.capabilities not importable'} + return { + 'registry_ok': False, + 'error': 'ms_agent.capabilities not importable' + } except Exception as e: return {'registry_ok': False, 'error': str(e)} @@ -68,7 +73,10 @@ def check_mcp_server() -> dict: """Check that the MCP server can start in --check mode.""" try: result = subprocess.run( - [sys.executable, '-m', 'ms_agent.capabilities.mcp_server', '--check'], + [ + sys.executable, '-m', 'ms_agent.capabilities.mcp_server', + '--check' + ], capture_output=True, text=True, timeout=30, @@ -79,7 +87,10 @@ def check_mcp_server() -> dict: else: return {'mcp_server_ok': False, 'error': result.stderr.strip()} except subprocess.TimeoutExpired: - return {'mcp_server_ok': False, 'error': 'MCP server --check timed out'} + return { + 'mcp_server_ok': False, + 'error': 'MCP server --check timed out' + } except Exception as e: return {'mcp_server_ok': False, 'error': str(e)} @@ -88,7 +99,6 @@ def check_mcp_package() -> dict: """Check that the mcp Python package is installed.""" try: import mcp # noqa: F401 - version = getattr(mcp, '__version__', 'unknown') return {'installed': True, 'version': version} except ImportError: @@ -109,8 +119,7 @@ def main() -> None: all_ok = ( report['ms_agent'].get('importable', False) and report['capabilities'].get('registry_ok', False) - and report['mcp_server'].get('mcp_server_ok', False) - ) + and report['mcp_server'].get('mcp_server_ok', False)) report['overall_status'] = 'ok' if all_ok else 'issues_found' print(json.dumps(report, indent=2, ensure_ascii=False)) @@ -118,13 +127,21 @@ def main() -> None: if not all_ok: print('\n--- Issues ---', file=sys.stderr) if not report['ms_agent'].get('importable'): - print(' ms-agent is not installed. Run: pip install ms-agent', file=sys.stderr) + print( + ' ms-agent is not installed. Run: pip install ms-agent', + file=sys.stderr) if not report['mcp_package'].get('installed'): - print(' mcp package is not installed. Run: pip install mcp', file=sys.stderr) + print( + ' mcp package is not installed. Run: pip install mcp', + file=sys.stderr) if not report['capabilities'].get('registry_ok'): - print(f" Registry error: {report['capabilities'].get('error')}", file=sys.stderr) + print( + f" Registry error: {report['capabilities'].get('error')}", + file=sys.stderr) if not report['mcp_server'].get('mcp_server_ok'): - print(f" MCP server error: {report['mcp_server'].get('error')}", file=sys.stderr) + print( + f" MCP server error: {report['mcp_server'].get('error')}", + file=sys.stderr) sys.exit(1) diff --git a/ms_agent/agent/base.py b/ms_agent/agent/base.py index 7908a53dd..cb78d5ce2 100644 --- a/ms_agent/agent/base.py +++ b/ms_agent/agent/base.py @@ -3,11 +3,10 @@ from abc import ABC, abstractmethod from typing import Any, AsyncGenerator, List, Tuple, Union -from omegaconf import DictConfig - from ms_agent.llm import Message from ms_agent.utils import read_history, save_history from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_RETRY_COUNT +from omegaconf import DictConfig class Agent(ABC): @@ -19,31 +18,36 @@ class Agent(ABC): retry_count = int(os.environ.get('AGENT_RETRY_COUNT', DEFAULT_RETRY_COUNT)) - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): - """ - Base class for all agents. Provides core functionality such as configuration loading, - lifecycle handling via external code, and defining the interface for agent execution. - - The agent can be initialized either with a config object directly or by loading from a config directory or ID. - If external code (e.g., custom handlers) is involved, the agent must be explicitly trusted via - `trust_remote_code=True`. - - Base class for all agents. Make sure your custom agents are derived from this class. - Args: - config (DictConfig): Pre-loaded configuration object. - tag (str): A custom tag for identifying this agent run. - trust_remote_code (bool): Whether to allow loading of external code (e.g., custom handler modules). + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): """ + Base class for all agents. Provides core functionality such as configuration loading, + lifecycle handling via external code, and defining the interface for agent execution. + + The agent can be initialized either with a config object directly or by loading from a config directory or ID. + If external code (e.g., custom handlers) is involved, the agent must be explicitly trusted via + `trust_remote_code=True`. + + Base class for all agents. Make sure your custom agents are derived from this class. + Args: + config (DictConfig): Pre-loaded configuration object. + tag (str): A custom tag for identifying this agent run. + trust_remote_code (bool): Whether to allow loading of external code (e.g., custom handler modules). + """ self.config = config self.tag = tag self.trust_remote_code = trust_remote_code self.config.tag = tag self.config.trust_remote_code = trust_remote_code - self.output_dir = getattr(self.config, 'output_dir', DEFAULT_OUTPUT_DIR) + self.output_dir = getattr(self.config, 'output_dir', + DEFAULT_OUTPUT_DIR) @abstractmethod async def run( - self, inputs: Union[str, List[Message]], **kwargs + self, inputs: Union[str, List[Message]], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: """ Main method to execute the agent. @@ -61,7 +65,8 @@ async def run( """ raise NotImplementedError() - def read_history(self, messages: Any, **kwargs) -> Tuple[DictConfig, List[Message]]: + def read_history(self, messages: Any, + **kwargs) -> Tuple[DictConfig, List[Message]]: return read_history(self.output_dir, self.tag) def save_history(self, messages: Any, **kwargs): @@ -72,7 +77,6 @@ def save_history(self, messages: Any, **kwargs): def list_snapshots(self) -> list: """Return snapshots for this agent's output_dir, most recent first.""" from ms_agent.utils.snapshot import list_snapshots - return list_snapshots(self.output_dir) def rollback(self, commit_hash: str) -> bool: diff --git a/ms_agent/agent/code_agent.py b/ms_agent/agent/code_agent.py index 2200712ec..33b2b63a4 100644 --- a/ms_agent/agent/code_agent.py +++ b/ms_agent/agent/code_agent.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import Any, List, Union -from omegaconf import DictConfig - from ms_agent.llm import Message +from omegaconf import DictConfig from .base import Agent @@ -13,11 +12,16 @@ class CodeAgent(Agent): AGENT_NAME = 'CodeAgent' - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.load_cache = kwargs.get('load_cache', False) - async def run(self, inputs: Union[str, List[Message]], **kwargs) -> List[Message]: + async def run(self, inputs: Union[str, List[Message]], + **kwargs) -> List[Message]: """Run the external code. Default implementation here does nothing. Args: @@ -38,5 +42,6 @@ async def run(self, inputs: Union[str, List[Message]], **kwargs) -> List[Message self.save_history(messages, **kwargs) return messages - async def execute_code(self, inputs: Union[str, List[Message]], **kwargs) -> List[Message]: + async def execute_code(self, inputs: Union[str, List[Message]], + **kwargs) -> List[Message]: return inputs diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 379f439c3..14adfa8b3 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -2,7 +2,6 @@ import asyncio import importlib import inspect -import json import os.path import sys import threading @@ -11,8 +10,7 @@ from copy import deepcopy from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union -from omegaconf import DictConfig, OmegaConf - +import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping from ms_agent.llm.llm import LLM @@ -23,10 +21,11 @@ from ms_agent.rag.utils import rag_mapping from ms_agent.tools import ToolManager from ms_agent.utils import async_retry, read_history, save_history +from ms_agent.utils.task_manager import TaskManager from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger from ms_agent.utils.snapshot import take_snapshot -from ms_agent.utils.task_manager import TaskManager +from omegaconf import DictConfig, OmegaConf from ..config.config import Config, ConfigLifecycleHandler from .base import Agent @@ -104,14 +103,17 @@ def resolve_enable_snapshots(config: Any) -> bool: like ``\"false\"`` coerced to boolean). """ if OmegaConf.is_config(config): - raw = OmegaConf.select(config, 'enable_snapshots', default=_MISSING_ENABLE_SNAPSHOTS) + raw = OmegaConf.select(config, 'enable_snapshots', + default=_MISSING_ENABLE_SNAPSHOTS) if raw is not _MISSING_ENABLE_SNAPSHOTS and raw is not None: return LLMAgent._coerce_enable_snapshots_value(raw) - sub = bool(OmegaConf.select(config, 'ms_agent_subagent', default=False)) + sub = bool( + OmegaConf.select(config, 'ms_agent_subagent', default=False)) return not sub if isinstance(config, dict): if 'enable_snapshots' in config and config['enable_snapshots'] is not None: - return LLMAgent._coerce_enable_snapshots_value(config['enable_snapshots']) + return LLMAgent._coerce_enable_snapshots_value( + config['enable_snapshots']) return not bool(config.get('ms_agent_subagent')) return True @@ -129,7 +131,8 @@ def __init__( **kwargs, ): if not hasattr(config, 'llm'): - default_yaml = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') + default_yaml = os.path.join( + os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') llm_config = Config.from_task(default_yaml) config = OmegaConf.merge(llm_config, config) super().__init__(config, tag, trust_remote_code) @@ -144,7 +147,8 @@ def __init__( self.load_cache = kwargs.get('load_cache', False) self.config.load_cache = self.load_cache self.mcp_server_file = kwargs.get('mcp_server_file', None) - self.mcp_config: Dict[str, Any] = self.parse_mcp_servers(kwargs.get('mcp_config', {})) + self.mcp_config: Dict[str, Any] = self.parse_mcp_servers( + kwargs.get('mcp_config', {})) self.mcp_client = kwargs.get('mcp_client', None) self.config_handler = self.register_config_handler() @@ -194,25 +198,31 @@ def _ensure_auto_skills(self) -> bool: from ms_agent.utils.docker_utils import is_docker_daemon_running if not is_docker_daemon_running(): - logger.warning('Docker not running, disabling sandbox for skills') + logger.warning( + 'Docker not running, disabling sandbox for skills') use_sandbox = False # Build retrieve args retrieve_args = {} if hasattr(skills_config, 'retrieve_args'): - retrieve_args = OmegaConf.to_container(skills_config.retrieve_args) + retrieve_args = OmegaConf.to_container( + skills_config.retrieve_args) self._auto_skills = AutoSkills( skills=skills_path, llm=self.llm, - enable_retrieve=getattr(skills_config, 'enable_retrieve', None), + enable_retrieve=getattr(skills_config, 'enable_retrieve', + None), retrieve_args=retrieve_args, - max_candidate_skills=getattr(skills_config, 'max_candidate_skills', 10), + max_candidate_skills=getattr(skills_config, + 'max_candidate_skills', 10), max_retries=getattr(skills_config, 'max_retries', 3), work_dir=getattr(skills_config, 'work_dir', None), use_sandbox=use_sandbox, ) - logger.info(f'AutoSkills initialized with {len(self._auto_skills.all_skills)} skills') + logger.info( + f'AutoSkills initialized with {len(self._auto_skills.all_skills)} skills' + ) self._auto_skills_initialized = True return True @@ -291,7 +301,9 @@ async def execute_skills(self, query: str, execution_input=None): return None skills_config = self._get_skills_config() - stop_on_failure = getattr(skills_config, 'stop_on_failure', True) if skills_config else True + stop_on_failure = ( + getattr(skills_config, 'stop_on_failure', True) + if skills_config else True) result = await self._auto_skills.run( query=query, @@ -315,7 +327,8 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: # Handle chat-only response if dag_result.chat_response: - messages.append(Message(role='assistant', content=dag_result.chat_response)) + messages.append( + Message(role='assistant', content=dag_result.chat_response)) return messages # Handle incomplete skills @@ -346,7 +359,9 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: if output.output_files: content += f'**Generated files:** {list(output.output_files.values())}\n\n' - content += f'Total execution time: {exec_result.total_duration_ms:.2f}ms' + content += ( + f'Total execution time: {exec_result.total_duration_ms:.2f}ms' + ) else: content = 'Skill execution completed with errors.\n\n' for skill_id, result in exec_result.results.items(): @@ -372,14 +387,14 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: def rollback(self, commit_hash: str) -> bool: """Restore output_dir to snapshot and truncate message history.""" from ms_agent.utils.snapshot import restore_snapshot - ok, message_count = restore_snapshot(self.output_dir, commit_hash) if not ok: return False # Truncate saved history to the message count at snapshot time _, saved_messages = read_history(self.output_dir, self.tag) if saved_messages and message_count < len(saved_messages): - save_history(self.output_dir, self.tag, self.config, saved_messages[:message_count]) + save_history(self.output_dir, self.tag, self.config, + saved_messages[:message_count]) # Clear read cache on FileSystemTool so stale entries don't block edits if self.tool_manager is not None: for tool in self.tool_manager.extra_tools: @@ -407,7 +422,8 @@ def parse_mcp_servers(self, mcp_config: Dict[str, Any]) -> Dict[str, Any]: Dict[str, Any]: Merged configuration including file-based overrides. """ mcp_config = mcp_config or {} - if self.mcp_server_file is not None and os.path.isfile(self.mcp_server_file): + if self.mcp_server_file is not None and os.path.isfile( + self.mcp_server_file): with open(self.mcp_server_file, 'r') as f: config = json.load(f) config.update(mcp_config) @@ -439,19 +455,27 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: f'[External Code]A Config Lifecycle handler ' f'registered in the config: {handler_file}. ' f'\nThis is external code, if you trust this workflow, ' - f'please specify `--trust_remote_code true`' - ) - assert local_dir is not None, 'Using external py files, but local_dir cannot be found.' + f'please specify `--trust_remote_code true`') + assert ( + local_dir is not None + ), 'Using external py files, but local_dir cannot be found.' if local_dir not in sys.path: sys.path.insert(0, local_dir) handler_module = importlib.import_module(handler_file) - module_classes = {name: cls for name, cls in inspect.getmembers(handler_module, inspect.isclass)} + module_classes = { + name: cls + for name, cls in inspect.getmembers(handler_module, + inspect.isclass) + } handler = None for name, handler_cls in module_classes.items(): - if handler_cls.__bases__[0] is ConfigLifecycleHandler and handler_cls.__module__ == handler_file: + if (handler_cls.__bases__[0] is ConfigLifecycleHandler + and handler_cls.__module__ == handler_file): handler = handler_cls() - assert handler is not None, f'Config Lifecycle handler class cannot be found in {handler_file}' + assert ( + handler is not None + ), f'Config Lifecycle handler class cannot be found in {handler_file}' return handler return None @@ -462,12 +486,15 @@ def register_callback_from_config(self): Raises: AssertionError: If untrusted external code is referenced without permission. """ - local_dir = self.config.local_dir if hasattr(self.config, 'local_dir') else None + local_dir = self.config.local_dir if hasattr(self.config, + 'local_dir') else None if hasattr(self.config, 'callbacks'): callbacks = self.config.callbacks or [] for _callback in callbacks: subdir = os.path.dirname(_callback) - assert local_dir is not None, 'Using external py files, but local_dir cannot be found.' + assert ( + local_dir is not None + ), 'Using external py files, but local_dir cannot be found.' if subdir: subdir = os.path.join(local_dir, str(subdir)) _callback = os.path.basename(_callback) @@ -485,19 +512,26 @@ def register_callback_from_config(self): if _callback.endswith('.py'): _callback = _callback[:-3] callback_file = importlib.import_module(_callback) - module_classes = {name: cls for name, cls in inspect.getmembers(callback_file, inspect.isclass)} + module_classes = { + name: cls + for name, cls in inspect.getmembers( + callback_file, inspect.isclass) + } for name, cls in module_classes.items(): # Find cls which base class is `Callback` - if issubclass(cls, Callback) and cls.__module__ == _callback: + if issubclass( + cls, Callback) and cls.__module__ == _callback: self.callbacks.append(cls(self.config)) # noqa else: - self.callbacks.append(callbacks_mapping[_callback](self.config)) + self.callbacks.append(callbacks_mapping[_callback]( + self.config)) async def on_task_begin(self, messages: List[Message]): self.log_output(f'Agent {self.tag} task beginning.') if self.resolve_enable_snapshots(self.config): _user_content = next( - ((getattr(m, 'content', '') or '')[:80] for m in messages if getattr(m, 'role', '') == 'user'), + ((getattr(m, 'content', '') or '')[:80] + for m in messages if getattr(m, 'role', '') == 'user'), '', ) take_snapshot( @@ -533,7 +567,8 @@ async def loop_callback(self, point, messages: List[Message]): for callback in self.callbacks: await getattr(callback, point)(self.runtime, messages) - async def parallel_tool_call(self, messages: List[Message]) -> List[Message]: + async def parallel_tool_call(self, + messages: List[Message]) -> List[Message]: """ Execute multiple tool calls in parallel and append results to the message list. @@ -543,9 +578,11 @@ async def parallel_tool_call(self, messages: List[Message]) -> List[Message]: Returns: List[Message]: Updated message list including tool responses. """ - tool_call_result = await self.tool_manager.parallel_call_tool(messages[-1].tool_calls) + tool_call_result = await self.tool_manager.parallel_call_tool( + messages[-1].tool_calls) assert len(tool_call_result) == len(messages[-1].tool_calls) - for tool_call_result, tool_call_query in zip(tool_call_result, messages[-1].tool_calls): + for tool_call_result, tool_call_query in zip(tool_call_result, + messages[-1].tool_calls): tool_call_result_format = ToolResult.from_raw(tool_call_result) _new_message = Message( role='tool', @@ -587,7 +624,8 @@ async def cleanup_tools(self): @property def stream(self): - generation_config = getattr(self.config, 'generation_config', DictConfig({})) + generation_config = getattr(self.config, 'generation_config', + DictConfig({})) return getattr(generation_config, 'stream', False) @property @@ -598,7 +636,8 @@ def show_reasoning(self) -> bool: - This only affects local console output. - Reasoning is carried by `Message.reasoning_content` (if the backend provides it). """ - generation_config = getattr(self.config, 'generation_config', DictConfig({})) + generation_config = getattr(self.config, 'generation_config', + DictConfig({})) return bool(getattr(generation_config, 'show_reasoning', False)) @property @@ -609,7 +648,8 @@ def reasoning_output(self) -> str: - "stderr" (default): keep stdout clean for assistant final text - "stdout": interleave reasoning with assistant output on stdout """ - generation_config = getattr(self.config, 'generation_config', DictConfig({})) + generation_config = getattr(self.config, 'generation_config', + DictConfig({})) return str(getattr(generation_config, 'reasoning_output', 'stdout')) _THINKING_SEP = '─' * 40 @@ -650,16 +690,19 @@ def _write_thinking_footer(self): @property def system(self): - return getattr(getattr(self.config, 'prompt', DictConfig({})), 'system', None) + return getattr( + getattr(self.config, 'prompt', DictConfig({})), 'system', None) @property def query(self): - query = getattr(getattr(self.config, 'prompt', DictConfig({})), 'query', None) + query = getattr( + getattr(self.config, 'prompt', DictConfig({})), 'query', None) if not query: query = input('>>>') return query - async def create_messages(self, messages: Union[List[Message], str]) -> List[Message]: + async def create_messages( + self, messages: Union[List[Message], str]) -> List[Message]: """ Convert input into a standardized list of messages. @@ -671,15 +714,18 @@ async def create_messages(self, messages: Union[List[Message], str]) -> List[Mes """ if isinstance(messages, list): system = self.system - if system is not None and messages[0].role == 'system' and system != messages[0].content: + if (system is not None and messages[0].role == 'system' + and system != messages[0].content): # Replace the existing system messages[0].content = system else: - assert isinstance(messages, str), ( - f'inputs can be either a list or a string, but current is {type(messages)}' - ) + assert isinstance( + messages, str + ), f'inputs can be either a list or a string, but current is {type(messages)}' messages = [ - Message(role='system', content=self.system or LLMAgent.DEFAULT_SYSTEM), + Message( + role='system', + content=self.system or LLMAgent.DEFAULT_SYSTEM), Message(role='user', content=messages or self.query), ] return messages @@ -700,7 +746,8 @@ async def do_rag(self, messages: List[Message]): if self.rag is not None: user_message.content = await self.rag.query(query) - async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: + async def do_skill(self, + messages: List[Message]) -> Optional[List[Message]]: """ Process skill-related query if applicable. @@ -715,7 +762,9 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: None if no skill processing or fallback to standard agent """ # Extract user query from normalized messages - query = messages[1].content if len(messages) > 1 and messages[1].role == 'user' else None + query = ( + messages[1].content + if len(messages) > 1 and messages[1].role == 'user' else None) if not query: return None @@ -729,7 +778,9 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: try: skills_config = self._get_skills_config() - auto_execute = getattr(skills_config, 'auto_execute', True) if skills_config else True + auto_execute = ( + getattr(skills_config, 'auto_execute', True) + if skills_config else True) if auto_execute: dag_result = await self.execute_skills(query) @@ -737,7 +788,8 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: dag_result = await self.get_skill_dag(query) if dag_result: - skill_messages = self._format_skill_result_as_messages(dag_result) + skill_messages = self._format_skill_result_as_messages( + dag_result) for msg in skill_messages: messages.append(msg) return messages @@ -747,7 +799,8 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: return None except Exception as e: - logger.warning(f'Skill execution failed: {e}, falling back to standard agent') + logger.warning( + f'Skill execution failed: {e}, falling back to standard agent') self._skill_mode_active = False return None @@ -761,10 +814,11 @@ async def load_memory(self): if hasattr(self.config, 'memory'): for mem_instance_type, _memory in self.config.memory.items(): assert mem_instance_type in memory_mapping, ( - f'{mem_instance_type} not in memory_mapping, which supports: {list(memory_mapping.keys())}' - ) + f'{mem_instance_type} not in memory_mapping, ' + f'which supports: {list(memory_mapping.keys())}') - shared_memory = await SharedMemoryManager.get_shared_memory(self.config, mem_instance_type) + shared_memory = await SharedMemoryManager.get_shared_memory( + self.config, mem_instance_type) self.memory_tools.append(shared_memory) async def prepare_rag(self): @@ -773,8 +827,8 @@ async def prepare_rag(self): rag = self.config.rag if rag is not None: assert rag.name in rag_mapping, ( - f'{rag.name} not in rag_mapping, which supports: {list(rag_mapping.keys())}' - ) + f'{rag.name} not in rag_mapping, ' + f'which supports: {list(rag_mapping.keys())}') self.rag: RAG = rag_mapping(rag.name)(self.config) async def condense_memory(self, messages: List[Message]) -> List[Message]: @@ -821,7 +875,8 @@ def log_output(self, content: Union[str, list]): for _line in line.split('\\n'): logger.info(f'[{self.tag}] {_line}') - def handle_new_response(self, messages: List[Message], response_message: Message): + def handle_new_response(self, messages: List[Message], + response_message: Message): assert response_message is not None, 'No response message generated from LLM.' if response_message.tool_calls: self.log_output('[tool_calling]:') @@ -829,15 +884,18 @@ def handle_new_response(self, messages: List[Message], response_message: Message tool_call = deepcopy(tool_call) if isinstance(tool_call['arguments'], str): try: - tool_call['arguments'] = json.loads(tool_call['arguments']) + tool_call['arguments'] = json.loads( + tool_call['arguments']) except json.decoder.JSONDecodeError: pass - self.log_output(json.dumps(tool_call, ensure_ascii=False, indent=4)) + self.log_output( + json.dumps(tool_call, ensure_ascii=False, indent=4)) if messages[-1] is not response_message: messages.append(response_message) - if messages[-1].role == 'assistant' and not messages[-1].content and response_message.tool_calls: + if (messages[-1].role == 'assistant' and not messages[-1].content + and response_message.tool_calls): messages[-1].content = 'Let me do a tool calling.' def _append_task_notifications(self, messages: List[Message]) -> List[Message]: @@ -852,7 +910,9 @@ def _append_task_notifications(self, messages: List[Message]) -> List[Message]: return messages @async_retry(max_attempts=Agent.retry_count, delay=1.0) - async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], Any]: # type: ignore + async def step( + self, messages: List[Message] + ) -> AsyncGenerator[List[Message], Any]: # type: ignore """ Execute a single step in the agent's interaction loop. @@ -890,17 +950,20 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A _response_message = None _printed_reasoning_header = False _printed_reasoning_footer = False - for _response_message in self.llm.generate(messages, tools=tools): + for _response_message in self.llm.generate( + messages, tools=tools): if is_first: messages.append(_response_message) is_first = False if self.show_reasoning: - reasoning_text = getattr(_response_message, 'reasoning_content', '') or '' + reasoning_text = ( + getattr(_response_message, 'reasoning_content', '') + or '') # Some providers may reset / shorten content across chunks. if len(reasoning_text) < len(_reasoning): _reasoning = '' - new_reasoning = reasoning_text[len(_reasoning) :] + new_reasoning = reasoning_text[len(_reasoning):] if new_reasoning: if not _printed_reasoning_header: self._write_thinking_header() @@ -908,7 +971,7 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A self._write_reasoning(new_reasoning, dim=True) _reasoning = reasoning_text - new_content = _response_message.content[len(_content) :] + new_content = _response_message.content[len(_content):] if new_content: if _printed_reasoning_header and not _printed_reasoning_footer: self._write_thinking_footer() @@ -923,7 +986,8 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A # Handle reasoning summaries that arrive after content if self.show_reasoning and _response_message is not None: - final_reasoning = getattr(_response_message, 'reasoning_content', '') or '' + final_reasoning = getattr(_response_message, + 'reasoning_content', '') or '' if final_reasoning and not _printed_reasoning_header: self._write_thinking_header() self._write_reasoning(final_reasoning, dim=True) @@ -933,7 +997,9 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A else: _response_message = self.llm.generate(messages, tools=tools) if self.show_reasoning: - reasoning_text = getattr(_response_message, 'reasoning_content', '') or '' + reasoning_text = ( + getattr(_response_message, 'reasoning_content', '') + or '') if reasoning_text: self._write_thinking_header() self._write_reasoning(reasoning_text, dim=True) @@ -961,7 +1027,8 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A prompt_tokens = _response_message.prompt_tokens completion_tokens = _response_message.completion_tokens cached_tokens = getattr(_response_message, 'cached_tokens', 0) or 0 - cache_creation_input_tokens = getattr(_response_message, 'cache_creation_input_tokens', 0) or 0 + cache_creation_input_tokens = ( + getattr(_response_message, 'cache_creation_input_tokens', 0) or 0) async with LLMAgent.TOKEN_LOCK: LLMAgent.TOTAL_PROMPT_TOKENS += prompt_tokens @@ -970,14 +1037,17 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS += cache_creation_input_tokens # tokens in the current step - self.log_output(f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}') + self.log_output( + f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}' + ) if cached_tokens or cache_creation_input_tokens: - self.log_output(f'[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}') + self.log_output( + f'[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}' + ) # total tokens for the process so far self.log_output( f'[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, ' - f'total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}' - ) + f'total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}') if LLMAgent.TOTAL_CACHED_TOKENS or LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS: self.log_output( f'[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, ' @@ -994,7 +1064,8 @@ def prepare_runtime(self): """Initialize the runtime context.""" self.runtime: Runtime = Runtime(llm=self.llm) - def read_history(self, messages: List[Message], **kwargs) -> Tuple[DictConfig, Runtime, List[Message]]: + def read_history(self, messages: List[Message], + **kwargs) -> Tuple[DictConfig, Runtime, List[Message]]: """ Load previous chat history from disk if available. @@ -1037,8 +1108,10 @@ def get_user_id(self, default_user_id=DEFAULT_USER) -> Optional[str]: return user_id def _get_step_memory_info(self, memory_config: DictConfig): - user_id, agent_id, run_id, memory_type = get_memory_meta_safe(memory_config, 'add_after_step') - if all(value is None for value in [user_id, agent_id, run_id, memory_type]): + user_id, agent_id, run_id, memory_type = get_memory_meta_safe( + memory_config, 'add_after_step') + if all(value is None + for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) return user_id, agent_id, run_id, memory_type @@ -1049,7 +1122,8 @@ def _get_run_memory_info(self, memory_config: DictConfig): 'add_after_task', default_user_id=getattr(memory_config, 'user_id', None), ) - if all(value is None for value in [user_id, agent_id, run_id, memory_type]): + if all(value is None + for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) agent_id = agent_id or self.tag @@ -1060,14 +1134,18 @@ async def add_memory(self, messages: List[Message], add_type, **kwargs): if hasattr(self.config, 'memory') and self.config.memory: tools_num = len(self.memory_tools) if self.memory_tools else 0 - for idx, (mem_instance_type, memory_config) in enumerate(self.config.memory.items()): + for idx, (mem_instance_type, + memory_config) in enumerate(self.config.memory.items()): if add_type == 'add_after_task': - user_id, agent_id, run_id, memory_type = self._get_run_memory_info(memory_config) + user_id, agent_id, run_id, memory_type = self._get_run_memory_info( + memory_config) else: - user_id, agent_id, run_id, memory_type = self._get_step_memory_info(memory_config) + user_id, agent_id, run_id, memory_type = self._get_step_memory_info( + memory_config) if idx < tools_num: - if any(v is not None for v in [user_id, agent_id, run_id, memory_type]): + if any(v is not None + for v in [user_id, agent_id, run_id, memory_type]): await self.memory_tools[idx].add( messages, user_id=user_id, @@ -1096,9 +1174,11 @@ def save_history(self, messages: List[Message], **kwargs): config: DictConfig = deepcopy(self.config) config.runtime = self.runtime.to_dict() - save_history(self.output_dir, task=self.tag, config=config, messages=messages) + save_history( + self.output_dir, task=self.tag, config=config, messages=messages) - async def run_loop(self, messages: Union[List[Message], str], **kwargs) -> AsyncGenerator[Any, Any]: + async def run_loop(self, messages: Union[List[Message], str], + **kwargs) -> AsyncGenerator[Any, Any]: """ Run the agent, mainly contains a llm calling and tool calling loop. @@ -1111,7 +1191,8 @@ async def run_loop(self, messages: Union[List[Message], str], **kwargs) -> Async List[Message]: A list of message objects representing the agent's response or interaction history. """ try: - self.max_chat_round = getattr(self.config, 'max_chat_round', LLMAgent.DEFAULT_MAX_CHAT_ROUND) + self.max_chat_round = getattr(self.config, 'max_chat_round', + LLMAgent.DEFAULT_MAX_CHAT_ROUND) self.register_callback_from_config() self.prepare_llm() self.prepare_runtime() @@ -1161,7 +1242,8 @@ async def run_loop(self, messages: Union[List[Message], str], **kwargs) -> Async yield messages self.runtime.round += 1 # save memory and history - await self.add_memory(messages, add_type='add_after_step', **kwargs) + await self.add_memory( + messages, add_type='add_after_step', **kwargs) self.save_history(messages) # +1 means the next round the assistant may give a conclusion @@ -1170,10 +1252,10 @@ async def run_loop(self, messages: Union[List[Message], str], **kwargs) -> Async messages.append( Message( role='assistant', - content=f'Task {messages[1].content} was cutted off, because ' + content= + f'Task {messages[1].content} was cutted off, because ' f'max round({self.max_chat_round}) exceeded.', - ) - ) + )) self.runtime.should_stop = True yield messages @@ -1183,7 +1265,9 @@ async def run_loop(self, messages: Union[List[Message], str], **kwargs) -> Async yield messages def _add_memory(): - asyncio.run(self.add_memory(messages, add_type='add_after_task', **kwargs)) + asyncio.run( + self.add_memory( + messages, add_type='add_after_task', **kwargs)) loop = asyncio.get_running_loop() loop.run_in_executor(None, _add_memory) @@ -1192,19 +1276,23 @@ def _add_memory(): logger.warning(traceback.format_exc()) if hasattr(self.config, 'help'): - logger.error(f'[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}') + logger.error( + f'[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}' + ) raise e async def run( - self, messages: Union[List[Message], str], **kwargs + self, messages: Union[List[Message], str], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: stream = kwargs.get('stream', False) with self.config_context(): if stream: - OmegaConf.update(self.config, 'generation_config.stream', True, merge=True) + OmegaConf.update( + self.config, 'generation_config.stream', True, merge=True) async def stream_generator(): - async for _chunk in self.run_loop(messages=messages, **kwargs): + async for _chunk in self.run_loop( + messages=messages, **kwargs): yield _chunk return stream_generator() diff --git a/ms_agent/agent/loader.py b/ms_agent/agent/loader.py index 7dd5f7a02..21b1687a7 100644 --- a/ms_agent/agent/loader.py +++ b/ms_agent/agent/loader.py @@ -5,30 +5,27 @@ import sys from typing import Dict, Optional -from omegaconf import DictConfig, OmegaConf - from ms_agent.config.config import Config from ms_agent.utils.constants import DEFAULT_AGENT_FILE, DEFAULT_TAG +from omegaconf import DictConfig, OmegaConf from .base import Agent class AgentLoader: + @classmethod - def build( - cls, - config_dir_or_id: Optional[str] = None, - config: Optional[DictConfig] = None, - env: Optional[Dict[str, str]] = None, - tag: Optional[str] = None, - trust_remote_code: bool = False, - **kwargs, - ) -> Agent: + def build(cls, + config_dir_or_id: Optional[str] = None, + config: Optional[DictConfig] = None, + env: Optional[Dict[str, str]] = None, + tag: Optional[str] = None, + trust_remote_code: bool = False, + **kwargs) -> Agent: agent_config: Optional[DictConfig] = None if config_dir_or_id is not None: if not os.path.exists(config_dir_or_id): from modelscope import snapshot_download - config_dir_or_id = snapshot_download(config_dir_or_id) agent_config: DictConfig = Config.from_task(config_dir_or_id, env) if config is not None: @@ -43,30 +40,35 @@ def build( agent_tag = tag agent_config.tag = agent_tag agent_config.trust_remote_code = trust_remote_code - if getattr(agent_config, 'local_dir', None) is None and config_dir_or_id is not None: + if getattr(agent_config, 'local_dir', + None) is None and config_dir_or_id is not None: agent_config.local_dir = config_dir_or_id - from .code_agent import CodeAgent from .llm_agent import LLMAgent - + from .code_agent import CodeAgent agent_type = LLMAgent.AGENT_NAME if 'code_file' in kwargs: code_file = kwargs.pop('code_file') elif agent_config is not None: - agent_type = getattr(agent_config, 'type', '').lower() or agent_type.lower() + agent_type = getattr(agent_config, 'type', + '').lower() or agent_type.lower() code_file = getattr(agent_config, 'code_file', None) else: assert getattr(agent_config, 'local_dir', None) is not None - code_file = os.path.join(getattr(agent_config, 'local_dir', ''), DEFAULT_AGENT_FILE) + code_file = os.path.join( + getattr(agent_config, 'local_dir', ''), DEFAULT_AGENT_FILE) if code_file is not None: - agent_instance = cls._load_external_code(agent_config, code_file, **kwargs) + agent_instance = cls._load_external_code(agent_config, code_file, + **kwargs) else: assert agent_config is not None if agent_type == LLMAgent.AGENT_NAME.lower(): - agent_instance = LLMAgent(agent_config, agent_tag, trust_remote_code, **kwargs) + agent_instance = LLMAgent(agent_config, agent_tag, + trust_remote_code, **kwargs) elif agent_type == CodeAgent.AGENT_NAME.lower(): - agent_instance = CodeAgent(agent_config, agent_tag, trust_remote_code, **kwargs) + agent_instance = CodeAgent(agent_config, agent_tag, + trust_remote_code, **kwargs) else: raise ValueError(f'Unknown agent type: {agent_type}') return agent_instance @@ -77,8 +79,7 @@ def _load_external_code(cls, config, code_file, **kwargs) -> 'Agent': assert config.trust_remote_code, ( f'[External Code]A code file is required to run in the LLMAgent: {code_file}' f'\nThis is external code, if you trust this code file, ' - f'please specify `--trust_remote_code true`' - ) + f'please specify `--trust_remote_code true`') subdir = os.path.dirname(code_file) code_file = os.path.basename(code_file) local_dir = config.local_dir @@ -96,11 +97,20 @@ def _load_external_code(cls, config, code_file, **kwargs) -> 'Agent': if code_file in sys.modules: del sys.modules[code_file] code_module = importlib.import_module(code_file) - module_classes = {name: agent_cls for name, agent_cls in inspect.getmembers(code_module, inspect.isclass)} + module_classes = { + name: agent_cls + for name, agent_cls in inspect.getmembers(code_module, + inspect.isclass) + } agent_instance = None for name, agent_cls in module_classes.items(): - if Agent in agent_cls.__mro__[1:] and agent_cls.__module__ == code_file: - agent_instance = agent_cls(config, config.tag, trust_remote_code=config.trust_remote_code, **kwargs) + if Agent in agent_cls.__mro__[ + 1:] and agent_cls.__module__ == code_file: + agent_instance = agent_cls( + config, + config.tag, + trust_remote_code=config.trust_remote_code, + **kwargs) break assert agent_instance is not None, f'Cannot find a proper agent class in the external code file: {code_file}' if subdir_inserted: diff --git a/ms_agent/agent/runtime.py b/ms_agent/agent/runtime.py index 508eeaf0e..55a0dbf9e 100644 --- a/ms_agent/agent/runtime.py +++ b/ms_agent/agent/runtime.py @@ -7,6 +7,7 @@ @dataclass class Runtime: + should_stop: bool = False llm: LLM = None diff --git a/ms_agent/app/doc_research.py b/ms_agent/app/doc_research.py index 9cfdcf8bf..8c55f99c4 100644 --- a/ms_agent/app/doc_research.py +++ b/ms_agent/app/doc_research.py @@ -2076,7 +2076,7 @@ def initialize_page(request: gr.Request): session_status_html = f"""
📊 会话状态: {'已加载历史数据' if any(session_data.values()) else '新会话'} - {f'| 最后更新: {session_data.get('timestamp', '未知')}' if session_data.get('timestamp') else ''} + {f'| 最后更新: {session_data.get("timestamp", "未知")}' if session_data.get("timestamp") else ''}
""" if any(session_data.values()) else """
diff --git a/ms_agent/app/fin_research.py b/ms_agent/app/fin_research.py index a8344803e..1767d3e70 100644 --- a/ms_agent/app/fin_research.py +++ b/ms_agent/app/fin_research.py @@ -443,14 +443,14 @@ def build_fin_prompt( sections.append(f'Market / region focus: {markets.strip()}') if focus_areas: sections.append( - f'Priority analytical pillars: {', '.join(focus_areas)}') + f'Priority analytical pillars: {", ".join(focus_areas)}') if macro_view: sections.append(f'Macro sensitivity preference: {macro_view}') if extra_notes.strip(): sections.append(f'Additional analyst notes:\n{extra_notes.strip()}') instructions = [ - f'Desired deliverable style: {deliverable_style or 'Balanced'}', + f'Desired deliverable style: {deliverable_style or "Balanced"}', f'Analytical depth target (1-5): {analysis_depth}' ] if output_language: @@ -1137,7 +1137,7 @@ def format_result_summary(workdir: str, include_sentiment: bool, f'- 工作目录: {workdir}', ] if focus_areas: - lines.append(f'- 关注领域: {', '.join(focus_areas)}') + lines.append(f'- 关注领域: {", ".join(focus_areas)}') lines.append('请查阅过程报告及最终综合报告。') return '\n'.join(lines) diff --git a/ms_agent/callbacks/base.py b/ms_agent/callbacks/base.py index f4509e15b..849fe7069 100644 --- a/ms_agent/callbacks/base.py +++ b/ms_agent/callbacks/base.py @@ -1,17 +1,18 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import List -from omegaconf import DictConfig - from ms_agent.agent.runtime import Runtime from ms_agent.llm.utils import Message +from omegaconf import DictConfig class Callback: + def __init__(self, config: DictConfig): self.config = config - async def on_task_begin(self, runtime: Runtime, messages: List[Message]) -> None: + async def on_task_begin(self, runtime: Runtime, + messages: List[Message]) -> None: """Called when a task begins. Args: @@ -23,7 +24,8 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]) -> None """ pass - async def on_generate_response(self, runtime: Runtime, messages: List[Message]): + async def on_generate_response(self, runtime: Runtime, + messages: List[Message]): """Called before LLM generates response. Args: diff --git a/ms_agent/callbacks/input_callback.py b/ms_agent/callbacks/input_callback.py index a5bffe998..e44db1e31 100644 --- a/ms_agent/callbacks/input_callback.py +++ b/ms_agent/callbacks/input_callback.py @@ -1,12 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import List -from omegaconf import DictConfig - from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() diff --git a/ms_agent/capabilities/__init__.py b/ms_agent/capabilities/__init__.py index bded8c19a..73feb864e 100644 --- a/ms_agent/capabilities/__init__.py +++ b/ms_agent/capabilities/__init__.py @@ -23,7 +23,6 @@ """ from __future__ import annotations - from typing import Any from ms_agent.capabilities.descriptor import CapabilityDescriptor @@ -40,7 +39,13 @@ def create_registry(config: Any = None) -> CapabilityRegistry: """ registry = CapabilityRegistry() - from ms_agent.capabilities.wrappers import agent_delegate, deep_research, filesystem, lsp_code_server, web_search + from ms_agent.capabilities.wrappers import ( + agent_delegate, + deep_research, + filesystem, + lsp_code_server, + web_search, + ) filesystem.register_all(registry, config) lsp_code_server.register_all(registry, config) diff --git a/ms_agent/capabilities/async_task.py b/ms_agent/capabilities/async_task.py index c13db3a7b..74ac90fbe 100644 --- a/ms_agent/capabilities/async_task.py +++ b/ms_agent/capabilities/async_task.py @@ -145,7 +145,10 @@ def check( task = self._tasks.get(task_id) if task is None: known = [t.task_id for t in self._tasks.values()] - return {'error': f'Unknown task_id: {task_id}', 'known_tasks': known} + return { + 'error': f'Unknown task_id: {task_id}', + 'known_tasks': known + } info: dict[str, Any] = { 'task_id': task.task_id, @@ -162,7 +165,8 @@ def check( try: info.update(progress_fn(task)) except Exception: - logger.debug('progress_fn raised for task %s', task_id, exc_info=True) + logger.debug( + 'progress_fn raised for task %s', task_id, exc_info=True) return info @@ -178,7 +182,11 @@ def get_result(self, task_id: str) -> dict[str, Any]: 'message': 'Task is still in progress.', } if task.status == 'failed': - return {'task_id': task_id, 'status': 'failed', 'error': task.error} + return { + 'task_id': task_id, + 'status': 'failed', + 'error': task.error + } if task.status == 'cancelled': return {'task_id': task_id, 'status': 'cancelled'} return { @@ -194,7 +202,8 @@ async def cancel(self, task_id: str) -> dict[str, Any]: return {'error': f'Unknown task_id: {task_id}'} if task.status != 'running': return { - 'error': f'Task {task_id} is not running (status: {task.status})', + 'error': + f'Task {task_id} is not running (status: {task.status})', } # Cancel the asyncio task first diff --git a/ms_agent/capabilities/mcp_server.py b/ms_agent/capabilities/mcp_server.py index 06314a019..650dc33cf 100644 --- a/ms_agent/capabilities/mcp_server.py +++ b/ms_agent/capabilities/mcp_server.py @@ -1,12 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import argparse -import json import logging import os import sys +import json from dotenv import find_dotenv, load_dotenv - from ms_agent.capabilities import create_registry logger = logging.getLogger(__name__) @@ -50,15 +49,13 @@ def _print_check() -> None: registry = create_registry() caps = registry.list_all() info = { - 'status': 'ok', - 'capabilities': [ - { - 'name': c.name, - 'granularity': c.granularity, - 'summary': c.summary, - } - for c in caps - ], + 'status': + 'ok', + 'capabilities': [{ + 'name': c.name, + 'granularity': c.granularity, + 'summary': c.summary, + } for c in caps], } print(json.dumps(info, indent=2)) @@ -119,8 +116,7 @@ def main() -> None: """ parser = argparse.ArgumentParser( - description='ms-agent MCP Capability Server', - ) + description='ms-agent MCP Capability Server', ) parser.add_argument( '--check', action='store_true', @@ -155,7 +151,8 @@ def main() -> None: from mcp.server.fastmcp import FastMCP except ImportError: print( - 'ERROR: The "mcp" package is required. Install it with:\n pip install mcp\n', + 'ERROR: The "mcp" package is required. Install it with:\n' + ' pip install mcp\n', file=sys.stderr, ) sys.exit(1) @@ -165,9 +162,8 @@ def main() -> None: server = FastMCP( 'ms-agent-capabilities', - instructions=( - 'ms-agent Capability Gateway. Provides deep research, LSP code validation, and advanced file-editing tools.' - ), + instructions=('ms-agent Capability Gateway. Provides deep research, ' + 'LSP code validation, and advanced file-editing tools.'), ) for cap in registry.list_all(): @@ -208,13 +204,18 @@ def _build_handler(registry, cap, workspace: str): for pname, pschema in properties.items(): py_type = type_map.get(pschema.get('type', 'string'), str) if pname in required_params: - params.append(inspect.Parameter(pname, inspect.Parameter.KEYWORD_ONLY, annotation=py_type)) + params.append( + inspect.Parameter( + pname, inspect.Parameter.KEYWORD_ONLY, annotation=py_type)) else: opt_type = typing.Optional[py_type] default = pschema.get('default') params.append( - inspect.Parameter(pname, inspect.Parameter.KEYWORD_ONLY, default=default, annotation=opt_type) - ) + inspect.Parameter( + pname, + inspect.Parameter.KEYWORD_ONLY, + default=default, + annotation=opt_type)) annotations[pname] = params[-1].annotation cap_name = cap.name diff --git a/ms_agent/capabilities/registry.py b/ms_agent/capabilities/registry.py index bc1f87516..5d0dbd4b2 100644 --- a/ms_agent/capabilities/registry.py +++ b/ms_agent/capabilities/registry.py @@ -21,7 +21,8 @@ def __init__(self) -> None: self._descriptors: dict[str, CapabilityDescriptor] = {} self._handlers: dict[str, Handler] = {} - def register(self, descriptor: CapabilityDescriptor, handler: Handler) -> None: + def register(self, descriptor: CapabilityDescriptor, + handler: Handler) -> None: if descriptor.name in self._descriptors: logger.warning('Overwriting capability %s', descriptor.name) self._descriptors[descriptor.name] = descriptor @@ -44,7 +45,8 @@ def discover( results = self.list_all() if granularity is not None: - levels = [granularity] if isinstance(granularity, str) else granularity + levels = [granularity] if isinstance(granularity, + str) else granularity results = [c for c in results if c.granularity in levels] if tags: @@ -54,12 +56,14 @@ def discover( if query: q = query.lower() results = [ - c for c in results if q in c.name.lower() or q in c.summary.lower() or q in c.description.lower() + c for c in results if q in c.name.lower() + or q in c.summary.lower() or q in c.description.lower() ] return results - async def invoke(self, name: str, args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: + async def invoke(self, name: str, args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: """Invoke a registered capability by name.""" if name not in self._handlers: return {'error': f'Unknown capability: {name}'} diff --git a/ms_agent/capabilities/wrappers/agent_delegate.py b/ms_agent/capabilities/wrappers/agent_delegate.py index e7a89905a..0c0cc2bf7 100644 --- a/ms_agent/capabilities/wrappers/agent_delegate.py +++ b/ms_agent/capabilities/wrappers/agent_delegate.py @@ -24,13 +24,13 @@ 'description': 'Custom system prompt for the agent (optional)', }, 'tools': { - 'type': 'string', - 'description': ( - 'Comma-separated basic tool component names to enable, e.g. ' - '"web_search,file_system,todo_list". Alias "filesystem" is accepted ' - 'for backward compatibility. Leave empty to use the default agent ' - 'config tools.' - ), + 'type': + 'string', + 'description': + ('Comma-separated basic tool component names to enable, e.g. ' + '"web_search,file_system,todo_list". Alias "filesystem" is accepted ' + 'for backward compatibility. Leave empty to use the default agent ' + 'config tools.'), }, 'max_rounds': { 'type': 'integer', @@ -69,13 +69,13 @@ name='delegate_task', version='0.1.0', granularity='project', - summary=('Delegate a task to an LLM agent that can use tools. Blocks until the agent completes.'), + summary=('Delegate a task to an LLM agent that can use tools. ' + 'Blocks until the agent completes.'), description=( 'Creates an LLMAgent with the given configuration, runs it on the ' 'provided query, and returns the final response text. The agent ' 'can use tools (web search, filesystem, etc.) to accomplish the ' - 'task. WARNING: this call blocks and may take minutes.' - ), + 'task. WARNING: this call blocks and may take minutes.'), input_schema={ 'type': 'object', 'properties': _DELEGATE_INPUT_PROPERTIES, @@ -84,8 +84,12 @@ output_schema={ 'type': 'object', 'properties': { - 'status': {'type': 'string'}, - 'response': {'type': 'string'}, + 'status': { + 'type': 'string' + }, + 'response': { + 'type': 'string' + }, }, }, tags=['agent', 'delegate', 'llm', 'sync'], @@ -96,12 +100,11 @@ name='submit_agent_task', version='0.1.0', granularity='project', - summary=('Submit an agent task to run in the background. Returns a task_id immediately.'), - description=( - 'Starts an LLMAgent in the background and returns a task_id. ' - 'Use check_agent_task(task_id) to poll progress and ' - 'get_agent_result(task_id) to retrieve the final response.' - ), + summary=('Submit an agent task to run in the background. ' + 'Returns a task_id immediately.'), + description=('Starts an LLMAgent in the background and returns a task_id. ' + 'Use check_agent_task(task_id) to poll progress and ' + 'get_agent_result(task_id) to retrieve the final response.'), input_schema={ 'type': 'object', 'properties': _DELEGATE_INPUT_PROPERTIES, @@ -110,8 +113,12 @@ output_schema={ 'type': 'object', 'properties': { - 'task_id': {'type': 'string'}, - 'status': {'type': 'string'}, + 'task_id': { + 'type': 'string' + }, + 'status': { + 'type': 'string' + }, }, }, tags=['agent', 'delegate', 'llm', 'async', 'submit'], @@ -123,9 +130,8 @@ version='0.1.0', granularity='tool', summary='Check progress of a background agent task.', - description=( - 'Polls the status of an agent task previously submitted via submit_agent_task. Returns the current status.' - ), + description=('Polls the status of an agent task previously submitted via ' + 'submit_agent_task. Returns the current status.'), input_schema={ 'type': 'object', 'properties': { @@ -145,10 +151,8 @@ version='0.1.0', granularity='tool', summary='Get the result of a completed agent task.', - description=( - 'Retrieves the final response from a completed agent task. ' - 'If the task is still running, returns a status message.' - ), + description=('Retrieves the final response from a completed agent task. ' + 'If the task is still running, returns a status message.'), input_schema={ 'type': 'object', 'properties': { @@ -205,7 +209,8 @@ def _build_basic_tools_config(tools_list: list[str] | None) -> dict[str, Any]: for raw_name in tools_list: tool_name = _BASIC_TOOL_ALIASES.get(raw_name) if tool_name is None: - logger.warning('Ignoring unsupported delegate tool name: %s', raw_name) + logger.warning('Ignoring unsupported delegate tool name: %s', + raw_name) continue if tool_name in tools_cfg: continue @@ -245,7 +250,9 @@ def _build_agent_config( # return. OmegaConf.merge(default, ours) lets our value win. safe_cbs: list[str] = [] if hasattr(config, 'callbacks') and config.callbacks: - safe_cbs = [c for c in config.callbacks if c not in ('input_callback',)] + safe_cbs = [ + c for c in config.callbacks if c not in ('input_callback', ) + ] OmegaConf.update(config, 'callbacks', safe_cbs, merge=False) OmegaConf.update(config, 'save_history', False, merge=True) @@ -257,7 +264,8 @@ def _build_agent_config( # Preserve explicit config_path tool settings when already present. if hasattr(existing_tools, tool_name): continue - OmegaConf.update(config, f'tools.{tool_name}', tool_cfg, merge=True) + OmegaConf.update( + config, f'tools.{tool_name}', tool_cfg, merge=True) return config @@ -289,7 +297,8 @@ async def _run_agent( """Create, run, and clean up an LLMAgent. Returns the response text.""" from ms_agent.agent.llm_agent import LLMAgent - config = _build_agent_config(config_path, system_prompt, tools_list, max_rounds) + config = _build_agent_config(config_path, system_prompt, tools_list, + max_rounds) agent = LLMAgent(config=config, tag='delegate') try: @@ -313,7 +322,8 @@ async def _run_agent( logger.debug('Error during agent tool cleanup', exc_info=True) -async def _handle_delegate_task(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: +async def _handle_delegate_task(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: """Synchronous agent delegation -- blocks until the agent finishes.""" query = (args.get('query') or '').strip() if not query: @@ -345,7 +355,8 @@ async def _background_agent(task: AsyncTask) -> dict[str, Any]: return {'response': response} -async def _handle_submit_agent_task(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: +async def _handle_submit_agent_task(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: """Submit an agent task to run in the background.""" query = (args.get('query') or '').strip() if not query: @@ -363,27 +374,32 @@ async def _handle_submit_agent_task(args: dict[str, Any], **kwargs: Any) -> dict }, ) return { - 'task_id': task.task_id, - 'status': 'running', - 'message': ( - f'Agent task {task.task_id} started. Use check_agent_task(task_id="{task.task_id}") to poll status.' - ), + 'task_id': + task.task_id, + 'status': + 'running', + 'message': + (f'Agent task {task.task_id} started. ' + f'Use check_agent_task(task_id="{task.task_id}") to poll status.'), } -async def _handle_check_agent_task(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: +async def _handle_check_agent_task(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: """Check progress of a background agent task.""" return _manager.check(args['task_id']) -async def _handle_get_agent_result(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: +async def _handle_get_agent_result(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: """Get the result of a completed agent task.""" task_id = args['task_id'] max_chars = args.get('max_chars', 50000) result = _manager.get_result(task_id) # Truncate response if needed - if result.get('status') == 'completed' and isinstance(result.get('result'), dict): + if result.get('status') == 'completed' and isinstance( + result.get('result'), dict): response = result['result'].get('response', '') truncated = len(response) > max_chars if truncated: @@ -395,7 +411,8 @@ async def _handle_get_agent_result(args: dict[str, Any], **kwargs: Any) -> dict[ return result -async def _handle_cancel_agent_task(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: +async def _handle_cancel_agent_task(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: """Cancel a running agent task.""" return await _manager.cancel(args['task_id']) diff --git a/ms_agent/capabilities/wrappers/deep_research.py b/ms_agent/capabilities/wrappers/deep_research.py index 10672777f..68de160bb 100644 --- a/ms_agent/capabilities/wrappers/deep_research.py +++ b/ms_agent/capabilities/wrappers/deep_research.py @@ -18,17 +18,14 @@ name='submit_research_task', version='0.1.0', granularity='project', - summary=( - 'Submit a deep research task that runs in the background. ' - 'Returns a task_id immediately -- use check_research_progress ' - 'and get_research_report to poll results.' - ), + summary=('Submit a deep research task that runs in the background. ' + 'Returns a task_id immediately -- use check_research_progress ' + 'and get_research_report to poll results.'), description=( 'Launches the deep_research v2 pipeline as a background subprocess. ' 'The calling agent is NOT blocked and can continue other work. ' 'Use check_research_progress(task_id) to poll status, and ' - 'get_research_report(task_id) to retrieve the final report.' - ), + 'get_research_report(task_id) to retrieve the final report.'), input_schema={ 'type': 'object', 'properties': { @@ -37,12 +34,16 @@ 'description': 'The research question or topic to investigate', }, 'config_path': { - 'type': 'string', - 'description': ('Path to researcher.yaml config. Defaults to the bundled v2 config.'), + 'type': + 'string', + 'description': ('Path to researcher.yaml config. ' + 'Defaults to the bundled v2 config.'), }, 'output_dir': { - 'type': 'string', - 'description': 'Directory for research outputs (auto-generated if omitted)', + 'type': + 'string', + 'description': + 'Directory for research outputs (auto-generated if omitted)', }, }, 'required': ['query'], @@ -50,9 +51,15 @@ output_schema={ 'type': 'object', 'properties': { - 'task_id': {'type': 'string'}, - 'status': {'type': 'string'}, - 'output_dir': {'type': 'string'}, + 'task_id': { + 'type': 'string' + }, + 'status': { + 'type': 'string' + }, + 'output_dir': { + 'type': 'string' + }, }, }, tags=['research', 'search', 'report', 'async', 'submit'], @@ -63,14 +70,12 @@ name='check_research_progress', version='0.1.0', granularity='tool', - summary=( - 'Check the progress of a running deep research task. Returns status, evidence count, and latest activity.' - ), + summary=('Check the progress of a running deep research task. ' + 'Returns status, evidence count, and latest activity.'), description=( 'Polls the status of a research task previously submitted via ' 'submit_research_task. Inspects the output directory to report ' - 'how many evidence notes and analyses have been collected so far.' - ), + 'how many evidence notes and analyses have been collected so far.'), input_schema={ 'type': 'object', 'properties': { @@ -89,15 +94,12 @@ name='get_research_report', version='0.1.0', granularity='tool', - summary=( - 'Retrieve the final report from a completed deep research task. ' - 'Returns the report content or an error if not yet complete.' - ), + summary=('Retrieve the final report from a completed deep research task. ' + 'Returns the report content or an error if not yet complete.'), description=( 'Reads the final research report produced by a completed task. ' 'If the task is still running, returns a message to wait. ' - 'If completed, returns the full report markdown content.' - ), + 'If completed, returns the full report markdown content.'), input_schema={ 'type': 'object', 'properties': { @@ -124,14 +126,12 @@ granularity='project', summary=( 'Run deep research synchronously (BLOCKS until complete, 20-60 min). ' - 'Prefer submit_research_task for non-blocking usage.' - ), + 'Prefer submit_research_task for non-blocking usage.'), description=( 'Synchronous version that blocks until the research is complete. ' 'WARNING: This can take 20-60 minutes. Most MCP clients will ' 'timeout. Use submit_research_task + check_research_progress + ' - 'get_research_report for non-blocking async operation.' - ), + 'get_research_report for non-blocking async operation.'), input_schema={ 'type': 'object', 'properties': { @@ -139,8 +139,14 @@ 'type': 'string', 'description': 'The research question or topic to investigate', }, - 'config_path': {'type': 'string', 'description': 'Path to researcher.yaml'}, - 'output_dir': {'type': 'string', 'description': 'Output directory'}, + 'config_path': { + 'type': 'string', + 'description': 'Path to researcher.yaml' + }, + 'output_dir': { + 'type': 'string', + 'description': 'Output directory' + }, }, 'required': ['query'], }, @@ -153,12 +159,14 @@ def _find_default_config() -> str | None: """Locate the bundled deep_research v2 researcher.yaml.""" candidates = [ - os.path.join(os.path.dirname(__file__), '..', '..', '..', 'projects', 'deep_research', 'v2', 'researcher.yaml'), + os.path.join( + os.path.dirname(__file__), '..', '..', '..', 'projects', + 'deep_research', 'v2', 'researcher.yaml'), ] try: from importlib import resources as importlib_resources - - trav = importlib_resources.files('ms_agent').joinpath('projects', 'deep_research', 'v2', 'researcher.yaml') + trav = importlib_resources.files('ms_agent').joinpath( + 'projects', 'deep_research', 'v2', 'researcher.yaml') candidates.insert(0, str(trav)) except Exception: pass @@ -200,8 +208,12 @@ def _count_evidence(output_dir: str) -> dict[str, int]: notes_dir = os.path.join(evidence_dir, 'notes') analyses_dir = os.path.join(evidence_dir, 'analyses') return { - 'notes': len(list(Path(notes_dir).glob('*.md'))) if os.path.isdir(notes_dir) else 0, - 'analyses': len(list(Path(analyses_dir).glob('*.md'))) if os.path.isdir(analyses_dir) else 0, + 'notes': + len(list(Path(notes_dir).glob('*.md'))) + if os.path.isdir(notes_dir) else 0, + 'analyses': + len(list(Path(analyses_dir).glob('*.md'))) + if os.path.isdir(analyses_dir) else 0, } @@ -218,7 +230,8 @@ def _research_progress_fn(task: AsyncTask) -> dict[str, Any]: 'report_available': bool(report_path), } if task.status == 'completed': - info['report_path'] = task.metadata.get('report_path', '') or report_path + info['report_path'] = task.metadata.get('report_path', + '') or report_path return info @@ -258,7 +271,8 @@ async def _background_research(task: AsyncTask) -> dict[str, Any]: raise RuntimeError(stderr.decode('utf-8', errors='replace')[-2000:]) -async def _handle_submit(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: +async def _handle_submit(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: """Submit a research task to run in the background.""" query: str = args['query'] config_path = args.get('config_path', '') or _find_default_config() or '' @@ -283,23 +297,28 @@ async def _handle_submit(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: ) return { - 'task_id': task.task_id, - 'status': 'running', - 'output_dir': output_dir, - 'message': ( - f'Research task {task.task_id} started. ' - f'Use check_research_progress(task_id="{task.task_id}") to poll status.' - ), + 'task_id': + task.task_id, + 'status': + 'running', + 'output_dir': + output_dir, + 'message': + (f'Research task {task.task_id} started. ' + f'Use check_research_progress(task_id="{task.task_id}") to poll status.' + ), } -async def _handle_check_progress(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: +async def _handle_check_progress(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: """Check the progress of a running research task.""" task_id: str = args['task_id'] return _manager.check(task_id, progress_fn=_research_progress_fn) -async def _handle_get_report(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: +async def _handle_get_report(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: """Retrieve the final report from a completed task.""" task_id: str = args['task_id'] max_chars: int = args.get('max_chars', 50000) @@ -311,14 +330,15 @@ async def _handle_get_report(args: dict[str, Any], **kwargs: Any) -> dict[str, A if task.status == 'running': evidence = _count_evidence(task.metadata.get('output_dir', '')) return { - 'task_id': task_id, - 'status': 'running', - 'message': ( - 'Research is still in progress. ' - f'Evidence collected so far: {evidence["notes"]} notes, ' - f'{evidence["analyses"]} analyses. ' - 'Please check again later.' - ), + 'task_id': + task_id, + 'status': + 'running', + 'message': + ('Research is still in progress. ' + f'Evidence collected so far: {evidence["notes"]} notes, ' + f'{evidence["analyses"]} analyses. ' + 'Please check again later.'), } if task.status == 'failed': @@ -329,7 +349,8 @@ async def _handle_get_report(args: dict[str, Any], **kwargs: Any) -> dict[str, A } output_dir = task.metadata.get('output_dir', '') - report_path = task.metadata.get('report_path', '') or _find_report(output_dir) + report_path = task.metadata.get('report_path', + '') or _find_report(output_dir) if not report_path or not os.path.isfile(report_path): return { 'task_id': task_id, @@ -354,14 +375,18 @@ async def _handle_get_report(args: dict[str, Any], **kwargs: Any) -> dict[str, A } -async def _handle_deep_research_sync(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: +async def _handle_deep_research_sync(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: """Launch deep_research synchronously (blocks until complete).""" query: str = args['query'] config_path = args.get('config_path', '') or _find_default_config() or '' output_dir = args.get('output_dir', '') if not config_path or not os.path.isfile(config_path): - return {'status': 'failed', 'error': f'Config not found: {config_path}'} + return { + 'status': 'failed', + 'error': f'Config not found: {config_path}' + } if not output_dir: ts = time.strftime('%Y%m%d_%H%M%S') @@ -380,12 +405,16 @@ async def _handle_deep_research_sync(args: dict[str, Any], **kwargs: Any) -> dic await proc.wait() report_path = _find_report(output_dir) if proc.returncode == 0: - return {'status': 'completed', 'output_dir': output_dir, 'report_path': report_path} + return { + 'status': 'completed', + 'output_dir': output_dir, + 'report_path': report_path + } else: return { 'status': 'failed', 'output_dir': output_dir, - 'error': stderr.decode('utf-8', errors='replace')[-2000:], + 'error': stderr.decode('utf-8', errors='replace')[-2000:] } except Exception as e: return {'status': 'failed', 'error': str(e)} @@ -398,4 +427,5 @@ def register_all(registry: CapabilityRegistry, config: Any = None) -> None: registry.register(CHECK_PROGRESS_DESCRIPTOR, _handle_check_progress) registry.register(GET_REPORT_DESCRIPTOR, _handle_get_report) # Sync (for direct Python API or long-timeout scenarios) - registry.register(DEEP_RESEARCH_SYNC_DESCRIPTOR, _handle_deep_research_sync) + registry.register(DEEP_RESEARCH_SYNC_DESCRIPTOR, + _handle_deep_research_sync) diff --git a/ms_agent/capabilities/wrappers/filesystem.py b/ms_agent/capabilities/wrappers/filesystem.py index fcf4b944e..5becb5615 100644 --- a/ms_agent/capabilities/wrappers/filesystem.py +++ b/ms_agent/capabilities/wrappers/filesystem.py @@ -9,40 +9,44 @@ name='replace_file_contents', version='0.1.0', granularity='tool', - summary=( - 'Replace exact content in a file without line numbers. ' - 'Concurrent-safe: matches by content instead of line numbers, ' - 'so parallel edits on the same file do not conflict.' - ), - description=( - 'Performs an exact-string replacement inside a file. The caller supplies ' - 'the verbatim `source` text to find and the `target` text to replace it with. ' - 'An `occurrence` parameter controls which match to replace (1-based) or ' - '-1 for all. Because it relies on content matching rather than line numbers, ' - 'it is safe to use from multiple agents editing the same file concurrently.' - ), + summary=('Replace exact content in a file without line numbers. ' + 'Concurrent-safe: matches by content instead of line numbers, ' + 'so parallel edits on the same file do not conflict.'), + description= + ('Performs an exact-string replacement inside a file. The caller supplies ' + 'the verbatim `source` text to find and the `target` text to replace it with. ' + 'An `occurrence` parameter controls which match to replace (1-based) or ' + '-1 for all. Because it relies on content matching rather than line numbers, ' + 'it is safe to use from multiple agents editing the same file concurrently.' + ), input_schema={ 'type': 'object', 'properties': { 'path': { - 'type': 'string', - 'description': 'Path to the file to modify (relative to workspace or absolute)', + 'type': + 'string', + 'description': + 'Path to the file to modify (relative to workspace or absolute)', }, 'source': { - 'type': 'string', - 'description': ( - 'Exact content to find. Must match the file content verbatim ' - 'including whitespace, punctuation, and line breaks.' - ), + 'type': + 'string', + 'description': + ('Exact content to find. Must match the file content verbatim ' + 'including whitespace, punctuation, and line breaks.'), }, 'target': { 'type': 'string', 'description': 'New content to replace the source with', }, 'occurrence': { - 'type': 'integer', - 'description': ('Which occurrence to replace (1-based). Use -1 to replace all occurrences. Default: 1'), - 'default': 1, + 'type': + 'integer', + 'description': + ('Which occurrence to replace (1-based). ' + 'Use -1 to replace all occurrences. Default: 1'), + 'default': + 1, }, }, 'required': ['path', 'source', 'target'], @@ -50,7 +54,9 @@ output_schema={ 'type': 'object', 'properties': { - 'result': {'type': 'string'}, + 'result': { + 'type': 'string' + }, }, }, tags=['filesystem', 'edit', 'replace', 'diff', 'concurrent-safe'], @@ -61,10 +67,10 @@ name='replace_file_lines', version='0.1.0', granularity='tool', - summary=( - 'Replace, insert, or append content by line range. ' - 'Supports insert-at-beginning (start_line=0) and append-at-end (start_line=-1).' - ), + summary= + ('Replace, insert, or append content by line range. ' + 'Supports insert-at-beginning (start_line=0) and append-at-end (start_line=-1).' + ), description=( 'Replaces a range of lines in a file with new content. ' 'Special modes: start_line=0 inserts at the beginning, ' @@ -82,12 +88,17 @@ 'description': 'New content to insert or replace with', }, 'start_line': { - 'type': 'integer', - 'description': ('Start line (1-based inclusive). 0 = insert at beginning, -1 = append at end.'), + 'type': + 'integer', + 'description': + ('Start line (1-based inclusive). ' + '0 = insert at beginning, -1 = append at end.'), }, 'end_line': { - 'type': 'integer', - 'description': 'End line (1-based inclusive). Required unless start_line is 0 or -1.', + 'type': + 'integer', + 'description': + 'End line (1-based inclusive). Required unless start_line is 0 or -1.', }, }, 'required': ['path', 'content', 'start_line'], @@ -95,7 +106,9 @@ output_schema={ 'type': 'object', 'properties': { - 'result': {'type': 'string'}, + 'result': { + 'type': 'string' + }, }, }, tags=['filesystem', 'edit', 'replace', 'lines'], @@ -112,8 +125,10 @@ def _resolve_path(path: str, workspace: str | None) -> str: return os.path.abspath(path) -async def _handle_replace_contents(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: - workspace = kwargs.get('workspace') or os.environ.get('MS_AGENT_OUTPUT_DIR', '') +async def _handle_replace_contents(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: + workspace = kwargs.get('workspace') or os.environ.get( + 'MS_AGENT_OUTPUT_DIR', '') path = _resolve_path(args['path'], workspace) source: str = args['source'] target: str = args['target'] @@ -130,7 +145,9 @@ async def _handle_replace_contents(args: dict[str, Any], **kwargs: Any) -> dict[ content = f.read() if source not in content: - return {'error': f'Could not find the exact content to replace in {path}'} + return { + 'error': f'Could not find the exact content to replace in {path}' + } count = content.count(source) @@ -143,7 +160,8 @@ async def _handle_replace_contents(args: dict[str, Any], **kwargs: Any) -> dict[ return {'error': f'occurrence {occurrence} exceeds total ({count})'} else: parts = content.split(source, occurrence) - updated = source.join(parts[:occurrence]) + target + source.join(parts[occurrence:]) + updated = source.join(parts[:occurrence]) + target + source.join( + parts[occurrence:]) msg = f'Replaced occurrence {occurrence} of {count}' with open(path, 'w', encoding='utf-8') as f: @@ -152,8 +170,10 @@ async def _handle_replace_contents(args: dict[str, Any], **kwargs: Any) -> dict[ return {'result': f'{msg} in {path}'} -async def _handle_replace_lines(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: - workspace = kwargs.get('workspace') or os.environ.get('MS_AGENT_OUTPUT_DIR', '') +async def _handle_replace_lines(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: + workspace = kwargs.get('workspace') or os.environ.get( + 'MS_AGENT_OUTPUT_DIR', '') path = _resolve_path(args['path'], workspace) new_content: str = args['content'] start_line: int = args['start_line'] diff --git a/ms_agent/capabilities/wrappers/lsp_code_server.py b/ms_agent/capabilities/wrappers/lsp_code_server.py index 69b39bb3b..072c394db 100644 --- a/ms_agent/capabilities/wrappers/lsp_code_server.py +++ b/ms_agent/capabilities/wrappers/lsp_code_server.py @@ -9,25 +9,30 @@ name='lsp_check_directory', version='0.1.0', granularity='component', - summary=('Run LSP diagnostics on all code files in a directory. Supports TypeScript/JavaScript, Python, and Java.'), + summary=('Run LSP diagnostics on all code files in a directory. ' + 'Supports TypeScript/JavaScript, Python, and Java.'), description=( 'Starts the appropriate Language Server Protocol backend ' '(typescript-language-server, pyright, or jdtls) and runs ' 'diagnostics on every matching file in the given directory. ' 'Returns structured error/warning information. Useful for ' - 'validating generated code or checking a project for issues.' - ), + 'validating generated code or checking a project for issues.'), input_schema={ 'type': 'object', 'properties': { 'directory': { - 'type': 'string', - 'description': 'Path to the directory to check (relative to workspace or absolute)', + 'type': + 'string', + 'description': + 'Path to the directory to check (relative to workspace or absolute)', }, 'language': { - 'type': 'string', + 'type': + 'string', 'enum': ['typescript', 'python', 'java'], - 'description': ('Programming language to check. typescript covers .ts/.tsx/.js/.jsx/.mjs/.cjs files'), + 'description': + ('Programming language to check. ' + 'typescript covers .ts/.tsx/.js/.jsx/.mjs/.cjs files'), }, }, 'required': ['directory', 'language'], @@ -35,10 +40,16 @@ output_schema={ 'type': 'object', 'properties': { - 'result': {'type': 'string', 'description': 'Diagnostic summary'}, + 'result': { + 'type': 'string', + 'description': 'Diagnostic summary' + }, }, }, - tags=['code', 'lsp', 'diagnostics', 'validation', 'typescript', 'python', 'java'], + tags=[ + 'code', 'lsp', 'diagnostics', 'validation', 'typescript', 'python', + 'java' + ], estimated_duration='minutes', parent='lsp_code_server', requires={'bins': []}, @@ -50,19 +61,19 @@ granularity='tool', summary=( 'Incrementally update a file and check for LSP errors. ' - 'More efficient than a full directory check for single-file edits.' - ), + 'More efficient than a full directory check for single-file edits.'), description=( 'Updates a file with new content and runs LSP diagnostics on it. ' 'The LSP server is reused across calls, making repeated checks on ' - 'the same project very efficient.' - ), + 'the same project very efficient.'), input_schema={ 'type': 'object', 'properties': { 'file_path': { - 'type': 'string', - 'description': 'Path to the file (relative to workspace or absolute)', + 'type': + 'string', + 'description': + 'Path to the file (relative to workspace or absolute)', }, 'content': { 'type': 'string', @@ -79,7 +90,10 @@ output_schema={ 'type': 'object', 'properties': { - 'result': {'type': 'string', 'description': 'Diagnostic output'}, + 'result': { + 'type': 'string', + 'description': 'Diagnostic output' + }, }, }, tags=['code', 'lsp', 'diagnostics', 'validation'], @@ -91,17 +105,18 @@ name='lsp_code_server', version='0.1.0', granularity='component', - summary=( - 'LSP-based code validation server supporting TypeScript, Python, and Java. ' - 'Provides directory-wide and incremental file-level diagnostics.' - ), + summary= + ('LSP-based code validation server supporting TypeScript, Python, and Java. ' + 'Provides directory-wide and incremental file-level diagnostics.'), description=( 'A component that wraps Language Server Protocol backends to provide ' 'code diagnostics without requiring an IDE. Sub-capabilities: ' 'lsp_check_directory (full project scan) and lsp_update_and_check ' - '(incremental single-file validation).' - ), - input_schema={'type': 'object', 'properties': {}}, + '(incremental single-file validation).'), + input_schema={ + 'type': 'object', + 'properties': {} + }, tags=['code', 'lsp', 'diagnostics', 'validation'], estimated_duration='minutes', sub_capabilities=['lsp_check_directory', 'lsp_update_and_check'], @@ -141,13 +156,16 @@ def _resolve_workspace(directory: str, fallback: str) -> str: return fallback -async def _handle_check_directory(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: - fallback = kwargs.get('workspace') or os.environ.get('MS_AGENT_OUTPUT_DIR', os.getcwd()) +async def _handle_check_directory(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: + fallback = kwargs.get('workspace') or os.environ.get( + 'MS_AGENT_OUTPUT_DIR', os.getcwd()) directory = args['directory'] workspace = _resolve_workspace(directory, fallback) lsp = _get_lsp_server(workspace) - rel_dir = os.path.relpath(directory, workspace) if os.path.isabs(directory) else directory + rel_dir = os.path.relpath( + directory, workspace) if os.path.isabs(directory) else directory result = await lsp.call_tool( 'lsp_code_server', @@ -160,13 +178,18 @@ async def _handle_check_directory(args: dict[str, Any], **kwargs: Any) -> dict[s return {'result': result} -async def _handle_update_and_check(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: - fallback = kwargs.get('workspace') or os.environ.get('MS_AGENT_OUTPUT_DIR', os.getcwd()) +async def _handle_update_and_check(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: + fallback = kwargs.get('workspace') or os.environ.get( + 'MS_AGENT_OUTPUT_DIR', os.getcwd()) file_path = args['file_path'] - workspace = _resolve_workspace(os.path.dirname(file_path), fallback) if os.path.isabs(file_path) else fallback + workspace = _resolve_workspace( + os.path.dirname(file_path), + fallback) if os.path.isabs(file_path) else fallback lsp = _get_lsp_server(workspace) - rel_path = os.path.relpath(file_path, workspace) if os.path.isabs(file_path) else file_path + rel_path = os.path.relpath( + file_path, workspace) if os.path.isabs(file_path) else file_path result = await lsp.call_tool( 'lsp_code_server', diff --git a/ms_agent/capabilities/wrappers/web_search.py b/ms_agent/capabilities/wrappers/web_search.py index 60fba7d29..02a9e2b56 100644 --- a/ms_agent/capabilities/wrappers/web_search.py +++ b/ms_agent/capabilities/wrappers/web_search.py @@ -16,7 +16,6 @@ def _get_engine(engine_type: str) -> Any: """Return a cached :class:`SearchEngine` instance for *engine_type*.""" if engine_type not in _engines: from ms_agent.tools.search.websearch_tool import get_search_engine - _engines[engine_type] = get_search_engine(engine_type) return _engines[engine_type] @@ -26,7 +25,6 @@ def _get_fetcher() -> Any: global _fetcher if _fetcher is None: from ms_agent.tools.search.websearch_tool import get_content_fetcher - _fetcher = get_content_fetcher('jina_reader') return _fetcher @@ -35,13 +33,13 @@ def _get_fetcher() -> Any: name='web_search', version='0.1.0', granularity='tool', - summary=('Search the web using multiple engines (exa, serpapi, arxiv) and optionally fetch full page content.'), + summary=('Search the web using multiple engines (exa, serpapi, arxiv) ' + 'and optionally fetch full page content.'), description=( 'Performs a web search and returns structured results including ' 'title, URL, and summary for each hit. Supports exa, serpapi, ' 'and arxiv backends. Set fetch_content=true to additionally ' - 'retrieve and return page text (truncated to 10 000 chars).' - ), + 'retrieve and return page text (truncated to 10 000 chars).'), input_schema={ 'type': 'object', 'properties': { @@ -55,14 +53,22 @@ def _get_fetcher() -> Any: 'default': 5, }, 'engine_type': { - 'type': 'string', - 'description': ("Search engine to use: 'exa', 'serpapi', or 'arxiv' (default: 'arxiv')"), - 'default': 'arxiv', + 'type': + 'string', + 'description': + ("Search engine to use: 'exa', 'serpapi', or 'arxiv' " + "(default: 'arxiv')"), + 'default': + 'arxiv', }, 'fetch_content': { - 'type': 'boolean', - 'description': ('Whether to fetch full page content for each result (default: false)'), - 'default': False, + 'type': + 'boolean', + 'description': + ('Whether to fetch full page content for each result ' + '(default: false)'), + 'default': + False, }, }, 'required': ['query'], @@ -70,11 +76,21 @@ def _get_fetcher() -> Any: output_schema={ 'type': 'object', 'properties': { - 'status': {'type': 'string'}, - 'query': {'type': 'string'}, - 'engine': {'type': 'string'}, - 'count': {'type': 'integer'}, - 'results': {'type': 'array'}, + 'status': { + 'type': 'string' + }, + 'query': { + 'type': 'string' + }, + 'engine': { + 'type': 'string' + }, + 'count': { + 'type': 'integer' + }, + 'results': { + 'type': 'array' + }, }, }, tags=['search', 'web', 'research'], @@ -82,7 +98,8 @@ def _get_fetcher() -> Any: ) -async def _handle_web_search(args: dict[str, Any], **kwargs: Any) -> dict[str, Any]: +async def _handle_web_search(args: dict[str, Any], + **kwargs: Any) -> dict[str, Any]: """Execute a web search and return structured results.""" query = (args.get('query') or '').strip() if not query: @@ -96,7 +113,10 @@ async def _handle_web_search(args: dict[str, Any], **kwargs: Any) -> dict[str, A try: engine = _get_engine(engine_type) except Exception as exc: - return {'error': f'Failed to initialise search engine {engine_type!r}: {exc}'} + return { + 'error': + f'Failed to initialise search engine {engine_type!r}: {exc}' + } # Build request via the engine's class method engine_cls = type(engine) @@ -123,13 +143,11 @@ async def _handle_web_search(args: dict[str, Any], **kwargs: Any) -> dict[str, A raw_list = search_result.to_list() if search_result else [] results: list[dict[str, Any]] = [] for item in raw_list[:num_results]: - results.append( - { - 'title': item.get('title', ''), - 'url': item.get('url', ''), - 'summary': item.get('summary', ''), - } - ) + results.append({ + 'title': item.get('title', ''), + 'url': item.get('url', ''), + 'summary': item.get('summary', ''), + }) # Optional content fetching if fetch_content and results: diff --git a/ms_agent/cli/app.py b/ms_agent/cli/app.py index 70bed1229..d9a690d7d 100644 --- a/ms_agent/cli/app.py +++ b/ms_agent/cli/app.py @@ -5,7 +5,8 @@ def subparser_func(args): - """Function which will be called for a specific sub parser.""" + """ Function which will be called for a specific sub parser. + """ return AppCMD(args) @@ -29,29 +30,41 @@ def define_args(parsers: argparse.ArgumentParser): '--app_type', type=str, default='doc_research', - help='The app type, supported values: `doc_research`, `fin_research`', - ) + help= + 'The app type, supported values: `doc_research`, `fin_research`') - parser.add_argument('--server_name', type=str, default='0.0.0.0', help='The gradio server name to bind to.') + parser.add_argument( + '--server_name', + type=str, + default='0.0.0.0', + help='The gradio server name to bind to.') - parser.add_argument('--server_port', type=int, default=7860, help='The gradio server port to bind to.') + parser.add_argument( + '--server_port', + type=int, + default=7860, + help='The gradio server port to bind to.') - parser.add_argument('--share', action='store_true', help='Whether to share the gradio app publicly.') + parser.add_argument( + '--share', + action='store_true', + help='Whether to share the gradio app publicly.') parser.set_defaults(func=subparser_func) def execute(self): + if self.args.app_type == 'doc_research': from ms_agent.app.doc_research import launch_server as launch_doc_research - launch_doc_research( - server_name=self.args.server_name, server_port=self.args.server_port, share=self.args.share - ) + server_name=self.args.server_name, + server_port=self.args.server_port, + share=self.args.share) elif self.args.app_type == 'fin_research': from ms_agent.app.fin_research import launch_server as launch_fin_research - launch_fin_research( - server_name=self.args.server_name, server_port=self.args.server_port, share=self.args.share - ) + server_name=self.args.server_name, + server_port=self.args.server_port, + share=self.args.share) else: raise ValueError(f'Unsupported app type: {self.args.app_type}') diff --git a/ms_agent/cli/cli.py b/ms_agent/cli/cli.py index fafc6e849..da709e98d 100644 --- a/ms_agent/cli/cli.py +++ b/ms_agent/cli/cli.py @@ -10,9 +10,12 @@ def run_cmd(): This cmd imports all other sub commands, for example, `run` and `app`. """ - parser = argparse.ArgumentParser('ModelScope-agent Command Line tool', usage='ms-agent []') + parser = argparse.ArgumentParser( + 'ModelScope-agent Command Line tool', + usage='ms-agent []') - subparsers = parser.add_subparsers(help='ModelScope-agent commands helpers') + subparsers = parser.add_subparsers( + help='ModelScope-agent commands helpers') RunCMD.define_args(subparsers) AppCMD.define_args(subparsers) diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index 855ce4144..2accdb40e 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -4,12 +4,11 @@ import os from importlib import resources as importlib_resources -from omegaconf import OmegaConf - from ms_agent.config import Config from ms_agent.config.env import Env from ms_agent.utils import get_logger, strtobool from ms_agent.utils.constants import AGENT_CONFIG_FILE, MS_AGENT_ASCII +from omegaconf import OmegaConf from .base import CLICommand @@ -17,7 +16,8 @@ def subparser_func(args): - """Function which will be called for a specific sub parser.""" + """ Function which will be called for a specific sub parser. + """ return RunCMD(args) @@ -36,7 +36,9 @@ def list_builtin_projects(): def project_help_text(): projects = list_builtin_projects() if projects: - return f'Built-in bundled project name under package ms_agent/projects. Available: {", ".join(projects)}' + return ( + 'Built-in bundled project name under package ms_agent/projects. ' + f'Available: {", ".join(projects)}') return 'Built-in bundled project name under package ms_agent/projects.' @@ -74,18 +76,22 @@ def define_args(parsers: argparse.ArgumentParser): type=str, default=None, metavar='PATH', - help='Path to a .env file. If omitted, loads ./.env from the current ' - 'working directory when present; missing file is ignored.', - ) + help= + 'Path to a .env file. If omitted, loads ./.env from the current ' + 'working directory when present; missing file is ignored.') parser.add_argument( '--query', required=False, type=str, - help='The query or prompt to send to the LLM. If not set, will enter an interactive mode.', + help= + 'The query or prompt to send to the LLM. If not set, will enter an interactive mode.' ) parser.add_argument( - '--config', required=False, type=str, default=None, help='The directory or the repo id of the config file' - ) + '--config', + required=False, + type=str, + default=None, + help='The directory or the repo id of the config file') parser.add_argument( '--project', required=False, @@ -99,64 +105,78 @@ def define_args(parsers: argparse.ArgumentParser): required=False, type=str, default='false', - help='Trust the code belongs to the config file, default False', - ) + help='Trust the code belongs to the config file, default False') parser.add_argument( '--load_cache', required=False, type=str, default='false', - help='Load previous step histories from cache, this is useful when a query fails and retry', + help= + 'Load previous step histories from cache, this is useful when a query fails and retry' ) - parser.add_argument('--mcp_config', required=False, type=str, default=None, help='The extra mcp server config') parser.add_argument( - '--mcp_server_file', required=False, type=str, default=None, help='An extra mcp server file.' - ) + '--mcp_config', + required=False, + type=str, + default=None, + help='The extra mcp server config') + parser.add_argument( + '--mcp_server_file', + required=False, + type=str, + default=None, + help='An extra mcp server file.') parser.add_argument( '--openai_api_key', required=False, type=str, default=None, - help='API key for accessing an OpenAI-compatible service.', - ) + help='API key for accessing an OpenAI-compatible service.') parser.add_argument( '--modelscope_api_key', required=False, type=str, default=None, - help='API key for accessing ModelScope api-inference services.', - ) + help='API key for accessing ModelScope api-inference services.') parser.add_argument( '--animation_mode', required=False, type=str, choices=['auto', 'human'], default=None, - help='Animation mode for video_generate project: auto (default) or human.', + help= + 'Animation mode for video_generate project: auto (default) or human.' ) parser.add_argument( '--knowledge_search_paths', required=False, type=str, default=None, - help='Comma-separated list of paths for knowledge search.', + help= + 'Comma-separated list of paths for knowledge search.' ) parser.set_defaults(func=subparser_func) def execute(self): if getattr(self.args, 'project', None): if self.args.config: - raise ValueError('Please specify only one of --config or --project') + raise ValueError( + 'Please specify only one of --config or --project') project = self.args.project - project_trav = importlib_resources.files('ms_agent').joinpath('projects', project) + project_trav = importlib_resources.files('ms_agent').joinpath( + 'projects', project) if not project_trav.exists(): - projects_root = importlib_resources.files('ms_agent').joinpath('projects') + projects_root = importlib_resources.files('ms_agent').joinpath( + 'projects') available = [] if projects_root.exists(): - available = [p.name for p in projects_root.iterdir() if p.is_dir()] - raise ValueError(f'Unknown project: {project}. Available: {available}') + available = [ + p.name for p in projects_root.iterdir() if p.is_dir() + ] + raise ValueError( + f'Unknown project: {project}. Available: {available}') # as_file ensures we get a real filesystem path even if installed as zip with importlib_resources.as_file(project_trav) as project_dir: @@ -172,14 +192,16 @@ def _execute_with_config(self): self.args.config = os.path.join(current_dir, AGENT_CONFIG_FILE) else: # Use built-in default agent.yaml from package - default_config_path = importlib_resources.files('ms_agent').joinpath('agent', AGENT_CONFIG_FILE) - with importlib_resources.as_file(default_config_path) as config_file: + default_config_path = importlib_resources.files( + 'ms_agent').joinpath('agent', AGENT_CONFIG_FILE) + with importlib_resources.as_file( + default_config_path) as config_file: self.args.config = str(config_file) elif not os.path.exists(self.args.config): from modelscope import snapshot_download - self.args.config = snapshot_download(self.args.config) - self.args.trust_remote_code = strtobool(self.args.trust_remote_code) # noqa + self.args.trust_remote_code = strtobool( + self.args.trust_remote_code) # noqa self.args.load_cache = strtobool(self.args.load_cache) # Propagate animation mode via environment variable for downstream code agents @@ -197,19 +219,26 @@ def _execute_with_config(self): author = f.read() blue_color_prefix = '\033[34m' blue_color_suffix = '\033[0m' - print(blue_color_prefix + MS_AGENT_ASCII + blue_color_suffix, flush=True) + print( + blue_color_prefix + MS_AGENT_ASCII + blue_color_suffix, flush=True) line_start = '═════════════════════════Workflow Contributed By════════════════════════════' line_end = '════════════════════════════════════════════════════════════════════════════' if author: - print(blue_color_prefix + line_start + blue_color_suffix, flush=True) - print(blue_color_prefix + author.strip() + blue_color_suffix, flush=True) + print( + blue_color_prefix + line_start + blue_color_suffix, flush=True) + print( + blue_color_prefix + author.strip() + blue_color_suffix, + flush=True) print(blue_color_prefix + line_end + blue_color_suffix, flush=True) config = Config.from_task(self.args.config) # If knowledge_search_paths is provided, configure tools.localsearch if getattr(self.args, 'knowledge_search_paths', None): - paths = [p.strip() for p in self.args.knowledge_search_paths.split(',') if p.strip()] + paths = [ + p.strip() for p in self.args.knowledge_search_paths.split(',') + if p.strip() + ] if paths: if not hasattr(config, 'tools') or config.tools is None: config['tools'] = OmegaConf.create({}) @@ -220,7 +249,8 @@ def _execute_with_config(self): 'work_path': './.sirchmunk', 'mode': 'FAST', } - config.tools['localsearch'] = OmegaConf.create(localsearch_config) + config.tools['localsearch'] = OmegaConf.create( + localsearch_config) else: existing = OmegaConf.to_container(tl, resolve=True) existing['paths'] = paths @@ -228,22 +258,18 @@ def _execute_with_config(self): if Config.is_workflow(config): from ms_agent.workflow.loader import WorkflowLoader - engine = WorkflowLoader.build( config_dir_or_id=self.args.config, config=config, mcp_server_file=self.args.mcp_server_file, load_cache=self.args.load_cache, - trust_remote_code=self.args.trust_remote_code, - ) + trust_remote_code=self.args.trust_remote_code) else: from ms_agent.agent.loader import AgentLoader - engine = AgentLoader.build( config_dir_or_id=self.args.config, config=config, mcp_server_file=self.args.mcp_server_file, load_cache=self.args.load_cache, - trust_remote_code=self.args.trust_remote_code, - ) + trust_remote_code=self.args.trust_remote_code) asyncio.run(engine.run(self.args.query)) diff --git a/ms_agent/cli/ui.py b/ms_agent/cli/ui.py index 4d0c40be2..b10c9bb5d 100644 --- a/ms_agent/cli/ui.py +++ b/ms_agent/cli/ui.py @@ -11,7 +11,8 @@ def subparser_func(args): - """Function which will be called for a specific sub parser.""" + """ Function which will be called for a specific sub parser. + """ return UICMD(args) @@ -27,11 +28,28 @@ def __init__(self, args): def define_args(parsers: argparse.ArgumentParser): """Define args for the ui command.""" parser: argparse.ArgumentParser = parsers.add_parser(UICMD.name) - parser.add_argument('--host', type=str, default='0.0.0.0', help='The server host to bind to.') - parser.add_argument('--port', type=int, default=7860, help='The server port to bind to.') - parser.add_argument('--reload', action='store_true', help='Enable auto-reload for development.') - parser.add_argument('--production', action='store_true', help='Run in production mode (serve built frontend).') - parser.add_argument('--no-browser', action='store_true', help='Do not automatically open browser.') + parser.add_argument( + '--host', + type=str, + default='0.0.0.0', + help='The server host to bind to.') + parser.add_argument( + '--port', + type=int, + default=7860, + help='The server port to bind to.') + parser.add_argument( + '--reload', + action='store_true', + help='Enable auto-reload for development.') + parser.add_argument( + '--production', + action='store_true', + help='Run in production mode (serve built frontend).') + parser.add_argument( + '--no-browser', + action='store_true', + help='Do not automatically open browser.') parser.set_defaults(func=subparser_func) def execute(self): @@ -41,7 +59,6 @@ def execute(self): if not webui_dir.exists(): import ms_agent - ms_agent_path = Path(ms_agent.__file__).parent webui_dir = ms_agent_path / 'webui' @@ -56,10 +73,13 @@ def execute(self): sys.exit(1) frontend_dist = frontend_dir / 'dist' - frontend_built = frontend_dist.exists() and (frontend_dist / 'index.html').exists() + frontend_built = frontend_dist.exists() and (frontend_dist + / 'index.html').exists() if self.args.production and not frontend_built: - print('Error: Frontend not built. Please run "npm run build" in webui/frontend first.') + print( + 'Error: Frontend not built. Please run "npm run build" in webui/frontend first.' + ) sys.exit(1) if not self.args.production and not frontend_built: @@ -95,7 +115,8 @@ def open_browser(): time.sleep(1.5) webbrowser.open(browser_url) - browser_thread = threading.Thread(target=open_browser, daemon=True) + browser_thread = threading.Thread( + target=open_browser, daemon=True) browser_thread.start() main() @@ -105,7 +126,6 @@ def open_browser(): except Exception as e: print(f'Error starting WebUI: {e}') import traceback - traceback.print_exc() sys.exit(1) finally: @@ -116,33 +136,33 @@ def _build_frontend(self, frontend_dir: Path) -> bool: import subprocess try: - subprocess.run(['npm', '--version'], capture_output=True, check=True, timeout=5) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): + subprocess.run(['npm', '--version'], + capture_output=True, + check=True, + timeout=5) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, + FileNotFoundError): return False node_modules = frontend_dir / 'node_modules' if not node_modules.exists(): try: - subprocess.run( - ['npm', 'install'], - cwd=frontend_dir, - check=True, - timeout=300, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) + subprocess.run(['npm', 'install'], + cwd=frontend_dir, + check=True, + timeout=300, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) except (subprocess.TimeoutExpired, subprocess.CalledProcessError): return False try: - subprocess.run( - ['npm', 'run', 'build'], - cwd=frontend_dir, - check=True, - timeout=300, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) + subprocess.run(['npm', 'run', 'build'], + cwd=frontend_dir, + check=True, + timeout=300, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) return True except (subprocess.TimeoutExpired, subprocess.CalledProcessError): return False diff --git a/ms_agent/config/config.py b/ms_agent/config/config.py index 93cf3a338..2f6175524 100644 --- a/ms_agent/config/config.py +++ b/ms_agent/config/config.py @@ -5,13 +5,12 @@ from copy import deepcopy from typing import Any, Dict, Union -from modelscope import snapshot_download -from omegaconf import DictConfig, ListConfig, OmegaConf -from omegaconf.basecontainer import BaseContainer - from ms_agent.prompting import apply_prompt_files from ms_agent.utils import get_logger +from omegaconf import DictConfig, ListConfig, OmegaConf +from omegaconf.basecontainer import BaseContainer +from modelscope import snapshot_download from ..utils.constants import TOOL_PLUGIN_NAME from .env import Env @@ -19,6 +18,7 @@ class ConfigLifecycleHandler: + def task_begin(self, config: DictConfig, tag: str) -> DictConfig: """Modify config when the task begins. @@ -51,10 +51,14 @@ class Config: """All tasks begin from a config""" tag: str = '' - supported_config_names = ['workflow.yaml', 'workflow.yml', 'agent.yaml', 'agent.yml'] + supported_config_names = [ + 'workflow.yaml', 'workflow.yml', 'agent.yaml', 'agent.yml' + ] @classmethod - def from_task(cls, config_dir_or_id: str, env: Dict[str, str] = None) -> Union[DictConfig, ListConfig]: + def from_task(cls, + config_dir_or_id: str, + env: Dict[str, str] = None) -> Union[DictConfig, ListConfig]: """Read a task config file and return a config object. Args: @@ -84,8 +88,7 @@ def from_task(cls, config_dir_or_id: str, env: Dict[str, str] = None) -> Union[D assert config is not None, ( f'Cannot find any valid config file in {config_dir_or_id}, ' - f'supported configs are: {Config.supported_config_names}' - ) + f'supported configs are: {Config.supported_config_names}') envs = Env.load_env(env) cls._update_config(config, envs) _dict_config = cls.parse_args() @@ -114,7 +117,10 @@ def fill_missing_fields(config: DictConfig) -> DictConfig: @staticmethod def is_workflow(config: DictConfig) -> bool: assert config.name is not None, 'Cannot find a valid name in this config' - return config.name in ['workflow.yaml', 'workflow.yml', 'simple_workflow.yaml', 'simple_workflow.yml'] + return config.name in [ + 'workflow.yaml', 'workflow.yml', 'simple_workflow.yaml', + 'simple_workflow.yml' + ] @staticmethod def parse_args() -> Dict[str, Any]: @@ -125,16 +131,19 @@ def parse_args() -> Dict[str, Any]: for idx in range(1, len(unknown) - 1, 2): key = unknown[idx] value = unknown[idx + 1] - assert key.startswith('--'), f'Parameter not correct: {unknown}' + assert key.startswith( + '--'), f'Parameter not correct: {unknown}' _dict_config[key[2:]] = value return _dict_config @staticmethod - def _update_config(config: Union[DictConfig, ListConfig], extra: Dict[str, str] = None): + def _update_config(config: Union[DictConfig, ListConfig], + extra: Dict[str, str] = None): if not extra: return config - def traverse_config(_config: Union[DictConfig, ListConfig, Any], path: str = ''): + def traverse_config(_config: Union[DictConfig, ListConfig, Any], + path: str = ''): if isinstance(_config, DictConfig): for name, value in _config.items(): current_path = f'{path}.{name}' if path else name @@ -143,45 +152,48 @@ def traverse_config(_config: Union[DictConfig, ListConfig, Any], path: str = '') traverse_config(value, current_path) else: if current_path in extra: - logger.info(f'Replacing {current_path} with extra value.') + logger.info( + f'Replacing {current_path} with extra value.') # Convert temperature to float and max_tokens to int if they're numeric strings value_to_set = extra[current_path] - if name == 'temperature' and isinstance(value_to_set, str): + if name == 'temperature' and isinstance( + value_to_set, str): try: value_to_set = float(value_to_set) except (ValueError, TypeError): pass - elif name == 'max_tokens' and isinstance(value_to_set, str): + elif name == 'max_tokens' and isinstance( + value_to_set, str): try: value_to_set = int(value_to_set) except (ValueError, TypeError): pass setattr(_config, name, value_to_set) # Find the key in extra that matches name (case-insensitive) - elif ( - key_match := next((key for key in extra if key.lower() == name.lower()), None) - ) is not None: + elif (key_match := next( + (key + for key in extra if key.lower() == name.lower()), + None)) is not None: logger.info(f'Replacing {name} with extra value.') # Convert temperature to float and max_tokens to int if they're numeric strings value_to_set = extra[key_match] - if name == 'temperature' and isinstance(value_to_set, str): + if name == 'temperature' and isinstance( + value_to_set, str): try: value_to_set = float(value_to_set) except (ValueError, TypeError): pass - elif name == 'max_tokens' and isinstance(value_to_set, str): + elif name == 'max_tokens' and isinstance( + value_to_set, str): try: value_to_set = int(value_to_set) except (ValueError, TypeError): pass setattr(_config, name, value_to_set) # Handle placeholder replacement like - elif ( - isinstance(value, str) - and value.startswith('<') - and value.endswith('>') - and value[1:-1] in extra - ): + elif (isinstance(value, str) and value.startswith('<') + and value.endswith('>') + and value[1:-1] in extra): logger.info(f'Replacing {value} with extra value.') setattr(_config, name, extra[value[1:-1]]) @@ -191,12 +203,9 @@ def traverse_config(_config: Union[DictConfig, ListConfig, Any], path: str = '') if isinstance(value, BaseContainer): traverse_config(value, path) else: - if ( - isinstance(value, str) - and value.startswith('<') - and value.endswith('>') - and value[1:-1] in extra - ): + if (isinstance(value, str) and value.startswith('<') + and value.endswith('>') + and value[1:-1] in extra): logger.info(f'Replacing {value} with extra value.') _config[idx] = extra[value[1:-1]] @@ -208,20 +217,24 @@ def traverse_config(_config: Union[DictConfig, ListConfig, Any], path: str = '') current = config # Navigate/create nested structure for i, part in enumerate(parts[:-1]): - if not hasattr(current, part) or getattr(current, part) is None: + if not hasattr(current, + part) or getattr(current, part) is None: setattr(current, part, DictConfig({})) current = getattr(current, part) final_key = parts[-1] - if not hasattr(current, final_key) or getattr(current, final_key) is None: + if not hasattr(current, final_key) or getattr( + current, final_key) is None: logger.info(f'Adding new config key: {key}') # Convert temperature to float and max_tokens to int if they're numeric strings value_to_set = value - if final_key == 'temperature' and isinstance(value_to_set, str): + if final_key == 'temperature' and isinstance( + value_to_set, str): try: value_to_set = float(value_to_set) except (ValueError, TypeError): pass - elif final_key == 'max_tokens' and isinstance(value_to_set, str): + elif final_key == 'max_tokens' and isinstance( + value_to_set, str): try: value_to_set = int(value_to_set) except (ValueError, TypeError): @@ -231,7 +244,9 @@ def traverse_config(_config: Union[DictConfig, ListConfig, Any], path: str = '') return None @staticmethod - def convert_mcp_servers_to_json(config: Union[DictConfig, ListConfig]) -> Dict[str, Dict[str, Any]]: + def convert_mcp_servers_to_json( + config: Union[DictConfig, + ListConfig]) -> Dict[str, Dict[str, Any]]: """Convert the mcp servers to json mcp config.""" servers = {'mcpServers': {}} if getattr(config, 'tools', None): diff --git a/ms_agent/config/env.py b/ms_agent/config/env.py index 658c58a18..83553254c 100644 --- a/ms_agent/config/env.py +++ b/ms_agent/config/env.py @@ -7,6 +7,7 @@ class Env: + @staticmethod def load_dotenv_into_environ(dotenv_path: Optional[str] = None) -> None: """Load key=value pairs from a .env file into ``os.environ``. @@ -28,7 +29,8 @@ def load_dotenv_into_environ(dotenv_path: Optional[str] = None) -> None: load_dotenv(default, override=False) @staticmethod - def load_env(envs: Dict[str, str] = None, dotenv_path: Optional[str] = None) -> Dict[str, str]: + def load_env(envs: Dict[str, str] = None, + dotenv_path: Optional[str] = None) -> Dict[str, str]: """Load .env into the process env, then merge with ``envs`` and return.""" Env.load_dotenv_into_environ(dotenv_path) _envs = copy(os.environ) diff --git a/ms_agent/llm/anthropic_llm.py b/ms_agent/llm/anthropic_llm.py index 5df3c9cfa..5b35bfb5d 100644 --- a/ms_agent/llm/anthropic_llm.py +++ b/ms_agent/llm/anthropic_llm.py @@ -1,14 +1,13 @@ import inspect -import json from typing import Any, Dict, Generator, Iterator, List, Optional, Union import httpx -from omegaconf import DictConfig, OmegaConf - +import json from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool, ToolCall from ms_agent.utils import assert_package_exist, retry from ms_agent.utils.constants import get_service_config +from omegaconf import DictConfig, OmegaConf class _SSEEventInjector(httpx.SyncByteStream): @@ -62,7 +61,10 @@ class DashScopeAnthropicTransport(httpx.BaseTransport): rewrites URL, auth headers, and body so the Anthropic SDK works unmodified. """ - def __init__(self, dashscope_url: str, api_key: str, supplier: Optional[str] = None): + def __init__(self, + dashscope_url: str, + api_key: str, + supplier: Optional[str] = None): self.dashscope_url = dashscope_url self.api_key = api_key self.supplier = supplier @@ -81,7 +83,10 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: 'content-type': 'application/json', 'authorization': f'Bearer {self.api_key}', } - _skip = frozenset({'x-api-key', 'content-type', 'authorization', 'content-length', 'host', 'transfer-encoding'}) + _skip = frozenset({ + 'x-api-key', 'content-type', 'authorization', 'content-length', + 'host', 'transfer-encoding' + }) for key, value in request.headers.items(): k = key.lower() if k not in _skip and not k.startswith('anthropic'): @@ -111,6 +116,7 @@ def close(self): class Anthropic(LLM): + def __init__( self, config: DictConfig, @@ -123,7 +129,8 @@ def __init__( self.model: str = config.llm.model - base_url = base_url or config.llm.get('anthropic_base_url') or get_service_config('anthropic').base_url + base_url = base_url or config.llm.get( + 'anthropic_base_url') or get_service_config('anthropic').base_url api_key = api_key or config.llm.get('anthropic_api_key') if not api_key: @@ -155,28 +162,30 @@ def __init__( base_url=base_url, ) - self.args: Dict = OmegaConf.to_container(getattr(config, 'generation_config', DictConfig({}))) + self.args: Dict = OmegaConf.to_container( + getattr(config, 'generation_config', DictConfig({}))) - def format_tools(self, tools: Optional[List[Tool]]) -> Optional[List[Dict]]: + def format_tools(self, + tools: Optional[List[Tool]]) -> Optional[List[Dict]]: if not tools: return None formatted_tools = [] for tool in tools: - formatted_tools.append( - { - 'name': tool['tool_name'], - 'description': tool.get('description', ''), - 'input_schema': { - 'type': 'object', - 'properties': tool.get('parameters', {}).get('properties', {}), - 'required': tool.get('parameters', {}).get('required', []), - }, + formatted_tools.append({ + 'name': tool['tool_name'], + 'description': tool.get('description', ''), + 'input_schema': { + 'type': 'object', + 'properties': tool.get('parameters', + {}).get('properties', {}), + 'required': tool.get('parameters', {}).get('required', []), } - ) + }) return formatted_tools - def _format_input_message(self, messages: List[Message]) -> List[Dict[str, Any]]: + def _format_input_message(self, + messages: List[Message]) -> List[Dict[str, Any]]: """Converts a list of Message objects into the format expected by the Anthropic API. Args: @@ -194,30 +203,34 @@ def _format_input_message(self, messages: List[Message]) -> List[Dict[str, Any]] if msg.tool_calls: for tool_call in msg.tool_calls: - content.append( - { - 'type': 'tool_use', - 'id': tool_call['id'], - 'name': tool_call['tool_name'], - 'input': tool_call.get('arguments', {}), - } - ) + content.append({ + 'type': 'tool_use', + 'id': tool_call['id'], + 'name': tool_call['tool_name'], + 'input': tool_call.get('arguments', {}) + }) if msg.role == 'tool': - formatted_messages.append( - { - 'role': 'user', - 'content': [{'type': 'tool_result', 'tool_use_id': msg.tool_call_id, 'content': msg.content}], - } - ) + formatted_messages.append({ + 'role': + 'user', + 'content': [{ + 'type': 'tool_result', + 'tool_use_id': msg.tool_call_id, + 'content': msg.content + }] + }) continue formatted_messages.append({'role': msg.role, 'content': content}) return formatted_messages - def _call_llm( - self, messages: List[Message], tools: Optional[List[Dict]] = None, stream: bool = False, **kwargs - ) -> Any: + def _call_llm(self, + messages: List[Message], + tools: Optional[List[Dict]] = None, + stream: bool = False, + **kwargs) -> Any: + formatted_messages = self._format_input_message(messages) formatted_messages = [m for m in formatted_messages if m['content']] @@ -233,14 +246,21 @@ def _call_llm( thinking_type = kwargs.pop('thinking_type', None) raw_extra_body = kwargs.pop('extra_body', {}) or {} - extra_body = dict(raw_extra_body) if isinstance(raw_extra_body, dict) else {} - enable_thinking = bool(extra_body.pop('enable_thinking', enable_thinking)) - thinking_budget = extra_body.pop('thinking_budget', thinking_budget) or max_tokens + extra_body = dict(raw_extra_body) if isinstance(raw_extra_body, + dict) else {} + enable_thinking = bool( + extra_body.pop('enable_thinking', enable_thinking)) + thinking_budget = extra_body.pop('thinking_budget', + thinking_budget) or max_tokens thinking_type = extra_body.pop('thinking_type', thinking_type) for _k in ('show_reasoning', 'reasoning_output'): extra_body.pop(_k, None) - params = {'model': self.model, 'messages': formatted_messages, 'max_tokens': max_tokens} + params = { + 'model': self.model, + 'messages': formatted_messages, + 'max_tokens': max_tokens + } if thinking_type == 'adaptive': params['thinking'] = {'type': 'adaptive'} @@ -264,13 +284,12 @@ def _call_llm( return self.client.messages.create(**params) @retry(max_attempts=LLM.retry_count, delay=1.0) - def generate( - self, - messages: List[Message], - tools: Optional[List[Tool]] = None, - max_continue_runs: Optional[int] = None, - **kwargs, - ) -> Union[Message, Generator[Message, None, None]]: + def generate(self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + max_continue_runs: Optional[int] = None, + **kwargs) -> Union[Message, Generator[Message, None, None]]: + formatted_tools = self.format_tools(tools) args = self.args.copy() args.update(kwargs) @@ -279,14 +298,16 @@ def generate( sig_params = inspect.signature(self.client.messages.create).parameters filtered_args = {k: v for k, v in args.items() if k in sig_params} - completion = self._call_llm(messages, formatted_tools, stream, **filtered_args) + completion = self._call_llm(messages, formatted_tools, stream, + **filtered_args) if stream: return self._stream_format_output_message(completion) else: return self._format_output_message(completion) - def _stream_format_output_message(self, stream_manager) -> Iterator[Message]: + def _stream_format_output_message(self, + stream_manager) -> Iterator[Message]: current_message = Message( role='assistant', content='', @@ -339,11 +360,11 @@ def _stream_format_output_message(self, stream_manager) -> Iterator[Message]: current_message.content = full_content current_message.partial = False current_message.completion_tokens = getattr( - final_msg.usage, 'output_tokens', current_message.completion_tokens - ) + final_msg.usage, 'output_tokens', + current_message.completion_tokens) current_message.prompt_tokens = getattr( - final_msg.usage, 'input_tokens', current_message.prompt_tokens - ) + final_msg.usage, 'input_tokens', + current_message.prompt_tokens) yield current_message @@ -371,11 +392,11 @@ def _format_output_message(completion) -> Message: ToolCall( id=block.id, index=len(tool_calls), # index based on appearance - type='function', # or "tool_use" depending on your schema + type= + 'function', # or "tool_use" depending on your schema arguments=block.input, tool_name=block.name, - ) - ) + )) # Anthropic does not have a native "reasoning_content" field reasoning_content = '' @@ -393,31 +414,34 @@ def _format_output_message(completion) -> Message: if __name__ == '__main__': import os - config = { 'llm': { 'model': 'Qwen/Qwen2.5-VL-72B-Instruct', 'anthropic_api_key': os.getenv('MODELSCOPE_API_KEY'), - 'anthropic_base_url': 'https://api-inference.modelscope.cn', + 'anthropic_base_url': 'https://api-inference.modelscope.cn' }, 'generation_config': { 'stream': True, - }, + } } - tools = [ - { - 'tool_name': 'get_weather', - 'description': 'Get the current weather in a given location', - 'parameters': { - 'type': 'object', - 'properties': { - 'location': {'type': 'string', 'description': 'City and state'}, - 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}, + tools = [{ + 'tool_name': 'get_weather', + 'description': 'Get the current weather in a given location', + 'parameters': { + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': 'City and state' }, - 'required': ['location'], + 'unit': { + 'type': 'string', + 'enum': ['celsius', 'fahrenheit'] + } }, + 'required': ['location'] } - ] + }] messages = [Message(role='user', content='描述杭州,300字')] # messages = [Message(role='user', content='去伦敦现在该带什么样的衣服?')] diff --git a/ms_agent/llm/dashscope_llm.py b/ms_agent/llm/dashscope_llm.py index d50ca7edb..b4a6ddaa8 100644 --- a/ms_agent/llm/dashscope_llm.py +++ b/ms_agent/llm/dashscope_llm.py @@ -1,24 +1,29 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import List -from omegaconf import DictConfig - from ms_agent.llm.openai_llm import OpenAI from ms_agent.llm.utils import Message, Tool from ms_agent.utils.constants import get_service_config +from omegaconf import DictConfig class DashScope(OpenAI): + def __init__(self, config: DictConfig): super().__init__( config, - base_url=config.llm.dashscope_base_url or get_service_config('dashscope').base_url, - api_key=config.llm.dashscope_api_key, - ) + base_url=config.llm.dashscope_base_url + or get_service_config('dashscope').base_url, + api_key=config.llm.dashscope_api_key) - def _call_llm_for_continue_gen(self, messages: List[Message], new_message, tools: List[Tool] = None, **kwargs): + def _call_llm_for_continue_gen(self, + messages: List[Message], + new_message, + tools: List[Tool] = None, + **kwargs): # ref: https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=https%3A%2F%2Fhelp.aliyun.com%2Fdocument_detail%2F2862210.html&renderType=iframe # noqa if messages and messages[-1].to_dict().get('partial', False): + messages[-1].reasoning_content += new_message.reasoning_content messages[-1].content += new_message.content if new_message.tool_calls: diff --git a/ms_agent/llm/deepseek_llm.py b/ms_agent/llm/deepseek_llm.py index d379debd3..e565308bc 100644 --- a/ms_agent/llm/deepseek_llm.py +++ b/ms_agent/llm/deepseek_llm.py @@ -1,21 +1,28 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import List -from omegaconf import DictConfig - from ms_agent.llm.openai_llm import OpenAI from ms_agent.llm.utils import Message, Tool +from omegaconf import DictConfig class DeepSeek(OpenAI): input_msg = {'role', 'content', 'tool_calls', 'prefix'} def __init__(self, config: DictConfig): - super().__init__(config, base_url=config.llm.deepseek_base_url, api_key=config.llm.deepseek_api_key) - - def _call_llm_for_continue_gen(self, messages: List[Message], new_message, tools: List[Tool] = None, **kwargs): + super().__init__( + config, + base_url=config.llm.deepseek_base_url, + api_key=config.llm.deepseek_api_key) + + def _call_llm_for_continue_gen(self, + messages: List[Message], + new_message, + tools: List[Tool] = None, + **kwargs): # ref: https://api-docs.deepseek.com/zh-cn/guides/chat_prefix_completion if messages and messages[-1].to_dict().get('prefix', False): + messages[-1].reasoning_content += new_message.reasoning_content messages[-1].content += new_message.content if new_message.tool_calls: @@ -29,30 +36,28 @@ def _call_llm_for_continue_gen(self, messages: List[Message], new_message, tools messages = self.format_input_message(messages) stop = kwargs.pop('stop', []).append('```') - return self._call_llm(messages=messages, tools=tools, stop=stop, **kwargs) + return self._call_llm( + messages=messages, tools=tools, stop=stop, **kwargs) if __name__ == '__main__': import os - from omegaconf import OmegaConf # 创建一个嵌套的字典结构 - conf: DictConfig = OmegaConf.create( - { - 'llm': { - 'model': 'deepseek-reasoner', - 'deepseek_base_url': 'https://api.deepseek.com/beta/v1', - 'deepseek_api_key': os.getenv('DEEPSEEK_API_KEY'), - 'openai_base_url': 'https://api-inference.modelscope.cn/v1', - 'openai_api_key': os.getenv('MODELSCOPE_API_KEY'), - 'generation_config': { - 'stream': True, - 'max_tokens': 500, - }, + conf: DictConfig = OmegaConf.create({ + 'llm': { + 'model': 'deepseek-reasoner', + 'deepseek_base_url': 'https://api.deepseek.com/beta/v1', + 'deepseek_api_key': os.getenv('DEEPSEEK_API_KEY'), + 'openai_base_url': 'https://api-inference.modelscope.cn/v1', + 'openai_api_key': os.getenv('MODELSCOPE_API_KEY'), + 'generation_config': { + 'stream': True, + 'max_tokens': 500, } } - ) + }) messages = [ Message(role='assistant', content='You are a helpful assistant.'), @@ -81,7 +86,11 @@ def _call_llm_for_continue_gen(self, messages: List[Message], new_message, tools # print(chunk) # kwargs覆盖conf - message = llm.generate(messages=messages, tools=tools, stream=False, extra_body={'enable_thinking': False}) + message = llm.generate( + messages=messages, + tools=tools, + stream=False, + extra_body={'enable_thinking': False}) print(message) messages.append(message) # messages.append(Message(role='tool', content='北京市朝阳区崔各庄阿里巴巴朝阳科技园')) diff --git a/ms_agent/llm/llm.py b/ms_agent/llm/llm.py index e6998c2de..72af53467 100644 --- a/ms_agent/llm/llm.py +++ b/ms_agent/llm/llm.py @@ -3,15 +3,15 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional -from omegaconf import DictConfig - from ms_agent.config import Config +from omegaconf import DictConfig from ..utils.constants import DEFAULT_RETRY_COUNT from .utils import Message, Tool class LLM: + retry_count = int(os.environ.get('LLM_RETRY_COUNT', DEFAULT_RETRY_COUNT)) def __init__(self, config: DictConfig): @@ -23,9 +23,11 @@ def __init__(self, config: DictConfig): self.config = config @abstractmethod - def generate( - self, messages: List[Message], model: Optional[str] = None, tools: Optional[List[Tool]] = None, **kwargs - ) -> Any: + def generate(self, + messages: List[Message], + model: Optional[str] = None, + tools: Optional[List[Tool]] = None, + **kwargs) -> Any: """Generate response by the given messages. Args: @@ -40,7 +42,10 @@ def generate( pass @classmethod - def from_task(cls, config_dir_or_id: str, *, env: Optional[Dict[str, str]] = None) -> Any: + def from_task(cls, + config_dir_or_id: str, + *, + env: Optional[Dict[str, str]] = None) -> Any: """Instantiate an LLM instance. Args: @@ -64,8 +69,7 @@ def from_config(cls, config: DictConfig) -> Any: Returns: The LLM instance. """ - from .model_mapping import OpenAI, all_services_mapping - + from .model_mapping import all_services_mapping, OpenAI if config.llm.get('service') in all_services_mapping: return all_services_mapping[config.llm.service](config) else: diff --git a/ms_agent/llm/modelscope_llm.py b/ms_agent/llm/modelscope_llm.py index e1ba329ab..7b761c5c0 100644 --- a/ms_agent/llm/modelscope_llm.py +++ b/ms_agent/llm/modelscope_llm.py @@ -1,17 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from omegaconf import DictConfig - from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils.constants import get_service_config +from omegaconf import DictConfig class ModelScope(OpenAI): + def __init__(self, config: DictConfig): - assert hasattr(config.llm, 'modelscope_api_key') and config.llm.modelscope_api_key is not None, ( - 'Please provide `modelscope_api_key` in env or cmd.' - ) + assert hasattr( + config.llm, 'modelscope_api_key' + ) and config.llm.modelscope_api_key is not None, 'Please provide `modelscope_api_key` in env or cmd.' super().__init__( config, - base_url=config.llm.modelscope_base_url or get_service_config('modelscope').base_url, - api_key=config.llm.modelscope_api_key, - ) + base_url=config.llm.modelscope_base_url + or get_service_config('modelscope').base_url, + api_key=config.llm.modelscope_api_key) diff --git a/ms_agent/llm/openai.py b/ms_agent/llm/openai.py index 6b44c61b9..390f1d4ab 100644 --- a/ms_agent/llm/openai.py +++ b/ms_agent/llm/openai.py @@ -1,11 +1,11 @@ # flake8: noqa -import json import uuid -from openai import OpenAI, Stream -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall from typing import TYPE_CHECKING, Any, Dict, List, Literal +import json from ms_agent.utils.logger import get_logger +from openai import OpenAI, Stream +from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall logger = get_logger() @@ -14,7 +14,12 @@ class OpenAIChat: - def __init__(self, api_key: str = None, base_url: str = None, model: str = None, **kwargs): + + def __init__(self, + api_key: str = None, + base_url: str = None, + model: str = None, + **kwargs): """ Initialize the OpenAIChat client. """ @@ -26,25 +31,31 @@ def __init__(self, api_key: str = None, base_url: str = None, model: str = None, self._model = model self._kwargs = kwargs - def chat(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, Any]: + def chat(self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]] = None, + **kwargs) -> Dict[str, Any]: + completion: ChatCompletion = self._client.chat.completions.create( - messages=messages, model=self._model, tools=tools, **kwargs - ) + messages=messages, model=self._model, tools=tools, **kwargs) res_d: Dict[str, Any] = dict( role='assistant', reasoning_content='', content=completion.choices[0].message.content, - tool_calls=completion.choices[0].message.tool_calls - if hasattr(completion.choices[0].message, 'tool_calls') - else [], - finish_reason=completion.choices[0].finish_reason, # 'stop', 'tool_calls', 'length', None + tool_calls=completion.choices[0].message.tool_calls if hasattr( + completion.choices[0].message, 'tool_calls') else [], + finish_reason=completion.choices[0]. + finish_reason, # 'stop', 'tool_calls', 'length', None usage=completion.usage.to_dict(), ) return res_d - def chat_stream(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]] = None, **kwargs): + def chat_stream(self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]] = None, + **kwargs): """ Get chat response from OpenAI API using streaming. @@ -87,7 +98,9 @@ def chat_stream(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any] if 'stream' not in kwargs: kwargs['stream'] = True - assert kwargs.get('stream', True), "Streaming must be enabled by setting 'stream=True' in kwargs." + assert kwargs.get( + 'stream', True + ), "Streaming must be enabled by setting 'stream=True' in kwargs." logger.info(f"Temperature: {kwargs.get('temperature', -1)}") @@ -97,8 +110,7 @@ def chat_stream(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any] tools=tools, # Note: Gemini2.5-Pro does not support parallel_tool_calls # parallel_tool_calls=True, - **kwargs, - ) + **kwargs) res_d: Dict[str, Any] = dict( role='assistant', @@ -106,7 +118,11 @@ def chat_stream(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any] content='', tool_calls=[], finish_reason=None, # 'stop', 'tool_calls', 'length', None - usage={'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0}, + usage={ + 'completion_tokens': 0, + 'prompt_tokens': 0, + 'total_tokens': 0 + }, ) for chunk in completion: @@ -118,21 +134,24 @@ def chat_stream(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any] delta = chunk.choices[0].delta res_d['role'] = delta.role - res_d['reasoning_content'] = delta.reasoning_content if hasattr(delta, 'reasoning_content') else '' + res_d['reasoning_content'] = delta.reasoning_content if hasattr( + delta, 'reasoning_content') else '' res_d['content'] = delta.content - res_d['tool_calls'] = delta.tool_calls if hasattr(delta, 'tool_calls') else [] + res_d['tool_calls'] = delta.tool_calls if hasattr( + delta, 'tool_calls') else [] res_d['finish_reason'] = chunk.choices[0].finish_reason if hasattr(chunk, 'usage') and chunk.usage: res_d['usage'] = { 'completion_tokens': chunk.usage.completion_tokens, 'prompt_tokens': chunk.usage.prompt_tokens, - 'total_tokens': chunk.usage.total_tokens, + 'total_tokens': chunk.usage.total_tokens } yield res_d @staticmethod - def aggregate_stream_chunks(stream_chunks: List[Dict[str, Any]]) -> Dict[str, Any]: + def aggregate_stream_chunks( + stream_chunks: List[Dict[str, Any]]) -> Dict[str, Any]: """ Aggregate the streaming chunks into a single response dictionary within current round of chat. @@ -155,29 +174,40 @@ def aggregate_stream_chunks(stream_chunks: List[Dict[str, Any]]) -> Dict[str, An content='', tool_calls=[], finish_reason=None, # 'stop', 'tool_calls', 'length', None - usage={'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0}, + usage={ + 'completion_tokens': 0, + 'prompt_tokens': 0, + 'total_tokens': 0 + }, ) for chunk_d in stream_chunks: res_d['role'] = chunk_d.get('role') - res_d['reasoning_content'] += ( - chunk_d.get('reasoning_content', '') if chunk_d.get('reasoning_content') is not None else '' - ) - res_d['content'] += chunk_d.get('content', '') if chunk_d.get('content') is not None else '' + res_d['reasoning_content'] += chunk_d.get( + 'reasoning_content', + '') if chunk_d.get('reasoning_content') is not None else '' + res_d['content'] += chunk_d.get( + 'content', '') if chunk_d.get('content') is not None else '' if chunk_d.get('tool_calls') is not None: res_d['tool_calls'].extend(chunk_d.get('tool_calls', [])) - res_d['finish_reason'] = chunk_d.get('finish_reason', res_d['finish_reason']) + res_d['finish_reason'] = chunk_d.get('finish_reason', + res_d['finish_reason']) # Get the last usage information as final usage for current round (consider cache tokens) if chunk_d.get('usage') is not None: - res_d['usage']['completion_tokens'] = chunk_d['usage'].get('completion_tokens', 0) - res_d['usage']['prompt_tokens'] = chunk_d['usage'].get('prompt_tokens', 0) - res_d['usage']['total_tokens'] = chunk_d['usage'].get('total_tokens', 0) + res_d['usage']['completion_tokens'] = chunk_d['usage'].get( + 'completion_tokens', 0) + res_d['usage']['prompt_tokens'] = chunk_d['usage'].get( + 'prompt_tokens', 0) + res_d['usage']['total_tokens'] = chunk_d['usage'].get( + 'total_tokens', 0) return res_d @staticmethod - def convert_message(role: Literal['assistant', 'tool'], round_message: Dict[str, Any]) -> Dict[str, Any]: + def convert_message(role: Literal['assistant', 'tool'], + round_message: Dict[str, Any]) -> Dict[str, Any]: + if role == 'assistant': res_msg: Dict[str, Any] = { 'role': 'assistant', @@ -190,30 +220,34 @@ def convert_message(role: Literal['assistant', 'tool'], round_message: Dict[str, if isinstance(tool_call, ChoiceDeltaToolCall): if not tool_call.id: tool_call.id = f'tc_{uuid.uuid4().hex}' - tool_call = tool_call.model_dump(include=['id', 'index', 'type', 'function']) + tool_call = tool_call.model_dump( + include=['id', 'index', 'type', 'function']) else: - raise ValueError(f'Unsupported tool call type: {type(tool_call)}. Expected ChoiceDeltaToolCall.') + raise ValueError( + f'Unsupported tool call type: {type(tool_call)}. Expected ChoiceDeltaToolCall.' + ) tmp_tool_calls.append(tool_call) res_msg['tool_calls'] = tmp_tool_calls elif role == 'tool': # TODO: tbd ... - raise ValueError('`tool message` is to be implemented in the future.') + raise ValueError( + '`tool message` is to be implemented in the future.') else: - raise ValueError(f"Unsupported role: {role}. Supported roles are 'assistant' and 'tool' for now.") + raise ValueError( + f"Unsupported role: {role}. Supported roles are 'assistant' and 'tool' for now." + ) return res_msg - def chat_stream_mt( - self, - messages: List[Dict[str, Any]], - available_functions: Dict[str, Any], - tools: List[Dict[str, Any]] = None, - history: List[Dict[str, Any]] = None, - **kwargs, - ): + def chat_stream_mt(self, + messages: List[Dict[str, Any]], + available_functions: Dict[str, Any], + tools: List[Dict[str, Any]] = None, + history: List[Dict[str, Any]] = None, + **kwargs): """ Get chat response from OpenAI API using streaming for multi-turn chat. """ @@ -224,10 +258,15 @@ def chat_stream_mt( # Add a system message if not present roles: List[str] = [msg['role'] for msg in messages] if 'system' not in roles: - system_message: Dict[str, Any] = {'role': 'system', 'content': 'You are a helpful assistant.'} + system_message: Dict[str, Any] = { + 'role': 'system', + 'content': 'You are a helpful assistant.' + } messages.insert(0, system_message) - assert len(messages) >= 2, 'At least two messages are required: user and system' + assert len( + messages + ) >= 2, 'At least two messages are required: user and system' ## User Message history.extend(messages) @@ -243,13 +282,16 @@ def chat_stream_mt( for chunk_d in self.chat_stream(messages, tools, **kwargs): streaming_chunks.append(chunk_d) - round_d: Dict[str, Any] = self.aggregate_stream_chunks(streaming_chunks) + round_d: Dict[str, Any] = self.aggregate_stream_chunks( + streaming_chunks) yield round_d # Convert `round_d` to OpenAI's chat messages format ## Assistant Message if round_d['role'] == 'assistant': - assistant_message = self.convert_message(role='assistant', round_message=round_d) + + assistant_message = self.convert_message( + role='assistant', round_message=round_d) history.append(assistant_message) # Execute tool calls and append the tool messages @@ -257,9 +299,11 @@ def chat_stream_mt( for tool_call in assistant_message.get('tool_calls', []): if tool_call['type'] == 'function': function_name = tool_call['function']['name'] - function_args = json.loads(tool_call['function']['arguments']) + function_args = json.loads( + tool_call['function']['arguments']) # Call the function and get the result - tool_call_result = available_functions[function_name](**function_args) + tool_call_result = available_functions[ + function_name](**function_args) # Construct a tool message with the result # TODO: Check the `tool_call_id` is empty ? @@ -272,7 +316,9 @@ def chat_stream_mt( history.append(tool_message) # If the response is complete, break the loop - if round_d['finish_reason'] in ['stop', 'tool_calls', 'length']: + if round_d['finish_reason'] in [ + 'stop', 'tool_calls', 'length' + ]: break except Exception as e: @@ -281,9 +327,15 @@ def chat_stream_mt( # Note: must contain role=assistant(with tool_calls) and role=tool if history[-1]['role'] == 'tool': - messages = history + [{'role': 'user', 'content': 'Please output the tool calling results very briefly.'}] - round_item: dict = self.aggregate_stream_chunks( - [chunk_item for chunk_item in self.chat_stream(messages=messages, tools=tools, **kwargs)] - ) + messages = history + [{ + 'role': + 'user', + 'content': + 'Please output the tool calling results very briefly.' + }] + round_item: dict = self.aggregate_stream_chunks([ + chunk_item for chunk_item in self.chat_stream( + messages=messages, tools=tools, **kwargs) + ]) yield round_item diff --git a/ms_agent/llm/openai_llm.py b/ms_agent/llm/openai_llm.py index 883cafab4..fa2df6004 100644 --- a/ms_agent/llm/openai_llm.py +++ b/ms_agent/llm/openai_llm.py @@ -1,17 +1,18 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import inspect -import json from copy import deepcopy from typing import Any, Dict, Generator, Iterable, List, Optional import httpx -from omegaconf import DictConfig, OmegaConf -from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function - +import json from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool, ToolCall -from ms_agent.utils import MAX_CONTINUE_RUNS, assert_package_exist, get_logger, retry +from ms_agent.utils import (MAX_CONTINUE_RUNS, assert_package_exist, + get_logger, retry) from ms_agent.utils.constants import get_service_config +from omegaconf import DictConfig, OmegaConf +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, Function) logger = get_logger() @@ -27,7 +28,8 @@ class _DashScopeResponsesTransport(httpx.HTTPTransport): def handle_request(self, request): if b'/v1/responses' in request.url.raw_path: - new_path = request.url.raw_path.replace(b'/v1/responses', b'/v1/chat/completions') + new_path = request.url.raw_path.replace(b'/v1/responses', + b'/v1/chat/completions') request.url = request.url.copy_with(raw_path=new_path) return super().handle_request(request) @@ -49,8 +51,9 @@ class OpenAI(LLM): base_url (`Optional[str]`): Custom base URL for the API endpoint. Defaults to None. api_key (`Optional[str]`): Authentication key for the API. Defaults to None. """ - - input_msg = {'role', 'content', 'tool_calls', 'partial', 'prefix', 'tool_call_id'} + input_msg = { + 'role', 'content', 'tool_calls', 'partial', 'prefix', 'tool_call_id' + } # Providers that support cache_control in structured content blocks CACHE_CONTROL_PROVIDERS = ['dashscope', 'anthropic'] @@ -64,10 +67,12 @@ def __init__( super().__init__(config) assert_package_exist('openai') import openai - self.model: str = config.llm.model - self.max_continue_runs = getattr(config.llm, 'max_continue_runs', None) or MAX_CONTINUE_RUNS - base_url = base_url or getattr(config.llm, 'openai_base_url', None) or get_service_config('openai').base_url + self.max_continue_runs = getattr(config.llm, 'max_continue_runs', + None) or MAX_CONTINUE_RUNS + base_url = base_url or getattr( + config.llm, 'openai_base_url', + None) or get_service_config('openai').base_url api_key = api_key or getattr(config.llm, 'openai_api_key', None) self.client = openai.OpenAI( @@ -75,17 +80,21 @@ def __init__( base_url=base_url, ) self.base_url = base_url or '' - self.args: Dict = OmegaConf.to_container(getattr(config, 'generation_config', DictConfig({}))) + self.args: Dict = OmegaConf.to_container( + getattr(config, 'generation_config', DictConfig({}))) # Responses API support - self._use_responses_api = bool(self.args.get('use_responses_api', False)) + self._use_responses_api = bool( + self.args.get('use_responses_api', False)) self._responses_client = None - self._responses_state_mode = str(self.args.get('responses_state_mode', 'stateless')).lower() + self._responses_state_mode = str( + self.args.get('responses_state_mode', 'stateless')).lower() if self._responses_state_mode == 'stateful': self._responses_state_mode = 'previous_response_id' if self._use_responses_api: - self._is_dashscope = bool(base_url and 'dashscope' in base_url.lower()) + self._is_dashscope = bool(base_url + and 'dashscope' in base_url.lower()) if self._is_dashscope: http_client = httpx.Client( transport=_DashScopeResponsesTransport(), @@ -107,7 +116,8 @@ def __init__( # - Special values: 'last_message' (only cache the last message in the list) # Default: ['system'] - system prompt is usually the longest stable prefix self._prefix_cache_enabled = self.args.get('force_prefix_cache', False) - self._prefix_cache_roles = set(self.args.get('prefix_cache_roles', ['system'])) + self._prefix_cache_roles = set( + self.args.get('prefix_cache_roles', ['system'])) self._prefix_cache_provider = self._detect_cache_provider() def _detect_cache_provider(self) -> Optional[str]: @@ -161,7 +171,8 @@ def _to_structured_content( # Add cache_control to text blocks that don't have it new_list = [] for item in content: - if isinstance(item, dict) and item.get('type') == 'text' and 'cache_control' not in item: + if (isinstance(item, dict) and item.get('type') == 'text' + and 'cache_control' not in item): new_item = dict(item) new_item['cache_control'] = {'type': 'ephemeral'} new_list.append(new_item) @@ -172,7 +183,9 @@ def _to_structured_content( # Other types: return as-is return content - def format_tools(self, tools: Optional[List[Tool]] = None) -> List[Dict[str, Any]]: + def format_tools(self, + tools: Optional[List[Tool]] = None + ) -> List[Dict[str, Any]]: """Formats a list of tools into the structure expected by the OpenAI API. If server_name is present in a tool, it will be used as a prefix for the function name. @@ -184,29 +197,24 @@ def format_tools(self, tools: Optional[List[Tool]] = None) -> List[Dict[str, Any List[Dict[str, Any]]: A list of formatted tool definitions suitable for OpenAI API. """ if tools: - tools = [ - { - 'type': 'function', - 'function': { - 'name': tool['tool_name'], - 'description': tool['description'], - 'parameters': tool['parameters'], - }, + tools = [{ + 'type': 'function', + 'function': { + 'name': tool['tool_name'], + 'description': tool['description'], + 'parameters': tool['parameters'] } - for tool in tools - ] + } for tool in tools] else: tools = None return tools @retry(max_attempts=LLM.retry_count, delay=1.0) - def generate( - self, - messages: List[Message], - tools: Optional[List[Tool]] = None, - max_continue_runs: Optional[int] = None, - **kwargs, - ) -> Message | Generator[Message, None, None]: + def generate(self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + max_continue_runs: Optional[int] = None, + **kwargs) -> Message | Generator[Message, None, None]: """Generates a response based on the given conversation history and optional tools. Args: @@ -228,17 +236,24 @@ def generate( else: return self._responses_generate(messages, tools, **args) - parameters = inspect.signature(self.client.chat.completions.create).parameters + parameters = inspect.signature( + self.client.chat.completions.create).parameters args = {key: value for key, value in args.items() if key in parameters} completion = self._call_llm(messages, self.format_tools(tools), **args) max_continue_runs = max_continue_runs or self.max_continue_runs if stream: - return self._stream_continue_generate(messages, completion, tools, max_continue_runs - 1, **args) + return self._stream_continue_generate(messages, completion, tools, + max_continue_runs - 1, + **args) else: - return self._continue_generate(messages, completion, tools, max_continue_runs - 1, **args) + return self._continue_generate(messages, completion, tools, + max_continue_runs - 1, **args) - def _call_llm(self, messages: List[Message], tools: Optional[List[Tool]] = None, **kwargs) -> Any: + def _call_llm(self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + **kwargs) -> Any: """Calls the OpenAI chat completion API with the provided messages and tools. Args: @@ -258,7 +273,8 @@ def _call_llm(self, messages: List[Message], tools: Optional[List[Tool]] = None, if is_streaming and stream_options_config.get('include_usage', True): kwargs.setdefault('stream_options', {})['include_usage'] = True - return self.client.chat.completions.create(model=self.model, messages=messages, tools=tools, **kwargs) + return self.client.chat.completions.create( + model=self.model, messages=messages, tools=tools, **kwargs) @staticmethod def _extract_cache_info(usage_obj: Any) -> tuple: @@ -284,10 +300,12 @@ def _extract_cache_info(usage_obj: Any) -> tuple: created = int(details.get('cache_creation_input_tokens', 0) or 0) else: cached = int(getattr(details, 'cached_tokens', 0) or 0) - created = int(getattr(details, 'cache_creation_input_tokens', 0) or 0) + created = int( + getattr(details, 'cache_creation_input_tokens', 0) or 0) return cached, created - def _merge_stream_message(self, pre_message_chunk: Optional[Message], message_chunk: Message) -> Optional[Message]: + def _merge_stream_message(self, pre_message_chunk: Optional[Message], + message_chunk: Message) -> Optional[Message]: """Merges a new chunk of message into the previous chunks during streaming. Used to accumulate partial results into a complete Message object. @@ -312,17 +330,25 @@ def _merge_stream_message(self, pre_message_chunk: Optional[Message], message_ch message.content += message_chunk.content if message_chunk.tool_calls: if message.tool_calls: - if message.tool_calls[-1]['index'] == message_chunk.tool_calls[0]['index']: + if message.tool_calls[-1]['index'] == message_chunk.tool_calls[ + 0]['index']: if message_chunk.tool_calls[0]['id']: - message.tool_calls[-1]['id'] = message_chunk.tool_calls[0]['id'] + message.tool_calls[-1][ + 'id'] = message_chunk.tool_calls[0]['id'] if message_chunk.tool_calls[0]['arguments']: if message.tool_calls[-1]['arguments']: - message.tool_calls[-1]['arguments'] += message_chunk.tool_calls[0]['arguments'] + message.tool_calls[-1][ + 'arguments'] += message_chunk.tool_calls[0][ + 'arguments'] else: # message.tool_calls[-1]['arguments'] may be None - message.tool_calls[-1]['arguments'] = message_chunk.tool_calls[0]['arguments'] + message.tool_calls[-1][ + 'arguments'] = message_chunk.tool_calls[0][ + 'arguments'] if message_chunk.tool_calls[0]['tool_name']: - message.tool_calls[-1]['tool_name'] = message_chunk.tool_calls[0]['tool_name'] + message.tool_calls[-1][ + 'tool_name'] = message_chunk.tool_calls[0][ + 'tool_name'] else: message.tool_calls.append( ToolCall( @@ -330,21 +356,17 @@ def _merge_stream_message(self, pre_message_chunk: Optional[Message], message_ch arguments=message_chunk.tool_calls[0]['arguments'], type='function', tool_name=message_chunk.tool_calls[0]['tool_name'], - index=message_chunk.tool_calls[0]['index'], - ) - ) + index=message_chunk.tool_calls[0]['index'])) else: message.tool_calls = message_chunk.tool_calls return message - def _stream_continue_generate( - self, - messages: List[Message], - completion: Iterable, - tools: Optional[List[Tool]] = None, - max_runs: Optional[int] = None, - **kwargs, - ) -> Generator[Message, None, None]: + def _stream_continue_generate(self, + messages: List[Message], + completion: Iterable, + tools: Optional[List[Tool]] = None, + max_runs: Optional[int] = None, + **kwargs) -> Generator[Message, None, None]: """Recursively continues generating until the model finishes naturally in streaming mode. Args: @@ -366,7 +388,8 @@ def _stream_continue_generate( try: next_chunk = next(completion) message.prompt_tokens += next_chunk.usage.prompt_tokens - cached, created = self._extract_cache_info(getattr(next_chunk, 'usage', None)) + cached, created = self._extract_cache_info( + getattr(next_chunk, 'usage', None)) message.cached_tokens += cached message.cache_creation_input_tokens += created message.completion_tokens += next_chunk.usage.completion_tokens @@ -374,14 +397,21 @@ def _stream_continue_generate( # The stream may end without a final usage chunk, which is acceptable. pass first_run = not messages[-1].to_dict().get('partial', False) - if chunk.choices[0].finish_reason in ['length', 'null'] and (max_runs is None or max_runs != 0): - logger.info(f'finish_reason: {chunk.choices[0].finish_reason}, continue generate.') - completion = self._call_llm_for_continue_gen(messages, message, tools, **kwargs) + if chunk.choices[0].finish_reason in [ + 'length', 'null' + ] and (max_runs is None or max_runs != 0): + logger.info( + f'finish_reason: {chunk.choices[0].finish_reason}, continue generate.' + ) + completion = self._call_llm_for_continue_gen( + messages, message, tools, **kwargs) for chunk in self._stream_continue_generate( - messages, completion, tools, max_runs - 1 if max_runs is not None else None, **kwargs - ): + messages, completion, tools, + max_runs - 1 if max_runs is not None else None, + **kwargs): if first_run: - yield self._merge_stream_message(messages[-1], chunk) + yield self._merge_stream_message( + messages[-1], chunk) else: yield chunk elif not first_run: @@ -406,7 +436,8 @@ def _stream_format_output_message(completion_chunk) -> Message: content = '' if completion_chunk.choices and completion_chunk.choices[0].delta: content = completion_chunk.choices[0].delta.content - reasoning_content = getattr(completion_chunk.choices[0].delta, 'reasoning_content', '') + reasoning_content = getattr(completion_chunk.choices[0].delta, + 'reasoning_content', '') if completion_chunk.choices[0].delta.tool_calls: func = completion_chunk.choices[0].delta.tool_calls tool_calls = [ @@ -415,8 +446,7 @@ def _stream_format_output_message(completion_chunk) -> Message: index=tool_call.index, type=tool_call.type, arguments=tool_call.function.arguments, - tool_name=tool_call.function.name, - ) + tool_name=tool_call.function.name) for tool_call in func ] content = content or '' @@ -428,22 +458,23 @@ def _stream_format_output_message(completion_chunk) -> Message: tool_calls=tool_calls, id=completion_chunk.id, prompt_tokens=getattr(completion_chunk.usage, 'prompt_tokens', 0), - completion_tokens=getattr(completion_chunk.usage, 'completion_tokens', 0), - ) + completion_tokens=getattr(completion_chunk.usage, + 'completion_tokens', 0)) @staticmethod def _format_output_message(completion) -> Message: """Formats the full non-streaming response into a Message object. - Args: - completion: The raw response from the OpenAI API. + Args: + completion: The raw response from the OpenAI API. - Returns: - Message: A Message object containing the final response. - """ + Returns: + Message: A Message object containing the final response. + """ content = completion.choices[0].message.content or '' if hasattr(completion.choices[0].message, 'reasoning_content'): - reasoning_content = completion.choices[0].message.reasoning_content or '' + reasoning_content = completion.choices[ + 0].message.reasoning_content or '' else: reasoning_content = '' tool_calls = None @@ -454,11 +485,11 @@ def _format_output_message(completion) -> Message: index=getattr(tool_call, 'index', idx), type=tool_call.type, arguments=tool_call.function.arguments, - tool_name=tool_call.function.name, - ) - for idx, tool_call in enumerate(completion.choices[0].message.tool_calls) + tool_name=tool_call.function.name) for idx, tool_call in + enumerate(completion.choices[0].message.tool_calls) ] - cached, created = OpenAI._extract_cache_info(getattr(completion, 'usage', None)) + cached, created = OpenAI._extract_cache_info( + getattr(completion, 'usage', None)) return Message( role='assistant', content=content, @@ -468,8 +499,7 @@ def _format_output_message(completion) -> Message: prompt_tokens=completion.usage.prompt_tokens, cached_tokens=cached, cache_creation_input_tokens=created, - completion_tokens=completion.usage.completion_tokens, - ) + completion_tokens=completion.usage.completion_tokens) @staticmethod def _merge_partial_message(messages: List[Message], new_message: Message): @@ -483,7 +513,8 @@ def _merge_partial_message(messages: List[Message], new_message: Message): messages[-1].content += new_message.content messages[-1].prompt_tokens += new_message.prompt_tokens messages[-1].cached_tokens += new_message.cached_tokens - messages[-1].cache_creation_input_tokens += new_message.cache_creation_input_tokens + messages[ + -1].cache_creation_input_tokens += new_message.cache_creation_input_tokens messages[-1].completion_tokens += new_message.completion_tokens if new_message.tool_calls: if messages[-1].tool_calls: @@ -491,9 +522,11 @@ def _merge_partial_message(messages: List[Message], new_message: Message): else: messages[-1].tool_calls = new_message.tool_calls - def _call_llm_for_continue_gen( - self, messages: List[Message], new_message: Message, tools: List[Tool] = None, **kwargs - ) -> Any: + def _call_llm_for_continue_gen(self, + messages: List[Message], + new_message: Message, + tools: List[Tool] = None, + **kwargs) -> Any: """Prepares and calls the LLM for continuation when the response is unfinished. If the previous message marked as unfinished, it will be updated with the new content. @@ -522,9 +555,12 @@ def _call_llm_for_continue_gen( return self._call_llm(messages, tools, **kwargs) - def _continue_generate( - self, messages: List[Message], completion, tools: List[Tool] = None, max_runs: Optional[int] = None, **kwargs - ) -> Message: + def _continue_generate(self, + messages: List[Message], + completion, + tools: List[Tool] = None, + max_runs: Optional[int] = None, + **kwargs) -> Message: """Recursively continues generating until the model finishes naturally. This method checks whether the generation was stopped due to length limitations, @@ -540,12 +576,17 @@ def _continue_generate( Message: A fully formed Message object containing the complete response. """ new_message = self._format_output_message(completion) - if completion.choices[0].finish_reason in ['length', 'null'] and (max_runs is None or max_runs != 0): - logger.info(f'finish_reason: {completion.choices[0].finish_reason}, continue generate.') - completion = self._call_llm_for_continue_gen(messages, new_message, tools, **kwargs) - return self._continue_generate( - messages, completion, tools, max_runs - 1 if max_runs is not None else None, **kwargs + if completion.choices[0].finish_reason in [ + 'length', 'null' + ] and (max_runs is None or max_runs != 0): + logger.info( + f'finish_reason: {completion.choices[0].finish_reason}, continue generate.' ) + completion = self._call_llm_for_continue_gen( + messages, new_message, tools, **kwargs) + return self._continue_generate( + messages, completion, tools, + max_runs - 1 if max_runs is not None else None, **kwargs) elif messages[-1].to_dict().get('partial', False): self._merge_partial_message(messages, new_message) messages[-1].partial = False @@ -553,7 +594,8 @@ def _continue_generate( else: return new_message - def _build_responses_input(self, messages: List[Message]) -> List[Dict[str, Any]]: + def _build_responses_input( + self, messages: List[Message]) -> List[Dict[str, Any]]: """Convert internal Message list to the ``input`` format expected by the Responses API. @@ -568,66 +610,60 @@ def _build_responses_input(self, messages: List[Message]) -> List[Dict[str, Any] items: List[Dict[str, Any]] = [] for msg in messages: if msg.role == 'system': - items.append( - { - 'role': 'developer', - 'content': msg.content, - } - ) + items.append({ + 'role': 'developer', + 'content': msg.content, + }) elif msg.role == 'assistant': if self._responses_state_mode != 'previous_response_id': # Stateless mode needs explicit passback of opaque reasoning # items returned by the previous response. - for raw_item in getattr(msg, '_responses_output_items', []): + for raw_item in getattr(msg, '_responses_output_items', + []): items.append(raw_item) - if msg.content and not self._is_responses_tool_placeholder(msg): - items.append( - { - 'role': 'assistant', - 'content': msg.content, - } - ) + if msg.content and not self._is_responses_tool_placeholder( + msg): + items.append({ + 'role': 'assistant', + 'content': msg.content, + }) if msg.tool_calls: for tc in msg.tool_calls: arguments = tc.get('arguments', '{}') if not isinstance(arguments, str): - arguments = json.dumps(arguments, ensure_ascii=False) - items.append( - { - 'type': 'function_call', - 'call_id': tc.get('id', ''), - 'name': tc.get('tool_name', ''), - 'arguments': arguments, - } - ) + arguments = json.dumps( + arguments, ensure_ascii=False) + items.append({ + 'type': 'function_call', + 'call_id': tc.get('id', ''), + 'name': tc.get('tool_name', ''), + 'arguments': arguments, + }) elif msg.role == 'tool': content = msg.content if not isinstance(content, str): content = json.dumps(content, ensure_ascii=False) - items.append( - { - 'type': 'function_call_output', - 'call_id': msg.tool_call_id or '', - 'output': content, - } - ) + items.append({ + 'type': 'function_call_output', + 'call_id': msg.tool_call_id or '', + 'output': content, + }) else: - items.append( - { - 'role': msg.role, - 'content': msg.content, - } - ) + items.append({ + 'role': msg.role, + 'content': msg.content, + }) return items @staticmethod def _is_responses_tool_placeholder(message: Message) -> bool: """Return True for framework-generated assistant placeholder text.""" - return bool(message.tool_calls) and message.content == 'Let me do a tool calling.' + return bool(message.tool_calls + ) and message.content == 'Let me do a tool calling.' def _prepare_responses_request( - self, messages: List[Message], args: Dict[str, Any] - ) -> tuple[List[Message], Dict[str, Any]]: + self, messages: List[Message], + args: Dict[str, Any]) -> tuple[List[Message], Dict[str, Any]]: """Prepare message slice and request args for Responses API calls.""" request_args = dict(args) @@ -641,23 +677,22 @@ def _prepare_responses_request( msg = messages[idx] if msg.role == 'assistant' and msg.id: request_args['previous_response_id'] = msg.id - return messages[idx + 1 :], request_args + return messages[idx + 1:], request_args return messages, request_args - def _build_responses_tools(self, tools: Optional[List[Tool]]) -> Optional[List[Dict[str, Any]]]: + def _build_responses_tools( + self, + tools: Optional[List[Tool]]) -> Optional[List[Dict[str, Any]]]: """Convert internal Tool list to Responses API function tool format.""" if not tools: return None - return [ - { - 'type': 'function', - 'name': t['tool_name'], - 'description': t.get('description', ''), - 'parameters': t.get('parameters', {}), - } - for t in tools - ] + return [{ + 'type': 'function', + 'name': t['tool_name'], + 'description': t.get('description', ''), + 'parameters': t.get('parameters', {}), + } for t in tools] def _build_responses_kwargs(self, args: Dict) -> Dict: """Filter and reshape generation args for ``responses.create``.""" @@ -707,7 +742,8 @@ def _extract_reasoning_summaries_from_response(response) -> str: return '\n'.join(parts) @staticmethod - def _extract_tool_calls_from_response(response) -> Optional[List[ToolCall]]: + def _extract_tool_calls_from_response( + response) -> Optional[List[ToolCall]]: """Extract tool calls from a completed Responses API object.""" tool_calls: List[ToolCall] = [] for item in getattr(response, 'output', []) or []: @@ -717,13 +753,13 @@ def _extract_tool_calls_from_response(response) -> Optional[List[ToolCall]]: arguments = json.dumps(arguments, ensure_ascii=False) tool_calls.append( ToolCall( - id=getattr(item, 'call_id', '') or getattr(item, 'id', ''), + id=getattr(item, 'call_id', '') + or getattr(item, 'id', ''), index=len(tool_calls), type='function', tool_name=getattr(item, 'name', ''), arguments=arguments, - ) - ) + )) return tool_calls if tool_calls else None @staticmethod @@ -745,7 +781,10 @@ def _to_jsonable(value: Any) -> Any: if isinstance(value, list): return [OpenAI._to_jsonable(item) for item in value] if isinstance(value, dict): - return {key: OpenAI._to_jsonable(item) for key, item in value.items()} + return { + key: OpenAI._to_jsonable(item) + for key, item in value.items() + } if hasattr(value, 'model_dump'): return OpenAI._to_jsonable(value.model_dump()) if hasattr(value, 'to_dict'): @@ -763,8 +802,10 @@ def _collect_passback_items(self, response) -> List[Dict[str, Any]]: item_type = getattr(item, 'type', None) if item_type == 'reasoning': passback_item: Dict[str, Any] = { - 'type': 'reasoning', - 'summary': self._to_jsonable(getattr(item, 'summary', []) or []), + 'type': + 'reasoning', + 'summary': + self._to_jsonable(getattr(item, 'summary', []) or []), } encrypted_content = getattr(item, 'encrypted_content', None) if encrypted_content: @@ -776,9 +817,13 @@ def _collect_passback_items(self, response) -> List[Dict[str, Any]]: items.append(passback_item) return items - def _responses_generate(self, messages: List[Message], tools: Optional[List[Tool]] = None, **args) -> Message: + def _responses_generate(self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + **args) -> Message: """Non-streaming Responses API call.""" - request_messages, request_args = self._prepare_responses_request(messages, args) + request_messages, request_args = self._prepare_responses_request( + messages, args) input_items = self._build_responses_input(request_messages) resp_tools = self._build_responses_tools(tools) kwargs = self._build_responses_kwargs(request_args) @@ -793,7 +838,8 @@ def _responses_generate(self, messages: List[Message], tools: Optional[List[Tool text = getattr(response, 'output_text', '') or '' reasoning = self._extract_reasoning_summaries_from_response(response) resp_tool_calls = self._extract_tool_calls_from_response(response) - prompt_tokens, completion_tokens = self._extract_usage_from_response(response) + prompt_tokens, completion_tokens = self._extract_usage_from_response( + response) passback = self._collect_passback_items(response) return Message( @@ -817,9 +863,10 @@ def _extract_reasoning_from_item(item) -> str: parts.append(text) return '\n'.join(parts) - def _responses_stream_generate( - self, messages: List[Message], tools: Optional[List[Tool]] = None, **args - ) -> Generator[Message, None, None]: + def _responses_stream_generate(self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + **args) -> Generator[Message, None, None]: """Streaming Responses API call. Yields incremental ``Message`` objects. Reasoning summaries are @@ -827,7 +874,8 @@ def _responses_stream_generate( which arrive *before* the first text delta, so the agent layer can display the thinking header before content begins streaming. """ - request_messages, request_args = self._prepare_responses_request(messages, args) + request_messages, request_args = self._prepare_responses_request( + messages, args) input_items = self._build_responses_input(request_messages) resp_tools = self._build_responses_tools(tools) kwargs = self._build_responses_kwargs(request_args) @@ -860,7 +908,8 @@ def _responses_stream_generate( summary_text = self._extract_reasoning_from_item(item) if summary_text: reasoning_parts.append(summary_text) - current_message.reasoning_content = '\n'.join(reasoning_parts) + current_message.reasoning_content = '\n'.join( + reasoning_parts) yield current_message elif event_type == 'response.output_text.delta': @@ -883,29 +932,35 @@ def _responses_stream_generate( elif event_type == 'response.failed': failed_response = getattr(event, 'response', None) failed_error = getattr(failed_response, 'error', None) - response_error_msg = getattr(failed_error, 'message', '') or str(failed_error) + response_error_msg = getattr(failed_error, 'message', + '') or str(failed_error) if final_response: if not reasoning_parts: - reasoning = self._extract_reasoning_summaries_from_response(final_response) + reasoning = self._extract_reasoning_summaries_from_response( + final_response) if reasoning: current_message.reasoning_content = reasoning - resp_tool_calls = self._extract_tool_calls_from_response(final_response) + resp_tool_calls = self._extract_tool_calls_from_response( + final_response) if resp_tool_calls: current_message.tool_calls = resp_tool_calls passback = self._collect_passback_items(final_response) if passback: current_message._responses_output_items = passback - prompt_tokens, completion_tokens = self._extract_usage_from_response(final_response) + prompt_tokens, completion_tokens = self._extract_usage_from_response( + final_response) current_message.prompt_tokens = prompt_tokens current_message.completion_tokens = completion_tokens current_message.id = getattr(final_response, 'id', '') yield current_message elif response_error_msg: logger.error(f'Responses API failed: {response_error_msg}') - raise RuntimeError(f'Responses API call failed: {response_error_msg}') + raise RuntimeError( + f'Responses API call failed: {response_error_msg}') - def _format_input_message(self, messages: List[Message]) -> List[Dict[str, Any]]: + def _format_input_message(self, + messages: List[Message]) -> List[Dict[str, Any]]: """Converts a list of Message objects into the format expected by the OpenAI API. Args: @@ -927,7 +982,8 @@ def _format_input_message(self, messages: List[Message]) -> List[Dict[str, Any]] # Check for role-based caching role_cache = self._prefix_cache_roles - {'last_message'} for idx, msg in enumerate(messages): - msg_role = msg.role if isinstance(msg, Message) else msg.get('role', '') + msg_role = msg.role if isinstance(msg, Message) else msg.get( + 'role', '') if msg_role in role_cache: cache_indices.add(idx) cache_indice = max(cache_indices) if cache_indices else None @@ -951,8 +1007,9 @@ def _format_input_message(self, messages: List[Message]) -> List[Dict[str, Any]] # Only for string content, multimodal content is already structured if cache_indice is not None and idx == cache_indice: content = self._to_structured_content( - content, add_cache_control=True, provider=self._prefix_cache_provider - ) + content, + add_cache_control=True, + provider=self._prefix_cache_provider) # Build the message dict, handling both string and multimodal content formatted_message = {} diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index c2320d932..4ae5833bf 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Optional, Union +import json from typing_extensions import Literal, Required, TypedDict @@ -86,20 +86,19 @@ def to_dict_clean(self): 'function': { 'name': tool_call['tool_name'], 'arguments': tool_call['arguments'], - }, + } } required = ['content', 'role'] # Never send UI-only fields to model providers. rm = [ - 'completion_tokens', - 'prompt_tokens', - 'api_calls', - 'tool_detail', - 'searching_detail', - 'search_result', - '_responses_output_items', + 'completion_tokens', 'prompt_tokens', 'api_calls', 'tool_detail', + 'searching_detail', 'search_result', '_responses_output_items', ] - return {key: value for key, value in raw_dict.items() if (value or key in required) and key not in rm} + return { + key: value + for key, value in raw_dict.items() + if (value or key in required) and key not in rm + } @dataclass @@ -128,6 +127,9 @@ def from_raw(raw): text=str(model_text), resources=raw.get('resources', []), tool_detail=None if td is None else str(td), - extra={k: v for k, v in raw.items() if k not in ['text', 'resources', 'result', 'tool_detail']}, - ) + extra={ + k: v + for k, v in raw.items() + if k not in ['text', 'resources', 'result', 'tool_detail'] + }) raise TypeError('tool_call_result must be str or dict') diff --git a/ms_agent/memory/base.py b/ms_agent/memory/base.py index ed55e7fe5..fde42fb56 100644 --- a/ms_agent/memory/base.py +++ b/ms_agent/memory/base.py @@ -2,10 +2,9 @@ from abc import ABC, abstractmethod from typing import List -from omegaconf import DictConfig - from ms_agent.llm.utils import Message from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR +from omegaconf import DictConfig class Memory(ABC): @@ -13,7 +12,8 @@ class Memory(ABC): def __init__(self, config): self.config = config - self.output_dir = getattr(self.config, 'output_dir', DEFAULT_OUTPUT_DIR) + self.output_dir = getattr(self.config, 'output_dir', + DEFAULT_OUTPUT_DIR) self.base_config = None @abstractmethod diff --git a/ms_agent/memory/condenser/code_condenser.py b/ms_agent/memory/condenser/code_condenser.py index ff4a7d0a9..12fe626e5 100644 --- a/ms_agent/memory/condenser/code_condenser.py +++ b/ms_agent/memory/condenser/code_condenser.py @@ -1,17 +1,19 @@ -import json import os from typing import List +import json from ms_agent.llm import LLM, Message from ms_agent.memory import Memory from ms_agent.utils import get_logger -from ms_agent.utils.constants import DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, DEFAULT_OUTPUT_WRAPPER +from ms_agent.utils.constants import (DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, + DEFAULT_OUTPUT_WRAPPER) from ms_agent.utils.utils import extract_code_blocks, file_lock logger = get_logger() class CodeCondenser(Memory): + system = """你是一个帮我简化代码并返回缩略信息的机器人。你缩略的文件会给与另一个LLM用来编写代码,因此你生成的缩略文件需要具有充足的供其他文件依赖的信息。 需要保留的信息: @@ -95,7 +97,7 @@ class CodeCondenser(Memory): 你的优化目标: 1. 【优先】保留充足的信息供其它代码使用 2. 【其次】保留尽量少的token数量 -""" # noqa +""" # noqa def __init__(self, config): super().__init__(config) @@ -106,7 +108,8 @@ def __init__(self, config): index_dir = getattr(config, 'index_cache_dir', DEFAULT_INDEX_DIR) self.index_dir = os.path.join(self.output_dir, index_dir) self.lock_dir = os.path.join(self.output_dir, DEFAULT_LOCK_DIR) - self.code_wrapper = getattr(mem_config, 'code_wrapper', DEFAULT_OUTPUT_WRAPPER) + self.code_wrapper = getattr(mem_config, 'code_wrapper', + DEFAULT_OUTPUT_WRAPPER) def condense_code(self, message: Message): prefix = 'Your generated code was replaced by a index version:\n' @@ -119,17 +122,22 @@ def condense_code(self, message: Message): arguments = json.loads(arguments) code_file = arguments['path'] content = arguments['content'] - index_content = self.generate_index_file(code_file, content) + index_content = self.generate_index_file( + code_file, content) arguments['content'] = f'{prefix}{index_content}' - tool_call['arguments'] = json.dumps(arguments, ensure_ascii=False) - elif self.code_wrapper[0] in message.content and self.code_wrapper[1] in message.content: - result, remaining_text = extract_code_blocks(message.content, file_wrapper=self.code_wrapper) + tool_call['arguments'] = json.dumps( + arguments, ensure_ascii=False) + elif self.code_wrapper[0] in message.content and self.code_wrapper[ + 1] in message.content: + result, remaining_text = extract_code_blocks( + message.content, file_wrapper=self.code_wrapper) if result: final_content = remaining_text + prefix for code_block in result: code_file = code_block['filename'] content = code_block['code'] - index_content = self.generate_index_file(code_file, content) + index_content = self.generate_index_file( + code_file, content) final_content += index_content + '\n' message.content = final_content @@ -164,7 +172,8 @@ def generate_index_file(self, file: str, content: str = None): error = None for i in range(3): try: - response_message = self.llm.generate(messages, stream=False) + response_message = self.llm.generate( + messages, stream=False) content = response_message.content.split('\n') if '```' in content[0]: content = content[1:] @@ -174,11 +183,14 @@ def generate_index_file(self, file: str, content: str = None): os.makedirs(os.path.dirname(index_file), exist_ok=True) with open(index_file, 'w') as f: f.write(content) - json.loads(content) # try to load once to ensure the json format is ok + json.loads( + content + ) # try to load once to ensure the json format is ok break except Exception as e: error = e - logger.error(f'Code index file generate failed because of {e}') + logger.error( + f'Code index file generate failed because of {e}') if content is None: raise error return content diff --git a/ms_agent/memory/condenser/context_compressor.py b/ms_agent/memory/condenser/context_compressor.py index 035a12a6f..9bec9bcf1 100644 --- a/ms_agent/memory/condenser/context_compressor.py +++ b/ms_agent/memory/condenser/context_compressor.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json from typing import List, Optional +import json from ms_agent.llm import LLM, Message from ms_agent.memory import Memory from ms_agent.utils.logger import logger @@ -44,7 +44,8 @@ def __init__(self, config): self.reserved_buffer = getattr(mem_config, 'reserved_buffer', 20000) # Summary prompt - self.summary_prompt = getattr(mem_config, 'summary_prompt', SUMMARY_PROMPT) + self.summary_prompt = getattr(mem_config, 'summary_prompt', + SUMMARY_PROMPT) # LLM for summarization self.llm: Optional[LLM] = None @@ -66,7 +67,9 @@ def _estimate_message_tokens_from_content(self, msg: Message) -> int: """Heuristic token count from message body (no API usage fields).""" total = 0 if msg.content: - content = msg.content if isinstance(msg.content, str) else json.dumps(msg.content, ensure_ascii=False) + content = msg.content if isinstance( + msg.content, str) else json.dumps( + msg.content, ensure_ascii=False) total += self.estimate_tokens(content) if msg.tool_calls: total += self.estimate_tokens(json.dumps(msg.tool_calls)) @@ -96,8 +99,11 @@ def estimate_total_tokens(self, messages: List[Message]) -> int: break if last_usage_idx >= 0: m = messages[last_usage_idx] - base = int(getattr(m, 'prompt_tokens', 0) or 0) + int(getattr(m, 'completion_tokens', 0) or 0) - tail = sum(self._estimate_message_tokens_from_content(x) for x in messages[last_usage_idx + 1 :]) + base = int(getattr(m, 'prompt_tokens', 0) or 0) + int( + getattr(m, 'completion_tokens', 0) or 0) + tail = sum( + self._estimate_message_tokens_from_content(x) + for x in messages[last_usage_idx + 1:]) return base + tail return sum(self.estimate_message_tokens(m) for m in messages) @@ -122,7 +128,9 @@ def prune_tool_outputs(self, messages: List[Message]) -> List[Message]: msg = messages[idx] if msg.role != 'tool' or not msg.content: continue - content_str = msg.content if isinstance(msg.content, str) else json.dumps(msg.content, ensure_ascii=False) + content_str = msg.content if isinstance( + msg.content, str) else json.dumps( + msg.content, ensure_ascii=False) tokens = self.estimate_tokens(content_str) total_tool_tokens += tokens @@ -144,7 +152,8 @@ def summarize(self, messages: List[Message]) -> Optional[str]: conv_parts = [] for msg in messages: role = msg.role.upper() - content = msg.content if isinstance(msg.content, str) else str(msg.content) + content = msg.content if isinstance(msg.content, str) else str( + msg.content) if content: conv_parts.append(f'{role}: {content[:2000]}') @@ -152,7 +161,8 @@ def summarize(self, messages: List[Message]) -> Optional[str]: query = f'{self.summary_prompt}\n\n---\n{conversation}' try: - response = self.llm.generate([Message(role='user', content=query)], stream=False) + response = self.llm.generate([Message(role='user', content=query)], + stream=False) return response.content except Exception as e: logger.error(f'Summary generation failed: {e}') @@ -189,8 +199,10 @@ def compress(self, messages: List[Message]) -> List[Message]: break result.append( - Message(role='user', content=f'[Conversation Summary]\n{summary}\n\nPlease continue based on this summary.') - ) + Message( + role='user', + content=f'[Conversation Summary]\n{summary}\n\n' + 'Please continue based on this summary.')) # Keep the most recent user message if different if messages and messages[-1].role == 'user': @@ -198,7 +210,8 @@ def compress(self, messages: List[Message]) -> List[Message]: if last_user.content and last_user.content != result[-1].content: result.append(last_user) - logger.info(f'Compressed {len(messages)} messages to {len(result)} messages') + logger.info( + f'Compressed {len(messages)} messages to {len(result)} messages') return result async def run(self, messages: List[Message]) -> List[Message]: diff --git a/ms_agent/memory/condenser/refine_condenser.py b/ms_agent/memory/condenser/refine_condenser.py index d779e4b73..557e2a1ff 100644 --- a/ms_agent/memory/condenser/refine_condenser.py +++ b/ms_agent/memory/condenser/refine_condenser.py @@ -1,6 +1,6 @@ -import json from typing import List +import json from ms_agent.llm import LLM, Message from ms_agent.memory import Memory @@ -68,7 +68,8 @@ def __init__(self, config): self.threshold = getattr(mem_config, 'threshold', 60000) async def condense_memory(self, messages): - if len(str(messages)) > self.threshold and messages[-1].role in ('user', 'tool'): + if len(str(messages)) > self.threshold and messages[-1].role in ( + 'user', 'tool'): keep_messages = messages[:2] # keep system and user keep_messages_tail = [] i = 0 @@ -79,23 +80,24 @@ async def condense_memory(self, messages): keep_messages_tail = reversed(keep_messages_tail) compress_messages = json.dumps( - [message.to_dict_clean() for message in messages[2 : -i - 1]], ensure_ascii=False, indent=2 - ) + [message.to_dict_clean() for message in messages[2:-i - 1]], + ensure_ascii=False, + indent=2) keep_messages_json = json.dumps( - [message.to_dict_clean() for message in keep_messages], ensure_ascii=False, indent=2 - ) + [message.to_dict_clean() for message in keep_messages], + ensure_ascii=False, + indent=2) keep_messages_tail_json = json.dumps( - [message.to_dict_clean() for message in keep_messages_tail], ensure_ascii=False, indent=2 - ) + [message.to_dict_clean() for message in keep_messages_tail], + ensure_ascii=False, + indent=2) - query = ( - f'# Messages to be retained\n' - f'## system and user: {keep_messages_json}\n' - f'## Last assistant response: {keep_messages_tail_json}\n' - f'# Messages to be compressed' - f'## These messages are located between system/user ' - f'and the last assistant response: {compress_messages}' - ) + query = (f'# Messages to be retained\n' + f'## system and user: {keep_messages_json}\n' + f'## Last assistant response: {keep_messages_tail_json}\n' + f'# Messages to be compressed' + f'## These messages are located between system/user ' + f'and the last assistant response: {compress_messages}') _messages = [ Message(role='system', content=self.system), @@ -106,21 +108,17 @@ async def condense_memory(self, messages): keep_messages.append( Message( role='user', - content=f'Intermediate messages are compressed, here is the compressed message:\n{content}\n', - ) - ) - messages = ( - keep_messages - + list(keep_messages_tail) - + [ - Message( - role='user', - content='History messages are compressed due to a long sequence, now ' - 'continue solve your problem according to ' - 'the messages and the tool calling:\n', - ) - ] - ) + content= + f'Intermediate messages are compressed, here is the compressed message:\n{content}\n' + )) + messages = keep_messages + list(keep_messages_tail) + [ + Message( + role='user', + content= + 'History messages are compressed due to a long sequence, now ' + 'continue solve your problem according to ' + 'the messages and the tool calling:\n') + ] return messages else: return messages diff --git a/ms_agent/memory/default_memory.py b/ms_agent/memory/default_memory.py index 943334c86..b087a2dd9 100644 --- a/ms_agent/memory/default_memory.py +++ b/ms_agent/memory/default_memory.py @@ -2,7 +2,6 @@ import asyncio import hashlib import importlib -import json import os import re import traceback @@ -11,14 +10,15 @@ from inspect import signature from typing import Any, Dict, List, Optional, Tuple +import json import json5 -from omegaconf import DictConfig, OmegaConf - from ms_agent.llm.utils import Message from ms_agent.memory import Memory from ms_agent.utils import get_fact_retrieval_prompt -from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_SEARCH_LIMIT, DEFAULT_USER, get_service_config +from ms_agent.utils.constants import (DEFAULT_OUTPUT_DIR, DEFAULT_SEARCH_LIMIT, + DEFAULT_USER, get_service_config) from ms_agent.utils.logger import logger +from omegaconf import DictConfig, OmegaConf class MemoryMapping: @@ -28,7 +28,8 @@ class MemoryMapping: enable_idxs: List[int] = [] disable_idx: int = -1 - def __init__(self, memory_id: str, value: str, enable_idxs: int or List[int]): + def __init__(self, memory_id: str, value: str, enable_idxs: int + or List[int]): self.memory_id = memory_id self.value = value self.valid = True @@ -58,15 +59,20 @@ def to_dict(self) -> Dict: 'memory_id': self.memory_id, 'value': self.value, 'valid': self.valid, - 'enable_idxs': self.enable_idxs.copy(), # Return a copy to prevent external modification - 'disable_idx': self.disable_idx, + 'enable_idxs': self.enable_idxs.copy( + ), # Return a copy to prevent external modification + 'disable_idx': self.disable_idx } @classmethod def from_dict(cls, data: Dict) -> 'MemoryMapping': - instance = cls(memory_id=data['memory_id'], value=data['value'], enable_idxs=data['enable_idxs']) + instance = cls( + memory_id=data['memory_id'], + value=data['value'], + enable_idxs=data['enable_idxs']) instance.valid = data['valid'] - instance.disable_idx = data.get('disable_idx', -1) # Compatible with old data + instance.disable_idx = data.get('disable_idx', + -1) # Compatible with old data return instance @@ -76,16 +82,22 @@ class DefaultMemory(Memory): def __init__(self, config: DictConfig): super().__init__(config) memory_config = config.memory.default_memory - self.user_id: Optional[str] = getattr(memory_config, 'user_id', DEFAULT_USER) + self.user_id: Optional[str] = getattr(memory_config, 'user_id', + DEFAULT_USER) self.agent_id: Optional[str] = getattr(memory_config, 'agent_id', None) self.run_id: Optional[str] = getattr(memory_config, 'run_id', None) self.compress: Optional[bool] = getattr(config, 'compress', True) self.is_retrieve: Optional[bool] = getattr(config, 'is_retrieve', True) - self.path: Optional[str] = getattr(memory_config, 'path', os.path.join(DEFAULT_OUTPUT_DIR, '.default_memory')) + self.path: Optional[str] = getattr( + memory_config, 'path', + os.path.join(DEFAULT_OUTPUT_DIR, '.default_memory')) self.history_mode = getattr(memory_config, 'history_mode', 'add') - self.ignore_roles: List[str] = getattr(memory_config, 'ignore_roles', ['tool', 'system']) - self.ignore_fields: List[str] = getattr(memory_config, 'ignore_fields', ['reasoning_content']) - self.search_limit: int = getattr(memory_config, 'search_limit', DEFAULT_SEARCH_LIMIT) + self.ignore_roles: List[str] = getattr(memory_config, 'ignore_roles', + ['tool', 'system']) + self.ignore_fields: List[str] = getattr(memory_config, 'ignore_fields', + ['reasoning_content']) + self.search_limit: int = getattr(memory_config, 'search_limit', + DEFAULT_SEARCH_LIMIT) # Add lock for thread safety in shared usage self._lock = asyncio.Lock() self.memory = self._init_memory_obj() @@ -111,7 +123,7 @@ def save_cache(self): str(k): ([msg.to_dict() for msg in msg_list], _hash) for k, (msg_list, _hash) in self.cache_messages.items() }, - 'memory_snapshot': [mm.to_dict() for mm in self.memory_snapshot], + 'memory_snapshot': [mm.to_dict() for mm in self.memory_snapshot] } with open(cache_file, 'w', encoding='utf-8') as f: @@ -145,7 +157,10 @@ def load_cache(self): self.cache_messages = cache_messages # Parse memory_snapshot - self.memory_snapshot = [MemoryMapping.from_dict(d) for d in data.get('memory_snapshot', [])] + self.memory_snapshot = [ + MemoryMapping.from_dict(d) + for d in data.get('memory_snapshot', []) + ] except (json.JSONDecodeError, KeyError, Exception) as e: logger.warning(f'Failed to load cache: {e}') @@ -164,6 +179,7 @@ def _delete_single(self, msg_id: int): idx = 0 while idx < len(self.memory_snapshot): + enable_ids = self.memory_snapshot[idx].enable_idxs disable_id = self.memory_snapshot[idx].disable_idx if msg_id == disable_id: @@ -175,8 +191,9 @@ def _delete_single(self, msg_id: int): metadata['run_id'] = self.run_id try: self.memory._create_memory( - data=self.memory_snapshot[idx].value, existing_embeddings={}, metadata=metadata - ) + data=self.memory_snapshot[idx].value, + existing_embeddings={}, + metadata=metadata) except Exception as e: logger.warning(f'Failed to recover memory: {e}') if msg_id in enable_ids: @@ -189,15 +206,13 @@ def _delete_single(self, msg_id: int): idx += 1 - async def add_single( - self, - messages: List[Message], - user_id: Optional[int] = None, - agent_id: Optional[int] = None, - run_id: Optional[int] = None, - memory_type: Optional[str] = None, - msg_id: Optional[int] = None, - ) -> None: + async def add_single(self, + messages: List[Message], + user_id: Optional[int] = None, + agent_id: Optional[int] = None, + run_id: Optional[int] = None, + memory_type: Optional[str] = None, + msg_id: Optional[int] = None) -> None: messages_dict = [] for message in messages: if isinstance(message, Message): @@ -218,16 +233,16 @@ async def add_single( user_id=user_id or self.user_id, agent_id=agent_id or self.agent_id, run_id=run_id or self.run_id, - memory_type=memory_type, - ) + memory_type=memory_type) logger.info('Add memory success.') except Exception as e: logger.warning(f'Failed to add memory: {e}') if self.history_mode == 'overwrite': res = self.memory.get_all( - user_id=user_id or self.user_id, agent_id=agent_id or self.agent_id, run_id=run_id or self.run_id - ) # sorted + user_id=user_id or self.user_id, + agent_id=agent_id or self.agent_id, + run_id=run_id or self.run_id) # sorted res = [(item['id'], item['memory']) for item in res['results']] if len(res): logger.info('All memory info:') @@ -251,11 +266,14 @@ async def add_single( for item in self.memory_snapshot: if item.memory_id not in valids: item.disable(msg_id) - for id, memory in unmatched: - m = MemoryMapping(memory_id=id, value=memory, enable_idxs=msg_id) + for (id, memory) in unmatched: + m = MemoryMapping( + memory_id=id, value=memory, enable_idxs=msg_id) self.memory_snapshot.append(m) - def search(self, query: str, meta_infos: List[Dict[str, Any]] = None) -> List[str]: + def search(self, + query: str, + meta_infos: List[Dict[str, Any]] = None) -> List[str]: """ Search for relevant memories based on a query string and optional metadata filters. @@ -284,14 +302,12 @@ def search(self, query: str, meta_infos: List[Dict[str, Any]] = None) -> List[st (self.user_id, self.agent_id, etc.) is used as fallback. """ if meta_infos is None: - meta_infos = [ - { - 'user_id': self.user_id, - 'agent_id': self.agent_id, - 'run_id': self.run_id, - 'limit': self.search_limit, - } - ] + meta_infos = [{ + 'user_id': self.user_id, + 'agent_id': self.agent_id, + 'run_id': self.run_id, + 'limit': self.search_limit, + }] memories = [] for meta_info in meta_infos: user_id = meta_info.get('user_id', None) @@ -303,12 +319,13 @@ def search(self, query: str, meta_infos: List[Dict[str, Any]] = None) -> List[st user_id=user_id or self.user_id, agent_id=agent_id or self.agent_id, run_id=run_id or self.run_id, - limit=limit, - ) - memories.extend([entry['memory'] for entry in relevant_memories['results']]) + limit=limit) + memories.extend( + [entry['memory'] for entry in relevant_memories['results']]) return memories - def _split_into_blocks(self, messages: List[Message]) -> List[List[Message]]: + def _split_into_blocks(self, + messages: List[Message]) -> List[List[Message]]: """ Split messages into blocks where each block starts with a 'user' message and includes all following non-user messages until the next 'user' (exclusive). @@ -345,20 +362,25 @@ def _hash_block(self, block: List[Message]) -> str: """Compute sha256 hash of a message block for comparison""" data = [message.to_dict_clean() for message in block] allow_role = ['user', 'system', 'assistant', 'tool'] - allow_role = [role for role in allow_role if role not in self.ignore_roles] + allow_role = [ + role for role in allow_role if role not in self.ignore_roles + ] allow_fields = ['reasoning_content', 'content', 'tool_calls', 'role'] - allow_fields = [field for field in allow_fields if field not in self.ignore_fields] - - data = [ - {field: value for field, value in msg.items() if field in allow_fields} - for msg in data - if msg['role'] in allow_role + allow_fields = [ + field for field in allow_fields if field not in self.ignore_fields ] + data = [{ + field: value + for field, value in msg.items() if field in allow_fields + } for msg in data if msg['role'] in allow_role] + block_data = json5.dumps(data) return hashlib.sha256(block_data.encode('utf-8')).hexdigest() - def _analyze_messages(self, messages: List[Message]) -> Tuple[List[List[Message]], List[int]]: + def _analyze_messages( + self, + messages: List[Message]) -> Tuple[List[List[Message]], List[int]]: """ Analyze incoming messages against cache. @@ -368,7 +390,8 @@ def _analyze_messages(self, messages: List[Message]) -> Tuple[List[List[Message] """ new_blocks = self._split_into_blocks(messages) self.cache_messages = dict(sorted(self.cache_messages.items())) - cache_messages = [(key, value) for key, value in self.cache_messages.items()] + cache_messages = [(key, value) + for key, value in self.cache_messages.items()] first_unmatched_idx = -1 @@ -376,7 +399,8 @@ def _analyze_messages(self, messages: List[Message]) -> Tuple[List[List[Message] block_hash = self._hash_block(new_blocks[idx]) # Must allow comparison up to the last cache entry - if idx < len(cache_messages) and str(block_hash) == str(cache_messages[idx][1][1]): + if idx < len(cache_messages) and str(block_hash) == str( + cache_messages[idx][1][1]): continue # mismatch @@ -386,12 +410,16 @@ def _analyze_messages(self, messages: List[Message]) -> Tuple[List[List[Message] # If all new_blocks match but the cache has extra entries → delete the extra cache entries if first_unmatched_idx == -1: should_add_messages = [] - should_delete = [item[0] for item in cache_messages[len(new_blocks) :]] + should_delete = [ + item[0] for item in cache_messages[len(new_blocks):] + ] return should_add_messages, should_delete # On mismatch: add all new blocks and delete all cache entries starting from the mismatch index should_add_messages = new_blocks[first_unmatched_idx:] - should_delete = [item[0] for item in cache_messages[first_unmatched_idx:]] + should_delete = [ + item[0] for item in cache_messages[first_unmatched_idx:] + ] return should_add_messages, should_delete @@ -417,8 +445,9 @@ async def add( for msg_id in should_delete: self._delete_single(msg_id=msg_id) res = self.memory.get_all( - user_id=user_id or self.user_id, agent_id=agent_id or self.agent_id, run_id=run_id or self.run_id - ) # sorted + user_id=user_id or self.user_id, + agent_id=agent_id or self.agent_id, + run_id=run_id or self.run_id) # sorted res = [(item['id'], item['memory']) for item in res['results']] logger.info('Roll back success. All memory info:') for item in res: @@ -427,8 +456,11 @@ async def add( for messages in should_add_messages: messages = self.parse_messages(messages) await self.add_single( - messages, user_id=user_id, agent_id=agent_id, run_id=run_id, memory_type=memory_type - ) + messages, + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + memory_type=memory_type) self.save_cache() def parse_messages(self, messages: List[Message]) -> List[Message]: @@ -448,17 +480,16 @@ def parse_messages(self, messages: List[Message]) -> List[Message]: return new_messages - def delete( - self, - user_id: Optional[str] = None, - agent_id: Optional[str] = None, - run_id: Optional[str] = None, - memory_ids: Optional[List[str]] = None, - ) -> Tuple[bool, str]: + def delete(self, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + memory_ids: Optional[List[str]] = None) -> Tuple[bool, str]: failed = {} if memory_ids is None: try: - self.memory.delete_all(user_id=user_id, agent_id=agent_id, run_id=run_id) + self.memory.delete_all( + user_id=user_id, agent_id=agent_id, run_id=run_id) return True, '' except Exception as e: return False, str(e) + '\n' + traceback.format_exc() @@ -466,7 +497,9 @@ def delete( try: self.memory.delete(memory_id=memory_id) except IndexError: - failed[memory_id] = 'This memory_id does not exist in the database.\n' + traceback.format_exc() # noqa + failed[ + memory_id] = 'This memory_id does not exist in the database.\n' + traceback.format_exc( + ) # noqa except Exception as e: failed[memory_id] = str(e) + '\n' + traceback.format_exc() if failed: @@ -474,42 +507,54 @@ def delete( else: return True, '' - def get_all(self, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None): + def get_all(self, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None): try: - res = self.memory.get_all(user_id=user_id or self.user_id, agent_id=agent_id, run_id=run_id) + res = self.memory.get_all( + user_id=user_id or self.user_id, + agent_id=agent_id, + run_id=run_id) return res['results'] except Exception: return [] - def _get_latest_user_message(self, messages: List[Message]) -> Optional[str]: + def _get_latest_user_message(self, + messages: List[Message]) -> Optional[str]: """Get the latest user message content.""" for message in reversed(messages): if message.role == 'user' and hasattr(message, 'content'): return message.content return None - def _inject_memories_into_messages( - self, messages: List[Message], memories: List[str], keep_details - ) -> List[Message]: + def _inject_memories_into_messages(self, messages: List[Message], + memories: List[str], + keep_details) -> List[Message]: """Inject relevant memories into the system message.""" # Format memories for injection - memories_str = 'User Memories:\n' + '\n'.join(f'- {memory}' for memory in memories) + memories_str = 'User Memories:\n' + '\n'.join(f'- {memory}' + for memory in memories) # Remove the messages section corresponding to memory, and add the related memory_str information if getattr(messages[0], 'role') == 'system': - system_prompt = getattr(messages[0], 'content') + f'\nUser Memories: {memories_str}' + system_prompt = getattr( + messages[0], 'content') + f'\nUser Memories: {memories_str}' remain_idx = 1 else: - system_prompt = ( - f'\nYou are a helpful assistant. Answer the question based on query and memories.\n' - f'User Memories: {memories_str}' - ) + system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\n' \ + f'User Memories: {memories_str}' remain_idx = 0 if not keep_details: - should_add_messages, should_delete = self._analyze_messages(messages) - remain_idx = max(remain_idx, len(messages) - sum([len(block) for block in should_add_messages])) - - new_messages = [Message(role='system', content=system_prompt)] + messages[remain_idx:] + should_add_messages, should_delete = self._analyze_messages( + messages) + remain_idx = max( + remain_idx, + len(messages) + - sum([len(block) for block in should_add_messages])) + + new_messages = [Message(role='system', content=system_prompt) + ] + messages[remain_idx:] return new_messages async def run( @@ -531,46 +576,51 @@ async def run( logger.warning(f'Failed to search memories: {search_error}') memories = [] if memories: - messages = self._inject_memories_into_messages(messages, memories, keep_details) + messages = self._inject_memories_into_messages( + messages, memories, keep_details) return messages def _init_memory_obj(self): try: import mem0 except ImportError as e: - logger.error(f'Failed to import mem0: {e}. Please install mem0ai package via `pip install mem0ai`.') + logger.error( + f'Failed to import mem0: {e}. Please install mem0ai package via `pip install mem0ai`.' + ) raise capture_event_origin = mem0.memory.main.capture_event @wraps(capture_event_origin) - def patched_capture_event(event_name, memory_instance, additional_data=None): + def patched_capture_event(event_name, + memory_instance, + additional_data=None): pass - mem0.memory.main.capture_event = partial( - patched_capture_event, - ) + mem0.memory.main.capture_event = partial(patched_capture_event, ) # emb config embedder = None - embedder_config = getattr(self.config.memory.default_memory, 'embedder', OmegaConf.create({})) + embedder_config = getattr(self.config.memory.default_memory, + 'embedder', OmegaConf.create({})) service = getattr(embedder_config, 'service', 'modelscope') api_key = getattr(embedder_config, 'api_key', None) - emb_model = getattr(embedder_config, 'model', 'Qwen/Qwen3-Embedding-8B') - embedding_dims = getattr(embedder_config, 'embedding_dims', None) # for vector store config + emb_model = getattr(embedder_config, 'model', + 'Qwen/Qwen3-Embedding-8B') + embedding_dims = getattr(embedder_config, 'embedding_dims', + None) # for vector store config if self.is_retrieve: - embedder = OmegaConf.create( - { - 'provider': 'openai', - 'config': { - 'api_key': api_key or os.getenv(f'{service.upper()}_API_KEY'), - 'openai_base_url': get_service_config(service).base_url, - 'model': emb_model, - 'embedding_dims': embedding_dims, - }, + embedder = OmegaConf.create({ + 'provider': 'openai', + 'config': { + 'api_key': api_key + or os.getenv(f'{service.upper()}_API_KEY'), + 'openai_base_url': get_service_config(service).base_url, + 'model': emb_model, + 'embedding_dims': embedding_dims } - ) + }) # llm config llm = None @@ -578,25 +628,32 @@ def patched_capture_event(event_name, memory_instance, additional_data=None): llm_config = getattr(self.config, 'llm', None) if llm_config is not None: service = getattr(llm_config, 'service', 'modelscope') - llm_model = getattr(llm_config, 'model', 'Qwen/Qwen3-Coder-30B-A3B-Instruct') + llm_model = getattr(llm_config, 'model', + 'Qwen/Qwen3-Coder-30B-A3B-Instruct') api_key = getattr(llm_config, f'{service}_api_key', None) - openai_base_url = getattr(llm_config, f'{service}_base_url', None) + openai_base_url = getattr(llm_config, f'{service}_base_url', + None) gen_cfg = getattr(self.config, 'generation_config', None) max_tokens = getattr(gen_cfg, 'max_tokens', None) llm = { 'provider': 'openai', 'config': { - 'model': llm_model, - 'api_key': api_key or os.getenv(f'{service.upper()}_API_KEY'), - 'openai_base_url': openai_base_url or get_service_config(service).base_url, - }, + 'model': + llm_model, + 'api_key': + api_key or os.getenv(f'{service.upper()}_API_KEY'), + 'openai_base_url': + openai_base_url + or get_service_config(service).base_url, + } } if max_tokens is not None: llm['config']['max_tokens'] = max_tokens # vector_store config - def sanitize_database_name(ori_name: str, default_name: str = 'default') -> str: + def sanitize_database_name(ori_name: str, + default_name: str = 'default') -> str: if not ori_name or not isinstance(ori_name, str): return default_name sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', ori_name) @@ -608,8 +665,10 @@ def sanitize_database_name(ori_name: str, default_name: str = 'default') -> str: sanitized = f'col_{sanitized}' return sanitized - vector_store_config = getattr(self.config.memory.default_memory, 'vector_store', OmegaConf.create({})) - vector_store_provider = getattr(vector_store_config, 'service', 'qdrant') + vector_store_config = getattr(self.config.memory.default_memory, + 'vector_store', OmegaConf.create({})) + vector_store_provider = getattr(vector_store_config, 'service', + 'qdrant') on_disk = getattr(vector_store_config, 'on_disk', True) path = getattr(vector_store_config, 'path', self.path) db_name = getattr(vector_store_config, 'db_name', None) @@ -618,12 +677,13 @@ def sanitize_database_name(ori_name: str, default_name: str = 'default') -> str: collection_name = getattr(vector_store_config, 'collection_name', path) db_name = sanitize_database_name(db_name) if db_name else None - collection_name = sanitize_database_name(collection_name) if collection_name else None + collection_name = sanitize_database_name( + collection_name) if collection_name else None # check value from mem0.memory.main import VectorStoreFactory - - class_type = VectorStoreFactory.provider_to_class.get(vector_store_provider) + class_type = VectorStoreFactory.provider_to_class.get( + vector_store_provider) if class_type: module_path, class_name = class_type.rsplit('.', 1) module = importlib.import_module(module_path) @@ -637,10 +697,17 @@ def sanitize_database_name(ori_name: str, default_name: str = 'default') -> str: 'url': url, 'token': token, 'db_name': db_name, - 'embedding_model_dims': embedding_dims, + 'embedding_model_dims': embedding_dims + } + config_format = { + key: value + for key, value in config_raw.items() + if value and key in parameters + } + vector_store = { + 'provider': vector_store_provider, + 'config': config_format } - config_format = {key: value for key, value in config_raw.items() if value and key in parameters} - vector_store = {'provider': vector_store_provider, 'config': config_format} else: vector_store = {} @@ -652,14 +719,13 @@ def sanitize_database_name(ori_name: str, default_name: str = 'default') -> str: logger.info(f'Memory config: {mem0_config}') # Prompt content is too long, default logging reduces readability custom_fact_extraction_prompt = getattr( - self.config.memory.default_memory, - 'fact_retrieval_prompt', - getattr(self.config.memory.default_memory, 'custom_fact_extraction_prompt', None), - ) + self.config.memory.default_memory, 'fact_retrieval_prompt', + getattr(self.config.memory.default_memory, + 'custom_fact_extraction_prompt', None)) if custom_fact_extraction_prompt is not None: mem0_config['custom_fact_extraction_prompt'] = ( - custom_fact_extraction_prompt + f'Today\'s date is {datetime.now().strftime("%Y-%m-%d")}.' - ) + custom_fact_extraction_prompt + + f'Today\'s date is {datetime.now().strftime("%Y-%m-%d")}.') try: memory = mem0.Memory.from_config(mem0_config) memory._telemetry_vector_store = None diff --git a/ms_agent/memory/diversity.py b/ms_agent/memory/diversity.py index 75cb885ff..775e80da1 100644 --- a/ms_agent/memory/diversity.py +++ b/ms_agent/memory/diversity.py @@ -3,9 +3,8 @@ from copy import deepcopy from typing import List -from omegaconf import DictConfig - from ms_agent.utils import get_logger +from omegaconf import DictConfig from ..llm import LLM, Message from .base import Memory @@ -14,6 +13,7 @@ class Diversity(Memory): + div_system1 = """You are an inspiration bot. You will be given an original requirement, and you need to provide keywords that you associate with it. The keywords must meet the following conditions: 1. The keywords you provide should be terms, such as "security", "independent module", "aesthetics", "style", "examples", etc. @@ -27,7 +27,7 @@ class Diversity(Memory): 6. Your keywords must be in the same language as the original requirement Here is the original query: -""" # noqa +""" # noqa div_system2 = """You are an inspiration bot. You will be given a series of keywords, and you need to provide related words that you associate with based on these keywords. The words must meet the following conditions: @@ -37,7 +37,7 @@ class Diversity(Memory): 4. Your keywords must be in the same language as the input keywords Here are the keywords: -""" # noqa +""" # noqa div_system3 = """You are an inspiration bot. You will be given a series of keywords and an original requirement. You need to carefully analyze the relationship between the original requirement and the keywords, and provide your suggestions for completing the original requirement based on the keywords: @@ -53,7 +53,7 @@ class Diversity(Memory): 7. Wrap your final suggestions with only one wrapper Here are the original query and the keywords: -""" # noqa +""" # noqa def __init__(self, config): super().__init__(config) @@ -72,7 +72,6 @@ def __init__(self, config): async def _run_tasks_sequential(self, tasks: list) -> str: """Run a list of {system, query} tasks sequentially using LLMAgent.""" from ms_agent.agent import LLMAgent - res = [] for i, task in enumerate(tasks): system = task.get('system', '') @@ -131,7 +130,10 @@ async def run(self, messages: List[Message]): pattern = r'(.*?)' all_keywords = [] for keywords in re.findall(pattern, results, re.DOTALL): - all_keywords.extend([keyword.strip() for keyword in keywords.split(',') if keyword.strip()]) + all_keywords.extend([ + keyword.strip() for keyword in keywords.split(',') + if keyword.strip() + ]) arguments = [] _query = ','.join(set(all_keywords)) @@ -147,11 +149,15 @@ async def run(self, messages: List[Message]): pattern = r'(.*?)' all_keywords = [] for keywords in re.findall(pattern, results, re.DOTALL): - all_keywords.extend([keyword.strip() for keyword in keywords.split(',') if keyword.strip()]) + all_keywords.extend([ + keyword.strip() for keyword in keywords.split(',') + if keyword.strip() + ]) _query = ','.join(set(all_keywords)) logger.info(f'Diversity second round keywords: {_query}') - _query = f'Original query: {query}\nKeywords generated by LLMs: {all_keywords}' + _query = (f'Original query: {query}\n' + f'Keywords generated by LLMs: {all_keywords}') _messages = [ Message(role='system', content=self.div_system3), Message(role='user', content=_query), @@ -168,8 +174,7 @@ async def run(self, messages: List[Message]): suggestions = ( '\nNow Additional suggestions and findings are given to you, ' 'you need to consider these suggestions and carefully process the query:\n' - f'{suggestions}' - ) + f'{suggestions}') if system != query: system = system + suggestions messages[0].content = system diff --git a/ms_agent/memory/memory_manager.py b/ms_agent/memory/memory_manager.py index 1c2622aee..5a203505d 100644 --- a/ms_agent/memory/memory_manager.py +++ b/ms_agent/memory/memory_manager.py @@ -1,22 +1,21 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import Dict -from omegaconf import DictConfig - from ms_agent.memory import Memory, memory_mapping from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_USER +from omegaconf import DictConfig logger = get_logger() class SharedMemoryManager: """Manager for shared memory instances across different agents.""" - _instances: Dict[str, Memory] = {} @classmethod - async def get_shared_memory(cls, config: DictConfig, mem_instance_type: str) -> Memory: + async def get_shared_memory(cls, config: DictConfig, + mem_instance_type: str) -> Memory: """Get or create a shared memory instance based on configuration.""" user_id: str = getattr(config, 'user_id', DEFAULT_USER) path: str = getattr(config, 'path', DEFAULT_OUTPUT_DIR) @@ -27,7 +26,8 @@ async def get_shared_memory(cls, config: DictConfig, mem_instance_type: str) -> logger.info(f'Creating new shared memory instance for key: {key}') cls._instances[key] = memory_mapping[mem_instance_type](config) else: - logger.info(f'Reusing existing shared memory instance for key: {key}') + logger.info( + f'Reusing existing shared memory instance for key: {key}') return cls._instances[key] @@ -46,4 +46,5 @@ def clear_shared_memory(cls, config: DictConfig, mem_instance_type: str): del cls._instances[key] logger.info(f'Cleared shared memory instance for key: {key}') else: - logger.warning(f'No shared memory instance found for key: {key}') + logger.warning( + f'No shared memory instance found for key: {key}') diff --git a/ms_agent/memory/utils.py b/ms_agent/memory/utils.py index b778327c6..b7e20ad30 100644 --- a/ms_agent/memory/utils.py +++ b/ms_agent/memory/utils.py @@ -16,7 +16,9 @@ } -def get_memory_meta_safe(config: DictConfig, key: str, default_user_id: str | None = None): +def get_memory_meta_safe(config: DictConfig, + key: str, + default_user_id: str | None = None): if not hasattr(config, key): return None, None, None, None trigger_config = getattr(config, key, OmegaConf.create({})) diff --git a/ms_agent/prompting/file_resolver.py b/ms_agent/prompting/file_resolver.py index 4c998b1b9..c04b8bdde 100644 --- a/ms_agent/prompting/file_resolver.py +++ b/ms_agent/prompting/file_resolver.py @@ -24,14 +24,15 @@ def candidate_paths(self) -> List[str]: paths = [] if family: - paths.extend( - [ - os.path.join(root, agent, lang, f'{family}.txt'), - os.path.join(root, agent, lang, f'{family}.md'), - ] - ) + paths.extend([ + os.path.join(root, agent, lang, f'{family}.txt'), + os.path.join(root, agent, lang, f'{family}.md'), + ]) # base fallback - paths.extend([os.path.join(root, agent, lang, 'base.txt'), os.path.join(root, agent, lang, 'base.md')]) + paths.extend([ + os.path.join(root, agent, lang, 'base.txt'), + os.path.join(root, agent, lang, 'base.md') + ]) return paths @@ -134,15 +135,18 @@ def _get_prompt_lang_and_family(config: DictConfig) -> Tuple[str, str]: prompt_cfg = getattr(config, 'prompt', None) # lang - env_lang = os.environ.get('MS_AGENT_PROMPT_LANG') or os.environ.get('MS_AGENT_LANG') - cfg_lang = getattr(prompt_cfg, 'lang', None) if isinstance(prompt_cfg, DictConfig) else None + env_lang = os.environ.get('MS_AGENT_PROMPT_LANG') or os.environ.get( + 'MS_AGENT_LANG') + cfg_lang = getattr(prompt_cfg, 'lang', None) if isinstance( + prompt_cfg, DictConfig) else None lang = _norm_lang(cfg_lang or env_lang or 'zh') # family env_family = os.environ.get('MS_AGENT_PROMPT_FAMILY') - cfg_family = getattr(prompt_cfg, 'family', None) if isinstance(prompt_cfg, DictConfig) else None + cfg_family = getattr(prompt_cfg, 'family', None) if isinstance( + prompt_cfg, DictConfig) else None - family = cfg_family or env_family or 'auto' + family = (cfg_family or env_family or 'auto') family = str(family).strip() if not family: family = 'auto' @@ -222,6 +226,7 @@ def apply_prompt_files(config: DictConfig) -> DictConfig: if not hasattr(config, 'prompt') or config.prompt is None: config.prompt = DictConfig({}) - if getattr(config.prompt, 'system', None) is None or not str(getattr(config.prompt, 'system', '')).strip(): + if getattr(config.prompt, 'system', None) is None or not str( + getattr(config.prompt, 'system', '')).strip(): config.prompt.system = prompt_text return config diff --git a/ms_agent/rag/base.py b/ms_agent/rag/base.py index 318b836cb..2d983a2c6 100644 --- a/ms_agent/rag/base.py +++ b/ms_agent/rag/base.py @@ -33,7 +33,11 @@ async def query(self, query: str) -> str: pass @abstractmethod - async def retrieve(self, query: str, limit: int = 5, score_threshold: float = 0.7, **filters) -> List[Any]: + async def retrieve(self, + query: str, + limit: int = 5, + score_threshold: float = 0.7, + **filters) -> List[Any]: """Retrieve documents Args: diff --git a/ms_agent/rag/extraction.py b/ms_agent/rag/extraction.py index 0215d8b7b..c62b10067 100644 --- a/ms_agent/rag/extraction.py +++ b/ms_agent/rag/extraction.py @@ -1,11 +1,11 @@ # flake8: noqa # yapf: disable from abc import ABC, abstractmethod +from typing import Any, Dict, List + from docling_core.transforms.chunker import BaseChunk from docling_core.types import DoclingDocument from docling_core.types.doc import DocItem, DocItemLabel -from typing import Any, Dict, List - from ms_agent.rag.schema import KeyInformation from ms_agent.tools.docling.chunker import HybridDocumentChunker from ms_agent.tools.docling.doc_loader import DocLoader diff --git a/ms_agent/rag/extraction_manager.py b/ms_agent/rag/extraction_manager.py index 909949205..dfee086cd 100644 --- a/ms_agent/rag/extraction_manager.py +++ b/ms_agent/rag/extraction_manager.py @@ -9,16 +9,14 @@ try: import ray # type: ignore - _RAY_AVAILABLE = True except Exception: # pragma: no cover - optional dependency ray = None # type: ignore _RAY_AVAILABLE = False logger.warning( 'Ray is not available. Install it for faster information extraction:\n' - ' pip install "ray[default]"\n' - 'Program will run without acceleration.' - ) + ' pip install \"ray[default]\"\n' + 'Program will run without acceleration.') class InformationExtractionManager: @@ -26,19 +24,19 @@ class InformationExtractionManager: Optimized key information extraction with optional Ray acceleration. """ - def __init__( - self, - verbose: bool = False, - use_ray: bool = False, - ray_num_workers: Optional[int] = None, - ray_cpus_per_task: float = 1.0, - ): + def __init__(self, + verbose: bool = False, + use_ray: bool = False, + ray_num_workers: Optional[int] = None, + ray_cpus_per_task: float = 1.0): self._verbose = verbose self._use_ray = use_ray and _RAY_AVAILABLE self._ray_num_workers = ray_num_workers self._ray_cpus_per_task = ray_cpus_per_task - def extract(self, urls_or_files: List[str]) -> Tuple[List[KeyInformation], Dict[str, str]]: + def extract( + self, urls_or_files: List[str] + ) -> Tuple[List[KeyInformation], Dict[str, str]]: """ Extract key information from URLs or files. @@ -53,27 +51,38 @@ def extract(self, urls_or_files: List[str]) -> Tuple[List[KeyInformation], Dict[ try: return self._extract_with_ray(urls_or_files) except Exception as e: - logger.warning(f'Ray extraction failed, falling back to sequential: {e}') + logger.warning( + f'Ray extraction failed, falling back to sequential: {e}') # Use sequential extraction if Ray is disabled or failed if not _RAY_AVAILABLE: - logger.warning('Ray is not available, falling back to sequential extraction.') + logger.warning( + 'Ray is not available, falling back to sequential extraction.') return self._extract_sequential(urls_or_files) - def _extract_sequential(self, urls_or_files: List[str]) -> Tuple[List[KeyInformation], Dict[str, str]]: + def _extract_sequential( + self, urls_or_files: List[str] + ) -> Tuple[List[KeyInformation], Dict[str, str]]: """Sequential extraction using the original implementation.""" - extractor = HierarchicalKeyInformationExtraction(urls_or_files=urls_or_files, verbose=self._verbose) + extractor = HierarchicalKeyInformationExtraction( + urls_or_files=urls_or_files, verbose=self._verbose) key_info_list = extractor.extract() return key_info_list, extractor.all_ref_items - def _extract_with_ray(self, urls_or_files: List[str]) -> Tuple[List[KeyInformation], Dict[str, str]]: + def _extract_with_ray( + self, urls_or_files: List[str] + ) -> Tuple[List[KeyInformation], Dict[str, str]]: """Ray-accelerated extraction.""" if not ray.is_initialized(): - ray.init(ignore_reinit_error=True, include_dashboard=False, log_to_driver=False) + ray.init( + ignore_reinit_error=True, + include_dashboard=False, + log_to_driver=False) # Determine optimal worker count - max_workers = self._ray_num_workers or min(len(urls_or_files), (os.cpu_count() or 4)) + max_workers = self._ray_num_workers or min( + len(urls_or_files), (os.cpu_count() or 4)) max_workers = max(1, max_workers) # Partition URLs/files among workers: should be balanced @@ -84,8 +93,7 @@ def _extract_with_ray(self, urls_or_files: List[str]) -> Tuple[List[KeyInformati # Create actors and dispatch tasks actors = [ _ExtractionWorker.options(num_cpus=self._ray_cpus_per_task).remote( - urls_or_files=partitions[i], verbose=self._verbose - ) + urls_or_files=partitions[i], verbose=self._verbose) for i in range(max_workers) ] @@ -93,7 +101,8 @@ def _extract_with_ray(self, urls_or_files: List[str]) -> Tuple[List[KeyInformati for exraction_actor in actors: futures.append(exraction_actor.process_partition.remote()) - results: List[Tuple[List[KeyInformation], Dict[str, str]]] = ray.get(futures) + results: List[Tuple[List[KeyInformation], + Dict[str, str]]] = ray.get(futures) # Merge results merged_infos: List[KeyInformation] = [] @@ -115,9 +124,11 @@ class _ExtractionWorker: def __init__(self, urls_or_files: List[str], verbose: bool = False): self._verbose = verbose self._urls_or_files = urls_or_files - self.extractor = HierarchicalKeyInformationExtraction(urls_or_files=self._urls_or_files, verbose=verbose) + self.extractor = HierarchicalKeyInformationExtraction( + urls_or_files=self._urls_or_files, verbose=verbose) - def process_partition(self) -> Tuple[List[KeyInformation], Dict[str, str]]: + def process_partition( + self) -> Tuple[List[KeyInformation], Dict[str, str]]: """Process a partition of URLs/files and return extracted information.""" try: key_info_list_partition = self.extractor.extract() @@ -134,7 +145,7 @@ def extract_key_information( use_ray: bool = False, verbose: bool = False, ray_num_workers: Optional[int] = None, - ray_cpus_per_task: float = 1.0, + ray_cpus_per_task: float = 1.0 ) -> Tuple[List[KeyInformation], Dict[str, str]]: """ High-level function to extract key information with optional Ray acceleration. @@ -150,7 +161,9 @@ def extract_key_information( Tuple of (key_info_list, resource_map) """ extractor = InformationExtractionManager( - verbose=verbose, use_ray=use_ray, ray_num_workers=ray_num_workers, ray_cpus_per_task=ray_cpus_per_task - ) + verbose=verbose, + use_ray=use_ray, + ray_num_workers=ray_num_workers, + ray_cpus_per_task=ray_cpus_per_task) return extractor.extract(urls_or_files) diff --git a/ms_agent/rag/llama_index_rag.py b/ms_agent/rag/llama_index_rag.py index 8156b284b..e4535d38d 100644 --- a/ms_agent/rag/llama_index_rag.py +++ b/ms_agent/rag/llama_index_rag.py @@ -2,11 +2,10 @@ import shutil from typing import Any, List, Optional -from modelscope import snapshot_download -from omegaconf import DictConfig - from ms_agent.utils import assert_package_exist +from omegaconf import DictConfig +from modelscope import snapshot_download from ..llm import LLM, Message from .base import RAG @@ -29,7 +28,8 @@ def __init__(self, config: DictConfig): super().__init__(config) self._validate_config(config) - self.embedding_model = getattr(config.rag, 'embedding', 'Qwen/Qwen3-Embedding-0.6B') + self.embedding_model = getattr(config.rag, 'embedding', + 'Qwen/Qwen3-Embedding-0.6B') self.llm_model = getattr(config.rag, 'llm', None) self.chunk_size = getattr(config.rag, 'chunk_size', 512) self.chunk_overlap = getattr(config.rag, 'chunk_overlap', 50) @@ -41,21 +41,22 @@ def __init__(self, config: DictConfig): from llama_index.core import Settings from llama_index.core.node_parser import SentenceSplitter - # Set node parser - Settings.node_parser = SentenceSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) + Settings.node_parser = SentenceSplitter( + chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) # If retrieve only, don't set LLM if self.retrieve_only: Settings.llm = None else: - from llama_index.core.base.llms.types import CompletionResponse, LLMMetadata from llama_index.core.llms import CustomLLM + from llama_index.core.base.llms.types import LLMMetadata from llama_index.core.llms.callbacks import llm_completion_callback - + from llama_index.core.base.llms.types import CompletionResponse self._llm_instance = LLM.from_config(self.config) class MSCustomLLM(CustomLLM): + @property def metadata(_self) -> LLMMetadata: return LLMMetadata( @@ -65,17 +66,23 @@ def metadata(_self) -> LLMMetadata: ) @llm_completion_callback() - def complete(_self, prompt: str, **kwargs) -> CompletionResponse: + def complete(_self, prompt: str, + **kwargs) -> CompletionResponse: message: Message = self._llm_instance.generate( - messages=[Message(role='user', content=prompt)], stream=False, **kwargs - ) + messages=[Message(role='user', content=prompt)], + stream=False, + **kwargs) return CompletionResponse(text=message.content) @llm_completion_callback() - def stream_complete(_self, prompt: str, formatted: bool = False, **kwargs: Any): + def stream_complete(_self, + prompt: str, + formatted: bool = False, + **kwargs: Any): for message in self._llm_instance.generate( - messages=[Message(role='user', content=prompt)], stream=True, **kwargs - ): + messages=[Message(role='user', content=prompt)], + stream=True, + **kwargs): yield CompletionResponse(text=message.content) Settings.llm = MSCustomLLM() @@ -88,28 +95,28 @@ def _validate_requirements(self): 'llama_index', 'Please install llama_index to support llama-index-rag:\n' '> pip install -U llama-index-core llama-index-embeddings-huggingface ' - 'llama-index-llms-openai llama-index-llms-replicate\n', - ) + 'llama-index-llms-openai llama-index-llms-replicate\n') def _validate_config(self, config: DictConfig): """Validate configuration parameters""" if not hasattr(config, 'rag') or not hasattr(config.rag, 'embedding'): - raise ValueError('Missing rag.embedding parameter in configuration') + raise ValueError( + 'Missing rag.embedding parameter in configuration') chunk_size = getattr(config.rag, 'chunk_size', 512) if chunk_size <= 0: raise ValueError('chunk_size must be greater than 0') def _setup_embedding_model(self, config: DictConfig): - from llama_index.core import Settings + from llama_index.core import (Settings) from llama_index.embeddings.huggingface import HuggingFaceEmbedding - try: use_hf = getattr(config, 'use_huggingface', False) if not use_hf: self.embedding_model = snapshot_download(self.embedding_model) - Settings.embed_model = HuggingFaceEmbedding(model_name=self.embedding_model, device='cpu') + Settings.embed_model = HuggingFaceEmbedding( + model_name=self.embedding_model, device='cpu') except Exception as e: raise RuntimeError(f'Failed to load embedding model: {e}') @@ -117,8 +124,7 @@ def _setup_embedding_model(self, config: DictConfig): async def add_documents(self, documents: List[str]): if not documents: raise ValueError('Document list cannot be empty') - from llama_index.core import Document, VectorStoreIndex - + from llama_index.core import (Document, VectorStoreIndex) docs = [Document(text=doc) for doc in documents] self.index = VectorStoreIndex.from_documents(docs) if not self.retrieve_only: @@ -130,7 +136,6 @@ async def add_documents_from_files(self, file_paths: List[str]): from llama_index.core import VectorStoreIndex from llama_index.core.readers import SimpleDirectoryReader - documents = [] for file_path in file_paths: if not os.path.exists(file_path): @@ -154,14 +159,18 @@ async def _setup_query_engine(self): return from llama_index.core import Settings - # Check if LLM is set if Settings.llm is None and not self.retrieve_only: return - self.query_engine = self.index.as_query_engine(similarity_top_k=5, response_mode='compact') + self.query_engine = self.index.as_query_engine( + similarity_top_k=5, response_mode='compact') - async def _retrieve(self, query: str, limit: int = 5, score_threshold: float = 0.0, **filters) -> List[dict]: + async def _retrieve(self, + query: str, + limit: int = 5, + score_threshold: float = 0.0, + **filters) -> List[dict]: if self.index is None: return [] @@ -169,55 +178,58 @@ async def _retrieve(self, query: str, limit: int = 5, score_threshold: float = 0 return [] from llama_index.core.retrievers import VectorIndexRetriever - - retriever = VectorIndexRetriever(index=self.index, similarity_top_k=limit) + retriever = VectorIndexRetriever( + index=self.index, similarity_top_k=limit) nodes = retriever.retrieve(query) results = [] for node in nodes: if node.score >= score_threshold: - results.append( - { - 'text': node.node.text, - 'score': float(node.score), - 'metadata': node.node.metadata, - 'node_id': node.node.node_id, - } - ) + results.append({ + 'text': node.node.text, + 'score': float(node.score), + 'metadata': node.node.metadata, + 'node_id': node.node.node_id + }) return results - async def retrieve(self, query: str, limit: int = 5, score_threshold: float = 0.0, **filters) -> List[dict]: + async def retrieve(self, + query: str, + limit: int = 5, + score_threshold: float = 0.0, + **filters) -> List[dict]: if self.retrieve_only: - return await self._retrieve(query, limit, score_threshold, **filters) + return await self._retrieve(query, limit, score_threshold, + **filters) from llama_index.core import Settings from llama_index.core.postprocessor import SimilarityPostprocessor from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core.retrievers import VectorIndexRetriever - if self.index is None or Settings.llm is None: return [] - retriever = VectorIndexRetriever(index=self.index, similarity_top_k=limit) + retriever = VectorIndexRetriever( + index=self.index, similarity_top_k=limit) - postprocessor = SimilarityPostprocessor(similarity_cutoff=score_threshold) + postprocessor = SimilarityPostprocessor( + similarity_cutoff=score_threshold) - query_engine = RetrieverQueryEngine(retriever=retriever, node_postprocessors=[postprocessor]) + query_engine = RetrieverQueryEngine( + retriever=retriever, node_postprocessors=[postprocessor]) response = query_engine.query(query) results = [] for node in response.source_nodes: - results.append( - { - 'text': node.node.text, - 'score': float(node.score), - 'metadata': node.node.metadata, - 'node_id': node.node.node_id, - } - ) + results.append({ + 'text': node.node.text, + 'score': float(node.score), + 'metadata': node.node.metadata, + 'node_id': node.node.node_id + }) return results @@ -227,18 +239,17 @@ async def hybrid_search(self, query: str, top_k: int = 5) -> List[dict]: return [] from llama_index.core.retrievers import VectorIndexRetriever - # Try to import BM25 related modules try: - from llama_index.core.retrievers import QueryFusionRetriever from llama_index.retrievers.bm25 import BM25Retriever - + from llama_index.core.retrievers import QueryFusionRetriever bm25_available = True except ImportError: bm25_available = False # Vector retriever - vector_retriever = VectorIndexRetriever(index=self.index, similarity_top_k=top_k) + vector_retriever = VectorIndexRetriever( + index=self.index, similarity_top_k=top_k) if not bm25_available: # Use vector retrieval only @@ -246,11 +257,13 @@ async def hybrid_search(self, query: str, top_k: int = 5) -> List[dict]: else: # Use hybrid retrieval try: - bm25_retriever = BM25Retriever.from_defaults(docstore=self.index.docstore, similarity_top_k=top_k) + bm25_retriever = BM25Retriever.from_defaults( + docstore=self.index.docstore, similarity_top_k=top_k) fusion_retriever = QueryFusionRetriever( - retrievers=[vector_retriever, bm25_retriever], similarity_top_k=top_k, num_queries=1 - ) + retrievers=[vector_retriever, bm25_retriever], + similarity_top_k=top_k, + num_queries=1) nodes = fusion_retriever.retrieve(query) @@ -259,23 +272,25 @@ async def hybrid_search(self, query: str, top_k: int = 5) -> List[dict]: results = [] for node in nodes: - results.append( - { - 'text': node.node.text, - 'score': float(node.score), - 'metadata': node.node.metadata, - 'node_id': node.node.node_id, - } - ) + results.append({ + 'text': node.node.text, + 'score': float(node.score), + 'metadata': node.node.metadata, + 'node_id': node.node.node_id + }) return results async def query(self, query: str) -> str: if self.query_engine is None: if self.retrieve_only: - raise ValueError('Current mode is retrieve only, question answering not supported') + raise ValueError( + 'Current mode is retrieve only, question answering not supported' + ) else: - raise ValueError('Query engine not initialized, please add documents and set LLM first') + raise ValueError( + 'Query engine not initialized, please add documents and set LLM first' + ) try: response = self.query_engine.query(query) @@ -298,10 +313,10 @@ async def load_index(self, persist_dir: Optional[str] = None): load_dir = persist_dir or self.storage_dir if not os.path.exists(load_dir): - raise FileNotFoundError(f'Index directory does not exist: {load_dir}') - - from llama_index.core import StorageContext, load_index_from_storage + raise FileNotFoundError( + f'Index directory does not exist: {load_dir}') + from llama_index.core import (StorageContext, load_index_from_storage) storage_context = StorageContext.from_defaults(persist_dir=load_dir) self.index = load_index_from_storage(storage_context) @@ -321,7 +336,7 @@ def get_index_info(self) -> dict: 'retrieve_only': self.retrieve_only, 'chunk_size': self.chunk_size, 'chunk_overlap': self.chunk_overlap, - 'embedding_model': self.embedding_model, + 'embedding_model': self.embedding_model } async def remove_all_documents(self): diff --git a/ms_agent/rag/schema.py b/ms_agent/rag/schema.py index f6f997153..9142d4945 100644 --- a/ms_agent/rag/schema.py +++ b/ms_agent/rag/schema.py @@ -18,7 +18,6 @@ class KeyInformation: including images, tables, or other relevant data. [{'id': 'doc_file_name@binary_hash@self_ref', 'content': PILImage.Image}, ...] """ - text: str resources: List[Dict[str, Any]] diff --git a/ms_agent/retriever/hybrid_retriever.py b/ms_agent/retriever/hybrid_retriever.py index 283366652..e84bc8398 100644 --- a/ms_agent/retriever/hybrid_retriever.py +++ b/ms_agent/retriever/hybrid_retriever.py @@ -6,7 +6,6 @@ import faiss import numpy as np - from ms_agent.utils.tokenizer_util import TokenizerUtil os.environ['OMP_NUM_THREADS'] = '1' @@ -19,7 +18,10 @@ class BM25Retriever: Sparse retriever based on BM25 algorithm. """ - def __init__(self, tokenized_corpus: List[List[str]], k1: float = 1.5, b: float = 0.75): + def __init__(self, + tokenized_corpus: List[List[str]], + k1: float = 1.5, + b: float = 0.75): self.k1 = k1 self.b = b self.corpus_size = len(tokenized_corpus) @@ -48,7 +50,8 @@ def _initialize(self, tokenized_corpus: List[List[str]]): self.avgdl = total_length / doc_count if doc_count > 0 else 0 for word, freq in self.idf.items(): - self.idf[word] = math.log((self.corpus_size - freq + 0.5) / (freq + 0.5) + 1) + self.idf[word] = math.log((self.corpus_size - freq + 0.5) + / (freq + 0.5) + 1) for doc_tokens in tokenized_corpus: freqs = {} @@ -65,11 +68,11 @@ def get_scores(self, tokenized_query: List[str]) -> List[float]: idf_score = self.idf[token] for index, doc_freqs in enumerate(self.doc_term_freqs): freq = doc_freqs.get(token, 0) - if freq == 0: - continue # noqa: E701 + if freq == 0: continue # noqa: E701 doc_len = self.doc_len[index] numerator = freq * (self.k1 + 1) - denominator = freq + self.k1 * (1 - self.b + self.b * (doc_len / self.avgdl)) # noqa: W504 + denominator = freq + self.k1 * ( + 1 - self.b + self.b * (doc_len / self.avgdl)) # noqa: W504 scores[index] += idf_score * (numerator / denominator) return scores @@ -80,13 +83,13 @@ class HybridRetriever: """ def __init__( - self, - corpus: List[str] = None, - embed_model: str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', # noqa - tokenizer_model_id: str = 'Qwen/Qwen3-8B', - bm25_k1: float = 1.5, - bm25_b: float = 0.75, - ): + self, + corpus: List[str] = None, + embed_model: + str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', # noqa + tokenizer_model_id: str = 'Qwen/Qwen3-8B', + bm25_k1: float = 1.5, + bm25_b: float = 0.75): """ Initialize Hybrid Retriever with both Dense and Sparse indices. @@ -146,8 +149,10 @@ def __init__( # Initialize Dense Retriever (FAISS) embed_model_path: str = self._load_model( model_id=embed_model, - ignore_patterns=['openvino/*', 'onnx/*', 'pytorch_model.bin', 'rust_model.ot', 'tf_model.h5'], - ) + ignore_patterns=[ + 'openvino/*', 'onnx/*', 'pytorch_model.bin', 'rust_model.ot', + 'tf_model.h5' + ]) from sentence_transformers import SentenceTransformer @@ -166,11 +171,15 @@ def _load_model(model_id: str, ignore_patterns: List[str] = None) -> str: from modelscope import snapshot_download try: - return snapshot_download(model_id=model_id, ignore_patterns=ignore_patterns) + return snapshot_download( + model_id=model_id, ignore_patterns=ignore_patterns) except Exception as e: raise RuntimeError(f'Failed to load model {model_id}: {e}') from e - def _init_corpus(self, corpus: List[str], bm25_k1: float = 1.5, bm25_b: float = 0.75): + def _init_corpus(self, + corpus: List[str], + bm25_k1: float = 1.5, + bm25_b: float = 0.75): """ Initialize corpus and build both Dense and Sparse indices. @@ -195,7 +204,9 @@ def _init_corpus(self, corpus: List[str], bm25_k1: float = 1.5, bm25_b: float = # Initialize Sparse Retriever (BM25) print('Building BM25 index...') - self.tokenized_corpus = [self.tokenizer_util.segment(doc) for doc in self.corpus] + self.tokenized_corpus = [ + self.tokenizer_util.segment(doc) for doc in self.corpus + ] self.bm25 = BM25Retriever( tokenized_corpus=self.tokenized_corpus, k1=bm25_k1, @@ -212,17 +223,17 @@ def _build_dense_index(self, texts: List[str]): faiss.normalize_L2(embeddings) self.index = faiss.IndexFlatIP(embeddings.shape[1]) self.index.add(embeddings) - print(f'Successfully indexed {len(texts)} documents for Dense Retrieval.') + print( + f'Successfully indexed {len(texts)} documents for Dense Retrieval.' + ) @staticmethod def _z_score_normalization(scores: List[float]) -> List[float]: """Apply Z-score normalization: z = (x - mean) / std.""" - if not scores: - return [] # noqa: E701 + if not scores: return [] # noqa: E701 arr = np.array(scores) std = np.std(arr) - if std == 0: - return [0.0] * len(scores) # noqa: E701 + if std == 0: return [0.0] * len(scores) # noqa: E701 mean = np.mean(arr) return ((arr - mean) / std).tolist() @@ -246,7 +257,9 @@ def _validate_corpus(self, corpus: List[str] = None): if corpus is not None and corpus != self.corpus: self._init_corpus(corpus=corpus) elif self.corpus is None: - raise ValueError('Corpus is empty. Please provide a valid corpus for searching.') + raise ValueError( + 'Corpus is empty. Please provide a valid corpus for searching.' + ) if self.index is None: raise ValueError('Index not built.') @@ -265,7 +278,11 @@ def _compute_dense_scores(self, query: str) -> List[float]: search_k: int = min(len(self.corpus), 500) dense_dists, dense_indices = self.index.search(x=query_vec, k=search_k) - dense_scores_map = {idx: float(score) for idx, score in zip(dense_indices[0], dense_dists[0]) if idx != -1} + dense_scores_map = { + idx: float(score) + for idx, score in zip(dense_indices[0], dense_dists[0]) + if idx != -1 + } return [dense_scores_map.get(i, 0.0) for i in range(len(self.corpus))] def _compute_sparse_scores(self, query: str) -> List[float]: @@ -377,7 +394,8 @@ def search( raw_bm25_scores = self._compute_sparse_scores(query) # Fuse and normalize scores - candidates = self._fuse_and_normalize_scores(raw_dense_scores, raw_bm25_scores, alpha) + candidates = self._fuse_and_normalize_scores(raw_dense_scores, + raw_bm25_scores, alpha) # Filter and rank results return self._filter_and_rank(candidates, top_k, min_score) @@ -414,10 +432,12 @@ async def async_search( dense_task = asyncio.to_thread(self._compute_dense_scores, query) sparse_task = asyncio.to_thread(self._compute_sparse_scores, query) - raw_dense_scores, raw_bm25_scores = await asyncio.gather(dense_task, sparse_task) + raw_dense_scores, raw_bm25_scores = await asyncio.gather( + dense_task, sparse_task) # Fuse and normalize scores - candidates = self._fuse_and_normalize_scores(raw_dense_scores, raw_bm25_scores, alpha) + candidates = self._fuse_and_normalize_scores(raw_dense_scores, + raw_bm25_scores, alpha) # Filter and rank results return self._filter_and_rank(candidates, top_k, min_score) diff --git a/ms_agent/sandbox/sandbox.py b/ms_agent/sandbox/sandbox.py index 5974781ae..8a5753761 100644 --- a/ms_agent/sandbox/sandbox.py +++ b/ms_agent/sandbox/sandbox.py @@ -9,7 +9,8 @@ class Sandbox: Base class for sandbox environments. """ - def __init__(self): ... + def __init__(self): + ... async def async_execute(self, *args, **kwargs): """ @@ -45,7 +46,7 @@ def __init__(self, **kwargs): super().__init__() self._init() - from ms_enclave.sandbox import DockerSandboxConfig, SandboxConfig + from ms_enclave.sandbox import SandboxConfig, DockerSandboxConfig # Mount host directories into the sandbox container if provided _volumes = kwargs.pop('volumes', None) or [] @@ -54,12 +55,19 @@ def __init__(self, **kwargs): for host_path, container_path, mode in _volumes: host_path = str(host_path) container_path = str(container_path) - self.volume_dict[host_path] = {'bind': container_path, 'mode': mode} + self.volume_dict[host_path] = { + 'bind': container_path, + 'mode': mode + } self.sandbox_config: SandboxConfig = DockerSandboxConfig( image=kwargs.pop('image', None) or 'python:3.11-slim', memory_limit=kwargs.pop('memory_limit', None) or '512m', - tools_config={'python_executor': {}, 'file_operation': {}, 'shell_executor': {}}, + tools_config={ + 'python_executor': {}, + 'file_operation': {}, + 'shell_executor': {} + }, volumes=self.volume_dict, ) @@ -73,16 +81,17 @@ def _init(): """ logger.info('Installing ms-enclave package...') try: - install_package(package_name='ms-enclave', import_name='ms_enclave', extend_module='docker') + install_package( + package_name='ms-enclave', + import_name='ms_enclave', + extend_module='docker') except Exception as e: raise e - async def async_execute( - self, - python_code: Union[str, List[str]] = None, - shell_command: Union[str, List[str]] = None, - requirements: List[str] = None, - ) -> Dict[str, Any]: + async def async_execute(self, + python_code: Union[str, List[str]] = None, + shell_command: Union[str, List[str]] = None, + requirements: List[str] = None) -> Dict[str, Any]: """ Asynchronously execute Python code and shell commands within the sandbox. @@ -123,15 +132,20 @@ async def async_execute( 'shell_executor': [], } - async with SandboxFactory.create_sandbox(SandboxType.DOCKER, self.sandbox_config) as sandbox: + async with SandboxFactory.create_sandbox( + SandboxType.DOCKER, self.sandbox_config) as sandbox: + if requirements: requirements_file = f'/{str(uuid.uuid4())}/requirements.txt' await sandbox.execute_tool( - 'file_operation', - {'operation': 'write', 'file_path': f'{requirements_file}', 'content': '\n'.join(requirements)}, - ) - - result_requirements = await sandbox.execute_command(f'pip install -r {requirements_file}') + 'file_operation', { + 'operation': 'write', + 'file_path': f'{requirements_file}', + 'content': '\n'.join(requirements) + }) + + result_requirements = await sandbox.execute_command( + f'pip install -r {requirements_file}') logger.info(result_requirements.stdout) if python_code: @@ -139,11 +153,17 @@ async def async_execute( python_code = [python_code] for py_item in python_code: - py_result = await sandbox.execute_tool('python_executor', {'code': py_item}) - - results['python_executor'].append( - {'output': py_result.output, 'error': py_result.error, 'status': py_result.status} - ) + py_result = await sandbox.execute_tool( + 'python_executor', {'code': py_item}) + + results['python_executor'].append({ + 'output': + py_result.output, + 'error': + py_result.error, + 'status': + py_result.status + }) if shell_command: if isinstance(shell_command, str): @@ -152,18 +172,21 @@ async def async_execute( for shell_item in shell_command: shell_result = await sandbox.execute_command(shell_item) - results['shell_executor'].append( - {'output': shell_result.stdout, 'error': shell_result.stderr, 'status': shell_result.status} - ) + results['shell_executor'].append({ + 'output': + shell_result.stdout, + 'error': + shell_result.stderr, + 'status': + shell_result.status + }) return results - def execute( - self, - python_code: Union[str, List[str]] = None, - shell_command: Union[str, List[str]] = None, - requirements: List[str] = None, - ) -> Dict[str, Any]: + def execute(self, + python_code: Union[str, List[str]] = None, + shell_command: Union[str, List[str]] = None, + requirements: List[str] = None) -> Dict[str, Any]: """ Synchronously execute Python code and shell commands within the sandbox. @@ -180,5 +203,7 @@ def execute( import asyncio return asyncio.run( - self.async_execute(python_code=python_code, shell_command=shell_command, requirements=requirements) - ) + self.async_execute( + python_code=python_code, + shell_command=shell_command, + requirements=requirements)) diff --git a/ms_agent/skill/auto_skills.py b/ms_agent/skill/auto_skills.py index 36b09eb9d..170c49c91 100644 --- a/ms_agent/skill/auto_skills.py +++ b/ms_agent/skill/auto_skills.py @@ -625,7 +625,7 @@ async def _execute_with_progressive_analysis( skill_id=skill_id, success=False, error= - f'Skill cannot handle query: {context.plan.reasoning if context.plan else 'No plan'}' + f'Skill cannot handle query: {context.plan.reasoning if context.plan else "No plan"}' ) if not commands: @@ -872,7 +872,7 @@ async def _execute_command_with_retry( additional_reqs = analysis.get('additional_requirements', []) logger.info( - f'[{skill_id}] Error analysis: type={error_info.get('error_type')}, ' + f'[{skill_id}] Error analysis: type={error_info.get("error_type")}, ' f'fixable={is_fixable}') # Apply fix if available @@ -1402,13 +1402,13 @@ def _filter_skills( else: logger.info( f'Removing skill [{sid}]: cannot execute - ' - f'{analysis.get('reason', '')[:200]}' + f'{analysis.get("reason", "")[:200]}' ) filtered_ids = final_ids logger.info( f'Filter ({mode}): {len(skill_ids)} -> {len(filtered_ids)} skills. ' - f'Reason: {parsed.get('reasoning', '')[:1000]}' + f'Reason: {parsed.get("reasoning", "")[:1000]}' ) return set(filtered_ids) diff --git a/ms_agent/skill/container.py b/ms_agent/skill/container.py index d9806097e..51d96f6f3 100644 --- a/ms_agent/skill/container.py +++ b/ms_agent/skill/container.py @@ -10,7 +10,6 @@ - use_sandbox=True: Execute in Docker sandbox (default, recommended for untrusted code) - use_sandbox=False: Execute locally with security checks (for trusted code or no Docker) """ - import asyncio import os import platform @@ -55,7 +54,6 @@ class ExecutorType(Enum): """Supported executor types for skill execution.""" - PYTHON_SCRIPT = 'python_script' PYTHON_CODE = 'python_code' PYTHON_FUNCTION = 'python_function' @@ -65,7 +63,6 @@ class ExecutorType(Enum): class ExecutionStatus(Enum): """Execution status codes.""" - PENDING = 'pending' RUNNING = 'running' SUCCESS = 'success' @@ -89,7 +86,6 @@ class ExecutionInput: working_dir: Working directory for execution. requirements: Python packages to install before execution. """ - args: List[Any] = field(default_factory=list) kwargs: Dict[str, Any] = field(default_factory=dict) env_vars: Dict[str, str] = field(default_factory=dict) @@ -103,7 +99,8 @@ def to_dict(self) -> Dict[str, Any]: 'args': self.args, 'kwargs': self.kwargs, 'env_vars': self.env_vars, - 'input_files': {k: str(v) for k, v in self.input_files.items()}, + 'input_files': {k: str(v) + for k, v in self.input_files.items()}, 'stdin': self.stdin, 'working_dir': str(self.working_dir) if self.working_dir else None, 'requirements': self.requirements, @@ -124,7 +121,6 @@ class ExecutionOutput: artifacts: Any generated artifacts (data, objects, etc.). duration_ms: Execution duration in milliseconds. """ - return_value: Any = None stdout: str = '' stderr: str = '' @@ -135,11 +131,13 @@ class ExecutionOutput: def to_dict(self) -> Dict[str, Any]: return { - 'return_value': str(self.return_value) if self.return_value else None, + 'return_value': + str(self.return_value) if self.return_value else None, 'stdout': self.stdout, 'stderr': self.stderr, 'exit_code': self.exit_code, - 'output_files': {k: str(v) for k, v in self.output_files.items()}, + 'output_files': {k: str(v) + for k, v in self.output_files.items()}, 'artifacts': list(self.artifacts.keys()), 'duration_ms': self.duration_ms, } @@ -164,7 +162,6 @@ class ExecutionRecord: error_message: Error message if failed. sandbox_used: Whether sandbox was used for execution. """ - execution_id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) skill_id: str = '' executor_type: ExecutorType = ExecutorType.PYTHON_SCRIPT @@ -212,7 +209,8 @@ def to_markdown(self) -> str: for name, path in self.input_spec.input_files.items(): lines.append(f' - `{name}`: `{path}`') if self.input_spec.requirements: - lines.append(f'- **Requirements**: `{self.input_spec.requirements}`') + lines.append( + f'- **Requirements**: `{self.input_spec.requirements}`') # Output section lines.extend(['', '#### Output', '']) @@ -230,7 +228,8 @@ def to_markdown(self) -> str: lines.append(f' - `{name}`: `{path}`') if self.error_message: - lines.extend(['', '#### Error', '', f'```\n{self.error_message}\n```']) + lines.extend( + ['', '#### Error', '', f'```\n{self.error_message}\n```']) lines.append('') return '\n'.join(lines) @@ -249,7 +248,6 @@ class ExecutionSpec: created_at: Creation timestamp. upstream_outputs: Outputs from upstream skills available as inputs. """ - spec_id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) title: str = 'Skill Execution Spec' description: str = '' @@ -287,25 +285,26 @@ def to_markdown(self) -> str: # Summary total = len(self.records) - success = sum(1 for r in self.records if r.status == ExecutionStatus.SUCCESS) - failed = sum(1 for r in self.records if r.status == ExecutionStatus.FAILED) - blocked = sum(1 for r in self.records if r.status == ExecutionStatus.SECURITY_BLOCKED) - - lines.extend( - [ - '## Summary', - '', - f'- **Total Executions**: {total}', - f'- **Successful**: {success}', - f'- **Failed**: {failed}', - f'- **Security Blocked**: {blocked}', - '', - '---', - '', - '## Execution Records', - '', - ] - ) + success = sum(1 for r in self.records + if r.status == ExecutionStatus.SUCCESS) + failed = sum(1 for r in self.records + if r.status == ExecutionStatus.FAILED) + blocked = sum(1 for r in self.records + if r.status == ExecutionStatus.SECURITY_BLOCKED) + + lines.extend([ + '## Summary', + '', + f'- **Total Executions**: {total}', + f'- **Successful**: {success}', + f'- **Failed**: {failed}', + f'- **Security Blocked**: {blocked}', + '', + '---', + '', + '## Execution Records', + '', + ]) for record in self.records: lines.append(record.to_markdown()) @@ -343,16 +342,14 @@ class SkillContainer: SANDBOX_OUTPUT_DIR = '/sandbox/outputs' SANDBOX_WORK_DIR = '/sandbox/scripts' - def __init__( - self, - workspace_dir: Optional[Union[str, Path]] = None, - timeout: int = 300, - image: str = 'python:3.11-slim', - memory_limit: str = '512m', - enable_security_check: bool = True, - network_enabled: bool = False, - use_sandbox: bool = True, - ): + def __init__(self, + workspace_dir: Optional[Union[str, Path]] = None, + timeout: int = 300, + image: str = 'python:3.11-slim', + memory_limit: str = '512m', + enable_security_check: bool = True, + network_enabled: bool = False, + use_sandbox: bool = True): """ Initialize the skill container. @@ -369,7 +366,8 @@ def __init__( if workspace_dir: self.workspace_dir = Path(workspace_dir).resolve() else: - self.workspace_dir = Path(tempfile.mkdtemp(prefix='skill_container_')).resolve() + self.workspace_dir = Path( + tempfile.mkdtemp(prefix='skill_container_')).resolve() self.workspace_dir.mkdir(parents=True, exist_ok=True) self.timeout = timeout @@ -399,12 +397,10 @@ def __init__( logger.warning( 'SkillContainer running in LOCAL mode (use_sandbox=False). ' 'Scripts will execute directly on this machine. ' - 'Ensure you trust the code being executed!' - ) + 'Ensure you trust the code being executed!') - logger.info( - f'SkillContainer initialized at: {self.workspace_dir} [mode: {"sandbox" if self.use_sandbox else "local"}]' - ) + logger.info(f'SkillContainer initialized at: {self.workspace_dir} ' + f'[mode: {"sandbox" if self.use_sandbox else "local"}]') def _get_sandbox(self): """ @@ -427,7 +423,8 @@ def _get_sandbox(self): for skill_id, skill_dir in self._skill_dirs.items(): safe_id = skill_id.replace('@', '_').replace('/', '_') sandbox_path = f'{self.SANDBOX_ROOT}/skills/{safe_id}' - volumes.append((str(Path(skill_dir).resolve()), sandbox_path, 'ro')) + volumes.append( + (str(Path(skill_dir).resolve()), sandbox_path, 'ro')) self._sandbox = EnclaveSandbox( image=self.image, @@ -436,7 +433,8 @@ def _get_sandbox(self): ) return self._sandbox - def mount_skill_directory(self, skill_id: str, skill_dir: Union[str, Path]): + def mount_skill_directory(self, skill_id: str, skill_dir: Union[str, + Path]): """ Mount a skill directory for sandbox access. @@ -461,7 +459,9 @@ def get_skill_sandbox_path(self, skill_id: str) -> str: safe_id = skill_id.replace('@', '_').replace('/', '_') return f'{self.SANDBOX_ROOT}/skills/{safe_id}' - def _security_check(self, code: str, is_local: bool = False) -> tuple[bool, str]: + def _security_check(self, + code: str, + is_local: bool = False) -> tuple[bool, str]: """ Check code for potentially dangerous patterns. @@ -523,15 +523,13 @@ def _collect_output_files(self) -> Dict[str, Path]: outputs[f.name] = f return outputs - def _create_record( - self, - skill_id: str, - executor_type: ExecutorType, - input_spec: ExecutionInput, - script_path: str = None, - function_name: str = None, - sandbox_used: bool = None, - ) -> ExecutionRecord: + def _create_record(self, + skill_id: str, + executor_type: ExecutorType, + input_spec: ExecutionInput, + script_path: str = None, + function_name: str = None, + sandbox_used: bool = None) -> ExecutionRecord: """Create a new execution record.""" return ExecutionRecord( skill_id=skill_id, @@ -540,16 +538,18 @@ def _create_record( function_name=function_name, input_spec=input_spec, status=ExecutionStatus.PENDING, - sandbox_used=sandbox_used if sandbox_used is not None else self.use_sandbox, - ) + sandbox_used=sandbox_used + if sandbox_used is not None else self.use_sandbox) # ------------------------------------------------------------------------- # Local Execution Helpers (for use_sandbox=False mode) # ------------------------------------------------------------------------- - def _local_run_subprocess( - self, cmd: List[str], env: Dict[str, str] = None, cwd: Path = None, stdin_input: str = None - ) -> tuple[str, str, int]: + def _local_run_subprocess(self, + cmd: List[str], + env: Dict[str, str] = None, + cwd: Path = None, + stdin_input: str = None) -> tuple[str, str, int]: """ Run subprocess locally with security restrictions. @@ -607,7 +607,8 @@ def _get_node_executable(self) -> str: return 'node.exe' return 'node' - async def _local_install_requirements(self, requirements: List[str]) -> tuple[bool, str]: + async def _local_install_requirements( + self, requirements: List[str]) -> tuple[bool, str]: """ Install Python requirements locally using pip. @@ -622,12 +623,8 @@ async def _local_install_requirements(self, requirements: List[str]) -> tuple[bo try: cmd = [ - self._get_python_executable(), - '-m', - 'pip', - 'install', - '--quiet', - '--disable-pip-version-check', + self._get_python_executable(), '-m', 'pip', 'install', + '--quiet', '--disable-pip-version-check' ] + requirements stdout, stderr, exit_code = self._local_run_subprocess(cmd) @@ -642,7 +639,9 @@ async def _local_install_requirements(self, requirements: List[str]) -> tuple[bo logger.error(f'Error installing requirements: {e}') return False, str(e) - async def _local_execute_python_code(self, code: str, input_spec: ExecutionInput) -> tuple[str, str, int]: + async def _local_execute_python_code( + self, code: str, + input_spec: ExecutionInput) -> tuple[str, str, int]: """ Execute Python code locally. @@ -655,7 +654,8 @@ async def _local_execute_python_code(self, code: str, input_spec: ExecutionInput """ # Install requirements first if any if input_spec.requirements: - success, error = await self._local_install_requirements(input_spec.requirements) + success, error = await self._local_install_requirements( + input_spec.requirements) if not success: return '', f'Failed to install requirements: {error}', -1 @@ -677,8 +677,10 @@ async def _local_execute_python_code(self, code: str, input_spec: ExecutionInput cwd = input_spec.working_dir if input_spec.working_dir else None stdout, stderr, exit_code = self._local_run_subprocess( - cmd, env=input_spec.env_vars, cwd=cwd, stdin_input=input_spec.stdin - ) + cmd, + env=input_spec.env_vars, + cwd=cwd, + stdin_input=input_spec.stdin) # Keep script in scripts folder for logging/debugging return stdout, stderr, exit_code @@ -686,7 +688,9 @@ async def _local_execute_python_code(self, code: str, input_spec: ExecutionInput logger.error(f'Local Python execution failed: {e}') raise - async def _local_execute_shell(self, command: str, input_spec: ExecutionInput) -> tuple[str, str, int]: + async def _local_execute_shell( + self, command: str, + input_spec: ExecutionInput) -> tuple[str, str, int]: """ Execute shell command locally. @@ -703,20 +707,30 @@ async def _local_execute_shell(self, command: str, input_spec: ExecutionInput) - if platform.system() == 'Windows': # Windows: use set for environment env_cmds = [f'set {k}={v}' for k, v in input_spec.env_vars.items()] - full_cmd = ' && '.join(env_cmds + [command]) if env_cmds else command + full_cmd = ' && '.join(env_cmds + + [command]) if env_cmds else command cmd = shell_exec + [full_cmd] else: # Unix: use export - env_cmds = [f"export {k}='{v}'" for k, v in input_spec.env_vars.items()] - full_cmd = ' && '.join(env_cmds + [command]) if env_cmds else command + env_cmds = [ + f"export {k}='{v}'" for k, v in input_spec.env_vars.items() + ] + full_cmd = ' && '.join(env_cmds + + [command]) if env_cmds else command cmd = shell_exec + [full_cmd] # Use working_dir from input_spec for proper resource access cwd = input_spec.working_dir if input_spec.working_dir else None - return self._local_run_subprocess(cmd, env=input_spec.env_vars, cwd=cwd, stdin_input=input_spec.stdin) + return self._local_run_subprocess( + cmd, + env=input_spec.env_vars, + cwd=cwd, + stdin_input=input_spec.stdin) - async def _local_execute_javascript(self, js_code: str, input_spec: ExecutionInput) -> tuple[str, str, int]: + async def _local_execute_javascript( + self, js_code: str, + input_spec: ExecutionInput) -> tuple[str, str, int]: """ Execute JavaScript code locally via Node.js. @@ -745,7 +759,11 @@ async def _local_execute_javascript(self, js_code: str, input_spec: ExecutionInp cwd = input_spec.working_dir if input_spec.working_dir else None # Keep script in scripts folder for logging/debugging - return self._local_run_subprocess(cmd, env=input_spec.env_vars, cwd=cwd, stdin_input=input_spec.stdin) + return self._local_run_subprocess( + cmd, + env=input_spec.env_vars, + cwd=cwd, + stdin_input=input_spec.stdin) except Exception as e: logger.error(f'Local JavaScript execution failed: {e}') raise @@ -772,18 +790,16 @@ def _generate_local_env_setup(self, input_spec: ExecutionInput) -> str: # Add working directory to sys.path for imports and change to it if input_spec.working_dir: work_dir = str(input_spec.working_dir) - lines.extend( - [ - '', - '# Setup working directory for resource access (READ-ONLY for resources)', - f'_skill_dir = {repr(work_dir)}', - "os.environ['SKILL_DIR'] = _skill_dir", - 'SKILL_DIR = _skill_dir', - 'if _skill_dir not in sys.path:', - ' sys.path.insert(0, _skill_dir)', - 'os.chdir(_skill_dir)', - ] - ) + lines.extend([ + '', + '# Setup working directory for resource access (READ-ONLY for resources)', + f'_skill_dir = {repr(work_dir)}', + "os.environ['SKILL_DIR'] = _skill_dir", + 'SKILL_DIR = _skill_dir', + 'if _skill_dir not in sys.path:', + ' sys.path.insert(0, _skill_dir)', + 'os.chdir(_skill_dir)', + ]) # Add custom env vars for key, value in input_spec.env_vars.items(): @@ -814,7 +830,8 @@ def _generate_local_js_env_setup(self, input_spec: ExecutionInput) -> str: lines.append('') return '\n'.join(lines) - def _parse_sandbox_result(self, results: Dict[str, Any]) -> tuple[str, str, int]: + def _parse_sandbox_result(self, + results: Dict[str, Any]) -> tuple[str, str, int]: """Parse sandbox execution results into stdout, stderr, exit_code.""" stdout_parts = [] stderr_parts = [] @@ -833,20 +850,22 @@ def _parse_sandbox_result(self, results: Dict[str, Any]) -> tuple[str, str, int] return '\n'.join(stdout_parts), '\n'.join(stderr_parts), exit_code async def _execute_in_sandbox( - self, - python_code: Union[str, List[str]] = None, - shell_command: Union[str, List[str]] = None, - requirements: List[str] = None, - ) -> Dict[str, Any]: + self, + python_code: Union[str, List[str]] = None, + shell_command: Union[str, List[str]] = None, + requirements: List[str] = None) -> Dict[str, Any]: """Execute code in EnclaveSandbox.""" sandbox = self._get_sandbox() return await sandbox.async_execute( - python_code=python_code, shell_command=shell_command, requirements=requirements - ) + python_code=python_code, + shell_command=shell_command, + requirements=requirements) async def execute_python_script( - self, script_path: Union[str, Path], skill_id: str = 'unknown', input_spec: ExecutionInput = None - ) -> ExecutionOutput: + self, + script_path: Union[str, Path], + skill_id: str = 'unknown', + input_spec: ExecutionInput = None) -> ExecutionOutput: """ Execute a Python script file. @@ -867,8 +886,7 @@ async def execute_python_script( skill_id=skill_id, executor_type=ExecutorType.PYTHON_SCRIPT, input_spec=input_spec, - script_path=str(script_path), - ) + script_path=str(script_path)) record.start_time = datetime.now() record.status = ExecutionStatus.RUNNING @@ -879,11 +897,13 @@ async def execute_python_script( code = f.read() # Security check (stricter for local mode) - is_safe, reason = self._security_check(code, is_local=not self.use_sandbox) + is_safe, reason = self._security_check( + code, is_local=not self.use_sandbox) if not is_safe: record.status = ExecutionStatus.SECURITY_BLOCKED record.error_message = reason - output = ExecutionOutput(stderr=f'Security check failed: {reason}', exit_code=-1) + output = ExecutionOutput( + stderr=f'Security check failed: {reason}', exit_code=-1) record.end_time = datetime.now() record.output_spec = output self.spec.add_record(record) @@ -896,11 +916,14 @@ async def execute_python_script( env_setup = self._generate_env_setup(input_spec, {}) full_code = env_setup + '\n' + code - results = await self._execute_in_sandbox(python_code=full_code, requirements=input_spec.requirements) + results = await self._execute_in_sandbox( + python_code=full_code, + requirements=input_spec.requirements) stdout, stderr, exit_code = self._parse_sandbox_result(results) else: # Local mode: execute directly - stdout, stderr, exit_code = await self._local_execute_python_code(code, input_spec) + stdout, stderr, exit_code = await self._local_execute_python_code( + code, input_spec) end_time = datetime.now() @@ -909,10 +932,11 @@ async def execute_python_script( stderr=stderr, exit_code=exit_code, output_files=self._collect_output_files(), - duration_ms=(end_time - start_time).total_seconds() * 1000, - ) + duration_ms=(end_time - start_time).total_seconds() * 1000) - record.status = ExecutionStatus.SUCCESS if exit_code == 0 else ExecutionStatus.FAILED + record.status = ( + ExecutionStatus.SUCCESS + if exit_code == 0 else ExecutionStatus.FAILED) except Exception as e: output = ExecutionOutput(stderr=str(e), exit_code=-1) @@ -926,8 +950,10 @@ async def execute_python_script( return output async def execute_python_code( - self, code: str, skill_id: str = 'unknown', input_spec: ExecutionInput = None - ) -> ExecutionOutput: + self, + code: str, + skill_id: str = 'unknown', + input_spec: ExecutionInput = None) -> ExecutionOutput: """ Execute Python code string. @@ -944,19 +970,23 @@ async def execute_python_code( input_spec = input_spec or ExecutionInput() record = self._create_record( - skill_id=skill_id, executor_type=ExecutorType.PYTHON_CODE, input_spec=input_spec, script_path='' - ) + skill_id=skill_id, + executor_type=ExecutorType.PYTHON_CODE, + input_spec=input_spec, + script_path='') record.start_time = datetime.now() record.status = ExecutionStatus.RUNNING try: # Security check (stricter for local mode) - is_safe, reason = self._security_check(code, is_local=not self.use_sandbox) + is_safe, reason = self._security_check( + code, is_local=not self.use_sandbox) if not is_safe: record.status = ExecutionStatus.SECURITY_BLOCKED record.error_message = reason - output = ExecutionOutput(stderr=f'Security check failed: {reason}', exit_code=-1) + output = ExecutionOutput( + stderr=f'Security check failed: {reason}', exit_code=-1) record.end_time = datetime.now() record.output_spec = output self.spec.add_record(record) @@ -969,11 +999,14 @@ async def execute_python_code( env_setup = self._generate_env_setup(input_spec, {}) full_code = env_setup + '\n' + code - results = await self._execute_in_sandbox(python_code=full_code, requirements=input_spec.requirements) + results = await self._execute_in_sandbox( + python_code=full_code, + requirements=input_spec.requirements) stdout, stderr, exit_code = self._parse_sandbox_result(results) else: # Local mode - stdout, stderr, exit_code = await self._local_execute_python_code(code, input_spec) + stdout, stderr, exit_code = await self._local_execute_python_code( + code, input_spec) end_time = datetime.now() @@ -982,10 +1015,11 @@ async def execute_python_code( stderr=stderr, exit_code=exit_code, output_files=self._collect_output_files(), - duration_ms=(end_time - start_time).total_seconds() * 1000, - ) + duration_ms=(end_time - start_time).total_seconds() * 1000) - record.status = ExecutionStatus.SUCCESS if exit_code == 0 else ExecutionStatus.FAILED + record.status = ( + ExecutionStatus.SUCCESS + if exit_code == 0 else ExecutionStatus.FAILED) except Exception as e: output = ExecutionOutput(stderr=str(e), exit_code=-1) @@ -998,7 +1032,8 @@ async def execute_python_code( self.spec.add_record(record) return output - def _generate_env_setup(self, input_spec: ExecutionInput, sandbox_files: Dict[str, str]) -> str: + def _generate_env_setup(self, input_spec: ExecutionInput, + sandbox_files: Dict[str, str]) -> str: """Generate Python code to setup environment variables and paths.""" sandbox_logs_dir = f'{self.SANDBOX_ROOT}/logs' lines = [ @@ -1036,8 +1071,10 @@ def _generate_env_setup(self, input_spec: ExecutionInput, sandbox_files: Dict[st return '\n'.join(lines) def execute_python_function( - self, func: Callable, skill_id: str = 'unknown', input_spec: ExecutionInput = None - ) -> ExecutionOutput: + self, + func: Callable, + skill_id: str = 'unknown', + input_spec: ExecutionInput = None) -> ExecutionOutput: """ Execute a Python function directly (local execution, not sandboxed). @@ -1058,8 +1095,7 @@ def execute_python_function( skill_id=skill_id, executor_type=ExecutorType.PYTHON_FUNCTION, input_spec=input_spec, - function_name=func.__name__, - ) + function_name=func.__name__) record.sandbox_used = False # Local execution record.start_time = datetime.now() @@ -1078,8 +1114,7 @@ def execute_python_function( return_value=return_value, exit_code=0, output_files=self._collect_output_files(), - duration_ms=(end_time - start_time).total_seconds() * 1000, - ) + duration_ms=(end_time - start_time).total_seconds() * 1000) record.status = ExecutionStatus.SUCCESS @@ -1095,8 +1130,10 @@ def execute_python_function( return output async def execute_shell( - self, command: Union[str, List[str]], skill_id: str = 'unknown', input_spec: ExecutionInput = None - ) -> ExecutionOutput: + self, + command: Union[str, List[str]], + skill_id: str = 'unknown', + input_spec: ExecutionInput = None) -> ExecutionOutput: """ Execute a shell command. @@ -1115,19 +1152,23 @@ async def execute_shell( cmd_str = command if isinstance(command, str) else ' && '.join(command) record = self._create_record( - skill_id=skill_id, executor_type=ExecutorType.SHELL, input_spec=input_spec, script_path=cmd_str[:200] - ) + skill_id=skill_id, + executor_type=ExecutorType.SHELL, + input_spec=input_spec, + script_path=cmd_str[:200]) record.start_time = datetime.now() record.status = ExecutionStatus.RUNNING try: # Security check (stricter for local mode) - is_safe, reason = self._security_check(cmd_str, is_local=not self.use_sandbox) + is_safe, reason = self._security_check( + cmd_str, is_local=not self.use_sandbox) if not is_safe: record.status = ExecutionStatus.SECURITY_BLOCKED record.error_message = reason - output = ExecutionOutput(stderr=f'Security check failed: {reason}', exit_code=-1) + output = ExecutionOutput( + stderr=f'Security check failed: {reason}', exit_code=-1) record.end_time = datetime.now() record.output_spec = output self.spec.add_record(record) @@ -1146,11 +1187,13 @@ async def execute_shell( full_cmd = ' && '.join(env_exports + [cmd_str]) - results = await self._execute_in_sandbox(shell_command=full_cmd) + results = await self._execute_in_sandbox(shell_command=full_cmd + ) stdout, stderr, exit_code = self._parse_sandbox_result(results) else: # Local mode - stdout, stderr, exit_code = await self._local_execute_shell(cmd_str, input_spec) + stdout, stderr, exit_code = await self._local_execute_shell( + cmd_str, input_spec) end_time = datetime.now() @@ -1159,10 +1202,11 @@ async def execute_shell( stderr=stderr, exit_code=exit_code, output_files=self._collect_output_files(), - duration_ms=(end_time - start_time).total_seconds() * 1000, - ) + duration_ms=(end_time - start_time).total_seconds() * 1000) - record.status = ExecutionStatus.SUCCESS if exit_code == 0 else ExecutionStatus.FAILED + record.status = ( + ExecutionStatus.SUCCESS + if exit_code == 0 else ExecutionStatus.FAILED) except Exception as e: output = ExecutionOutput(stderr=str(e), exit_code=-1) @@ -1175,14 +1219,12 @@ async def execute_shell( self.spec.add_record(record) return output - async def execute_javascript( - self, - script_path: Union[str, Path] = None, - code: str = None, - skill_id: str = 'unknown', - input_spec: ExecutionInput = None, - runtime: str = 'node', - ) -> ExecutionOutput: + async def execute_javascript(self, + script_path: Union[str, Path] = None, + code: str = None, + skill_id: str = 'unknown', + input_spec: ExecutionInput = None, + runtime: str = 'node') -> ExecutionOutput: """ Execute JavaScript code via Node.js. @@ -1204,8 +1246,7 @@ async def execute_javascript( skill_id=skill_id, executor_type=ExecutorType.JAVASCRIPT, input_spec=input_spec, - script_path=str(script_path) if script_path else '', - ) + script_path=str(script_path) if script_path else '') record.start_time = datetime.now() record.status = ExecutionStatus.RUNNING @@ -1221,11 +1262,13 @@ async def execute_javascript( raise ValueError('Either script_path or code must be provided') # Security check (stricter for local mode) - is_safe, reason = self._security_check(js_code, is_local=not self.use_sandbox) + is_safe, reason = self._security_check( + js_code, is_local=not self.use_sandbox) if not is_safe: record.status = ExecutionStatus.SECURITY_BLOCKED record.error_message = reason - output = ExecutionOutput(stderr=f'Security check failed: {reason}', exit_code=-1) + output = ExecutionOutput( + stderr=f'Security check failed: {reason}', exit_code=-1) record.end_time = datetime.now() record.output_spec = output self.spec.add_record(record) @@ -1250,11 +1293,13 @@ async def execute_javascript( args_str = ' '.join(f'"{arg}"' for arg in input_spec.args) shell_cmd = f'{runtime} {sandbox_js_path} {args_str}' - results = await self._execute_in_sandbox(shell_command=shell_cmd) + results = await self._execute_in_sandbox( + shell_command=shell_cmd) stdout, stderr, exit_code = self._parse_sandbox_result(results) else: # Local mode - stdout, stderr, exit_code = await self._local_execute_javascript(js_code, input_spec) + stdout, stderr, exit_code = await self._local_execute_javascript( + js_code, input_spec) end_time = datetime.now() @@ -1263,10 +1308,11 @@ async def execute_javascript( stderr=stderr, exit_code=exit_code, output_files=self._collect_output_files(), - duration_ms=(end_time - start_time).total_seconds() * 1000, - ) + duration_ms=(end_time - start_time).total_seconds() * 1000) - record.status = ExecutionStatus.SUCCESS if exit_code == 0 else ExecutionStatus.FAILED + record.status = ( + ExecutionStatus.SUCCESS + if exit_code == 0 else ExecutionStatus.FAILED) except Exception as e: output = ExecutionOutput(stderr=str(e), exit_code=-1) @@ -1279,7 +1325,8 @@ async def execute_javascript( self.spec.add_record(record) return output - def _generate_js_env_setup(self, input_spec: ExecutionInput, sandbox_files: Dict[str, str]) -> str: + def _generate_js_env_setup(self, input_spec: ExecutionInput, + sandbox_files: Dict[str, str]) -> str: """Generate JavaScript code to setup environment.""" lines = [ '// Environment setup', @@ -1293,17 +1340,15 @@ def _generate_js_env_setup(self, input_spec: ExecutionInput, sandbox_files: Dict lines.append('') return '\n'.join(lines) - async def execute( - self, - executor_type: ExecutorType, - skill_id: str = 'unknown', - script_path: Union[str, Path] = None, - func: Callable = None, - command: Union[str, List[str]] = None, - code: str = None, - input_spec: ExecutionInput = None, - **kwargs, - ) -> ExecutionOutput: + async def execute(self, + executor_type: ExecutorType, + skill_id: str = 'unknown', + script_path: Union[str, Path] = None, + func: Callable = None, + command: Union[str, List[str]] = None, + code: str = None, + input_spec: ExecutionInput = None, + **kwargs) -> ExecutionOutput: """ Unified async execution interface. @@ -1321,25 +1366,40 @@ async def execute( ExecutionOutput with results. """ if executor_type == ExecutorType.PYTHON_SCRIPT: - return await self.execute_python_script(script_path=script_path, skill_id=skill_id, input_spec=input_spec) + return await self.execute_python_script( + script_path=script_path, + skill_id=skill_id, + input_spec=input_spec) elif executor_type == ExecutorType.PYTHON_CODE: - return await self.execute_python_code(code=code, skill_id=skill_id, input_spec=input_spec) + return await self.execute_python_code( + code=code, skill_id=skill_id, input_spec=input_spec) elif executor_type == ExecutorType.PYTHON_FUNCTION: - return self.execute_python_function(func=func, skill_id=skill_id, input_spec=input_spec) + return self.execute_python_function( + func=func, skill_id=skill_id, input_spec=input_spec) elif executor_type == ExecutorType.SHELL: - return await self.execute_shell(command=command, skill_id=skill_id, input_spec=input_spec) + return await self.execute_shell( + command=command, skill_id=skill_id, input_spec=input_spec) elif executor_type == ExecutorType.JAVASCRIPT: return await self.execute_javascript( - script_path=script_path, code=code, skill_id=skill_id, input_spec=input_spec, **kwargs - ) + script_path=script_path, + code=code, + skill_id=skill_id, + input_spec=input_spec, + **kwargs) else: raise ValueError(f'Unsupported executor type: {executor_type}') - def execute_sync(self, executor_type: ExecutorType, skill_id: str = 'unknown', **kwargs) -> ExecutionOutput: + def execute_sync(self, + executor_type: ExecutorType, + skill_id: str = 'unknown', + **kwargs) -> ExecutionOutput: """Synchronous wrapper for execute().""" return asyncio.run(self.execute(executor_type, skill_id, **kwargs)) - def link_skills(self, upstream_skill_id: str, downstream_input_key: str, output_key: str = None) -> Optional[Any]: + def link_skills(self, + upstream_skill_id: str, + downstream_input_key: str, + output_key: str = None) -> Optional[Any]: """ Link output from upstream skill to downstream skill input. diff --git a/ms_agent/skill/loader.py b/ms_agent/skill/loader.py index 0aae6cd1a..1f5dca2a7 100644 --- a/ms_agent/skill/loader.py +++ b/ms_agent/skill/loader.py @@ -21,7 +21,9 @@ def __init__(self): self.loaded_skills: Dict[str, SkillSchema] = {} self.parser = SkillSchemaParser() - def load_skills(self, skills: Union[str, List[str], List[SkillSchema]]) -> Dict[str, SkillSchema]: + def load_skills( + self, skills: Union[str, List[str], List[SkillSchema]] + ) -> Dict[str, SkillSchema]: """ Load agent skills from various sources. @@ -40,27 +42,30 @@ def load_skills(self, skills: Union[str, List[str], List[SkillSchema]]) -> Dict[ return all_skills def is_skill_id(s: str) -> bool: - return '/' in s and len(s.split('/')) == 2 and all(s.split('/')) and not os.path.exists(s) + return '/' in s and len(s.split('/')) == 2 and all( + s.split('/')) and not os.path.exists(s) if isinstance(skills, str): # Could be a single skill path, root path of skills, or skill ID on ModelScope hub skill_list = [skills] - elif all(isinstance(s, str) for s in skills) or all(isinstance(s, SkillSchema) for s in skills): + elif all(isinstance(s, str) for s in skills) or all( + isinstance(s, SkillSchema) for s in skills): skill_list = skills else: raise ValueError('Invalid skills input type.') for skill in skill_list: + if is_skill_id(skill): from modelscope import snapshot_download - skill_path: str = snapshot_download(repo_id=skill) skill = skill_path if isinstance(skill, SkillSchema): skill_key = self._get_skill_key(skill=skill) all_skills[skill_key] = skill - logger.info(f'Loaded skill from SkillSchema object: {skill_key}') + logger.info( + f'Loaded skill from SkillSchema object: {skill_key}') continue skill_dir: Path = Path(skill) @@ -76,7 +81,8 @@ def is_skill_id(s: str) -> bool: all_skills[skill_key] = skill_schema # logger.info(f'Successfully loaded skill: {skill_key}') else: - skill_schema_dict: Dict[str, SkillSchema] = self._scan_and_load_skills(skill_dir) + skill_schema_dict: Dict[ + str, SkillSchema] = self._scan_and_load_skills(skill_dir) all_skills.update(skill_schema_dict) self.loaded_skills.update(all_skills) @@ -221,7 +227,9 @@ def reload_skill(self, skill_path: str) -> Optional[SkillSchema]: return skill -def load_skills(skills: Union[str, List[str], List[SkillSchema]]) -> Dict[str, SkillSchema]: +def load_skills( + skills: Union[str, List[str], + List[SkillSchema]]) -> Dict[str, SkillSchema]: """ Convenience function to load skills without creating a SkillLoader instance. diff --git a/ms_agent/skill/schema.py b/ms_agent/skill/schema.py index 312556885..722e0acc4 100644 --- a/ms_agent/skill/schema.py +++ b/ms_agent/skill/schema.py @@ -5,20 +5,19 @@ Defines the data structure and validation logic for Agent Skills. Each Skill is represented as a self-contained directory with metadata. """ - import re from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Union import yaml - from ms_agent.utils.logger import logger from .spec import Spec SUPPORTED_SCRIPT_EXT = ('.py', '.sh', '.js') -SUPPORTED_READ_EXT = ('.md', '.txt', '.py', '.json', '.yaml', '.yml', '.sh', '.js', '.html', '.xml') +SUPPORTED_READ_EXT = ('.md', '.txt', '.py', '.json', '.yaml', '.yml', '.sh', + '.js', '.html', '.xml') @dataclass @@ -32,7 +31,6 @@ class SkillFile: path: Relative path within Skill directory required: Whether this file is required """ - name: str type: str path: Path @@ -57,7 +55,12 @@ def to_dict(self): Returns: Dictionary containing file information """ - return {'name': self.name, 'type': self.type, 'path': str(self.path), 'required': self.required} + return { + 'name': self.name, + 'type': self.type, + 'path': str(self.path), + 'required': self.required + } @dataclass @@ -78,7 +81,6 @@ class SkillSchema: scripts: List of script files (optional) references: List of reference documents (optional) """ - skill_id: str name: str description: str @@ -139,7 +141,8 @@ def validate(self) -> bool: return True except Exception as e: - logger.error(f'Skill validation failed with an unexpected error: {e}') + logger.error( + f'Skill validation failed with an unexpected error: {e}') return False def get_file_by_name(self, name: str) -> Optional[SkillFile]: @@ -165,17 +168,32 @@ def to_dict(self) -> Dict[str, Any]: Dictionary containing all schema information """ return { - 'skill_id': self.skill_id, - 'name': self.name, - 'description': self.description, - 'version': self.version, - 'author': self.author, - 'tags': self.tags, - 'skill_path': str(self.skill_path), - 'files': [{'name': f.name, 'type': f.type, 'path': f.path, 'required': f.required} for f in self.files], - 'scripts': self.scripts, - 'references': self.references, - 'resources': self.resources, + 'skill_id': + self.skill_id, + 'name': + self.name, + 'description': + self.description, + 'version': + self.version, + 'author': + self.author, + 'tags': + self.tags, + 'skill_path': + str(self.skill_path), + 'files': [{ + 'name': f.name, + 'type': f.type, + 'path': f.path, + 'required': f.required + } for f in self.files], + 'scripts': + self.scripts, + 'references': + self.references, + 'resources': + self.resources, } @@ -217,7 +235,10 @@ def is_ignored_path(p: Path) -> bool: Returns: True if path should be ignored, False otherwise """ - ignored_names = {'.DS_Store', '__pycache__', '.git', '.gitignore', '.pytest_cache', '.mypy_cache'} + ignored_names = { + '.DS_Store', '__pycache__', '.git', '.gitignore', '.pytest_cache', + '.mypy_cache' + } ignored_suffixes = {'.pyc', '.pyo'} return (p.name in ignored_names) or (p.suffix in ignored_suffixes) @@ -266,14 +287,17 @@ def parse_skill_directory(directory_path: Path) -> Optional[SkillSchema]: file_type = file_path.suffix if file_path.suffix else '.unknown' skill_file = SkillFile( - name=file_path.name, type=file_type, path=file_path, required=(file_path.name == 'SKILL.md') - ) + name=file_path.name, + type=file_type, + path=file_path, + required=(file_path.name == 'SKILL.md')) files.append(skill_file) # Get scripts, references and resources if skill_file.type in SUPPORTED_SCRIPT_EXT: scripts.append(skill_file) - elif skill_file.type in ['.md'] and skill_file.name != 'SKILL.md': + elif skill_file.type in ['.md' + ] and skill_file.name != 'SKILL.md': references.append(skill_file) else: resources.append(skill_file) @@ -346,7 +370,6 @@ class SkillExecutionPlan: parameters: Parameters extracted from user query. reasoning: Explanation of the plan. """ - can_handle: bool = False plan_summary: str = '' steps: List[Dict[str, Any]] = field(default_factory=list) @@ -373,7 +396,8 @@ class SkillContext: query: str = '' # The working directory (absolute path to skills folder's parent directory) - root_path: Path = field(default_factory=lambda: Path.cwd().parent.resolve()) + root_path: Path = field( + default_factory=lambda: Path.cwd().parent.resolve()) # Execution plan from progressive analysis plan: Optional[SkillExecutionPlan] = None @@ -440,7 +464,10 @@ def get_references_list(self) -> List[str]: def get_resources_list(self) -> List[str]: """Get list of available resource names without loading content.""" - return [r.name for r in self.skill.resources if r.name not in ['SKILL.md', 'LICENSE.txt']] + return [ + r.name for r in self.skill.resources + if r.name not in ['SKILL.md', 'LICENSE.txt'] + ] def _get_resource_path(self, file_path: Path) -> str: """ @@ -478,15 +505,13 @@ def load_scripts(self, names: List[str] = None) -> List[Dict[str, Any]]: loaded = [] for script in target_scripts: abs_path = script.path.resolve() - loaded.append( - { - 'name': script.name, - 'file': script.to_dict(), - 'path': self._get_resource_path(script.path), - 'abs_path': str(abs_path), - 'content': self._read_file_content(abs_path), - } - ) + loaded.append({ + 'name': script.name, + 'file': script.to_dict(), + 'path': self._get_resource_path(script.path), + 'abs_path': str(abs_path), + 'content': self._read_file_content(abs_path), + }) self.scripts.extend(loaded) return loaded @@ -507,15 +532,13 @@ def load_references(self, names: List[str] = None) -> List[Dict[str, Any]]: loaded = [] for ref in target_refs: abs_path = ref.path.resolve() - loaded.append( - { - 'name': ref.name, - 'file': ref.to_dict(), - 'path': self._get_resource_path(ref.path), - 'abs_path': str(abs_path), - 'content': self._read_file_content(abs_path), - } - ) + loaded.append({ + 'name': ref.name, + 'file': ref.to_dict(), + 'path': self._get_resource_path(ref.path), + 'abs_path': str(abs_path), + 'content': self._read_file_content(abs_path), + }) self.references.extend(loaded) return loaded @@ -529,22 +552,23 @@ def load_resources(self, names: List[str] = None) -> List[Dict[str, Any]]: Returns: List of loaded resource dictionaries with content. """ - target_res = [r for r in self.skill.resources if r.name not in ['SKILL.md', 'LICENSE.txt']] + target_res = [ + r for r in self.skill.resources + if r.name not in ['SKILL.md', 'LICENSE.txt'] + ] if names: target_res = [r for r in target_res if r.name in names] loaded = [] for res in target_res: abs_path = res.path.resolve() - loaded.append( - { - 'name': res.name, - 'file': res.to_dict(), - 'path': self._get_resource_path(res.path), - 'abs_path': str(abs_path), - 'content': self._read_file_content(abs_path), - } - ) + loaded.append({ + 'name': res.name, + 'file': res.to_dict(), + 'path': self._get_resource_path(res.path), + 'abs_path': str(abs_path), + 'content': self._read_file_content(abs_path), + }) self.resources.extend(loaded) return loaded diff --git a/ms_agent/skill/spec.py b/ms_agent/skill/spec.py index 676ed219c..3c666b8d4 100644 --- a/ms_agent/skill/spec.py +++ b/ms_agent/skill/spec.py @@ -18,6 +18,7 @@ class Spec: implementation: str = '' def __post_init__(self): + if not self.plan: self.plan = DEFAULT_PLAN @@ -40,13 +41,20 @@ def dump(self, output_dir: str) -> str: output_path: str = os.path.join(output_dir, '.spec') os.makedirs(output_path, exist_ok=True) - with open(os.path.join(output_path, 'plan.md'), 'w', encoding='utf-8') as f: + with open( + os.path.join(output_path, 'plan.md'), 'w', + encoding='utf-8') as f: f.write(self.plan) - with open(os.path.join(output_path, 'tasks.md'), 'w', encoding='utf-8') as f: + with open( + os.path.join(output_path, 'tasks.md'), 'w', + encoding='utf-8') as f: f.write(self.tasks) - with open(os.path.join(output_path, 'implementation.md'), 'w', encoding='utf-8') as f: + with open( + os.path.join(output_path, 'implementation.md'), + 'w', + encoding='utf-8') as f: f.write(self.implementation) return output_path diff --git a/ms_agent/tools/agent_tool.py b/ms_agent/tools/agent_tool.py index 1b4e2c38a..02a57b8d5 100644 --- a/ms_agent/tools/agent_tool.py +++ b/ms_agent/tools/agent_tool.py @@ -1,6 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio -import json import multiprocessing as mp import os import threading @@ -12,14 +11,16 @@ from queue import Full as QueueFull from typing import Any, Callable, Dict, List, Optional, Union -from omegaconf import DictConfig, ListConfig, OmegaConf - +import json from ms_agent.agent.loader import AgentLoader from ms_agent.llm.utils import Message, Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger -from ms_agent.utils.stats import append_stats, build_timing_record, get_stats_path, monotonic, now_iso, summarize_usage +from ms_agent.utils.stats import (append_stats, build_timing_record, + get_stats_path, monotonic, now_iso, + summarize_usage) from ms_agent.utils.stream_writer import SubAgentStreamWriter +from omegaconf import DictConfig, ListConfig, OmegaConf logger = get_logger() @@ -75,7 +76,8 @@ def _message_from_data(data: Any) -> Message: def _build_sub_agent(spec: _AgentToolSpec, default_trust_remote_code: bool): if spec.inline_config is not None: container = _to_container(spec.inline_config) - base_override = OmegaConf.create(container) if isinstance(container, dict) else OmegaConf.create({}) + base_override = OmegaConf.create(container) if isinstance( + container, dict) else OmegaConf.create({}) else: base_override = OmegaConf.create({}) # Sub-agents default snapshots off in LLMAgent unless enable_snapshots is set @@ -140,11 +142,18 @@ async def _runner(): history = chunk if stream_events and event_queue is not None: serialized_chunk = { - 'kind': 'messages', - 'messages': [_message_from_data(msg).to_dict() for msg in (history or [])], + 'kind': + 'messages', + 'messages': [ + _message_from_data(msg).to_dict() + for msg in (history or []) + ], } try: - event_queue.put_nowait({'type': 'chunk', 'history': serialized_chunk}) + event_queue.put_nowait({ + 'type': 'chunk', + 'history': serialized_chunk + }) except QueueFull: # Avoid blocking sub-agent progress if UI/event consumer # is temporarily slower than chunk production. @@ -153,11 +162,16 @@ async def _runner(): result = history if isinstance(result, list): return { - 'kind': 'messages', - 'messages': [_message_from_data(msg).to_dict() for msg in result], - 'streamed_chunks': chunk_count, - 'agent_tag': getattr(sub_agent, 'tag', None), - 'agent_type': getattr(sub_agent, 'AGENT_NAME', None), + 'kind': + 'messages', + 'messages': + [_message_from_data(msg).to_dict() for msg in result], + 'streamed_chunks': + chunk_count, + 'agent_tag': + getattr(sub_agent, 'tag', None), + 'agent_type': + getattr(sub_agent, 'AGENT_NAME', None), } return { 'kind': 'raw', @@ -169,15 +183,13 @@ async def _runner(): result_queue.put({'ok': True, 'result': asyncio.run(_runner())}) except BaseException as exc: # pragma: no cover - result_queue.put( - { - 'ok': False, - 'error': str(exc), - 'traceback': traceback.format_exc(), - 'agent_tag': getattr(sub_agent, 'tag', None), - 'agent_type': getattr(sub_agent, 'AGENT_NAME', None), - } - ) + result_queue.put({ + 'ok': False, + 'error': str(exc), + 'traceback': traceback.format_exc(), + 'agent_tag': getattr(sub_agent, 'tag', None), + 'agent_type': getattr(sub_agent, 'AGENT_NAME', None), + }) class AgentTool(ToolBase): @@ -214,8 +226,7 @@ def enabled(self) -> bool: 'split a website generation task into sub tasks, ' 'you plan the framework, include code files and classes and functions, and give the detail ' 'information to the system and query field of the subtask, then ' - 'let each subtask to write a single file' - ) + 'let each subtask to write a single file') _SPLIT_TASK_PARAMETERS = { 'type': 'object', @@ -224,8 +235,7 @@ def enabled(self) -> bool: 'type': 'array', 'description': ( 'MANDATORY: Each element is a dict, which must contains two fields: ' - '`system`(str) and `query`(str) to start one sub task.' - ), + '`system`(str) and `query`(str) to start one sub task.'), }, 'execution_mode': { 'type': 'string', @@ -274,10 +284,13 @@ def _load_specs(self): self._build_server_index() return - if isinstance(agent_tools_cfg, DictConfig) and hasattr(agent_tools_cfg, 'definitions'): + if isinstance(agent_tools_cfg, DictConfig) and hasattr( + agent_tools_cfg, 'definitions'): definitions = agent_tools_cfg.definitions - server_name = getattr(agent_tools_cfg, 'server_name', self.DEFAULT_SERVER) - self._enable_stats = bool(getattr(agent_tools_cfg, 'enable_stats', False)) + server_name = getattr(agent_tools_cfg, 'server_name', + self.DEFAULT_SERVER) + self._enable_stats = bool( + getattr(agent_tools_cfg, 'enable_stats', False)) else: definitions = agent_tools_cfg server_name = self.DEFAULT_SERVER @@ -301,22 +314,25 @@ def _load_specs(self): continue if spec.tool_name in self._specs: logger.warning( - 'Duplicate agent tool name detected: %s, overriding previous definition.', spec.tool_name - ) + 'Duplicate agent tool name detected: %s, overriding previous definition.', + spec.tool_name) self._specs[spec.tool_name] = spec self._build_server_index() - def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], default_server, idx: int) -> Optional[_AgentToolSpec]: + def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], + default_server, idx: int) -> Optional[_AgentToolSpec]: cfg = cfg or {} cfg = cfg if isinstance(cfg, DictConfig) else DictConfig(cfg) - tool_name = getattr(cfg, 'tool_name', None) or getattr(cfg, 'name', None) + tool_name = getattr(cfg, 'tool_name', None) or getattr( + cfg, 'name', None) if not tool_name: - logger.warning('agent_tools[%s] missing tool_name/name field, skip.', idx) + logger.warning( + 'agent_tools[%s] missing tool_name/name field, skip.', idx) return None mode = getattr(cfg, 'mode', None) - is_dynamic = mode == 'dynamic' + is_dynamic = (mode == 'dynamic') agent_cfg = getattr(cfg, 'agent', None) config_path = getattr(cfg, 'config_path', None) @@ -324,13 +340,17 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], default_server, id if agent_cfg is not None: config_path = getattr(agent_cfg, 'config_path', config_path) inline_cfg = getattr(agent_cfg, 'config', inline_cfg) - inline_cfg = _to_container(inline_cfg) if inline_cfg is not None else None + inline_cfg = _to_container( + inline_cfg) if inline_cfg is not None else None if not is_dynamic and not config_path and inline_cfg is None: - logger.warning('agent_tools[%s] (%s) missing config_path/config definition.', idx, tool_name) + logger.warning( + 'agent_tools[%s] (%s) missing config_path/config definition.', + idx, tool_name) return None - description = getattr(cfg, 'description', f'Invoke agent "{tool_name}" as a tool.') + description = getattr(cfg, 'description', + f'Invoke agent "{tool_name}" as a tool.') parameters = getattr(cfg, 'parameters', None) if parameters is None: if is_dynamic: @@ -340,8 +360,10 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], default_server, id 'type': 'object', 'properties': { 'request': { - 'type': 'string', - 'description': f'Task description forwarded to the sub-agent {tool_name}.', + 'type': + 'string', + 'description': + f'Task description forwarded to the sub-agent {tool_name}.' }, }, 'required': ['request'], @@ -350,7 +372,9 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], default_server, id else: parameters = _to_container(parameters) - tag_prefix = getattr(cfg, 'tag_prefix', f'{getattr(self.config, "tag", "agent")}-{tool_name}-') + tag_prefix = getattr( + cfg, 'tag_prefix', + f'{getattr(self.config, "tag", "agent")}-{tool_name}-') request_field = getattr(cfg, 'request_field', 'request') input_template = getattr(cfg, 'input_template', None) @@ -376,7 +400,8 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], default_server, id if config_path and not os.path.isabs(config_path): base_dir = getattr(self.config, 'local_dir', None) if base_dir: - config_path = os.path.normpath(os.path.join(base_dir, config_path)) + config_path = os.path.normpath( + os.path.join(base_dir, config_path)) return _AgentToolSpec( tool_name=tool_name, @@ -411,8 +436,7 @@ def _build_server_index(self): server_name=spec.server_name, description=spec.description, parameters=spec.parameters, - ) - ) + )) self._server_tools = server_map async def connect(self): @@ -457,7 +481,8 @@ def _stream_file_enabled(self) -> bool: 2. ``config.agent_stream_file`` Defaults to ``False``. """ - agent_tools_cfg = getattr(getattr(self.config, 'tools', None), 'agent_tools', None) + agent_tools_cfg = getattr( + getattr(self.config, 'tools', None), 'agent_tools', None) if agent_tools_cfg is not None: val = getattr(agent_tools_cfg, 'enable_stream_file', None) if val is not None: @@ -470,7 +495,8 @@ def _stream_file_dir(self) -> str: Checks ``config.tools.agent_tools.stream_file_dir`` first, then falls back to ``config.output_dir``. """ - agent_tools_cfg = getattr(getattr(self.config, 'tools', None), 'agent_tools', None) + agent_tools_cfg = getattr( + getattr(self.config, 'tools', None), 'agent_tools', None) if agent_tools_cfg is not None: override = getattr(agent_tools_cfg, 'stream_file_dir', None) if override: @@ -483,19 +509,23 @@ def _stream_include_in_result(self) -> bool: Controlled by ``config.tools.agent_tools.stream_include_in_result`` (defaults to ``True`` when stream files are enabled). """ - agent_tools_cfg = getattr(getattr(self.config, 'tools', None), 'agent_tools', None) + agent_tools_cfg = getattr( + getattr(self.config, 'tools', None), 'agent_tools', None) if agent_tools_cfg is not None: val = getattr(agent_tools_cfg, 'stream_include_in_result', None) if val is not None: return bool(val) return True - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: if tool_name not in self._specs: raise ValueError(f'Agent tool "{tool_name}" not registered.') spec = self._specs[tool_name] if spec.server_name != server_name: - raise ValueError(f'Agent tool "{tool_name}" is not part of server "{server_name}".') + raise ValueError( + f'Agent tool "{tool_name}" is not part of server "{server_name}".' + ) call_id = None if isinstance(tool_args, dict) and '__call_id' in tool_args: @@ -515,7 +545,8 @@ async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) use_subprocess = spec.run_in_thread and spec.run_in_process if use_subprocess: - messages = await self._run_agent(None, payload, spec, call_id=effective_call_id) + messages = await self._run_agent( + None, payload, spec, call_id=effective_call_id) result_str = self._format_output(messages, spec) return self._maybe_append_stream_path(result_str, effective_call_id) @@ -551,7 +582,8 @@ def _maybe_append_stream_path(self, result_str: str, effective_call_id: str) -> def _build_agent(self, spec: _AgentToolSpec): return _build_sub_agent(spec, self._trust_remote_code) - async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, call_id: Optional[str]) -> Any: + async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, + call_id: Optional[str]) -> Any: """Run sub-agent inline (pure async/await). If spec.sync_timeout_s is set, the call auto-escapes to background after @@ -564,14 +596,17 @@ async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, call_id: escape_event = asyncio.Event() effective_call_id = call_id or uuid.uuid4().hex[:12] - run_task = asyncio.create_task(self._run_agent(None, payload, spec, call_id=effective_call_id)) + run_task = asyncio.create_task( + self._run_agent(None, payload, spec, call_id=effective_call_id)) - self._active_sync_tasks[effective_call_id] = (run_task, spec, payload, escape_event) + self._active_sync_tasks[effective_call_id] = (run_task, spec, payload, + escape_event) try: if spec.sync_timeout_s and spec.sync_timeout_s > 0: escape_wait_task = asyncio.create_task(escape_event.wait()) - sleep_task = asyncio.create_task(asyncio.sleep(spec.sync_timeout_s)) + sleep_task = asyncio.create_task( + asyncio.sleep(spec.sync_timeout_s)) _, pending = await asyncio.wait( [run_task, escape_wait_task, sleep_task], return_when=asyncio.FIRST_COMPLETED, @@ -579,7 +614,8 @@ async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, call_id: for t in pending: t.cancel() if not run_task.done(): - return await self._escape_running_task(effective_call_id, run_task, spec, payload) + return await self._escape_running_task( + effective_call_id, run_task, spec, payload) else: # No timeout: wait for completion or explicit escape signal. escape_task = asyncio.create_task(escape_event.wait()) @@ -590,20 +626,22 @@ async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, call_id: for t in pending: t.cancel() if not run_task.done(): - return await self._escape_running_task(effective_call_id, run_task, spec, payload) + return await self._escape_running_task( + effective_call_id, run_task, spec, payload) return run_task.result() finally: self._active_sync_tasks.pop(effective_call_id, None) - async def _escape_running_task( - self, call_id: str, run_task: 'asyncio.Task[Any]', spec: _AgentToolSpec, payload: Any - ) -> str: + async def _escape_running_task(self, call_id: str, + run_task: 'asyncio.Task[Any]', + spec: _AgentToolSpec, + payload: Any) -> str: """Cancel the in-progress sync task and re-launch it as a background subprocess.""" if self._task_manager is None: raise RuntimeError( - f'AgentTool "{spec.tool_name}" tried to escape to background but no TaskManager is attached.' - ) + f'AgentTool "{spec.tool_name}" tried to escape to background but ' + 'no TaskManager is attached.') run_task.cancel() try: @@ -634,21 +672,22 @@ def escape_to_background(self, call_id: str) -> bool: escape_event.set() return True - async def _launch_background(self, payload: Any, spec: _AgentToolSpec, call_id: Optional[str]) -> str: + async def _launch_background(self, payload: Any, spec: _AgentToolSpec, + call_id: Optional[str]) -> str: """Fire-and-forget: start subprocess, register with TaskManager, return immediately.""" if self._task_manager is None: raise RuntimeError( f'AgentTool "{spec.tool_name}" has run_in_background=true but ' 'no TaskManager is attached. Ensure LLMAgent wires task_manager ' - 'into AgentTool via set_task_manager().' - ) + 'into AgentTool via set_task_manager().') ctx = mp.get_context('spawn') result_queue = ctx.Queue(maxsize=1) process_payload = self._serialize_payload_for_process(payload) proc = ctx.Process( target=_run_agent_in_subprocess, - args=(spec, self._trust_remote_code, process_payload, False, None, result_queue), + args=(spec, self._trust_remote_code, process_payload, False, None, + result_queue), name=f'agent_tool_bg_{spec.tool_name}', ) proc.start() @@ -689,14 +728,11 @@ async def _watcher(): self._watcher_tasks.add(t) t.add_done_callback(self._watcher_tasks.discard) - return json.dumps( - { - 'status': 'async_launched', - 'task_id': task_id, - 'tool_name': spec.tool_name, - }, - ensure_ascii=False, - ) + return json.dumps({ + 'status': 'async_launched', + 'task_id': task_id, + 'tool_name': spec.tool_name, + }, ensure_ascii=False) async def _call_dynamic(self, tool_args: dict, spec: '_AgentToolSpec') -> str: tasks = tool_args.get('tasks', []) @@ -763,7 +799,7 @@ async def _run_one(i: int, task: dict) -> str: formatted = '' for i, content in enumerate(res_list): if len(content) > spec.max_subtask_output_chars: - content = content[: spec.max_subtask_output_chars] + content = content[:spec.max_subtask_output_chars] formatted += f'SubTask{i}:{content}\n' return formatted @@ -817,9 +853,11 @@ def _terminate_all_active_processes(self, *, reason: str) -> None: for _, proc in active: self._terminate_process(proc, reason=reason) - async def _wait_process_result( - self, proc: mp.Process, result_queue: Any, on_poll: Optional[Callable[[], None]] = None - ): + async def _wait_process_result(self, + proc: mp.Process, + result_queue: Any, + on_poll: Optional[Callable[[], + None]] = None): exited_at = None while True: if on_poll is not None: @@ -834,13 +872,16 @@ async def _wait_process_result( if not proc.is_alive(): if exited_at is None: exited_at = monotonic() - elif (monotonic() - exited_at) >= self._PROCESS_EXIT_RESULT_GRACE_S: + elif (monotonic() + - exited_at) >= self._PROCESS_EXIT_RESULT_GRACE_S: return None await asyncio.sleep(self._PROCESS_POLL_INTERVAL_S) @staticmethod - def _drain_process_event_queue(event_queue: Any, on_event: Callable[[Dict[str, Any]], None]) -> None: + def _drain_process_event_queue( + event_queue: Any, on_event: Callable[[Dict[str, Any]], + None]) -> None: if event_queue is None: return while True: @@ -864,7 +905,11 @@ def _restore_process_result(result_payload: Dict[str, Any]) -> Any: return [_message_from_data(msg) for msg in messages] return result_payload.get('raw', '') - async def _run_agent(self, agent, payload, spec: _AgentToolSpec, call_id: Optional[str] = None): + async def _run_agent(self, + agent, + payload, + spec: _AgentToolSpec, + call_id: Optional[str] = None): runtime_agent = agent runtime_agent_tag = getattr(runtime_agent, 'tag', None) runtime_agent_type = getattr(runtime_agent, 'AGENT_NAME', None) @@ -880,9 +925,7 @@ async def _run_agent(self, agent, payload, spec: _AgentToolSpec, call_id: Option ) logger.info( '[stream] %s (call_id=%s) streaming to %s', - spec.tool_name, - _effective_call_id, - _writer.stream_path, + spec.tool_name, _effective_call_id, _writer.stream_path, ) # ─────────────────────────────────────────────────────────────────── @@ -898,67 +941,53 @@ async def _run_and_collect(): result = await runtime_agent.run(payload) if hasattr(result, '__aiter__'): history = None - self._emit_chunk_event( - 'start', - { - 'call_id': call_id, - 'tool_name': spec.tool_name, - }, - ) + self._emit_chunk_event('start', { + 'call_id': call_id, + 'tool_name': spec.tool_name, + }) if _writer is not None: _writer.on_start(runtime_agent_tag) async for chunk in result: history = chunk self._emit_chunk_event( - 'chunk', - { + 'chunk', { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': chunk, - }, - ) + }) if _writer is not None: _writer.on_chunk(chunk) if history is not None: self._emit_chunk_event( - 'end', - { + 'end', { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': history, - }, - ) + }) if _writer is not None: _writer.on_end(history) result = history else: - self._emit_chunk_event( - 'start', - { - 'call_id': call_id, - 'tool_name': spec.tool_name, - }, - ) + self._emit_chunk_event('start', { + 'call_id': call_id, + 'tool_name': spec.tool_name, + }) if _writer is not None: _writer.on_start(runtime_agent_tag) self._emit_chunk_event( - 'chunk', - { + 'chunk', { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': result, - }, - ) + }) if _writer is not None: _writer.on_chunk(result) self._emit_chunk_event( - 'end', - { + 'end', { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': result, - }, - ) + }) if _writer is not None: _writer.on_end(result) return result @@ -981,45 +1010,48 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: history = self._restore_process_result(history_payload) if self._chunk_cb: self._emit_chunk_event( - 'chunk', - { + 'chunk', { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': history, - }, - ) + }) if _writer is not None: _writer.on_chunk(history) try: if self._chunk_cb: - self._emit_chunk_event( - 'start', - { - 'call_id': call_id, - 'tool_name': spec.tool_name, - }, - ) + self._emit_chunk_event('start', { + 'call_id': call_id, + 'tool_name': spec.tool_name, + }) if _writer is not None: # agent_tag unknown until subprocess completes; pass None _writer.on_start(None) process_payload = self._serialize_payload_for_process(payload) proc = ctx.Process( target=_run_agent_in_subprocess, - args=(spec, self._trust_remote_code, process_payload, need_events, event_queue, result_queue), + args=(spec, self._trust_remote_code, process_payload, + need_events, event_queue, result_queue), name=f'agent_tool_{spec.tool_name}', ) proc.start() self._register_process(run_id, proc) result = await self._wait_process_result( - proc, result_queue, on_poll=lambda: self._drain_process_event_queue(event_queue, _emit_stream_event) - ) + proc, + result_queue, + on_poll=lambda: self._drain_process_event_queue( + event_queue, _emit_stream_event)) if result is None: - raise RuntimeError(f'AgentTool subprocess exited without result: {spec.tool_name}') - self._drain_process_event_queue(event_queue, _emit_stream_event) + raise RuntimeError( + f'AgentTool subprocess exited without result: {spec.tool_name}' + ) + self._drain_process_event_queue(event_queue, + _emit_stream_event) if not result.get('ok'): - runtime_agent_tag = result.get('agent_tag') or runtime_agent_tag - runtime_agent_type = result.get('agent_type') or runtime_agent_type + runtime_agent_tag = result.get( + 'agent_tag') or runtime_agent_tag + runtime_agent_type = result.get( + 'agent_type') or runtime_agent_type tb = result.get('traceback', '') if tb: logger.warning(tb) @@ -1028,28 +1060,27 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: _writer.on_error(err_msg) raise RuntimeError(err_msg) result_payload = result.get('result', {}) or {} - runtime_agent_tag = result_payload.get('agent_tag') or runtime_agent_tag - runtime_agent_type = result_payload.get('agent_type') or runtime_agent_type + runtime_agent_tag = result_payload.get( + 'agent_tag') or runtime_agent_tag + runtime_agent_type = result_payload.get( + 'agent_type') or runtime_agent_type restored = self._restore_process_result(result_payload) - streamed_chunks = int(result_payload.get('streamed_chunks', 0) or 0) + streamed_chunks = int( + result_payload.get('streamed_chunks', 0) or 0) if self._chunk_cb: if streamed_chunks <= 0: self._emit_chunk_event( - 'chunk', - { + 'chunk', { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': restored, - }, - ) + }) self._emit_chunk_event( - 'end', - { + 'end', { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': restored, - }, - ) + }) # Always finalise the writer regardless of _chunk_cb. if _writer is not None: _writer.on_end(restored) @@ -1072,7 +1103,8 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: except Exception: pass if proc.is_alive(): - self._terminate_process(proc, reason='did not exit after result handling') + self._terminate_process( + proc, reason='did not exit after result handling') try: result_queue.close() result_queue.join_thread() @@ -1113,7 +1145,8 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: self._stream_paths[store_key] = _writer.stream_path return result except BaseException as exc: - status = 'cancelled' if isinstance(exc, asyncio.CancelledError) else 'error' + status = 'cancelled' if isinstance( + exc, asyncio.CancelledError) else 'error' raise finally: end_ts = now_iso() @@ -1137,7 +1170,9 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: try: await append_stats(get_stats_path(self.config), record) except Exception as exc: - logger.warning(f'Failed to write agent tool stats for {spec.tool_name}: {exc}') + logger.warning( + f'Failed to write agent tool stats for {spec.tool_name}: {exc}' + ) def _save_transcript(self, messages: Any, agent_tag: Optional[str]) -> None: if not isinstance(messages, list) or not agent_tag: @@ -1159,7 +1194,9 @@ def _build_payload(self, tool_args: dict, spec: _AgentToolSpec): field = spec.request_field or 'messages' raw_messages = tool_args.get(field) if not isinstance(raw_messages, list): - raise ValueError(f'Agent tool "{spec.tool_name}" expects "{field}" to be a list of messages.') + raise ValueError( + f'Agent tool "{spec.tool_name}" expects "{field}" to be a list of messages.' + ) return [ Message( role=msg.get('role', 'user'), @@ -1168,8 +1205,7 @@ def _build_payload(self, tool_args: dict, spec: _AgentToolSpec): tool_call_id=msg.get('tool_call_id'), name=msg.get('name'), reasoning_content=msg.get('reasoning_content', ''), - ) - for msg in raw_messages # TODO: Change role to user or not + ) for msg in raw_messages # TODO: Change role to user or not ] if spec.input_template: @@ -1179,9 +1215,7 @@ def _build_payload(self, tool_args: dict, spec: _AgentToolSpec): except Exception as exc: logger.warning( 'Failed to render input template for tool %s: %s. Falling back to JSON payload.', - spec.tool_name, - exc, - ) + spec.tool_name, exc) field = spec.request_field or 'request' if field in tool_args and isinstance(tool_args[field], str): @@ -1195,25 +1229,31 @@ def _format_output(self, messages: Any, spec: _AgentToolSpec) -> str: if spec.output_mode == 'history': serialized = [self._serialize_message(msg) for msg in messages] - return self._truncate(json.dumps(serialized, ensure_ascii=False, indent=2), spec.max_output_chars) + return self._truncate( + json.dumps(serialized, ensure_ascii=False, indent=2), + spec.max_output_chars) if spec.output_mode == 'raw_json': serialized = [msg.to_dict() for msg in messages] # type: ignore - return self._truncate(json.dumps(serialized, ensure_ascii=False), spec.max_output_chars) + return self._truncate( + json.dumps(serialized, ensure_ascii=False), + spec.max_output_chars) # Default: return final assistant message text for msg in reversed(messages): if getattr(msg, 'role', '') == 'assistant': return self._truncate(msg.content or '', spec.max_output_chars) - return self._truncate(messages[-1].content or '', spec.max_output_chars) + return self._truncate(messages[-1].content or '', + spec.max_output_chars) def _serialize_message(self, message: Message) -> Dict[str, Any]: data = message.to_dict() if data.get('tool_calls'): for call in data['tool_calls']: if isinstance(call.get('arguments'), dict): - call['arguments'] = json.dumps(call['arguments'], ensure_ascii=False) + call['arguments'] = json.dumps( + call['arguments'], ensure_ascii=False) return data @staticmethod diff --git a/ms_agent/tools/audio_generator/audio_gen.py b/ms_agent/tools/audio_generator/audio_gen.py index e9897d0e7..2b533c08b 100644 --- a/ms_agent/tools/audio_generator/audio_gen.py +++ b/ms_agent/tools/audio_generator/audio_gen.py @@ -6,14 +6,15 @@ class AudioGenerator(ToolBase): + def __init__(self, config): super().__init__(config) - self.temp_dir = os.path.join(self.output_dir, '.temp', 'audio_generator') + self.temp_dir = os.path.join(self.output_dir, '.temp', + 'audio_generator') os.makedirs(self.temp_dir, exist_ok=True) audio_generator = self.config.audio_generator if audio_generator.type == 'edge_tts': from .edge_tts import EdgeTTSGenerator - self.generator = EdgeTTSGenerator(self.config, self.temp_dir) else: raise NotImplementedError() @@ -27,21 +28,25 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='generate_audio', server_name='audio_generator', - description='Generate audio with a prompt, and return the audio file path.', + description= + 'Generate audio with a prompt, and return the audio file path.', parameters={ 'type': 'object', 'properties': { - 'text': {'type': 'string', 'description': 'The text to generate speech'}, + 'text': { + 'type': 'string', + 'description': 'The text to generate speech' + }, }, 'required': ['text'], - 'additionalProperties': False, - }, - ) + 'additionalProperties': False + }) ] } async def generate_audio(self, text, **kwargs): return await self.generator.generate_audio(text, **kwargs) - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: return await self.generate_audio(**tool_args) diff --git a/ms_agent/tools/audio_generator/edge_tts.py b/ms_agent/tools/audio_generator/edge_tts.py index 3e05a7e40..7efa5ff77 100644 --- a/ms_agent/tools/audio_generator/edge_tts.py +++ b/ms_agent/tools/audio_generator/edge_tts.py @@ -3,6 +3,7 @@ class EdgeTTSGenerator: + def __init__(self, config, temp_dir): self.config = config self.temp_dir = temp_dir @@ -14,19 +15,25 @@ async def generate_audio(self, text, **kwargs): return output_file @staticmethod - async def edge_tts_generate(text, output_file, speaker='zh-CN-YunjianNeural', rate='+0%', pitch='+0Hz'): + async def edge_tts_generate(text, + output_file, + speaker='zh-CN-YunjianNeural', + rate='+0%', + pitch='+0Hz'): import edge_tts - output_dir = os.path.dirname(output_file) or '.' os.makedirs(output_dir, exist_ok=True) text = text.replace('[', '').replace(']', '') - communicate = edge_tts.Communicate(text=text, voice=speaker, rate=rate, pitch=pitch) + communicate = edge_tts.Communicate( + text=text, voice=speaker, rate=rate, pitch=pitch) audio_data = b'' async for chunk in communicate.stream(): if chunk['type'] == 'audio': audio_data += chunk['data'] - assert len(audio_data) > 0, 'Audio generation failed: no data received from edge_tts.' + assert len( + audio_data + ) > 0, 'Audio generation failed: no data received from edge_tts.' with open(output_file, 'wb') as f: f.write(audio_data) diff --git a/ms_agent/tools/base.py b/ms_agent/tools/base.py index ad867522e..12ece9948 100644 --- a/ms_agent/tools/base.py +++ b/ms_agent/tools/base.py @@ -2,9 +2,8 @@ from abc import abstractmethod from typing import Any, Dict -from omegaconf import DictConfig - from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR +from omegaconf import DictConfig class ToolBase: @@ -17,16 +16,17 @@ def __init__(self, config): self.config = config self.exclude_functions = [] self.include_functions = [] - self.output_dir = getattr(self.config, 'output_dir', DEFAULT_OUTPUT_DIR) + self.output_dir = getattr(self.config, 'output_dir', + DEFAULT_OUTPUT_DIR) def exclude_func(self, tool_config: DictConfig): if tool_config is not None: self.exclude_functions = getattr(tool_config, 'exclude', []) self.include_functions = getattr(tool_config, 'include', []) - assert (not self.exclude_functions) or (not self.include_functions), ( - 'Set either `include` or `exclude` in tools config.' - ) + assert (not self.exclude_functions) or ( + not self.include_functions + ), 'Set either `include` or `exclude` in tools config.' @abstractmethod async def connect(self) -> None: @@ -76,7 +76,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: pass @abstractmethod - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: """Call a tool. Args: diff --git a/ms_agent/tools/code/code_executor.py b/ms_agent/tools/code/code_executor.py index 6cce39b3a..1df89e39c 100644 --- a/ms_agent/tools/code/code_executor.py +++ b/ms_agent/tools/code/code_executor.py @@ -1,18 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio -import json import socket from pathlib import Path from typing import Any, Dict, Optional, Union -from omegaconf import DictConfig - +import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.tools.code.sandbox_manager import SandboxManagerFactory from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR from ms_agent.utils.utils import install_package +from omegaconf import DictConfig logger = get_logger() @@ -41,7 +40,9 @@ def check_port_available(port: int, host: str = '127.0.0.1') -> bool: if e.errno == 98 or e.errno == 48: # Address already in use return False # Port is occupied except Exception as e: - logger.warning(f'Bind test failed for port {port}, falling back to connection test: {e}') + logger.warning( + f'Bind test failed for port {port}, falling back to connection test: {e}' + ) # Second try: connection test (fallback method) try: @@ -54,7 +55,9 @@ def check_port_available(port: int, host: str = '127.0.0.1') -> bool: return False # Be conservative: assume occupied if we can't check reliably -def find_available_port(start_port: int = 8888, max_attempts: int = 100, host: str = '127.0.0.1') -> Optional[int]: +def find_available_port(start_port: int = 8888, + max_attempts: int = 100, + host: str = '127.0.0.1') -> Optional[int]: """ Find an available port starting from start_port. @@ -71,7 +74,9 @@ def find_available_port(start_port: int = 8888, max_attempts: int = 100, host: s logger.info(f'Found available port: {port}') return port - logger.error(f'Could not find available port in range {start_port}-{start_port + max_attempts - 1}') + logger.error( + f'Could not find available port in range {start_port}-{start_port + max_attempts - 1}' + ) return None @@ -89,7 +94,8 @@ class CodeExecutionTool(ToolBase): def __init__(self, config): logger.info('Installing ms-enclave package...') try: - install_package(package_name='ms-enclave', import_name='ms_enclave') + install_package( + package_name='ms-enclave', import_name='ms_enclave') except Exception as e: raise e @@ -107,15 +113,19 @@ def __init__(self, config): logger.info('CodeExecutionTool initialized (ms-enclave based)') - def _build_sandbox_config(self, config) -> Union['DockerNotebookConfig', 'DockerSandboxConfig']: + def _build_sandbox_config( + self, + config) -> Union['DockerNotebookConfig', 'DockerSandboxConfig']: """Build sandbox configuration from agent config""" from ms_enclave.sandbox.model import DockerNotebookConfig, DockerSandboxConfig, SandboxType # Get sandbox-specific config or use defaults - if isinstance(config, DictConfig) and hasattr(config, 'tools') and hasattr(config.tools, 'code_executor'): + if isinstance(config, DictConfig) and hasattr( + config, 'tools') and hasattr(config.tools, 'code_executor'): sandbox_cfg = getattr(config.tools.code_executor, 'sandbox', {}) else: - sandbox_cfg = getattr(config, 'sandbox', {}) or getattr(config, 'tools', {}).get('sandbox', {}) + sandbox_cfg = getattr(config, 'sandbox', {}) or getattr( + config, 'tools', {}).get('sandbox', {}) # Get output directory for data mounting output_dir = Path(getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR)) @@ -135,39 +145,51 @@ def _build_sandbox_config(self, config) -> Union['DockerNotebookConfig', 'Docker 'env_vars': env_vars, } if hasattr(sandbox_cfg, '__getitem__'): - self.sandbox_type = sandbox_cfg.get('type', SandboxType.DOCKER_NOTEBOOK) - - config_dict.update( - { - 'image': sandbox_cfg.get('image', 'jupyter-kernel-gateway'), - 'command': sandbox_cfg.get('command', None), - 'ports': sandbox_cfg.get('ports', {}), - 'network': sandbox_cfg.get('network', 'bridge'), - 'memory_limit': sandbox_cfg.get('memory_limit', '2g'), - 'cpu_limit': sandbox_cfg.get('cpu_limit', 2.0), - 'network_enabled': sandbox_cfg.get('network_enabled', True), - 'privileged': sandbox_cfg.get('privileged', False), - 'remove_on_exit': sandbox_cfg.get('remove_on_exit', True), - 'timeout': sandbox_cfg.get('timeout', 30), - 'tools_config': sandbox_cfg.get('tools_config', {}), - 'working_dir': sandbox_cfg.get('working_dir', '/workspace'), - 'resource_limits': sandbox_cfg.get('resource_limits', {}), - } - ) + self.sandbox_type = sandbox_cfg.get('type', + SandboxType.DOCKER_NOTEBOOK) + + config_dict.update({ + 'image': + sandbox_cfg.get('image', 'jupyter-kernel-gateway'), + 'command': + sandbox_cfg.get('command', None), + 'ports': + sandbox_cfg.get('ports', {}), + 'network': + sandbox_cfg.get('network', 'bridge'), + 'memory_limit': + sandbox_cfg.get('memory_limit', '2g'), + 'cpu_limit': + sandbox_cfg.get('cpu_limit', 2.0), + 'network_enabled': + sandbox_cfg.get('network_enabled', True), + 'privileged': + sandbox_cfg.get('privileged', False), + 'remove_on_exit': + sandbox_cfg.get('remove_on_exit', True), + 'timeout': + sandbox_cfg.get('timeout', 30), + 'tools_config': + sandbox_cfg.get('tools_config', {}), + 'working_dir': + sandbox_cfg.get('working_dir', '/workspace'), + 'resource_limits': + sandbox_cfg.get('resource_limits', {}), + }) if self.sandbox_type == SandboxType.DOCKER_NOTEBOOK: - config_dict.update( - { - 'host': sandbox_cfg.get('host', '127.0.0.1'), - 'port': sandbox_cfg.get('port', 8888), - 'token': sandbox_cfg.get('token', None), - } - ) + config_dict.update({ + 'host': sandbox_cfg.get('host', '127.0.0.1'), + 'port': sandbox_cfg.get('port', 8888), + 'token': sandbox_cfg.get('token', None), + }) # Store original port for retry logic self._original_port = config_dict['port'] - self._port_retry_enabled = sandbox_cfg.get('port_retry_enabled', True) - self._max_port_retries = sandbox_cfg.get('max_port_retries', 10) + self._port_retry_enabled = sandbox_cfg.get( + 'port_retry_enabled', True) + self._max_port_retries = sandbox_cfg.get( + 'max_port_retries', 10) logger.info(f'Sandbox config: type={self.sandbox_type}') @@ -190,7 +212,8 @@ async def connect(self) -> None: logger.info('Initializing sandbox manager...') # Create manager using factory - self.manager = await SandboxManagerFactory.create_manager(self.config) + self.manager = await SandboxManagerFactory.create_manager( + self.config) await self.manager.start() logger.info('Creating sandbox instance...') @@ -203,8 +226,8 @@ async def connect(self) -> None: while retry_count < max_retries: try: self.sandbox_id = await self.manager.create_sandbox( - sandbox_type=self.sandbox_type, config=self.sandbox_config - ) + sandbox_type=self.sandbox_type, + config=self.sandbox_config) logger.info(f'Sandbox created: {self.sandbox_id}') @@ -220,55 +243,68 @@ async def connect(self) -> None: last_error = e # Check if it's a port conflict error - is_port_conflict = any( - keyword in error_msg - for keyword in [ - 'address already in use', - 'port is already allocated', - 'bind: address already in use', - 'port already in use', - ] - ) - - if is_port_conflict and self._port_retry_enabled and retry_count < (max_retries - 1): + is_port_conflict = any(keyword in error_msg + for keyword in [ + 'address already in use', + 'port is already allocated', + 'bind: address already in use', + 'port already in use' + ]) + + if is_port_conflict and self._port_retry_enabled and retry_count < ( + max_retries - 1): retry_count += 1 - logger.warning(f'Port conflict detected (attempt {retry_count}/{max_retries}): {e}') + logger.warning( + f'Port conflict detected (attempt {retry_count}/{max_retries}): {e}' + ) # Try to find a new available port if self.sandbox_type == SandboxType.DOCKER_NOTEBOOK: new_port = find_available_port( - start_port=self.sandbox_config.port + 1, max_attempts=100, host=self.sandbox_config.host - ) + start_port=self.sandbox_config.port + 1, + max_attempts=100, + host=self.sandbox_config.host) if new_port: - logger.info(f'Retrying with new port: {new_port} (was {self.sandbox_config.port})') + logger.info( + f'Retrying with new port: {new_port} (was {self.sandbox_config.port})' + ) # Update the config with new port self.sandbox_config.port = new_port # Clean up failed sandbox if it was created if self.sandbox_id: try: - await self.manager.delete_sandbox(self.sandbox_id) + await self.manager.delete_sandbox( + self.sandbox_id) self.sandbox_id = None except Exception as cleanup_error: - logger.warning(f'Failed to cleanup sandbox: {cleanup_error}') + logger.warning( + f'Failed to cleanup sandbox: {cleanup_error}' + ) # Wait a bit before retry await asyncio.sleep(1) continue else: - logger.error('Could not find available port for retry') - raise RuntimeError(f'Port conflict and no available ports found: {e}') from e + logger.error( + 'Could not find available port for retry') + raise RuntimeError( + f'Port conflict and no available ports found: {e}' + ) from e else: # For non-notebook sandbox, just retry - logger.info(f'Retrying sandbox creation (attempt {retry_count}/{max_retries})...') + logger.info( + f'Retrying sandbox creation (attempt {retry_count}/{max_retries})...' + ) await asyncio.sleep(1) continue else: # Not a port conflict or retries exhausted raise - logger.error(f'Failed to create sandbox after {max_retries} attempts') + logger.error( + f'Failed to create sandbox after {max_retries} attempts') raise RuntimeError( f'Sandbox initialization failed after {max_retries} attempts: {last_error}' ) from last_error @@ -302,130 +338,166 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='notebook_executor', server_name='code_executor', - description=( - 'Execute Python code in an isolated Docker sandbox with state ' - 'persistence in a Jupyter kernel environment. Variables, imports, and ' - 'data are preserved across multiple calls within the same session. ' - 'Supports pandas, numpy, matplotlib, seaborn for data analysis. ' - 'Data files in the output directory are accessible at /data/ path. ' - 'Use print() to output results.' - ), + description= + ('Execute Python code in an isolated Docker sandbox with state ' + 'persistence in a Jupyter kernel environment. Variables, imports, and ' + 'data are preserved across multiple calls within the same session. ' + 'Supports pandas, numpy, matplotlib, seaborn for data analysis. ' + 'Data files in the output directory are accessible at /data/ path. ' + 'Use print() to output results.'), parameters={ 'type': 'object', 'properties': { 'code': { - 'type': 'string', - 'description': ( - 'Python code to execute. Can access previously defined variables. ' - 'Data files are at /data/ (e.g., pd.read_csv(\'/data/file.csv\')). ' - 'Use print() for output.' - ), + 'type': + 'string', + 'description': + ('Python code to execute. Can access previously defined variables. ' + 'Data files are at /data/ (e.g., pd.read_csv(\'/data/file.csv\')). ' + 'Use print() for output.') + }, + 'description': { + 'type': + 'string', + 'description': + 'Brief description of what the code does' }, - 'description': {'type': 'string', 'description': 'Brief description of what the code does'}, 'timeout': { 'type': 'integer', 'minimum': 1, 'maximum': 600, 'description': 'Execution timeout in seconds', - 'default': 60, - }, + 'default': 60 + } }, 'required': ['code'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='python_executor', server_name='code_executor', - description=( - 'Execute Python code in an isolated environment. ' - 'Supports pandas, numpy, matplotlib, seaborn and other libraries you need for data analysis. ' - 'Data files in the output directory are accessible at /data/ path. ' - 'Use print() to output results.' - ), + description= + ('Execute Python code in an isolated environment. ' + 'Supports pandas, numpy, matplotlib, seaborn and other libraries you need for data analysis. ' + 'Data files in the output directory are accessible at /data/ path. ' + 'Use print() to output results.'), parameters={ 'type': 'object', 'properties': { - 'code': {'type': 'string', 'description': 'Python code to execute'}, - 'description': {'type': 'string', 'description': 'Brief description of what the code does'}, + 'code': { + 'type': 'string', + 'description': 'Python code to execute' + }, + 'description': { + 'type': + 'string', + 'description': + 'Brief description of what the code does' + }, 'timeout': { 'type': 'integer', 'description': 'Execution timeout in seconds', - 'default': 30, - }, + 'default': 30 + } }, 'required': ['code'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='shell_executor', server_name='code_executor', - description=( - 'Execute one shell command in an isolated environment. ' - 'Commands will be executed directly without shell parsing. ' - 'For shell syntax (cd, &&, ||, pipes, redirection), use explicit wrapper like sh -lc "...". ' - 'Supports basic operations like ls, mkdir, rm, mv, npm, pip, etc. ' - 'Data files in the output directory are accessible at /data/ path. ' - ), + description= + ('Execute one shell command in an isolated environment. ' + 'Commands will be executed directly without shell parsing. ' + 'For shell syntax (cd, &&, ||, pipes, redirection), use explicit wrapper like sh -lc "...". ' + 'Supports basic operations like ls, mkdir, rm, mv, npm, pip, etc. ' + 'Data files in the output directory are accessible at /data/ path. ' + ), parameters={ 'type': 'object', 'properties': { - 'command': {'type': 'string', 'description': 'Shell command to execute'}, + 'command': { + 'type': 'string', + 'description': 'Shell command to execute' + }, 'timeout': { 'type': 'integer', 'description': 'Execution timeout in seconds', - 'default': 900, - }, + 'default': 900 + } }, 'required': ['command'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='file_operation', server_name='code_executor', - description='Perform file operations like read, write, delete, and list files', + description= + 'Perform file operations like read, write, delete, and list files', parameters={ 'type': 'object', 'properties': { 'operation': { + 'type': + 'string', + 'description': + 'Type of file operation to perform', + 'enum': [ + 'create', 'read', 'write', 'delete', + 'list', 'exists' + ] + }, + 'file_path': { 'type': 'string', - 'description': 'Type of file operation to perform', - 'enum': ['create', 'read', 'write', 'delete', 'list', 'exists'], + 'description': 'Path to the file or directory' }, - 'file_path': {'type': 'string', 'description': 'Path to the file or directory'}, 'content': { - 'type': 'string', - 'description': 'Content to write to file (only for write operation)', + 'type': + 'string', + 'description': + 'Content to write to file (only for write operation)' }, - 'encoding': {'type': 'string', 'description': 'File encoding', 'default': 'utf-8'}, + 'encoding': { + 'type': 'string', + 'description': 'File encoding', + 'default': 'utf-8' + } }, 'required': ['operation', 'file_path'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='reset_executor', server_name='code_executor', - description=( - 'Reset the sandbox state by restarting the kernel. ' - 'All variables, imports, and session state will be cleared.' - ), - parameters={'type': 'object', 'properties': {}, 'required': [], 'additionalProperties': False}, + description= + ('Reset the sandbox state by restarting the kernel. ' + 'All variables, imports, and session state will be cleared.' + ), + parameters={ + 'type': 'object', + 'properties': {}, + 'required': [], + 'additionalProperties': False + }, ), Tool( tool_name='get_executor_info', server_name='code_executor', description='Get current sandbox status and information', - parameters={'type': 'object', 'properties': {}, 'required': [], 'additionalProperties': False}, - ), + parameters={ + 'type': 'object', + 'properties': {}, + 'required': [], + 'additionalProperties': False + }, + ) ] } return tools - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: """Route tool calls to appropriate methods""" if not self._initialized: await self.connect() @@ -434,12 +506,26 @@ async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) method = getattr(self, tool_name) return await method(**tool_args) except AttributeError: - return json.dumps({'success': False, 'error': f'Unknown tool: {tool_name}'}, indent=2) + return json.dumps( + { + 'success': False, + 'error': f'Unknown tool: {tool_name}' + }, + indent=2) except Exception as e: - logger.error(f'Tool execution error ({tool_name}): {e}', exc_info=True) - return json.dumps({'success': False, 'error': f'Tool execution error: {str(e)}'}, indent=2) + logger.error( + f'Tool execution error ({tool_name}): {e}', exc_info=True) + return json.dumps( + { + 'success': False, + 'error': f'Tool execution error: {str(e)}' + }, + indent=2) - async def notebook_executor(self, code: str, description: str = '', timeout: Optional[int] = None) -> str: + async def notebook_executor(self, + code: str, + description: str = '', + timeout: Optional[int] = None) -> str: """ Execute Python code in the sandbox using notebook_executor. @@ -460,8 +546,10 @@ async def notebook_executor(self, code: str, description: str = '', timeout: Opt result = await self.manager.execute_tool( sandbox_id=self.sandbox_id, tool_name='notebook_executor', - parameters={'code': code, 'timeout': timeout or 60}, - ) + parameters={ + 'code': code, + 'timeout': timeout or 60 + }) success = result.status == ExecutionStatus.SUCCESS @@ -475,16 +563,24 @@ async def notebook_executor(self, code: str, description: str = '', timeout: Opt 'success': success, 'description': description, 'output': result.output or '', - 'error': result.error if result.error else None, + 'error': result.error if result.error else None }, - indent=2, - ) + indent=2) except Exception as e: logger.error(f'Execute python failed: {e}', exc_info=True) - return json.dumps({'success': False, 'description': description, 'error': str(e)}, indent=2) + return json.dumps( + { + 'success': False, + 'description': description, + 'error': str(e) + }, + indent=2) - async def python_executor(self, code: str, description: str = '', timeout: Optional[int] = None) -> str: + async def python_executor(self, + code: str, + description: str = '', + timeout: Optional[int] = None) -> str: """ Execute Python code in the sandbox. @@ -505,8 +601,10 @@ async def python_executor(self, code: str, description: str = '', timeout: Optio result = await self.manager.execute_tool( sandbox_id=self.sandbox_id, tool_name='python_executor', - parameters={'code': code, 'timeout': timeout or 60}, - ) + parameters={ + 'code': code, + 'timeout': timeout or 60 + }) success = result.status == ExecutionStatus.SUCCESS @@ -520,16 +618,23 @@ async def python_executor(self, code: str, description: str = '', timeout: Optio 'success': success, 'description': description, 'output': result.output or '', - 'error': result.error if result.error else None, + 'error': result.error if result.error else None }, - indent=2, - ) + indent=2) except Exception as e: logger.error(f'Execute python failed: {e}', exc_info=True) - return json.dumps({'success': False, 'description': description, 'error': str(e)}, indent=2) + return json.dumps( + { + 'success': False, + 'description': description, + 'error': str(e) + }, + indent=2) - async def shell_executor(self, command: str, timeout: Optional[int] = None) -> str: + async def shell_executor(self, + command: str, + timeout: Optional[int] = None) -> str: """ Execute shell commands in the sandbox. @@ -545,19 +650,23 @@ async def shell_executor(self, command: str, timeout: Optional[int] = None) -> s try: logger.info(f'Executing command: {command[:50]}...') - shell_meta = ('&&', '||', '|', ';', '>', '<', '`', '$(', 'cd ', 'export ') - already_wrapped = command.lstrip().startswith(('sh ', 'bash ', '/bin/sh ', '/bin/bash ')) - if not already_wrapped and any(meta in command for meta in shell_meta): + shell_meta = ('&&', '||', '|', ';', '>', '<', '`', '$(', 'cd ', + 'export ') + already_wrapped = command.lstrip().startswith( + ('sh ', 'bash ', '/bin/sh ', '/bin/bash ')) + if not already_wrapped and any(meta in command + for meta in shell_meta): import shlex - command = f'sh -lc {shlex.quote(command)}' # Execute via shell_executor result = await self.manager.execute_tool( sandbox_id=self.sandbox_id, tool_name='shell_executor', - parameters={'command': command, 'timeout': timeout or 900}, - ) + parameters={ + 'command': command, + 'timeout': timeout or 900 + }) success = result.status == ExecutionStatus.SUCCESS if success: @@ -566,17 +675,22 @@ async def shell_executor(self, command: str, timeout: Optional[int] = None) -> s logger.warning(f'Command execution failed: {result.error}') return json.dumps( - {'success': success, 'output': result.output or '', 'error': result.error if result.error else None}, - indent=2, - ) + { + 'success': success, + 'output': result.output or '', + 'error': result.error if result.error else None + }, + indent=2) except Exception as e: logger.error(f'Execute shell failed: {e}', exc_info=True) return json.dumps({'success': False, 'error': str(e)}, indent=2) - async def file_operation( - self, operation: str, file_path: str, content: Optional[str] = None, encoding: Optional[str] = 'utf-8' - ) -> str: + async def file_operation(self, + operation: str, + file_path: str, + content: Optional[str] = None, + encoding: Optional[str] = 'utf-8') -> str: """ Perform file operations like read, write, delete, and list files in the sandbox. @@ -595,29 +709,41 @@ async def file_operation( result = await self.manager.execute_tool( sandbox_id=self.sandbox_id, tool_name='file_operation', - parameters={'operation': operation, 'file_path': file_path, 'content': content, 'encoding': encoding}, - ) + parameters={ + 'operation': operation, + 'file_path': file_path, + 'content': content, + 'encoding': encoding + }) success = result.status == ExecutionStatus.SUCCESS if success: - logger.info(f'File operation {operation} successful for {file_path}') + logger.info( + f'File operation {operation} successful for {file_path}') else: - logger.warning(f'File operation {operation} failed for {file_path}: {result.error}') + logger.warning( + f'File operation {operation} failed for {file_path}: {result.error}' + ) return json.dumps( { 'success': success, 'file_path': file_path, 'output': result.output if success else '', - 'error': result.error if result.error else None, + 'error': result.error if result.error else None }, - indent=2, - ) + indent=2) except Exception as e: logger.error(f'Read file failed: {e}', exc_info=True) - return json.dumps({'success': False, 'file_path': file_path, 'error': str(e)}, indent=2) + return json.dumps( + { + 'success': False, + 'file_path': file_path, + 'error': str(e) + }, + indent=2) async def reset_executor(self) -> str: """ @@ -637,8 +763,8 @@ async def reset_executor(self) -> str: # Create new sandbox self.sandbox_id = await self.manager.create_sandbox( - sandbox_type=SandboxType.DOCKER_NOTEBOOK, config=self.sandbox_config - ) + sandbox_type=SandboxType.DOCKER_NOTEBOOK, + config=self.sandbox_config) # Wait for it to be ready await self._wait_for_sandbox_ready() @@ -648,11 +774,11 @@ async def reset_executor(self) -> str: return json.dumps( { 'success': True, - 'message': 'Sandbox reset successfully. All variables and state cleared.', - 'new_sandbox_id': self.sandbox_id, + 'message': + 'Sandbox reset successfully. All variables and state cleared.', + 'new_sandbox_id': self.sandbox_id }, - indent=2, - ) + indent=2) except Exception as e: logger.error(f'Reset sandbox failed: {e}', exc_info=True) @@ -681,14 +807,18 @@ async def get_executor_info(self) -> str: 'config': { 'memory_limit': self.sandbox_config.memory_limit, 'cpu_limit': self.sandbox_config.cpu_limit, - 'timeout': self.sandbox_config.timeout, - }, + 'timeout': self.sandbox_config.timeout + } }, indent=2, - default=str, - ) + default=str) else: - return json.dumps({'success': False, 'error': 'Sandbox info not available'}, indent=2) + return json.dumps( + { + 'success': False, + 'error': 'Sandbox info not available' + }, + indent=2) except Exception as e: logger.error(f'Get sandbox info failed: {e}', exc_info=True) @@ -716,12 +846,16 @@ async def _wait_for_sandbox_ready(self, max_wait: int = 60) -> None: logger.info('Sandbox is running and ready') return elif info.status == SandboxStatus.ERROR: - error_msg = info.metadata.get('error') or f'Unknown error: {info.metadata}' + error_msg = info.metadata.get( + 'error') or f'Unknown error: {info.metadata}' raise RuntimeError(f'Sandbox failed to start: {error_msg}') if i % 5 == 0: - logger.debug(f'Waiting for sandbox... ({i}/{max_wait}s, status={info.status.value})') + logger.debug( + f'Waiting for sandbox... ({i}/{max_wait}s, status={info.status.value})' + ) await asyncio.sleep(1) - raise TimeoutError(f'Sandbox failed to become ready within {max_wait} seconds') + raise TimeoutError( + f'Sandbox failed to become ready within {max_wait} seconds') diff --git a/ms_agent/tools/code/local_code_executor.py b/ms_agent/tools/code/local_code_executor.py index 3ceb6d1a4..d1e2104c5 100644 --- a/ms_agent/tools/code/local_code_executor.py +++ b/ms_agent/tools/code/local_code_executor.py @@ -2,7 +2,6 @@ import asyncio.subprocess as ai_subprocess import inspect import io -import json import os import shlex import shutil @@ -11,6 +10,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set +import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger @@ -41,13 +41,11 @@ def _coerce_str(value: Optional[bytes]) -> str: class LocalKernelSession: """Manage a local ipykernel instance for stateful notebook execution.""" - def __init__( - self, - working_dir: Path, - env: Optional[Dict[str, str]] = None, - kernel_name: str = 'python3', - extra_arguments: Optional[List[str]] = None, - ): + def __init__(self, + working_dir: Path, + env: Optional[Dict[str, str]] = None, + kernel_name: str = 'python3', + extra_arguments: Optional[List[str]] = None): self.working_dir = working_dir self.env = env or {} self.kernel_name = kernel_name @@ -69,8 +67,9 @@ async def start(self) -> None: logger.info('Starting local ipykernel session...') self._km = AsyncKernelManager( - kernel_name=self.kernel_name, env=self.env, cwd=str(self.working_dir) - ) # cwd may be ignored here + kernel_name=self.kernel_name, + env=self.env, + cwd=str(self.working_dir)) # cwd may be ignored here start_kernel_result = self._km.start_kernel( extra_arguments=self.extra_arguments, @@ -148,8 +147,10 @@ async def execute(self, code: str, timeout: int) -> Dict[str, Any]: if not self._client: raise RuntimeError('Kernel client not initialized') - execute_call = self._client.execute(code=code, allow_stdin=False, stop_on_error=False) - msg_id = await execute_call if inspect.isawaitable(execute_call) else execute_call + execute_call = self._client.execute( + code=code, allow_stdin=False, stop_on_error=False) + msg_id = await execute_call if inspect.isawaitable( + execute_call) else execute_call stdout_parts: List[str] = [] stderr_parts: List[str] = [] @@ -173,7 +174,8 @@ async def _drain() -> None: msg_type = msg['msg_type'] content = msg.get('content', {}) - if msg_type == 'status' and content.get('execution_state') == 'idle': + if msg_type == 'status' and content.get( + 'execution_state') == 'idle': break if msg_type == 'stream': name = content.get('name', 'stdout') @@ -189,7 +191,8 @@ async def _drain() -> None: elif 'text/html' in data: display_parts.append(data['text/html']) elif data: - display_parts.append(json.dumps(data, ensure_ascii=False)) + display_parts.append( + json.dumps(data, ensure_ascii=False)) elif msg_type == 'error': error_payload = { 'ename': content.get('ename'), @@ -206,15 +209,23 @@ async def _drain() -> None: except asyncio.TimeoutError as exc: logger.warning('Notebook execution timed out, interrupting kernel') await self.interrupt() - raise TimeoutError(f'Notebook execution timed out after {timeout} seconds') from exc + raise TimeoutError( + f'Notebook execution timed out after {timeout} seconds' + ) from exc self.execution_count += 1 stdout = ''.join(stdout_parts).strip('\n') stderr = ''.join(stderr_parts).strip('\n') displays = '\n'.join(display_parts).strip('\n') - output_segments = [segment for segment in [stdout, displays] if segment] - - return {'output': '\n'.join(output_segments), 'stderr': stderr, 'error': error_payload} + output_segments = [ + segment for segment in [stdout, displays] if segment + ] + + return { + 'output': '\n'.join(output_segments), + 'stderr': stderr, + 'error': error_payload + } class LocalCodeExecutionTool(ToolBase): @@ -222,17 +233,23 @@ class LocalCodeExecutionTool(ToolBase): def __init__(self, config): super().__init__(config) - self.output_dir = Path(getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR)).expanduser().resolve() + self.output_dir = Path( + getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR)).expanduser().resolve() self.output_dir.mkdir(parents=True, exist_ok=True) - self.tool_config = getattr(getattr(config, 'tools', None), 'code_executor', None) - self._notebook_timeout = getattr(self.tool_config, 'notebook_timeout', 60) if self.tool_config else 60 - self._python_timeout = getattr(self.tool_config, 'python_timeout', 30) if self.tool_config else 30 - self._shell_timeout = getattr(self.tool_config, 'shell_timeout', 60) if self.tool_config else 60 + self.tool_config = getattr( + getattr(config, 'tools', None), 'code_executor', None) + self._notebook_timeout = getattr(self.tool_config, 'notebook_timeout', + 60) if self.tool_config else 60 + self._python_timeout = getattr(self.tool_config, 'python_timeout', + 30) if self.tool_config else 30 + self._shell_timeout = getattr(self.tool_config, 'shell_timeout', + 60) if self.tool_config else 60 kernel_env = self._build_env('kernel_env', inherit=False) shell_env = self._build_env('shell_env', inherit=False) - self.kernel_session = LocalKernelSession(working_dir=self.output_dir, env=kernel_env) + self.kernel_session = LocalKernelSession( + working_dir=self.output_dir, env=kernel_env) self.shell_env = shell_env self._kernel_lock = asyncio.Lock() self._initialized = False @@ -248,9 +265,12 @@ def __init__(self, config): if dg: deny_globs = list(dg) shell_cfg = getattr(self.tool_config, 'shell', None) if self.tool_config else None - shell_mode = getattr(shell_cfg, 'default_mode', 'workspace_write') if shell_cfg else 'workspace_write' - net = bool(getattr(shell_cfg, 'network_enabled', False)) if shell_cfg else False - max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192)) if shell_cfg else 8192 + shell_mode = getattr(shell_cfg, 'default_mode', + 'workspace_write') if shell_cfg else 'workspace_write' + net = bool(getattr(shell_cfg, 'network_enabled', False) + ) if shell_cfg else False + max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192) + ) if shell_cfg else 8192 self._policy = WorkspacePolicyKernel( self.output_dir, extra_allow_roots=extra_allow, @@ -262,14 +282,19 @@ def __init__(self, config): max_kb = 256 if shell_cfg and getattr(shell_cfg, 'max_output_kb', None): max_kb = int(shell_cfg.max_output_kb) - self._artifacts = ArtifactManager(self.output_dir, max_combined_bytes=max_kb * 1024) + self._artifacts = ArtifactManager( + self.output_dir, max_combined_bytes=max_kb * 1024) - self.exclude_func(getattr(getattr(config, 'tools', None), 'code_executor', None)) + self.exclude_func( + getattr(getattr(config, 'tools', None), 'code_executor', None)) if 'file_operation' not in self.exclude_functions: - logger.warning('file_operation is not suggested to be included in local code execution tool.') + logger.warning( + 'file_operation is not suggested to be included in local code execution tool.' + ) results = self._check_dependencies() - logger.info(f'Dependency check results: {results}\nMake sure to install the missing dependencies.') + logger.info(f'Dependency check results: {results}\n' + f'Make sure to install the missing dependencies.') logger.info('LocalCodeExecutionTool initialized (ipykernel based)') @@ -303,11 +328,13 @@ def _check_dependencies(self) -> None: install_package(pip_name, import_name) module = importlib.import_module(import_name) except Exception as e: - logger.error(f'Failed to install or import {pip_name}: {e}') + logger.error( + f'Failed to install or import {pip_name}: {e}') results[pip_name] = None continue except Exception as e: - logger.error(f'Unexpected error when importing {pip_name}: {e}') + logger.error( + f'Unexpected error when importing {pip_name}: {e}') results[pip_name] = None continue @@ -318,7 +345,8 @@ def _check_dependencies(self) -> None: def _build_env(self, field: str, inherit: bool = False) -> Dict[str, str]: if inherit: env: Dict[str, str] = dict(os.environ) - logger.warning("It's not safe to inherit from the parent environment.") + logger.warning( + "It's not safe to inherit from the parent environment.") else: env: Dict[str, str] = { 'INHERITED_FROM_LOCAL': 'False', @@ -366,63 +394,72 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='notebook_executor', server_name='code_executor', - description=( - 'Execute Python code locally with state ' - 'persistence in a Jupyter kernel environment. Variables, imports, and ' - 'data are preserved across multiple calls within the same session. ' - 'Supports pandas, numpy, matplotlib, seaborn for data analysis. ' - 'Use print() to output results.' - ), + description= + ('Execute Python code locally with state ' + 'persistence in a Jupyter kernel environment. Variables, imports, and ' + 'data are preserved across multiple calls within the same session. ' + 'Supports pandas, numpy, matplotlib, seaborn for data analysis. ' + 'Use print() to output results.'), parameters={ 'type': 'object', 'properties': { 'code': { - 'type': 'string', - 'description': ( - 'Python code to execute in the notebook session. ' - 'Can access previously defined variables. ' - 'Use print() for output.' - ), + 'type': + 'string', + 'description': + ('Python code to execute in the notebook session. ' + 'Can access previously defined variables. ' + 'Use print() for output.') + }, + 'description': { + 'type': + 'string', + 'description': + 'Brief description of what the code does' }, - 'description': {'type': 'string', 'description': 'Brief description of what the code does'}, 'timeout': { 'type': 'integer', 'minimum': 1, 'maximum': 600, 'description': 'Execution timeout in seconds', - 'default': self._notebook_timeout, - }, + 'default': self._notebook_timeout + } }, 'required': ['code'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='python_executor', server_name='code_executor', - description=( - 'Execute stateless Python code locally. ' - 'Each call runs in an isolated environment without ' - 'persisting context between invocations. ' - 'Supports pandas, numpy, matplotlib, seaborn, and other ' - 'libraries you need for data analysis. ' - 'Use print() to output results.' - ), + description= + ('Execute stateless Python code locally. ' + 'Each call runs in an isolated environment without ' + 'persisting context between invocations. ' + 'Supports pandas, numpy, matplotlib, seaborn, and other ' + 'libraries you need for data analysis. ' + 'Use print() to output results.'), parameters={ 'type': 'object', 'properties': { - 'code': {'type': 'string', 'description': 'Python code to execute'}, - 'description': {'type': 'string', 'description': 'Brief description of what the code does'}, + 'code': { + 'type': 'string', + 'description': 'Python code to execute' + }, + 'description': { + 'type': + 'string', + 'description': + 'Brief description of what the code does' + }, 'timeout': { 'type': 'integer', 'description': 'Execution timeout in seconds', - 'default': self._python_timeout, - }, + 'default': self._python_timeout + } }, 'required': ['code'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='shell_executor', server_name='code_executor', @@ -435,15 +472,19 @@ async def _get_tools_inner(self) -> Dict[str, Any]: parameters={ 'type': 'object', 'properties': { - 'command': {'type': 'string', 'description': 'Shell command to execute'}, + 'command': { + 'type': 'string', + 'description': 'Shell command to execute' + }, 'timeout': { 'type': 'integer', 'description': 'Execution timeout in seconds', - 'default': self._shell_timeout, + 'default': self._shell_timeout }, 'run_in_background': { 'type': 'boolean', - 'description': 'If true, start the command asynchronously and return task_id (requires TaskManager).', + 'description': + 'If true, start the command asynchronously and return task_id (requires TaskManager).', 'default': False, }, '__call_id': { @@ -452,50 +493,76 @@ async def _get_tools_inner(self) -> Dict[str, Any]: }, }, 'required': ['command'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='file_operation', server_name='code_executor', - description='Perform file operations inside the local output directory', + description= + 'Perform file operations inside the local output directory', parameters={ 'type': 'object', 'properties': { 'operation': { + 'type': + 'string', + 'description': + 'Type of file operation to perform', + 'enum': [ + 'create', 'read', 'write', 'delete', + 'list', 'exists' + ] + }, + 'file_path': { 'type': 'string', - 'description': 'Type of file operation to perform', - 'enum': ['create', 'read', 'write', 'delete', 'list', 'exists'], + 'description': 'Path to the file or directory' }, - 'file_path': {'type': 'string', 'description': 'Path to the file or directory'}, - 'content': {'type': 'string', 'description': 'Content for write/create operations'}, - 'encoding': {'type': 'string', 'description': 'File encoding to use', 'default': 'utf-8'}, + 'content': { + 'type': + 'string', + 'description': + 'Content for write/create operations' + }, + 'encoding': { + 'type': 'string', + 'description': 'File encoding to use', + 'default': 'utf-8' + } }, 'required': ['operation', 'file_path'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='reset_executor', server_name='code_executor', - description=( - 'Restart the local ipykernel session to clear state. ' - 'All variables, imports, and session state will be cleared.' - ), - parameters={'type': 'object', 'properties': {}, 'required': [], 'additionalProperties': False}, - ), + description= + ('Restart the local ipykernel session to clear state. ' + 'All variables, imports, and session state will be cleared.' + ), + parameters={ + 'type': 'object', + 'properties': {}, + 'required': [], + 'additionalProperties': False + }), Tool( tool_name='get_executor_info', server_name='code_executor', - description='Get information about the local execution environment.', - parameters={'type': 'object', 'properties': {}, 'required': [], 'additionalProperties': False}, - ), + description= + 'Get information about the local execution environment.', + parameters={ + 'type': 'object', + 'properties': {}, + 'required': [], + 'additionalProperties': False + }), ] } return tools - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: if not self._initialized: await self.connect() @@ -503,12 +570,28 @@ async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) method = getattr(self, tool_name) return await method(**tool_args) except AttributeError: - return json.dumps({'success': False, 'error': f'Unknown tool: {tool_name}'}, ensure_ascii=False, indent=2) + return json.dumps( + { + 'success': False, + 'error': f'Unknown tool: {tool_name}' + }, + ensure_ascii=False, + indent=2) except Exception as exc: - logger.error(f'Tool execution error ({tool_name}): {exc}', exc_info=True) - return json.dumps({'success': False, 'error': f'Tool execution error: {exc}'}, ensure_ascii=False, indent=2) + logger.error( + f'Tool execution error ({tool_name}): {exc}', exc_info=True) + return json.dumps( + { + 'success': False, + 'error': f'Tool execution error: {exc}' + }, + ensure_ascii=False, + indent=2) - async def notebook_executor(self, code: str, description: str = '', timeout: Optional[int] = None) -> str: + async def notebook_executor(self, + code: str, + description: str = '', + timeout: Optional[int] = None) -> str: exec_timeout = timeout or self._notebook_timeout try: @@ -516,8 +599,13 @@ async def notebook_executor(self, code: str, description: str = '', timeout: Opt result = await self.kernel_session.execute(code, exec_timeout) except Exception as exc: return json.dumps( - {'success': False, 'description': description, 'error': str(exc)}, ensure_ascii=False, indent=2 - ) + { + 'success': False, + 'description': description, + 'error': str(exc) + }, + ensure_ascii=False, + indent=2) error_payload = result.get('error') stderr = result.get('stderr') or '' @@ -534,13 +622,15 @@ async def notebook_executor(self, code: str, description: str = '', timeout: Opt 'success': error_payload is None, 'description': description, 'output': result.get('output', ''), - 'error': stderr or None, + 'error': stderr or None }, ensure_ascii=False, - indent=2, - ) + indent=2) - async def python_executor(self, code: str, description: str = '', timeout: Optional[int] = None) -> str: + async def python_executor(self, + code: str, + description: str = '', + timeout: Optional[int] = None) -> str: exec_timeout = timeout or self._python_timeout def _exec_code(): @@ -548,26 +638,35 @@ def _exec_code(): stderr_buffer = io.StringIO() globals_dict: Dict[str, Any] = {'__builtins__': __builtins__} locals_dict: Dict[str, Any] = {} - with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer): + with redirect_stdout(stdout_buffer), redirect_stderr( + stderr_buffer): exec(code, globals_dict, locals_dict) return stdout_buffer.getvalue(), stderr_buffer.getvalue() try: - stdout, stderr = await asyncio.wait_for(asyncio.to_thread(_exec_code), timeout=exec_timeout) + stdout, stderr = await asyncio.wait_for( + asyncio.to_thread(_exec_code), timeout=exec_timeout) except asyncio.TimeoutError: return json.dumps( { - 'success': False, - 'description': description, - 'error': f'Python execution timed out after {exec_timeout} seconds', + 'success': + False, + 'description': + description, + 'error': + f'Python execution timed out after {exec_timeout} seconds' }, ensure_ascii=False, - indent=2, - ) + indent=2) except Exception as exc: return json.dumps( - {'success': False, 'description': description, 'error': str(exc)}, ensure_ascii=False, indent=2 - ) + { + 'success': False, + 'description': description, + 'error': str(exc) + }, + ensure_ascii=False, + indent=2) if not stderr: logger.info('Python code executed successfully') @@ -579,19 +678,16 @@ def _exec_code(): 'success': not stderr, 'description': description, 'output': stdout.strip('\n'), - 'error': stderr.strip('\n') or None, + 'error': stderr.strip('\n') or None }, ensure_ascii=False, - indent=2, - ) + indent=2) - async def shell_executor( - self, - command: str, - timeout: Optional[int] = None, - run_in_background: bool = False, - __call_id: Optional[str] = None, - ) -> str: + async def shell_executor(self, + command: str, + timeout: Optional[int] = None, + run_in_background: bool = False, + __call_id: Optional[str] = None) -> str: exec_timeout = timeout or self._shell_timeout call_id = __call_id or f'shell-{os.urandom(4).hex()}' @@ -599,7 +695,10 @@ async def shell_executor( self._policy.assert_shell_command_allowed(command) except WorkspacePolicyError as e: return json.dumps( - {'success': False, 'error': str(e)}, + { + 'success': False, + 'error': str(e) + }, ensure_ascii=False, indent=2, ) @@ -611,7 +710,8 @@ async def shell_executor( return json.dumps( { 'success': False, - 'error': 'run_in_background requires TaskManager (host must wire LLMAgent.task_manager).', + 'error': + 'run_in_background requires TaskManager (host must wire LLMAgent.task_manager).', }, ensure_ascii=False, indent=2, @@ -626,7 +726,10 @@ async def shell_executor( ) except FileNotFoundError as exc: return json.dumps( - {'success': False, 'error': f'Shell not available: {exc}'}, + { + 'success': False, + 'error': f'Shell not available: {exc}' + }, ensure_ascii=False, indent=2, ) @@ -640,7 +743,8 @@ async def shell_executor( async def _watcher() -> None: try: - stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=exec_timeout) + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=exec_timeout) stdout_text = _coerce_str(stdout).strip('\n') stderr_text = _coerce_str(stderr).strip('\n') success = process.returncode == 0 @@ -694,13 +798,17 @@ async def _watcher() -> None: ) except FileNotFoundError as exc: return json.dumps( - {'success': False, 'error': f'Shell not available: {exc}'}, + { + 'success': False, + 'error': f'Shell not available: {exc}' + }, ensure_ascii=False, indent=2, ) try: - stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=exec_timeout) + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=exec_timeout) except asyncio.TimeoutError: process.kill() try: @@ -708,7 +816,12 @@ async def _watcher() -> None: except Exception: # noqa: B902 pass return json.dumps( - {'success': False, 'error': f'Shell command timed out after {exec_timeout} seconds'}, + { + 'success': + False, + 'error': + f'Shell command timed out after {exec_timeout} seconds' + }, ensure_ascii=False, indent=2, ) @@ -728,15 +841,22 @@ async def _watcher() -> None: payload=payload, ) - async def file_operation( - self, operation: str, file_path: str, content: Optional[str] = None, encoding: Optional[str] = 'utf-8' - ) -> str: + async def file_operation(self, + operation: str, + file_path: str, + content: Optional[str] = None, + encoding: Optional[str] = 'utf-8') -> str: try: target = self._resolve_path(file_path) except ValueError as exc: return json.dumps( - {'success': False, 'error': str(exc), 'file_path': file_path}, ensure_ascii=False, indent=2 - ) + { + 'success': False, + 'error': str(exc), + 'file_path': file_path + }, + ensure_ascii=False, + indent=2) op = operation.lower() @@ -744,40 +864,69 @@ async def file_operation( if op == 'create': target.parent.mkdir(parents=True, exist_ok=True) target.touch(exist_ok=True) - result = {'success': True, 'file_path': str(target), 'message': 'File created'} + result = { + 'success': True, + 'file_path': str(target), + 'message': 'File created' + } elif op == 'read': data = target.read_text(encoding=encoding or 'utf-8') - result = {'success': True, 'file_path': str(target), 'output': data} + result = { + 'success': True, + 'file_path': str(target), + 'output': data + } elif op == 'write': if content is None: raise ValueError('Content is required for write operation') target.parent.mkdir(parents=True, exist_ok=True) target.write_text(content, encoding=encoding or 'utf-8') - result = {'success': True, 'file_path': str(target), 'message': 'File written'} + result = { + 'success': True, + 'file_path': str(target), + 'message': 'File written' + } elif op == 'delete': if target.is_dir(): shutil.rmtree(target) else: target.unlink(missing_ok=True) - result = {'success': True, 'file_path': str(target), 'message': 'Deleted successfully'} + result = { + 'success': True, + 'file_path': str(target), + 'message': 'Deleted successfully' + } elif op == 'list': if not target.is_dir(): - raise ValueError('List operation requires a directory path') - entries = [ - { - 'name': child.name, - 'is_dir': child.is_dir(), - 'size': child.stat().st_size if child.is_file() else None, - } - for child in sorted(target.iterdir()) - ] - result = {'success': True, 'file_path': str(target), 'entries': entries} + raise ValueError( + 'List operation requires a directory path') + entries = [{ + 'name': + child.name, + 'is_dir': + child.is_dir(), + 'size': + child.stat().st_size if child.is_file() else None + } for child in sorted(target.iterdir())] + result = { + 'success': True, + 'file_path': str(target), + 'entries': entries + } elif op == 'exists': - result = {'success': True, 'file_path': str(target), 'exists': target.exists()} + result = { + 'success': True, + 'file_path': str(target), + 'exists': target.exists() + } else: raise ValueError(f'Unsupported file operation: {operation}') except Exception as exc: - result = {'success': False, 'file_path': str(target), 'error': str(exc)} + result = { + 'success': False, + 'file_path': str(target), + 'error': str(exc) + } return json.dumps(result, ensure_ascii=False, indent=2, default=str) @@ -786,10 +935,14 @@ async def reset_executor(self) -> str: async with self._kernel_lock: await self.kernel_session.restart() return json.dumps( - {'success': True, 'message': 'Local kernel session restarted. State has been cleared.'}, + { + 'success': + True, + 'message': + 'Local kernel session restarted. State has been cleared.' + }, ensure_ascii=False, - indent=2, - ) + indent=2) except Exception as exc: return json.dumps({'success': False, 'error': str(exc)}, ensure_ascii=False, indent=2) # yapf: disable @@ -811,5 +964,6 @@ def _resolve_path(self, file_path: str) -> Path: else: raw_path = raw_path.resolve() if not _is_relative_to(raw_path, self.output_dir): - raise ValueError('Access outside the output directory is not permitted') + raise ValueError( + 'Access outside the output directory is not permitted') return raw_path diff --git a/ms_agent/tools/code/sandbox_manager.py b/ms_agent/tools/code/sandbox_manager.py index 8a3880a90..9744ca7f8 100644 --- a/ms_agent/tools/code/sandbox_manager.py +++ b/ms_agent/tools/code/sandbox_manager.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import Union -from omegaconf import DictConfig - from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() @@ -31,18 +30,22 @@ def ensure_local_image_exists(image: str) -> bool: try: client = docker.from_env() - image_exists = any(image in img.tags for img in client.images.list() if img.tags) + image_exists = any(image in img.tags + for img in client.images.list() if img.tags) if image_exists: logger.info(f'Image exists in local Docker registry: {image}') else: - logger.info(f'Image does not exist in local Docker registry: {image}') + logger.info( + f'Image does not exist in local Docker registry: {image}') return image_exists except Exception as e: logger.error(f'Error checking if image exists: {e}') raise RuntimeError(f'Failed to check image existence: {e}') from e @staticmethod - async def create_manager(config: Union[DictConfig, dict]) -> Union['LocalSandboxManager', 'HttpSandboxManager']: + async def create_manager( + config: Union[DictConfig, dict] + ) -> Union['LocalSandboxManager', 'HttpSandboxManager']: """ Create and initialize a sandbox manager based on configuration. @@ -58,12 +61,13 @@ async def create_manager(config: Union[DictConfig, dict]) -> Union['LocalSandbox from ms_enclave.sandbox.manager import HttpSandboxManager, LocalSandboxManager # Extract sandbox configuration - if isinstance(config, DictConfig) and hasattr(config, 'tools') and hasattr(config.tools, 'code_executor'): + if isinstance(config, DictConfig) and hasattr( + config, 'tools') and hasattr(config.tools, 'code_executor'): sandbox_config = getattr(config.tools.code_executor, 'sandbox', {}) elif isinstance(config, (DictConfig, dict)): - sandbox_config = config.get('tools', {}).get('code_executor', {}).get('sandbox', {}) or config.get( - 'sandbox', {} - ) + sandbox_config = config.get('tools', {}).get( + 'code_executor', {}).get('sandbox', {}) or config.get( + 'sandbox', {}) else: raise ValueError(f'Unknown config type: {type(config)}') @@ -74,16 +78,24 @@ async def create_manager(config: Union[DictConfig, dict]) -> Union['LocalSandbox if mode == 'local': cleanup_interval = sandbox_config.get('cleanup_interval', 300) manager = LocalSandboxManager(cleanup_interval=cleanup_interval) - logger.info(f'Created LocalSandboxManager with cleanup_interval={cleanup_interval}s') + logger.info( + f'Created LocalSandboxManager with cleanup_interval={cleanup_interval}s' + ) if image: try: - if not SandboxManagerFactory.ensure_local_image_exists(image): - raise ValueError(f'Image "{image}" does not exist in local Docker registry') + if not SandboxManagerFactory.ensure_local_image_exists( + image): + raise ValueError( + f'Image "{image}" does not exist in local Docker registry' + ) except RuntimeError as e: - raise ValueError(f'Error checking if image exists: {e}') from e + raise ValueError( + f'Error checking if image exists: {e}') from e else: - logger.warning('No image specified for LocalSandboxManager, using default') + logger.warning( + 'No image specified for LocalSandboxManager, using default' + ) elif mode == 'http': base_url = sandbox_config.get('http_url', 'http://localhost:8000') @@ -91,6 +103,7 @@ async def create_manager(config: Union[DictConfig, dict]) -> Union['LocalSandbox logger.info(f'Created HttpSandboxManager with base_url={base_url}') else: - raise ValueError(f"Unknown sandbox mode: {mode}. Must be 'local' or 'http'") + raise ValueError( + f"Unknown sandbox mode: {mode}. Must be 'local' or 'http'") return manager diff --git a/ms_agent/tools/code_server/lsp_code_server.py b/ms_agent/tools/code_server/lsp_code_server.py index 64431c84a..df1e043e5 100644 --- a/ms_agent/tools/code_server/lsp_code_server.py +++ b/ms_agent/tools/code_server/lsp_code_server.py @@ -1,14 +1,15 @@ import asyncio -import json import os import shutil import sys from pathlib import Path from typing import Any, Dict, List, Optional +import json from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger -from ms_agent.utils.constants import DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, DEFAULT_OUTPUT_DIR +from ms_agent.utils.constants import (DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, + DEFAULT_OUTPUT_DIR) logger = get_logger() @@ -23,7 +24,8 @@ def __init__(self, config): self.stdout = None self.message_id = 0 self.initialized = False - self.output_dir = getattr(self.config, 'output_dir', DEFAULT_OUTPUT_DIR) + self.output_dir = getattr(self.config, 'output_dir', + DEFAULT_OUTPUT_DIR) self.workspace_dir = Path(self.output_dir).resolve() self.index_dir = os.path.join(self.output_dir, DEFAULT_INDEX_DIR) self.lock_dir = os.path.join(self.output_dir, DEFAULT_LOCK_DIR) @@ -57,7 +59,12 @@ async def send_request(self, method: str, params: dict = None) -> dict: self.message_id += 1 request_id = self.message_id - request = {'jsonrpc': '2.0', 'id': request_id, 'method': method, 'params': params or {}} + request = { + 'jsonrpc': '2.0', + 'id': request_id, + 'method': method, + 'params': params or {} + } content = json.dumps(request) message = f'Content-Length: {len(content)}\r\n\r\n{content}' @@ -77,10 +84,14 @@ async def send_request(self, method: str, params: dict = None) -> dict: # It's a notification (no id) or response for different request # Log and continue reading if 'method' in msg: - logger.debug(f"Received notification during request: {msg.get('method')}") + logger.debug( + f"Received notification during request: {msg.get('method')}" + ) continue - logger.warning(f'No response received for request {request_id} after {max_retries} attempts') + logger.warning( + f'No response received for request {request_id} after {max_retries} attempts' + ) return {'error': 'No response received'} except Exception as e: @@ -92,7 +103,11 @@ async def send_notification(self, method: str, params: dict = None): if not self.process or not self.stdin: raise RuntimeError('LSP server not started') - notification = {'jsonrpc': '2.0', 'method': method, 'params': params or {}} + notification = { + 'jsonrpc': '2.0', + 'method': method, + 'params': params or {} + } content = json.dumps(notification) message = f'Content-Length: {len(content)}\r\n\r\n{content}' @@ -132,45 +147,55 @@ async def _read_message(self) -> dict: async def initialize(self): """Initialize the LSP server and wait for it to be ready""" response = await self.send_request( - 'initialize', - { - 'processId': os.getpid(), - 'rootUri': self.workspace_dir.as_uri(), - 'rootPath': str(self.workspace_dir), - 'workspaceFolders': [{'uri': self.workspace_dir.as_uri(), 'name': self.workspace_dir.name}], + 'initialize', { + 'processId': + os.getpid(), + 'rootUri': + self.workspace_dir.as_uri(), + 'rootPath': + str(self.workspace_dir), + 'workspaceFolders': [{ + 'uri': self.workspace_dir.as_uri(), + 'name': self.workspace_dir.name + }], 'capabilities': { 'textDocument': { 'publishDiagnostics': {}, - 'synchronization': {'didOpen': True, 'didChange': True, 'didClose': True}, + 'synchronization': { + 'didOpen': True, + 'didChange': True, + 'didClose': True + } } - }, - }, - ) + } + }) if 'result' in response: await self.send_notification('initialized', {}) # CRITICAL: Wait for server to be fully ready # Read and discard any startup messages - await asyncio.sleep(1.0) # Give server time to complete initialization + await asyncio.sleep( + 1.0) # Give server time to complete initialization await self.send_notification( - 'workspace/didChangeConfiguration', - { + 'workspace/didChangeConfiguration', { 'settings': { 'python': { 'pythonPath': sys.executable, }, - 'pyright': {'extraPaths': [str(self.workspace_dir)]}, + 'pyright': { + 'extraPaths': [str(self.workspace_dir)] + }, } - }, - ) + }) # Consume any pending messages (like "starting" notifications) try: for _ in range(10): try: - await asyncio.wait_for(self._read_message(), timeout=2.0) + await asyncio.wait_for( + self._read_message(), timeout=2.0) except asyncio.TimeoutError: break except Exception as e: @@ -183,11 +208,13 @@ async def initialize(self): logger.error(f'LSP initialization failed: {response}') return False - async def open_document(self, file_path: str, content: str, language_id: str): + async def open_document(self, file_path: str, content: str, + language_id: str): """Open a document in the LSP server""" file_uri = Path(file_path).resolve().as_uri() changes = [{'uri': file_uri, 'type': 1}] - await self.send_notification('workspace/didChangeWatchedFiles', {'changes': changes}) + await self.send_notification('workspace/didChangeWatchedFiles', + {'changes': changes}) if file_path.endswith('.tsx'): language_id = 'typescriptreact' @@ -199,25 +226,45 @@ async def open_document(self, file_path: str, content: str, language_id: str): language_id = 'javascript' await self.send_notification( - 'textDocument/didOpen', - {'textDocument': {'uri': file_uri, 'languageId': language_id, 'version': 1, 'text': content}}, - ) + 'textDocument/didOpen', { + 'textDocument': { + 'uri': file_uri, + 'languageId': language_id, + 'version': 1, + 'text': content + } + }) await asyncio.sleep(2.0) async def close_document(self, file_path: str): """Close a document to clean up old index""" file_uri = Path(file_path).resolve().as_uri() - await self.send_notification('textDocument/didClose', {'textDocument': {'uri': file_uri}}) - - async def update_document(self, file_path: str, content: str, version: int = 2): + await self.send_notification('textDocument/didClose', + {'textDocument': { + 'uri': file_uri + }}) + + async def update_document(self, + file_path: str, + content: str, + version: int = 2): """Update a document in the LSP server""" file_uri = Path(file_path).resolve().as_uri() await self.send_notification( - 'textDocument/didChange', - {'textDocument': {'uri': file_uri, 'version': version}, 'contentChanges': [{'text': content}]}, - ) - - async def get_diagnostics(self, file_path: str, wait_time: float = 2.0, use_cache: bool = True) -> List[dict]: + 'textDocument/didChange', { + 'textDocument': { + 'uri': file_uri, + 'version': version + }, + 'contentChanges': [{ + 'text': content + }] + }) + + async def get_diagnostics(self, + file_path: str, + wait_time: float = 2.0, + use_cache: bool = True) -> List[dict]: await asyncio.sleep(wait_time) file_uri = Path(file_path).resolve().as_uri() @@ -234,7 +281,8 @@ async def get_diagnostics(self, file_path: str, wait_time: float = 2.0, use_cach if msg.get('method') == 'textDocument/publishDiagnostics': current_uri = msg.get('params', {}).get('uri') - current_diags = msg.get('params', {}).get('diagnostics', []) + current_diags = msg.get('params', + {}).get('diagnostics', []) self.diagnostics_cache[current_uri] = current_diags logger.debug(f'Cached diagnostics for {current_uri}') @@ -242,12 +290,15 @@ async def get_diagnostics(self, file_path: str, wait_time: float = 2.0, use_cach if current_uri == file_uri: diagnostics = current_diags found_target = True - logger.debug(f'Found target diagnostics for {file_uri}') + logger.debug( + f'Found target diagnostics for {file_uri}') except asyncio.TimeoutError: consecutive_timeouts += 1 if consecutive_timeouts >= 3: - logger.debug(f'Stopped after {consecutive_timeouts} consecutive timeouts') + logger.debug( + f'Stopped after {consecutive_timeouts} consecutive timeouts' + ) break else: continue @@ -278,12 +329,17 @@ async def start(self) -> bool: try: # Check if typescript is installed check_process = await asyncio.create_subprocess_exec( - 'npx', 'tsc', '--version', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE - ) + 'npx', + 'tsc', + '--version', + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE) await check_process.communicate() if check_process.returncode != 0: - logger.error('TypeScript not found. Install with: npm install -g typescript') + logger.error( + 'TypeScript not found. Install with: npm install -g typescript' + ) return False # Start typescript-language-server @@ -294,8 +350,7 @@ async def start(self) -> bool: stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - cwd=str(self.workspace_dir), - ) + cwd=str(self.workspace_dir)) self.stdin = self.process.stdin self.stdout = self.process.stdout @@ -339,7 +394,8 @@ def _clean_env_for_node() -> dict[str, str]: env = dict(os.environ) removed = env.pop('PYTHONPATH', None) if removed: - logger.debug('Removed PYTHONPATH=%r from pyright subprocess env', removed) + logger.debug('Removed PYTHONPATH=%r from pyright subprocess env', + removed) return env async def start(self) -> bool: @@ -349,12 +405,15 @@ async def start(self) -> bool: # Check if pyright is installed check_process = await asyncio.create_subprocess_exec( - 'pyright', '--version', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE - ) + 'pyright', + '--version', + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE) await check_process.communicate() if check_process.returncode != 0: - logger.warning('Pyright not found. Install with: pip install pyright') + logger.warning( + 'Pyright not found. Install with: pip install pyright') return False # Start pyright langserver @@ -365,8 +424,7 @@ async def start(self) -> bool: stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=str(self.workspace_dir), - env=clean_env, - ) + env=clean_env) self.stdin = self.process.stdin self.stdout = self.process.stdout @@ -376,7 +434,8 @@ async def _read_server_stderr(process): line = await process.stderr.readline() if not line: break - logger.error(f"LSP: {line.decode(errors='ignore').rstrip()}") + logger.error( + f"LSP: {line.decode(errors='ignore').rstrip()}") asyncio.create_task(_read_server_stderr(self.process)) @@ -384,7 +443,9 @@ async def _read_server_stderr(process): return await self.initialize() except FileNotFoundError: - logger.error('pyright-langserver not found. Install with: pip install pyright') + logger.error( + 'pyright-langserver not found. Install with: pip install pyright' + ) return False except Exception as e: logger.error(f'Failed to start Python LSP server: {e}') @@ -414,8 +475,10 @@ async def start(self) -> bool: if not jdtls_cmd: # Try to find in PATH check_process = await asyncio.create_subprocess_exec( - 'which', 'jdtls', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE - ) + 'which', + 'jdtls', + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE) stdout, _ = await check_process.communicate() if check_process.returncode == 0: jdtls_cmd = stdout.decode('utf-8').strip() @@ -440,8 +503,7 @@ async def start(self) -> bool: stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - cwd=str(self.workspace_dir), - ) + cwd=str(self.workspace_dir)) self.stdin = self.process.stdin self.stdout = self.process.stdout @@ -462,20 +524,12 @@ async def start(self) -> bool: class LSPCodeServer(ToolBase): + skip_files = [ - 'vite.config.ts', - 'vite.config.js', - 'webpack.config.js', - 'webpack.config.ts', - 'rollup.config.js', - 'rollup.config.ts', - 'next.config.js', - 'next.config.ts', - 'tsconfig.json', - 'jsconfig.json', - 'package.json', - 'pom.xml', - 'build.gradle', + 'vite.config.ts', 'vite.config.js', 'webpack.config.js', + 'webpack.config.ts', 'rollup.config.js', 'rollup.config.ts', + 'next.config.js', 'next.config.ts', 'tsconfig.json', 'jsconfig.json', + 'package.json', 'pom.xml', 'build.gradle' ] language_mapping = { @@ -490,8 +544,10 @@ def __init__(self, config): super().__init__(config) self.servers: Dict[str, LSPServer] = {} self.file_versions: Dict[str, int] = {} - self.opened_documents: Dict[str, str] = {} # Track opened documents: file_path -> language - self.output_dir = getattr(self.config, 'output_dir', DEFAULT_OUTPUT_DIR) + self.opened_documents: Dict[str, str] = { + } # Track opened documents: file_path -> language + self.output_dir = getattr(self.config, 'output_dir', + DEFAULT_OUTPUT_DIR) self.workspace_dir = self.output_dir self.index_dir = os.path.join(self.output_dir, DEFAULT_INDEX_DIR) self.lock_dir = os.path.join(self.output_dir, DEFAULT_LOCK_DIR) @@ -504,8 +560,10 @@ async def connect(self) -> None: def cleanup_lsp_index_dirs(self): cleanup_dirs = [ os.path.join(self.output_dir, '.jdtls_workspace'), # Java LSP - os.path.join(self.output_dir, '.pyright'), # Python LSP (if exists) - os.path.join(self.output_dir, 'node_modules', '.cache'), # TypeScript LSP cache + os.path.join(self.output_dir, + '.pyright'), # Python LSP (if exists) + os.path.join(self.output_dir, 'node_modules', + '.cache'), # TypeScript LSP cache ] for dir_path in cleanup_dirs: @@ -513,7 +571,9 @@ def cleanup_lsp_index_dirs(self): try: shutil.rmtree(dir_path, ignore_errors=True) except Exception as e: # noqa - logger.warning(f'Failed to cleanup LSP index directory {dir_path}: {e}') + logger.warning( + f'Failed to cleanup LSP index directory {dir_path}: {e}' + ) async def cleanup(self) -> None: """Stop all LSP servers and clear indexes""" @@ -539,65 +599,81 @@ async def cleanup(self) -> None: async def _get_tools_inner(self) -> Dict[str, Any]: """Get available tools""" return { - 'lsp_code_server': [ - { - 'tool_name': 'check_directory', - 'description': ( - 'Check all code files in a directory for errors and issues. ' - 'Supports TypeScript/JavaScript, Python, Java files. ' - 'Returns a summary of all diagnostics found.' - ), - 'parameters': { - 'type': 'object', - 'properties': { - 'directory': { - 'type': 'string', - 'description': 'Path to the directory to check (relative to workspace)', - }, - 'language': { - 'type': 'string', - 'enum': ['typescript', 'python', 'java'], - 'description': 'Programming language to check (typescript for JS/TS, python for Python, java for Java)', - }, + 'lsp_code_server': [{ + 'tool_name': + 'check_directory', + 'description': + ('Check all code files in a directory for errors and issues. ' + 'Supports TypeScript/JavaScript, Python, Java files. ' + 'Returns a summary of all diagnostics found.'), + 'parameters': { + 'type': 'object', + 'properties': { + 'directory': { + 'type': + 'string', + 'description': + 'Path to the directory to check (relative to workspace)' }, - 'required': ['directory', 'language'], + 'language': { + 'type': + 'string', + 'enum': ['typescript', 'python', 'java'], + 'description': + 'Programming language to check (typescript for JS/TS, python for Python, java for Java)' + } }, - }, - { - 'tool_name': 'update_and_check', - 'description': ( - "Incrementally update a file's content and check for errors. " - 'Used during code generation to validate each N lines. ' - 'More efficient than checking from scratch each time.' - ), - 'parameters': { - 'type': 'object', - 'properties': { - 'file_path': {'type': 'string', 'description': 'Path to the file (relative to workspace)'}, - 'content': {'type': 'string', 'description': 'Updated file content'}, - 'language': { - 'type': 'string', - 'enum': ['typescript', 'python', 'java'], - 'description': 'Programming language to check (typescript for JS/TS, python for Python, java for Java)', - }, + 'required': ['directory', 'language'] + } + }, { + 'tool_name': + 'update_and_check', + 'description': + ("Incrementally update a file's content and check for errors. " + 'Used during code generation to validate each N lines. ' + 'More efficient than checking from scratch each time.'), + 'parameters': { + 'type': 'object', + 'properties': { + 'file_path': { + 'type': + 'string', + 'description': + 'Path to the file (relative to workspace)' + }, + 'content': { + 'type': 'string', + 'description': 'Updated file content' }, - 'required': ['file_path', 'content', 'language'], + 'language': { + 'type': + 'string', + 'enum': ['typescript', 'python', 'java'], + 'description': + 'Programming language to check (typescript for JS/TS, python for Python, java for Java)' + } }, - }, - ] + 'required': ['file_path', 'content', 'language'] + } + }] } - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: """Call a tool""" if tool_name == 'check_directory': - return await self._check_directory(tool_args['directory'], tool_args['language']) + return await self._check_directory(tool_args['directory'], + tool_args['language']) elif tool_name == 'update_and_check': - return await self._update_and_check(tool_args['file_path'], tool_args['content'], tool_args['language']) + return await self._update_and_check(tool_args['file_path'], + tool_args['content'], + tool_args['language']) else: return json.dumps({'error': f'Unknown tool: {tool_name}'}) - async def _get_or_create_server(self, language: str) -> Optional[LSPServer]: + async def _get_or_create_server(self, + language: str) -> Optional[LSPServer]: """Get or create an LSP server for the given language""" if language in self.servers: return self.servers[language] @@ -623,16 +699,19 @@ async def _check_directory(self, directory: str, language: str) -> str: language = language.lower() server = await self._get_or_create_server(language) if not server: - return json.dumps({'error': f'Failed to start LSP server for {language}'}) + return json.dumps( + {'error': f'Failed to start LSP server for {language}'}) dir_path = Path(self.workspace_dir) / directory if not dir_path.exists() or not dir_path.is_dir(): - return json.dumps({'error': f'Directory not found: {directory}'}) + return json.dumps( + {'error': f'Directory not found: {directory}'}) extensions = self.language_mapping.get(language) if not extensions: - return json.dumps({'error': f'No extensions found for language: {language}'}) + return json.dumps( + {'error': f'No extensions found for language: {language}'}) all_files = [] for ext in extensions: @@ -647,18 +726,26 @@ async def _check_directory(self, directory: str, language: str) -> str: configs = ('xml', 'json', 'yaml', 'yml', 'txt', 'md', 'gradle') if filename.endswith(configs): continue - if any([filename.startswith(prefix) for prefix in self.skip_prefixes]): + if any([ + filename.startswith(prefix) + for prefix in self.skip_prefixes + ]): continue rel_path = file.relative_to(dir_path) - if any([part.startswith(prefix) for part in rel_path.parts for prefix in self.skip_prefixes]): + if any([ + part.startswith(prefix) for part in rel_path.parts + for prefix in self.skip_prefixes + ]): continue cleaned_files.append(file) all_files = cleaned_files if not all_files: - return json.dumps( - {'message': f'No {language} files found in {directory}', 'file_count': 0, 'diagnostics': []} - ) + return json.dumps({ + 'message': f'No {language} files found in {directory}', + 'file_count': 0, + 'diagnostics': [] + }) all_diagnostics = [] for file_path in all_files: @@ -666,7 +753,8 @@ async def _check_directory(self, directory: str, language: str) -> str: content = file_path.read_text(encoding='utf-8') rel_path = file_path.relative_to(Path(self.workspace_dir)) self.file_versions[str(rel_path)] = 1 - await server.open_document(str(file_path), content, language) + await server.open_document( + str(file_path), content, language) self.opened_documents[str(file_path)] = language # Skip diagnostics for index-only mode (trust existing files) @@ -687,10 +775,9 @@ async def _check_directory(self, directory: str, language: str) -> str: 'file_count': len(all_files), 'diagnostics': all_diagnostics, 'files_indexed': len(all_files) - len(all_diagnostics), - 'status': 'indexed', + 'status': 'indexed' }, - indent=2, - ) + indent=2) except Exception as e: logger.error(f'Error checking directory: {e}') @@ -698,6 +785,7 @@ async def _check_directory(self, directory: str, language: str) -> str: @staticmethod def _format_diag_results(diagnostics_result): + ignored_errors = [ # 'cannot be assigned to', 'is not assignable to', 'cannot assign to', '"none"', @@ -705,17 +793,17 @@ def _format_diag_results(diagnostics_result): 'unused', 'never used', 'never read', - 'implicitly has', + 'implicitly has' ] if diagnostics_result.get('has_errors'): issues = diagnostics_result.get('diagnostics', []) # Filter critical errors only critical_errors = [ - d - for d in issues - if d.get('severity') == 'Error' - and not any([ignore in d.get('message', '').lower() for ignore in ignored_errors]) + d for d in issues if d.get('severity') == 'Error' and not any([ + ignore in d.get('message', '').lower() + for ignore in ignored_errors + ]) ] if critical_errors: @@ -730,12 +818,14 @@ def _format_diag_results(diagnostics_result): else: return '' - async def _update_and_check(self, file_path: str, content: str, language: str) -> str: + async def _update_and_check(self, file_path: str, content: str, + language: str) -> str: """Update file content and check for errors""" try: server = await self._get_or_create_server(language) if not server: - return json.dumps({'error': f'Failed to start LSP server for {language}'}) + return json.dumps( + {'error': f'Failed to start LSP server for {language}'}) full_path = Path(self.workspace_dir) / file_path full_path_str = str(full_path) @@ -746,7 +836,10 @@ async def _update_and_check(self, file_path: str, content: str, language: str) - self.opened_documents[full_path_str] = language else: self.file_versions[file_path] += 1 - await server.update_document(full_path_str, content, version=self.file_versions[file_path]) + await server.update_document( + full_path_str, + content, + version=self.file_versions[file_path]) diagnostics = await server.get_diagnostics(str(full_path)) @@ -756,7 +849,7 @@ async def _update_and_check(self, file_path: str, content: str, language: str) - 'version': self.file_versions[file_path], 'has_errors': len(diagnostics) > 0, 'diagnostic_count': len(diagnostics), - 'diagnostics': self._format_diagnostics(diagnostics), + 'diagnostics': self._format_diagnostics(diagnostics) } return self._format_diag_results(diagnostics_result) @@ -770,17 +863,26 @@ def _format_diagnostics(diagnostics: List[dict]) -> List[dict]: """Format diagnostics for better readability""" formatted = [] for diag in diagnostics: - severity_map = {1: 'Error', 2: 'Warning', 3: 'Information', 4: 'Hint'} + severity_map = { + 1: 'Error', + 2: 'Warning', + 3: 'Information', + 4: 'Hint' + } - formatted.append( - { - 'severity': severity_map.get(diag.get('severity', 1), 'Error'), - 'message': diag.get('message', ''), - 'line': diag.get('range', {}).get('start', {}).get('line', 0) + 1, - 'column': diag.get('range', {}).get('start', {}).get('character', 0) + 1, - 'source': diag.get('source', ''), - 'code': diag.get('code', ''), - } - ) + formatted.append({ + 'severity': + severity_map.get(diag.get('severity', 1), 'Error'), + 'message': + diag.get('message', ''), + 'line': + diag.get('range', {}).get('start', {}).get('line', 0) + 1, + 'column': + diag.get('range', {}).get('start', {}).get('character', 0) + 1, + 'source': + diag.get('source', ''), + 'code': + diag.get('code', '') + }) return formatted diff --git a/ms_agent/tools/docling/chunker.py b/ms_agent/tools/docling/chunker.py index 9b30d27b4..3781cb76b 100644 --- a/ms_agent/tools/docling/chunker.py +++ b/ms_agent/tools/docling/chunker.py @@ -1,18 +1,20 @@ from typing import Iterable, Iterator, List, Union from docling_core.transforms.chunker import BaseChunk, DocChunk -from docling_core.transforms.chunker.hierarchical_chunker import ChunkingDocSerializer, ChunkingSerializerProvider +from docling_core.transforms.chunker.hierarchical_chunker import ( + ChunkingDocSerializer, ChunkingSerializerProvider) from docling_core.transforms.chunker.hybrid_chunker import HybridChunker from docling_core.transforms.chunker.tokenizer.base import BaseTokenizer -from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer +from docling_core.transforms.chunker.tokenizer.huggingface import \ + HuggingFaceTokenizer from docling_core.transforms.serializer.markdown import MarkdownParams from docling_core.types import DoclingDocument from docling_core.types.doc import DocItemLabel -from modelscope import AutoTokenizer +from ms_agent.utils.logger import get_logger from rich.console import Console from rich.panel import Panel -from ms_agent.utils.logger import get_logger +from modelscope import AutoTokenizer logger = get_logger() @@ -22,12 +24,11 @@ class ImgPlaceholderSerializerProvider(ChunkingSerializerProvider): + def get_serializer(self, doc): return ChunkingDocSerializer( doc=doc, - params=MarkdownParams( - image_placeholder='', - ), + params=MarkdownParams(image_placeholder='', ), ) @@ -35,10 +36,13 @@ def get_serializer(self, doc): class HybridDocumentChunker: + EMBED_MODEL_ID = 'sentence-transformers/all-MiniLM-L6-v2' MAX_TOKENS = 1024 - def __init__(self, embed_model_id: str = EMBED_MODEL_ID, max_tokens: int = MAX_TOKENS): + def __init__(self, + embed_model_id: str = EMBED_MODEL_ID, + max_tokens: int = MAX_TOKENS): """ Hybrid chunker that splits interleaved picture, table, and text into chunks. @@ -62,7 +66,9 @@ def __init__(self, embed_model_id: str = EMBED_MODEL_ID, max_tokens: int = MAX_T ) @staticmethod - def find_n_th_chunk_with_label(chunks: List[BaseChunk], n: int, label: DocItemLabel) -> tuple[int, BaseChunk]: + def find_n_th_chunk_with_label( + chunks: List[BaseChunk], n: int, + label: DocItemLabel) -> tuple[int, BaseChunk]: """ Find the n-th chunk with the specified label in an iterable of chunks. @@ -82,7 +88,8 @@ def find_n_th_chunk_with_label(chunks: List[BaseChunk], n: int, label: DocItemLa return None, None @staticmethod - def find_all_chunks_with_label(chunks: List[BaseChunk], label: DocItemLabel) -> List[BaseChunk]: + def find_all_chunks_with_label(chunks: List[BaseChunk], + label: DocItemLabel) -> List[BaseChunk]: """ Find all chunks with the specified label in an iterable of chunks. @@ -94,11 +101,15 @@ def find_all_chunks_with_label(chunks: List[BaseChunk], label: DocItemLabel) -> List[BaseChunk]: A list of BaseChunk objects that match the label. """ return [ - chunk for chunk in chunks if any(it.label == label for it in DocChunk.model_validate(chunk).meta.doc_items) + chunk for chunk in chunks + if any(it.label == label + for it in DocChunk.model_validate(chunk).meta.doc_items) ] @staticmethod - def find_all_chunks_with_labels(chunks: List[BaseChunk], labels: List[DocItemLabel]) -> List[BaseChunk]: + def find_all_chunks_with_labels( + chunks: List[BaseChunk], + labels: List[DocItemLabel]) -> List[BaseChunk]: """ Find all chunks with any of the specified labels in an iterable of chunks. @@ -110,7 +121,9 @@ def find_all_chunks_with_labels(chunks: List[BaseChunk], labels: List[DocItemLab List[BaseChunk]: A list of BaseChunk objects that match any of the labels. """ return [ - chunk for chunk in chunks if any(it.label in labels for it in DocChunk.model_validate(chunk).meta.doc_items) + chunk for chunk in chunks if any( + it.label in labels + for it in DocChunk.model_validate(chunk).meta.doc_items) ] def print_chunk(self, chunks: List[BaseChunk], chunk_pos: int) -> None: @@ -144,7 +157,6 @@ def chunk(self, docs: Iterable[DoclingDocument]) -> Iterator[BaseChunk]: if __name__ == '__main__': from ms_agent.tools.docling.doc_loader import DocLoader - urls = [ 'https://arxiv.org/pdf/2408.09869', 'https://arxiv.org/pdf/2502.15214', diff --git a/ms_agent/tools/docling/doc_loader.py b/ms_agent/tools/docling/doc_loader.py index 18e30e635..5daf1652b 100644 --- a/ms_agent/tools/docling/doc_loader.py +++ b/ms_agent/tools/docling/doc_loader.py @@ -2,6 +2,9 @@ # yapf: disable import ast import os +from typing import Dict, Iterator, List, Optional, Tuple, Union +from unittest.mock import patch as mock_patch + from docling.backend.html_backend import HTMLDocumentBackend from docling.datamodel.accelerator_options import AcceleratorOptions from docling.datamodel.base_models import InputFormat @@ -9,17 +12,19 @@ from docling.datamodel.pipeline_options import PdfPipelineOptions from docling.datamodel.settings import DEFAULT_PAGE_RANGE, PageRange from docling.document_converter import DocumentConverter, PdfFormatOption -from docling.models.document_picture_classifier import DocumentPictureClassifier +from docling.models.document_picture_classifier import \ + DocumentPictureClassifier from docling.models.layout_model import LayoutModel from docling.models.table_structure_model import TableStructureModel from docling_core.types import DoclingDocument from docling_core.types.doc import DocItem -from typing import Dict, Iterator, List, Optional, Tuple, Union -from unittest.mock import patch as mock_patch - from ms_agent.tools.docling.doc_postprocess import PostProcess -from ms_agent.tools.docling.patches import (download_models_ms, download_models_pic_classifier_ms, html_handle_figure, - html_handle_image, patch_easyocr_models, requests_get_with_timeout,) +from ms_agent.tools.docling.patches import (download_models_ms, + download_models_pic_classifier_ms, + html_handle_figure, + html_handle_image, + patch_easyocr_models, + requests_get_with_timeout) from ms_agent.utils.logger import get_logger from ms_agent.utils.patcher import patch from ms_agent.utils.utils import normalize_url_or_file, txt_to_html diff --git a/ms_agent/tools/docling/doc_postprocess.py b/ms_agent/tools/docling/doc_postprocess.py index b1dbf091b..2e21051d9 100644 --- a/ms_agent/tools/docling/doc_postprocess.py +++ b/ms_agent/tools/docling/doc_postprocess.py @@ -4,9 +4,11 @@ class PostProcess: + MIN_PICTURE_SIZE = 200.0 * 200.0 # Minimum size for pictures in pixels - def __init__(self): ... + def __init__(self): + ... @staticmethod def filter(doc: DoclingDocument) -> Union[DoclingDocument, None]: diff --git a/ms_agent/tools/docling/patches.py b/ms_agent/tools/docling/patches.py index a3a873da7..b2ec77afd 100644 --- a/ms_agent/tools/docling/patches.py +++ b/ms_agent/tools/docling/patches.py @@ -1,12 +1,13 @@ # flake8: noqa import sys +from pathlib import Path + from bs4 import Tag from docling_core.types import DoclingDocument from docling_core.types.doc import DocItemLabel, ImageRef -from pathlib import Path - from ms_agent.utils.logger import get_logger -from ms_agent.utils.utils import load_image_from_uri_to_pil, load_image_from_url_to_pil, validate_url +from ms_agent.utils.utils import (load_image_from_uri_to_pil, + load_image_from_url_to_pil, validate_url) logger = get_logger() @@ -15,7 +16,9 @@ def html_handle_figure(self, element: Tag, doc: DoclingDocument) -> None: """ Patch the `docling.backend.html_backend.HTMLDocumentBackend.handle_figure` method. """ - logger.debug(f'Patching HTMLDocumentBackend.handle_figure for {doc.origin.filename}') + logger.debug( + f'Patching HTMLDocumentBackend.handle_figure for {doc.origin.filename}' + ) img_element: Tag = element.find('img') if isinstance(img_element, Tag): @@ -29,7 +32,8 @@ def html_handle_figure(self, element: Tag, doc: DoclingDocument) -> None: else: if not img_url.startswith('http'): img_url = validate_url(img_url=img_url, backend=self) - img_pil = load_image_from_url_to_pil(img_url) if img_url.startswith('http') else None + img_pil = load_image_from_url_to_pil( + img_url) if img_url.startswith('http') else None else: img_pil = None @@ -73,7 +77,8 @@ def html_handle_image(self, element: Tag, doc: DoclingDocument) -> None: """ Patch the `docling.backend.html_backend.HTMLDocumentBackend.handle_image` method to use the custom. """ - logger.debug(f'Patching HTMLDocumentBackend.handle_image for {doc.origin.filename}') + logger.debug( + f'Patching HTMLDocumentBackend.handle_image for {doc.origin.filename}') # Get the image from element img_url: str = element.attrs.get('src', None) @@ -142,41 +147,30 @@ def patch_easyocr_models(): logger.info('Patching EasyOCR models URLs for ModelScope...') # Patch detection models - detection_models['craft']['url'] = ( - 'https://modelscope.cn/models/ms-agent/craft_mlt_25k/resolve/master/craft_mlt_25k.zip' - ) - detection_models['dbnet18']['url'] = ( - 'https://modelscope.cn/models/ms-agent/pretrained_ic15_res18/resolve/master/pretrained_ic15_res18.zip' - ) - detection_models['dbnet50']['url'] = ( - 'https://modelscope.cn/models/ms-agent/pretrained_ic15_res50/resolve/master/pretrained_ic15_res50.zip' - ) + detection_models['craft'][ + 'url'] = 'https://modelscope.cn/models/ms-agent/craft_mlt_25k/resolve/master/craft_mlt_25k.zip' + detection_models['dbnet18'][ + 'url'] = 'https://modelscope.cn/models/ms-agent/pretrained_ic15_res18/resolve/master/pretrained_ic15_res18.zip' + detection_models['dbnet50'][ + 'url'] = 'https://modelscope.cn/models/ms-agent/pretrained_ic15_res50/resolve/master/pretrained_ic15_res50.zip' # Patch recognition models - recognition_models['gen2']['english_g2']['url'] = ( - 'https://modelscope.cn/models/ms-agent/english_g2/resolve/master/english_g2.zip' - ) - recognition_models['gen2']['latin_g2']['url'] = ( - 'https://modelscope.cn/models/ms-agent/latin_g2/resolve/master/latin_g2.zip' - ) - recognition_models['gen2']['zh_sim_g2']['url'] = ( - 'https://modelscope.cn/models/ms-agent/zh_sim_g2/resolve/master/zh_sim_g2.zip' - ) - recognition_models['gen2']['japanese_g2']['url'] = ( - 'https://modelscope.cn/models/ms-agent/japanese_g2/resolve/master/japanese_g2.zip' - ) - recognition_models['gen2']['korean_g2']['url'] = ( - 'https://modelscope.cn/models/ms-agent/korean_g2/resolve/master/korean_g2.zip' - ) - recognition_models['gen2']['telugu_g2']['url'] = ( - 'https://modelscope.cn/models/ms-agent/telugu_g2/resolve/master/telugu_g2.zip' - ) - recognition_models['gen2']['kannada_g2']['url'] = ( - 'https://modelscope.cn/models/ms-agent/kannada_g2/resolve/master/kannada_g2.zip' - ) - recognition_models['gen2']['cyrillic_g2']['url'] = ( - 'https://modelscope.cn/models/ms-agent/cyrillic_g2/resolve/master/cyrillic_g2.zip' - ) + recognition_models['gen2']['english_g2'][ + 'url'] = 'https://modelscope.cn/models/ms-agent/english_g2/resolve/master/english_g2.zip' + recognition_models['gen2']['latin_g2'][ + 'url'] = 'https://modelscope.cn/models/ms-agent/latin_g2/resolve/master/latin_g2.zip' + recognition_models['gen2']['zh_sim_g2'][ + 'url'] = 'https://modelscope.cn/models/ms-agent/zh_sim_g2/resolve/master/zh_sim_g2.zip' + recognition_models['gen2']['japanese_g2'][ + 'url'] = 'https://modelscope.cn/models/ms-agent/japanese_g2/resolve/master/japanese_g2.zip' + recognition_models['gen2']['korean_g2'][ + 'url'] = 'https://modelscope.cn/models/ms-agent/korean_g2/resolve/master/korean_g2.zip' + recognition_models['gen2']['telugu_g2'][ + 'url'] = 'https://modelscope.cn/models/ms-agent/telugu_g2/resolve/master/telugu_g2.zip' + recognition_models['gen2']['kannada_g2'][ + 'url'] = 'https://modelscope.cn/models/ms-agent/kannada_g2/resolve/master/kannada_g2.zip' + recognition_models['gen2']['cyrillic_g2'][ + 'url'] = 'https://modelscope.cn/models/ms-agent/cyrillic_g2/resolve/master/cyrillic_g2.zip' def requests_get_with_timeout( diff --git a/ms_agent/tools/fetch_playwright_fallback.py b/ms_agent/tools/fetch_playwright_fallback.py index 62b14b231..0ff89a91d 100644 --- a/ms_agent/tools/fetch_playwright_fallback.py +++ b/ms_agent/tools/fetch_playwright_fallback.py @@ -10,7 +10,6 @@ We keep **one browser per thread** (e.g. each ``ThreadPoolExecutor`` worker) and reuse it across URLs instead of launching Chromium for every fetch. """ - from __future__ import annotations import atexit @@ -37,9 +36,9 @@ def _chromium_launch_args() -> List[str]: '--blink-settings=imagesEnabled=false', ] if os.getenv('MS_AGENT_PLAYWRIGHT_NO_SANDBOX', '').lower() in ( - '1', - 'true', - 'yes', + '1', + 'true', + 'yes', ): args.extend(('--no-sandbox', '--disable-setuid-sandbox')) return args @@ -101,8 +100,7 @@ def _thread_browser() -> object: except ImportError: logger.debug( 'playwright is not installed; skip headless fetch. ' - 'Install with: pip install playwright && playwright install chromium' - ) + 'Install with: pip install playwright && playwright install chromium') raise RuntimeError('playwright not installed') from None pw = sync_playwright().start() @@ -137,8 +135,7 @@ def try_playwright_inner_text( except ImportError: logger.debug( 'playwright is not installed; skip headless fetch. ' - 'Install with: pip install playwright && playwright install chromium' - ) + 'Install with: pip install playwright && playwright install chromium') return '' text = '' @@ -150,11 +147,13 @@ def try_playwright_inner_text( page.goto(url, wait_until='domcontentloaded', timeout=timeout_ms) if settle_ms: page.wait_for_timeout(settle_ms) - raw = page.evaluate("""() => { + raw = page.evaluate( + """() => { const b = document.body; if (!b) return ''; return b.innerText || ''; - }""") + }""" + ) if isinstance(raw, str): text = raw[:_MAX_INNER_TEXT_CHARS] finally: @@ -180,7 +179,10 @@ def looks_like_spa_shell_html(raw_html: str) -> bool: if not raw_html or len(raw_html) < 80: return False low = raw_html.lower() - if any(x in low for x in ('enable javascript', 'javascript is required', 'you need to enable javascript')): + if any( + x in low + for x in ('enable javascript', 'javascript is required', + 'you need to enable javascript')): return True if re.search(r']+\bid=["\']root["\'][^>]*>\s*
', low): return True diff --git a/ms_agent/tools/filesystem_tool.py b/ms_agent/tools/filesystem_tool.py index 987973adc..caa58c58b 100644 --- a/ms_agent/tools/filesystem_tool.py +++ b/ms_agent/tools/filesystem_tool.py @@ -23,46 +23,11 @@ _FS_TOOL_ALIASES = {'read': 'read_file', 'edit': 'edit_file', 'write': 'write_file'} _TEXT_SUFFIXES = { - '.py', - '.md', - '.txt', - '.yaml', - '.yml', - '.json', - '.toml', - '.cfg', - '.ini', - '.sh', - '.bash', - '.js', - '.ts', - '.tsx', - '.jsx', - '.css', - '.html', - '.xml', - '.rs', - '.go', - '.java', - '.c', - '.h', - '.cpp', - '.hpp', - '.cs', - '.rb', - '.php', - '.sql', - '.vue', - '.svelte', - '.m', - '.swift', - '.kt', - '.gradle', - '.properties', - '.env', - '.gitignore', - '.dockerignore', - 'Dockerfile', + '.py', '.md', '.txt', '.yaml', '.yml', '.json', '.toml', '.cfg', '.ini', + '.sh', '.bash', '.js', '.ts', '.tsx', '.jsx', '.css', '.html', '.xml', + '.rs', '.go', '.java', '.c', '.h', '.cpp', '.hpp', '.cs', '.rb', '.php', + '.sql', '.vue', '.svelte', '.m', '.swift', '.kt', '.gradle', '.properties', + '.env', '.gitignore', '.dockerignore', 'Dockerfile', } @@ -73,10 +38,8 @@ class FileSystemTool(ToolBase): IMAGE_EXTENSIONS = frozenset({'png', 'jpg', 'jpeg', 'gif', 'webp'}) # Curly quote → straight quote mapping for fuzzy matching CURLY_QUOTE_MAP = { - '\u2018': "'", - '\u2019': "'", # ' ' - '\u201c': '"', - '\u201d': '"', # " " + '\u2018': "'", '\u2019': "'", # ' ' + '\u201c': '"', '\u201d': '"', # " " } SYSTEM_FOR_ABBREVIATIONS = """你是一个帮我简化文件信息并返回缩略的机器人,你需要根据输入文件内容来生成压缩过的文件内容。 @@ -97,12 +60,18 @@ def __init__(self, config, **kwargs): super().__init__(config) self.exclude_func(getattr(config.tools, 'file_system', None)) if self.include_functions: - self.include_functions = [_FS_TOOL_ALIASES.get(n, n) for n in self.include_functions] + self.include_functions = [ + _FS_TOOL_ALIASES.get(n, n) for n in self.include_functions + ] if self.exclude_functions: - self.exclude_functions = [_FS_TOOL_ALIASES.get(n, n) for n in self.exclude_functions] + self.exclude_functions = [ + _FS_TOOL_ALIASES.get(n, n) for n in self.exclude_functions + ] self.output_dir = getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR) self.trust_remote_code = kwargs.get('trust_remote_code', False) - self.allow_read_all_files = getattr(getattr(config.tools, 'file_system', {}), 'allow_read_all_files', False) + self.allow_read_all_files = getattr( + getattr(config.tools, 'file_system', {}), 'allow_read_all_files', + False) if not self.trust_remote_code: self.allow_read_all_files = False if hasattr(self.config, 'llm'): @@ -117,17 +86,22 @@ def __init__(self, config, **kwargs): fs_cfg = getattr(config.tools, 'file_system', None) self._grep_timeout = int(getattr(fs_cfg, 'grep_timeout_s', 120) or 120) - self._default_grep_head = int(getattr(fs_cfg, 'grep_head_limit', 250) or 250) + self._default_grep_head = int( + getattr(fs_cfg, 'grep_head_limit', 250) or 250) self._glob_max_files = int(getattr(fs_cfg, 'glob_max_files', 100) or 100) wp = getattr(getattr(config, 'tools', None), 'workspace_policy', None) extra = list(getattr(wp, 'allow_roots', []) or []) if wp else [] deny = list(getattr(wp, 'deny_globs', []) or []) if wp else [] - shell_cfg = getattr(getattr(config.tools, 'code_executor', None), 'shell', None) - shell_mode = getattr(shell_cfg, 'default_mode', 'workspace_write') if shell_cfg else 'workspace_write' - net = bool(getattr(shell_cfg, 'network_enabled', False)) if shell_cfg else False - max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192)) if shell_cfg else 8192 + shell_cfg = getattr( + getattr(config.tools, 'code_executor', None), 'shell', None) + shell_mode = getattr(shell_cfg, 'default_mode', + 'workspace_write') if shell_cfg else 'workspace_write' + net = bool(getattr(shell_cfg, 'network_enabled', False) + ) if shell_cfg else False + max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192) + ) if shell_cfg else 8192 _out_p = Path(self.output_dir).expanduser().resolve() try: @@ -145,13 +119,13 @@ def __init__(self, config, **kwargs): max_kb = 256 if shell_cfg and getattr(shell_cfg, 'max_output_kb', None): max_kb = int(shell_cfg.max_output_kb) - self._fs_artifacts = ArtifactManager(_out_p, max_combined_bytes=max_kb * 1024) + self._fs_artifacts = ArtifactManager( + _out_p, max_combined_bytes=max_kb * 1024) async def connect(self): logger.warning_once( '[IMPORTANT]FileSystemTool is not implemented with sandbox, please consider other similar ' - 'tools if you want to run dangerous code.' - ) + 'tools if you want to run dangerous code.') async def _get_tools_inner(self): tools = { @@ -181,9 +155,8 @@ async def _get_tools_inner(self): }, }, 'required': ['path', 'content'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='read_file', server_name='file_system', @@ -206,33 +179,37 @@ async def _get_tools_inner(self): 'paths': { 'type': 'array', 'items': {'type': 'string'}, - 'description': 'List of relative file path(s) to read. ' + 'description': + 'List of relative file path(s) to read. ' 'Use this OR `path` (single file).', }, 'path': { 'type': 'string', - 'description': 'Single relative file path to read (alias for `paths` of length 1).', + 'description': + 'Single relative file path to read (alias for `paths` of length 1).', }, 'offset': { 'type': 'integer', - 'description': 'Line number to start reading from (1-based). ' + 'description': + 'Line number to start reading from (1-based). ' 'Only provide if the file is too large to read at once.', }, 'limit': { 'type': 'integer', - 'description': 'Number of lines to read. ' + 'description': + 'Number of lines to read. ' 'Only provide if the file is too large to read at once.', }, 'abbreviate': { 'type': 'boolean', - 'description': 'If true, return an LLM-generated summary instead of raw content. ' + 'description': + 'If true, return an LLM-generated summary instead of raw content. ' 'Useful for large files or quick structural overview.', }, }, 'required': [], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='edit_file', server_name='file_system', @@ -268,13 +245,13 @@ async def _get_tools_inner(self): }, 'replace_all': { 'type': 'boolean', - 'description': 'If true, replace all occurrences. Default is false (replace only the first).', + 'description': + 'If true, replace all occurrences. Default is false (replace only the first).', }, }, 'required': ['path', 'old_string', 'new_string'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='grep', server_name='file_system', @@ -292,7 +269,8 @@ async def _get_tools_inner(self): }, 'path': { 'type': 'string', - 'description': 'Directory or file to search (relative to output_dir if not absolute). Default ".".', + 'description': + 'Directory or file to search (relative to output_dir if not absolute). Default ".".', }, 'glob': { 'type': 'string', @@ -301,7 +279,8 @@ async def _get_tools_inner(self): 'output_mode': { 'type': 'string', 'enum': ['content', 'files_with_matches', 'count'], - 'description': 'content: matching lines; files_with_matches: paths only; count: per-file counts', + 'description': + 'content: matching lines; files_with_matches: paths only; count: per-file counts', }, 'head_limit': { 'type': 'integer', @@ -336,18 +315,21 @@ async def _get_tools_inner(self): }, 'path': { 'type': 'string', - 'description': 'Base directory (relative to output_dir if not absolute).', + 'description': + 'Base directory (relative to output_dir if not absolute).', }, }, 'required': ['pattern'], 'additionalProperties': False, }, ), + ] } return tools - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: return await getattr(self, tool_name)(**tool_args) async def grep( @@ -361,7 +343,8 @@ async def grep( case_insensitive: bool = False, ) -> str: call_id = f'grep-{pattern[:40]}' - head_limit = head_limit if head_limit is not None else self._default_grep_head + head_limit = (head_limit if head_limit is not None else + self._default_grep_head) offset = offset or 0 path = path or '.' try: @@ -395,13 +378,13 @@ async def grep( try: rg = shutil.which('rg') if rg and root.is_file(): - lines = await self._grep_rg_file( - rg, pattern, root, case_insensitive, output_mode, head_limit, offset, glob - ) + lines = await self._grep_rg_file(rg, pattern, root, + case_insensitive, output_mode, + head_limit, offset, glob) elif rg and root.is_dir(): - lines = await self._grep_rg_dir( - rg, pattern, root, case_insensitive, output_mode, head_limit, offset, glob - ) + lines = await self._grep_rg_dir(rg, pattern, root, + case_insensitive, output_mode, + head_limit, offset, glob) else: lines = self._grep_python( pattern, @@ -415,7 +398,12 @@ async def grep( except Exception as e: err = str(e) # Expected user/tooling failures (bad regex, rg rules) — log without traceback noise. - _quiet = 'rg:' in err or 'exited' in err.lower() or 'regex' in err.lower() or 'pattern' in err.lower() + _quiet = ( + 'rg:' in err + or 'exited' in err.lower() + or 'regex' in err.lower() + or 'pattern' in err.lower() + ) logger.warning('grep failed: %s', e, exc_info=not _quiet) return json.dumps({'success': False, 'error': str(e)}, indent=2) @@ -461,7 +449,8 @@ async def _grep_rg_file( stderr=asyncio.subprocess.PIPE, cwd=str(self._fs_policy.workspace_root), ) - out_b, err_b = await asyncio.wait_for(proc.communicate(), timeout=self._grep_timeout) + out_b, err_b = await asyncio.wait_for(proc.communicate(), + timeout=self._grep_timeout) out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') if proc.returncode not in (0, 1): @@ -497,7 +486,8 @@ async def _grep_rg_dir( stderr=asyncio.subprocess.PIPE, cwd=str(self._fs_policy.workspace_root), ) - out_b, err_b = await asyncio.wait_for(proc.communicate(), timeout=self._grep_timeout) + out_b, err_b = await asyncio.wait_for(proc.communicate(), + timeout=self._grep_timeout) out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') if proc.returncode not in (0, 1): @@ -526,7 +516,8 @@ def _grep_python( def consider_file(fp: Path) -> bool: if glob_pat: rel = str(fp.relative_to(root)) if root.is_dir() else fp.name - if not fnmatch.fnmatch(fp.name, glob_pat) and not fnmatch.fnmatch(rel, glob_pat): + if not fnmatch.fnmatch(fp.name, glob_pat) and not fnmatch.fnmatch( + rel, glob_pat): return False suf = fp.suffix.lower() if suf not in _TEXT_SUFFIXES and fp.suffix == '': @@ -538,7 +529,8 @@ def consider_file(fp: Path) -> bool: if root.is_file(): files = [root] else: - for fp in _walk_files_limited(root, self._fs_policy.deny_globs, 50_000): + for fp in _walk_files_limited(root, self._fs_policy.deny_globs, + 50_000): if consider_file(fp): files.append(fp) @@ -547,11 +539,9 @@ def consider_file(fp: Path) -> bool: text = fp.read_text(encoding='utf-8', errors='replace') except OSError: continue - rel = ( - str(fp.relative_to(self._fs_policy.workspace_root)) - if _is_relative(fp, self._fs_policy.workspace_root) - else str(fp) - ) + rel = str(fp.relative_to(self._fs_policy.workspace_root) + ) if _is_relative(fp, self._fs_policy.workspace_root) else str( + fp) if output_mode == 'files_with_matches': if rx.search(text): lines_out.append(rel) @@ -599,11 +589,9 @@ async def glob(self, pattern: str, path: str = '') -> str: continue if _is_denied_path(rp, base, deny): continue - rel = ( - str(p.relative_to(self._fs_policy.workspace_root)) - if _is_relative(p, self._fs_policy.workspace_root) - else str(p) - ) + rel = str(p.relative_to(self._fs_policy.workspace_root) + ) if _is_relative(p, self._fs_policy.workspace_root + ) else str(p) matches.append(rel) if len(matches) >= self._glob_max_files: truncated = True @@ -709,17 +697,21 @@ def get_real_path(self, path): # Check if path is absolute or already starts with output_dir if os.path.isabs(path): target_path = path - elif path.startswith(self.output_dir + os.sep) or path.startswith(self.output_dir): + elif path.startswith(self.output_dir + os.sep) or path.startswith( + self.output_dir): # Path already includes output_dir as prefix target_path = path else: target_path = os.path.join(self.output_dir, path) target_path_real = os.path.realpath(target_path) output_dir_real = os.path.realpath(self.output_dir) - is_in_output_dir = target_path_real.startswith(output_dir_real + os.sep) or target_path_real == output_dir_real + is_in_output_dir = target_path_real.startswith( + output_dir_real + os.sep) or target_path_real == output_dir_real if not is_in_output_dir and not self.allow_read_all_files: - logger.warning(f'Attempt to read file outside output directory blocked: {path} -> {target_path_real}') + logger.warning( + f'Attempt to read file outside output directory blocked: {path} -> {target_path_real}' + ) return None else: return target_path_real @@ -731,19 +723,20 @@ def _normalize_read_paths(self, paths, path) -> List[str]: if isinstance(paths, str) and paths.strip(): out = [paths.strip()] elif isinstance(paths, list): - out = [p.strip() for p in paths if isinstance(p, str) and p.strip()] + out = [ + p.strip() for p in paths + if isinstance(p, str) and p.strip() + ] if not out and path is not None and isinstance(path, str) and path.strip(): out = [path.strip()] return out - async def read_file( - self, - paths: Optional[List[str]] = None, - path: Optional[str] = None, - offset: int = None, - limit: int = None, - abbreviate: bool = False, - ): + async def read_file(self, + paths: Optional[List[str]] = None, + path: Optional[str] = None, + offset: int = None, + limit: int = None, + abbreviate: bool = False): """Read the content of file(s). Args: @@ -772,7 +765,8 @@ async def read_file( return await self._read_files_abbreviated(paths) results = {} - use_line_range = len(paths) == 1 and (offset is not None or limit is not None) + use_line_range = len(paths) == 1 and (offset is not None + or limit is not None) for path in paths: try: @@ -780,8 +774,7 @@ async def read_file( if target_path_real is None: results[path] = ( f'Access denied: Reading file <{path}> outside output directory is not allowed. ' - f'Set allow_read_all_files=true in config to enable.' - ) + f'Set allow_read_all_files=true in config to enable.') continue ext = os.path.splitext(path)[1].lstrip('.').lower() @@ -803,14 +796,16 @@ async def read_file( if file_size > self.MAX_READ_BYTES and not use_line_range: results[path] = ( f'Error: File <{path}> is too large ({file_size} bytes). ' - f'Use offset and limit to read specific portions.' - ) + f'Use offset and limit to read specific portions.') continue # Dedup: return stub if file unchanged since last read mtime = os.path.getmtime(target_path_real) cached = self._read_cache.get(target_path_real) - if cached and cached['mtime'] == mtime and cached['offset'] == offset and cached['limit'] == limit: + if (cached + and cached['mtime'] == mtime + and cached['offset'] == offset + and cached['limit'] == limit): results[path] = { 'type': 'file_unchanged', 'message': 'File has not changed since last read.', @@ -824,8 +819,8 @@ async def read_file( content = raw_bytes.decode('utf-8') except UnicodeDecodeError: results[path] = ( - f'Error: File <{path}> appears to be binary. Only text and image files are supported.' - ) + f'Error: File <{path}> appears to be binary. ' + f'Only text and image files are supported.') continue # Normalize line endings @@ -840,13 +835,16 @@ async def read_file( if actual_start > total_lines: results[path] = f'Error: offset {offset} exceeds file length ({total_lines} lines)' continue - selected = lines[actual_start - 1 : actual_end] + selected = lines[actual_start - 1:actual_end] start_lineno = actual_start else: selected = lines start_lineno = 1 - results[path] = ''.join(f'{start_lineno + i}\t{line}' for i, line in enumerate(selected)) + results[path] = ''.join( + f'{start_lineno + i}\t{line}' + for i, line in enumerate(selected) + ) # Update dedup cache self._read_cache[target_path_real] = { @@ -903,9 +901,11 @@ def process_file(path): return json.dumps(results, indent=2, ensure_ascii=False) - async def edit_file( - self, path: str = None, old_string: str = None, new_string: str = None, replace_all: bool = False - ): + async def edit_file(self, + path: str = None, + old_string: str = None, + new_string: str = None, + replace_all: bool = False): """Edit a file by replacing an exact string with new content. Args: @@ -973,7 +973,7 @@ async def edit_file( norm_content = self._normalize_quotes(content) idx = norm_content.find(norm_old) if idx != -1: - actual_old = content[idx : idx + len(old_string)] + actual_old = content[idx:idx + len(old_string)] if actual_old is None: return ( @@ -1014,7 +1014,8 @@ async def edit_file( return f'Edit file <{path}> failed, error: ' + str(e) -def _apply_offset_limit(lines: List[str], offset: int, head_limit: int) -> List[str]: +def _apply_offset_limit(lines: List[str], offset: int, + head_limit: int) -> List[str]: if offset: lines = lines[offset:] if head_limit and head_limit > 0: @@ -1043,9 +1044,11 @@ def _is_denied_path(path: Path, root: Path, deny: tuple[str, ...]) -> bool: return False -def _walk_files_limited(root: Path, deny: tuple[str, ...], max_files: int) -> List[Path]: +def _walk_files_limited(root: Path, deny: tuple[str, ...], + max_files: int) -> List[Path]: out: List[Path] = [] - for dirpath, dirnames, filenames in os.walk(root, topdown=True, followlinks=False): + for dirpath, dirnames, filenames in os.walk( + root, topdown=True, followlinks=False): dp = Path(dirpath) pruned = [] for d in list(dirnames): diff --git a/ms_agent/tools/findata/__init__.py b/ms_agent/tools/findata/__init__.py index e7c91afd9..9ae760a1c 100644 --- a/ms_agent/tools/findata/__init__.py +++ b/ms_agent/tools/findata/__init__.py @@ -1,6 +1,7 @@ from .akshare_source import AKShareDataSource from .baostock_source import BaoStockDataSource -from .data_source_base import DataSourceError, FinancialDataSource, NoDataFoundError +from .data_source_base import (DataSourceError, FinancialDataSource, + NoDataFoundError) from .findata_fetcher import FinancialDataFetcher from .hybrid_source import HybridDataSource diff --git a/ms_agent/tools/findata/akshare_source.py b/ms_agent/tools/findata/akshare_source.py index 01f86bb4c..6e68aaf41 100644 --- a/ms_agent/tools/findata/akshare_source.py +++ b/ms_agent/tools/findata/akshare_source.py @@ -3,8 +3,9 @@ from typing import Any, Dict, List, Optional import pandas as pd - -from ms_agent.tools.findata.data_source_base import DataSourceError, FinancialDataSource, NoDataFoundError +from ms_agent.tools.findata.data_source_base import (DataSourceError, + FinancialDataSource, + NoDataFoundError) from ms_agent.utils import get_logger from ms_agent.utils.utils import install_package @@ -66,7 +67,8 @@ def _convert_date(self, date: str) -> str: """Convert date to AKShare format""" return date.replace('-', '') - def _standardize_columns(self, df: pd.DataFrame, code: str) -> pd.DataFrame: + def _standardize_columns(self, df: pd.DataFrame, + code: str) -> pd.DataFrame: """Standardize column names for compatibility with BaoStock format""" if df.empty: return df @@ -145,21 +147,31 @@ def get_historical_k_data( ed_date = self._convert_date(end_date) # Route by market heuristics - if code.startswith('sh.') or code.startswith('sz.') or code.startswith('bj'): + if code.startswith('sh.') or code.startswith( + 'sz.') or code.startswith('bj'): clean_code = self._convert_code(code, market='A') df = akshare.stock_zh_a_hist( - symbol=clean_code, period=period, start_date=st_date, end_date=ed_date, adjust=adjust - ) + symbol=clean_code, + period=period, + start_date=st_date, + end_date=ed_date, + adjust=adjust) elif code.startswith('hk'): clean_code = self._convert_code(code, market='HK') df = akshare.stock_hk_hist( - symbol=clean_code, period=period, start_date=st_date, end_date=ed_date, adjust=adjust - ) + symbol=clean_code, + period=period, + start_date=st_date, + end_date=ed_date, + adjust=adjust) else: clean_code = self._convert_code(code, market='US') df = akshare.stock_us_hist( - symbol=clean_code, period=period, start_date=st_date, end_date=ed_date, adjust=adjust - ) + symbol=clean_code, + period=period, + start_date=st_date, + end_date=ed_date, + adjust=adjust) if df.empty: raise NoDataFoundError(f'No K-data found for {code}') @@ -204,16 +216,14 @@ def _get_hk_basic_info(self, code: str) -> pd.DataFrame: df_base_info = akshare.stock_hk_spot_em() stock_info = df_base_info[df_base_info['代码'] == clean_code] if not stock_info.empty: - df_stock_info = pd.DataFrame( - { - 'code': [code], - 'code_name': [stock_info['名称'].iloc[0]], - 'listingDate': [''], # listing date might not be available - 'outDate': [''], - 'type': ['2'], # type of stock - 'status': ['1'], - } - ) + df_stock_info = pd.DataFrame({ + 'code': [code], + 'code_name': [stock_info['名称'].iloc[0]], + 'listingDate': [''], # listing date might not be available + 'outDate': [''], + 'type': ['2'], # type of stock + 'status': ['1'] + }) except Exception: logger.warning(f'Failed to fetch HK stock base info for {code}') @@ -239,11 +249,11 @@ def _get_hk_basic_info(self, code: str) -> pd.DataFrame: '联系电话': 'contact number', '核数师': 'auditor', '传真': 'fax', - '公司介绍': 'company description', - } - ) + '公司介绍': 'company description' + }) except Exception: - logger.warning(f'Failed to fetch HK stock business info for {code}') + logger.warning( + f'Failed to fetch HK stock business info for {code}') if df_stock_info.empty and df_business_info.empty: raise NoDataFoundError(f'No basic info found for {code}') @@ -259,23 +269,23 @@ def _get_us_basic_info(self, code: str) -> pd.DataFrame: stock_info = df[df['代码'] == symbol] if stock_info.empty: - raise NoDataFoundError(f'No US stock basic info found for {code}') - - result_df = pd.DataFrame( - { - 'code': [code], - 'code_name': [stock_info['名称'].iloc[0]], - 'listingDate': [''], - 'outDate': [''], - 'type': ['3'], - 'status': ['1'], - } - ) + raise NoDataFoundError( + f'No US stock basic info found for {code}') + + result_df = pd.DataFrame({ + 'code': [code], + 'code_name': [stock_info['名称'].iloc[0]], + 'listingDate': [''], + 'outDate': [''], + 'type': ['3'], + 'status': ['1'] + }) return result_df except Exception as e: - raise DataSourceError(f'Error fetching US stock basic info for {code}: {e}') + raise DataSourceError( + f'Error fetching US stock basic info for {code}: {e}') def _get_a_share_basic_info(self, code: str) -> pd.DataFrame: """Get A-share stock basic information""" @@ -287,24 +297,24 @@ def _get_a_share_basic_info(self, code: str) -> pd.DataFrame: if df_base_info.empty: raise NoDataFoundError(f'No basic info found for {code}') - result_df = pd.DataFrame( - { - 'code': [code], - 'code_name': [ - df_base_info.loc[df_base_info['item'] == '股票简称', 'value'].iloc[0] - if not df_base_info.loc[df_base_info['item'] == '股票简称', 'value'].empty - else '' - ], - 'listingDate': [ - df_base_info.loc[df_base_info['item'] == '上市时间', 'value'].iloc[0] - if not df_base_info.loc[df_base_info['item'] == '上市时间', 'value'].empty - else '' - ], - 'outDate': [''], - 'type': ['1'], - 'status': ['1'], - } - ) + result_df = pd.DataFrame({ + 'code': [code], + 'code_name': [ + df_base_info.loc[df_base_info['item'] == '股票简称', + 'value'].iloc[0] + if not df_base_info.loc[df_base_info['item'] == '股票简称', + 'value'].empty else '' + ], + 'listingDate': [ + df_base_info.loc[df_base_info['item'] == '上市时间', + 'value'].iloc[0] + if not df_base_info.loc[df_base_info['item'] == '上市时间', + 'value'].empty else '' + ], + 'outDate': [''], + 'type': ['1'], + 'status': ['1'] + }) df_business_info = akshare.stock_zyjs_ths(symbol=clean_code) if df_business_info.empty: @@ -316,35 +326,46 @@ def _get_a_share_basic_info(self, code: str) -> pd.DataFrame: '主营业务': 'main business', '产品类型': 'product type', '产品名称': 'product name', - '经营范围': 'business scope', - } - ) + '经营范围': 'business scope' + }) return pd.concat([result_df, df_business_info], axis=1) except Exception as e: - raise DataSourceError(f'Error fetching A-share basic info for {code}: {e}') + raise DataSourceError( + f'Error fetching A-share basic info for {code}: {e}') - def get_dividend_data(self, code: str, year: Optional[str] = None, year_type: str = 'report') -> pd.DataFrame: + def get_dividend_data(self, + code: str, + year: Optional[str] = None, + year_type: str = 'report') -> pd.DataFrame: """Dividend info is not provided via a unified endpoint across markets in AKShare.""" - raise DataSourceError('get_dividend_data is not supported by AKShareDataSource; use BaoStock or Hybrid') + raise DataSourceError( + 'get_dividend_data is not supported by AKShareDataSource; use BaoStock or Hybrid' + ) - def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> pd.DataFrame: + def get_adjust_factor_data(self, code: str, start_date: str, + end_date: str) -> pd.DataFrame: """Adjust factor via AKShare varies by function; not standardized here.""" - raise DataSourceError('get_adjust_factor_data is not supported by AKShareDataSource; use BaoStock or Hybrid') + raise DataSourceError( + 'get_adjust_factor_data is not supported by AKShareDataSource; use BaoStock or Hybrid' + ) - def get_financial_data(self, code: str, year: str, quarter: int, data_types: List[str]) -> dict: + def get_financial_data(self, code: str, year: str, quarter: int, + data_types: List[str]) -> dict: """ Get financial data for multiple categories in one call. """ - logger.info(f'Fetching financial data for {code} ({year}Q{quarter}) {data_types}') + logger.info( + f'Fetching financial data for {code} ({year}Q{quarter}) {data_types}' + ) if code.startswith(('hk.', 'us.')): logger.warning( 'For U.S. and Hong Kong stocks, only a single complete financial indicators table is ' - 'currently supported, covering all data types.' - ) - clean_code = self._convert_code(code, market='HK' if code.startswith('hk.') else 'US') + 'currently supported, covering all data types.') + clean_code = self._convert_code( + code, market='HK' if code.startswith('hk.') else 'US') elif code.startswith(('sh.', 'sz.', 'bj.')): clean_code = self._convert_code(code, market='A') else: @@ -373,7 +394,8 @@ def _select_row_by_report(df: pd.DataFrame) -> pd.DataFrame: # convert REPORT_DATE to date if 'REPORT_DATE' in d.columns: - d['_dt'] = pd.to_datetime(d['REPORT_DATE']).dt.date.astype('str') + d['_dt'] = pd.to_datetime( + d['REPORT_DATE']).dt.date.astype('str') hit = d[d['_dt'] == target_date] if not hit.empty: return hit.drop(columns=['_dt']) @@ -381,22 +403,29 @@ def _select_row_by_report(df: pd.DataFrame) -> pd.DataFrame: # match report period name (中报/一季报/三季报/年报) for col in ('REPORT_DATE_NAME', 'REPORT_TYPE'): if col in d.columns: - hit = d[d[col].astype(str).str.contains(str(year)) & d[col].astype(str).str.contains(target_qname)] + hit = d[d[col].astype(str).str.contains(str(year)) + & d[col].astype(str).str.contains(target_qname)] if not hit.empty: return hit # Fallback: Select the row closest to the target_date if 'REPORT_DATE' in d.columns: - d['_dt'] = pd.to_datetime(d['REPORT_DATE']).dt.date.astype('str') - d['_diff'] = (pd.to_datetime(d['_dt']) - pd.to_datetime(target_date)).abs() + d['_dt'] = pd.to_datetime( + d['REPORT_DATE']).dt.date.astype('str') + d['_diff'] = (pd.to_datetime(d['_dt']) + - pd.to_datetime(target_date)).abs() d = d.sort_values('_diff') return d.drop(columns=['_dt', '_diff']).head(1) return d.head(1) - META_KEEP = ['REPORT_DATE', 'REPORT_TYPE', 'REPORT_DATE_NAME', 'NOTICE_DATE', 'UPDATE_DATE'] + META_KEEP = [ + 'REPORT_DATE', 'REPORT_TYPE', 'REPORT_DATE_NAME', 'NOTICE_DATE', + 'UPDATE_DATE' + ] - def _filter_columns(row_df: pd.DataFrame, category: str) -> pd.DataFrame: + def _filter_columns(row_df: pd.DataFrame, + category: str) -> pd.DataFrame: if row_df.empty: return row_df cols = list(row_df.columns) @@ -431,10 +460,13 @@ def _filter_columns(row_df: pd.DataFrame, category: str) -> pd.DataFrame: ], } - keep = set(c for c in cols if any(re.match(p, c) for p in PATTERNS[category])) + keep = set( + c for c in cols if any( + re.match(p, c) for p in PATTERNS[category])) keep |= set(c for c in META_KEEP if c in cols) - out = row_df.loc[:, [c for c in row_df.columns if c in keep]].copy() + out = row_df.loc[:, + [c for c in row_df.columns if c in keep]].copy() # Dupont net profit margin fallback calculation: PARENTNETPROFIT / TOTALOPERATEREVE if category == 'dupont' and 'XSJLL' not in out.columns: @@ -445,7 +477,8 @@ def _filter_columns(row_df: pd.DataFrame, category: str) -> pd.DataFrame: den = float(row_df.iloc[0]['TOTALOPERATEREVE']) out['XSJLL_calc'] = (val / den) if den else pd.NA except Exception as e: - logger.warning(f'Failed to calculate XSJLL_calc for {code}: {e}') + logger.warning( + f'Failed to calculate XSJLL_calc for {code}: {e}') out['XSJLL_calc'] = pd.NA out.insert(0, 'code', code) @@ -455,11 +488,9 @@ def _filter_columns(row_df: pd.DataFrame, category: str) -> pd.DataFrame: ind_df = pd.DataFrame() if code.startswith(('hk.', 'us.')): try: - ind_df = ( - akshare.stock_financial_hk_analysis_indicator_em(symbol=clean_code) - if code.startswith('hk.') - else akshare.stock_financial_us_analysis_indicator_em(symbol=clean_code) - ) + ind_df = akshare.stock_financial_hk_analysis_indicator_em( + symbol=clean_code) if code.startswith('hk.') else \ + akshare.stock_financial_us_analysis_indicator_em(symbol=clean_code) ind_df = _select_row_by_report(ind_df) except Exception as e: logger.warning( @@ -468,34 +499,43 @@ def _filter_columns(row_df: pd.DataFrame, category: str) -> pd.DataFrame: result['financial_indicators'] = ind_df elif code.startswith(('sh.', 'sz.', 'bj.')): - needs_indicator = any(dt in ('profit', 'operation', 'growth', 'dupont') for dt in data_types) + needs_indicator = any( + dt in ('profit', 'operation', 'growth', 'dupont') + for dt in data_types) if needs_indicator: try: - ind_df = akshare.stock_financial_analysis_indicator(symbol=clean_code) + ind_df = akshare.stock_financial_analysis_indicator( + symbol=clean_code) ind_df = _select_row_by_report(ind_df) except Exception as e: - logger.warning(f'Failed to fetch financial_analysis_indicator: {e}') + logger.warning( + f'Failed to fetch financial_analysis_indicator: {e}') ind_df = pd.DataFrame() for data_type in data_types: try: result[data_type] = pd.DataFrame() - if data_type in ('profit', 'operation', 'growth', 'dupont'): + if data_type in ('profit', 'operation', 'growth', + 'dupont'): if ind_df.empty: - logger.warning(f'No indicator row for {code} {year}Q{quarter}') + logger.warning( + f'No indicator row for {code} {year}Q{quarter}' + ) continue result[data_type] = _filter_columns(ind_df, data_type) continue elif data_type == 'balance': - df = akshare.stock_balance_sheet_by_report_em(symbol=code.replace('.', '').upper()) + df = akshare.stock_balance_sheet_by_report_em( + symbol=code.replace('.', '').upper()) row = _select_row_by_report(df) if not row.empty: result[data_type] = row elif data_type == 'cash_flow': - df = akshare.stock_cash_flow_sheet_by_report_em(symbol=code.replace('.', '').upper()) + df = akshare.stock_cash_flow_sheet_by_report_em( + symbol=code.replace('.', '').upper()) row = _select_row_by_report(df) if not row.empty: result[data_type] = row @@ -509,23 +549,34 @@ def _filter_columns(row_df: pd.DataFrame, category: str) -> pd.DataFrame: continue if not result or all(df.empty for df in result.values()): - raise NoDataFoundError(f'No financial data found for {code} ({year}Q{quarter})') + raise NoDataFoundError( + f'No financial data found for {code} ({year}Q{quarter})') return result - def get_report( - self, code: str, start_date: str, end_date: str, report_type: str = 'performance_express' - ) -> pd.DataFrame: + def get_report(self, + code: str, + start_date: str, + end_date: str, + report_type: str = 'performance_express') -> pd.DataFrame: """Report data is not supported by AKShare.""" - raise DataSourceError('get_report is not supported by AKShareDataSource; use BaoStock or Hybrid') + raise DataSourceError( + 'get_report is not supported by AKShareDataSource; use BaoStock or Hybrid' + ) def get_stock_industry(self, code: str, date: str) -> pd.DataFrame: """Industry classification is not supported by AKShare.""" - raise DataSourceError('get_stock_industry is not supported by AKShareDataSource; use BaoStock or Hybrid') + raise DataSourceError( + 'get_stock_industry is not supported by AKShareDataSource; use BaoStock or Hybrid' + ) - def get_stock_list(self, date: str, data_type: str = 'all_a_share') -> pd.DataFrame: + def get_stock_list(self, + date: str, + data_type: str = 'all_a_share') -> pd.DataFrame: """Get stock list (A-shares only, index constituents not supported).""" - logger.info(f'Fetching stock list for {data_type}, only support a_share and latest data') + logger.info( + f'Fetching stock list for {data_type}, only support a_share and latest data' + ) try: if data_type == 'sse50': @@ -545,7 +596,9 @@ def get_stock_list(self, date: str, data_type: str = 'all_a_share') -> pd.DataFr except Exception as e: raise DataSourceError(f'Failed to fetch stock list: {e}') - def get_trade_dates(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: + def get_trade_dates(self, + start_date: Optional[str] = None, + end_date: Optional[str] = None) -> pd.DataFrame: """Get trading calendar""" logger.info(f'Fetching trade dates ({start_date} to {end_date})') @@ -571,15 +624,17 @@ def get_macro_data( start_date: str, end_date: str, data_types: Optional[List[str]] = None, - extra_kwargs: Optional[Dict[str, Any]] = None, - ) -> Dict[str, pd.DataFrame]: + extra_kwargs: Optional[Dict[str, + Any]] = None) -> Dict[str, pd.DataFrame]: """Macroeconomic data.""" if data_types is None: data_types = [] if extra_kwargs is None: extra_kwargs = {} - logger.info(f'Fetching macroeconomic data ({start_date} to {end_date}) {data_types}') + logger.info( + f'Fetching macroeconomic data ({start_date} to {end_date}) {data_types}' + ) if not data_types: raise ValueError('data_types cannot be empty') @@ -590,11 +645,14 @@ def get_macro_data( if data_type in ('deposit_rate', 'loan_rate'): result[data_type] = akshare.rate_interbank() elif data_type in ('required_reserve_ratio'): - raise DataSourceError('Required reserve ratio is not supported by AKShare') + raise DataSourceError( + 'Required reserve ratio is not supported by AKShare') elif data_type == 'money_supply_year': - result[data_type] = self._get_money_supply_data_year(start_date, end_date) + result[data_type] = self._get_money_supply_data_year( + start_date, end_date) elif data_type == 'money_supply_month': - result[data_type] = self._get_money_supply_data_month(start_date, end_date) + result[data_type] = self._get_money_supply_data_month( + start_date, end_date) else: raise ValueError(f'Invalid data type: {data_type}') @@ -604,32 +662,41 @@ def get_macro_data( continue if not result: - raise NoDataFoundError('No macro data found for the specified criteria') + raise NoDataFoundError( + 'No macro data found for the specified criteria') return result def _get_money_supply_data_month( - self, start_date: Optional[str] = None, end_date: Optional[str] = None - ) -> pd.DataFrame: + self, + start_date: Optional[str] = None, + end_date: Optional[str] = None) -> pd.DataFrame: try: df = akshare.macro_china_money_supply() # from 2008-01 to now - df['月份'] = pd.to_datetime(df['月份'].str.replace('月份', '').str.replace('年', '-')) + df['月份'] = pd.to_datetime(df['月份'].str.replace('月份', + '').str.replace( + '年', '-')) df['月份'] = df['月份'].dt.to_period('M') if start_date: - df = df[df['月份'] >= pd.to_datetime(start_date).strftime('%Y-%m')] + df = df[ + df['月份'] >= pd.to_datetime(start_date).strftime('%Y-%m')] if end_date: df = df[df['月份'] <= pd.to_datetime(end_date).strftime('%Y-%m')] return df.sort_values('月份').reset_index(drop=True) except Exception as e: - raise DataSourceError(f'Error fetching monthly money supply data: {e}') + raise DataSourceError( + f'Error fetching monthly money supply data: {e}') def _get_money_supply_data_year( - self, start_date: Optional[str] = None, end_date: Optional[str] = None - ) -> pd.DataFrame: + self, + start_date: Optional[str] = None, + end_date: Optional[str] = None) -> pd.DataFrame: month_df = self._get_money_supply_data_month() # Take the last issue of each year (usually December; if missing, take the last available entry of that year). month_df['年'] = month_df['月份'].dt.year - last_in_year = month_df.sort_values('月份').groupby('年', as_index=False).tail(1).reset_index(drop=True) + last_in_year = ( + month_df.sort_values('月份').groupby( + '年', as_index=False).tail(1).reset_index(drop=True)) cols = [ '货币和准货币(M2)-数量(亿元)', '货币和准货币(M2)-同比增长', @@ -638,7 +705,8 @@ def _get_money_supply_data_year( '流通中的现金(M0)-数量(亿元)', '流通中的现金(M0)-同比增长', ] - year_df = last_in_year[['年'] + [c for c in cols if c in last_in_year.columns]] + year_df = last_in_year[ + ['年'] + [c for c in cols if c in last_in_year.columns]] if start_date: year_df = year_df[year_df['年'] >= pd.to_datetime(start_date).year] diff --git a/ms_agent/tools/findata/baostock_source.py b/ms_agent/tools/findata/baostock_source.py index f14437604..dd05b7a1f 100644 --- a/ms_agent/tools/findata/baostock_source.py +++ b/ms_agent/tools/findata/baostock_source.py @@ -5,8 +5,9 @@ from typing import Any, Dict, List, Optional import pandas as pd - -from ms_agent.tools.findata.data_source_base import DataSourceError, FinancialDataSource, NoDataFoundError +from ms_agent.tools.findata.data_source_base import (DataSourceError, + FinancialDataSource, + NoDataFoundError) from ms_agent.utils import get_logger from ms_agent.utils.utils import install_package @@ -15,7 +16,6 @@ class BaoStockSessionManager: """Thread-safe BaoStock session manager with connection reuse""" - _instance = None _lock = threading.Lock() _session_lock = threading.Lock() @@ -57,7 +57,8 @@ def ensure_login(self): if not self._is_logged_in: lg = baostock.login() if lg.error_code != '0': - raise DataSourceError(f'BaoStock login failed: {lg.error_msg}') + raise DataSourceError( + f'BaoStock login failed: {lg.error_msg}') self._is_logged_in = True self._login_count = 1 logger.debug('BaoStock session established') @@ -65,7 +66,8 @@ def ensure_login(self): self._login_count += 1 # Someone reused the session within idle timeout; cancel scheduled logout self._cancel_logout() - logger.debug(f'BaoStock session reused (count: {self._login_count})') + logger.debug( + f'BaoStock session reused (count: {self._login_count})') def release(self): """Release session (logout only when no active users)""" @@ -99,24 +101,9 @@ class BaoStockDataSource(FinancialDataSource): """ DEFAULT_K_FIELDS = [ - 'date', - 'code', - 'open', - 'high', - 'low', - 'close', - 'preclose', - 'volume', - 'amount', - 'adjustflag', - 'turn', - 'tradestatus', - 'pctChg', - 'peTTM', - 'pbMRQ', - 'psTTM', - 'pcfNcfTTM', - 'isST', + 'date', 'code', 'open', 'high', 'low', 'close', 'preclose', 'volume', + 'amount', 'adjustflag', 'turn', 'tradestatus', 'pctChg', 'peTTM', + 'pbMRQ', 'psTTM', 'pcfNcfTTM', 'isST' ] def __init__(self): @@ -137,9 +124,11 @@ def __init__(self): def _query_to_dataframe(self, rs, data_type: str = 'data') -> pd.DataFrame: """Convert BaoStock query result to DataFrame""" if rs.error_code != '0': - if 'no record found' in rs.error_msg.lower() or rs.error_code == '10002': + if 'no record found' in rs.error_msg.lower( + ) or rs.error_code == '10002': raise NoDataFoundError(f'No {data_type} found: {rs.error_msg}') - raise DataSourceError(f'BaoStock API error: {rs.error_msg} (code: {rs.error_code})') + raise DataSourceError( + f'BaoStock API error: {rs.error_msg} (code: {rs.error_code})') data_list = [] while rs.next(): @@ -171,8 +160,7 @@ def get_historical_k_data( start_date=start_date, end_date=end_date, frequency=frequency, - adjustflag=adjust_flag, - ) + adjustflag=adjust_flag) return self._query_to_dataframe(rs, f'K-data for {code}') def get_stock_basic_info(self, code: str) -> pd.DataFrame: @@ -183,25 +171,40 @@ def get_stock_basic_info(self, code: str) -> pd.DataFrame: rs = baostock.query_stock_basic(code=code) return self._query_to_dataframe(rs, f'basic info for {code}') - def get_dividend_data(self, code: str, year: Optional[str] = None, year_type: str = 'report') -> pd.DataFrame: + def get_dividend_data(self, + code: str, + year: Optional[str] = None, + year_type: str = 'report') -> pd.DataFrame: """Get dividend data""" logger.info(f'Fetching dividend data for {code} ({year} {year_type})') with baostock_session(): - rs = baostock.query_dividend_data(code=code, year=year, yearType=year_type) - return self._query_to_dataframe(rs, f'dividend data for {code} ({year} {year_type})') + rs = baostock.query_dividend_data( + code=code, year=year, yearType=year_type) + return self._query_to_dataframe( + rs, f'dividend data for {code} ({year} {year_type})') - def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> pd.DataFrame: + def get_adjust_factor_data(self, code: str, start_date: str, + end_date: str) -> pd.DataFrame: """Get adjustment factor data""" - logger.info(f'Fetching adjustment factor data for {code} ({start_date} to {end_date})') + logger.info( + f'Fetching adjustment factor data for {code} ({start_date} to {end_date})' + ) with baostock_session(): - rs = baostock.query_adjust_factor(code=code, start_date=start_date, end_date=end_date) - return self._query_to_dataframe(rs, f'adjustment factor data for {code} ({start_date} to {end_date})') + rs = baostock.query_adjust_factor( + code=code, start_date=start_date, end_date=end_date) + return self._query_to_dataframe( + rs, + f'adjustment factor data for {code} ({start_date} to {end_date})' + ) - def get_financial_data(self, code: str, year: str, quarter: int, data_types: List[str]) -> Dict[str, pd.DataFrame]: + def get_financial_data(self, code: str, year: str, quarter: int, + data_types: List[str]) -> Dict[str, pd.DataFrame]: """Get financial data""" - logger.info(f'Fetching financial data for {code} ({year}Q{quarter}) {data_types}') + logger.info( + f'Fetching financial data for {code} ({year}Q{quarter}) {data_types}' + ) if not data_types: raise ValueError('data_types cannot be empty') @@ -224,49 +227,72 @@ def get_financial_data(self, code: str, year: str, quarter: int, data_types: Lis else: raise ValueError(f'Invalid data type: {data_type}') - df = self._query_financial_data(query_func, data_type, code, year, quarter) + df = self._query_financial_data(query_func, data_type, code, + year, quarter) result[data_type] = df if not result: - raise NoDataFoundError(f'No financial data found for {code} ({year}Q{quarter})') + raise NoDataFoundError( + f'No financial data found for {code} ({year}Q{quarter})') return result - def _query_financial_data(self, query_func, data_type: str, code: str, year: str, quarter: int) -> pd.DataFrame: + def _query_financial_data(self, query_func, data_type: str, code: str, + year: str, quarter: int) -> pd.DataFrame: """Query financial data using provided function (assumes session is already active)""" logger.info(f'Fetching {data_type} for {code} ({year}Q{quarter})') rs = query_func(code=code, year=year, quarter=quarter) return self._query_to_dataframe(rs, f'{data_type} for {code}') - def get_report(self, code: str, start_date: str, end_date: str, report_type: str = '') -> pd.DataFrame: + def get_report(self, + code: str, + start_date: str, + end_date: str, + report_type: str = '') -> pd.DataFrame: """Get report data""" - logger.info(f'Fetching report data for {code} ({start_date} to {end_date}) {report_type}') + logger.info( + f'Fetching report data for {code} ({start_date} to {end_date}) {report_type}' + ) if not report_type: raise ValueError('report_type cannot be empty') with baostock_session(): if report_type == 'performance_express': - rs = baostock.query_performance_express_report(code=code, start_date=start_date, end_date=end_date) + rs = baostock.query_performance_express_report( + code=code, start_date=start_date, end_date=end_date) elif report_type == 'performance_forecast': - rs = baostock.query_forecast_report(code=code, start_date=start_date, end_date=end_date) + rs = baostock.query_forecast_report( + code=code, start_date=start_date, end_date=end_date) else: raise ValueError(f'Invalid report type: {report_type}') - return self._query_to_dataframe(rs, f'report data for {code} ({start_date} to {end_date}) {report_type}') + return self._query_to_dataframe( + rs, + f'report data for {code} ({start_date} to {end_date}) {report_type}' + ) - def get_stock_industry(self, code: Optional[str] = None, date: Optional[str] = None) -> pd.DataFrame: + def get_stock_industry(self, + code: Optional[str] = None, + date: Optional[str] = None) -> pd.DataFrame: """Get stock industry""" - logger.info(f"Fetching stock industry for code={code or 'all'}, date={date or 'latest'}") + logger.info( + f"Fetching stock industry for code={code or 'all'}, date={date or 'latest'}" + ) with baostock_session(): rs = baostock.query_stock_industry(code=code, date=date) - return self._query_to_dataframe(rs, f'stock industry for {code or "all"} ({date or "latest"})') + return self._query_to_dataframe( + rs, f'stock industry for {code or "all"} ({date or "latest"})') - def get_stock_list(self, date: str, data_type: str = 'all_a_share') -> pd.DataFrame: + def get_stock_list(self, + date: str, + data_type: str = 'all_a_share') -> pd.DataFrame: """Get stock list or index constituents""" - logger.info(f'Fetching stock list for {date} {data_type}, only support a_share') + logger.info( + f'Fetching stock list for {date} {data_type}, only support a_share' + ) with baostock_session(): if data_type == 'sse50': @@ -280,17 +306,23 @@ def get_stock_list(self, date: str, data_type: str = 'all_a_share') -> pd.DataFr else: raise ValueError(f'Invalid data type: {data_type}') - df = self._query_to_dataframe(rs, f'stock list for {date} {data_type}') + df = self._query_to_dataframe( + rs, f'stock list for {date} {data_type}') logger.info(f'Stock list for {date} {data_type}: {df.head()}') return df - def get_trade_dates(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: + def get_trade_dates(self, + start_date: Optional[str] = None, + end_date: Optional[str] = None) -> pd.DataFrame: """Get trading calendar""" - logger.info(f"Fetching trade dates ({start_date or 'default'} to {end_date or 'default'})") + logger.info( + f"Fetching trade dates ({start_date or 'default'} to {end_date or 'default'})" + ) with baostock_session(): - rs = baostock.query_trade_dates(start_date=start_date, end_date=end_date) + rs = baostock.query_trade_dates( + start_date=start_date, end_date=end_date) return self._query_to_dataframe(rs, 'trade dates') def get_macro_data( @@ -298,15 +330,16 @@ def get_macro_data( start_date: str, end_date: str, data_types: Optional[List[str]] = None, - extra_kwargs: Optional[Dict[str, Any]] = None, - ) -> Dict[str, pd.DataFrame]: + extra_kwargs: Optional[Dict[str, + Any]] = None) -> Dict[str, pd.DataFrame]: """Fetch macroeconomic data""" if data_types is None: data_types = [] if extra_kwargs is None: extra_kwargs = {} - logger.info(f'Fetching macro data ({start_date} to {end_date}) {data_types}') + logger.info( + f'Fetching macro data ({start_date} to {end_date}) {data_types}') result = {} with baostock_session(): @@ -331,20 +364,25 @@ def get_macro_data( elif data_type == 'money_supply_month': query_func = baostock.query_money_supply_data_month - parsed_start_date = pd.to_datetime(start_date).strftime('%Y-%m') - parsed_end_date = pd.to_datetime(end_date).strftime('%Y-%m') + parsed_start_date = pd.to_datetime( + start_date).strftime('%Y-%m') + parsed_end_date = pd.to_datetime(end_date).strftime( + '%Y-%m') elif data_type == 'money_supply_year': query_func = baostock.query_money_supply_data_year - parsed_start_date = pd.to_datetime(start_date).strftime('%Y') - parsed_end_date = pd.to_datetime(end_date).strftime('%Y') + parsed_start_date = pd.to_datetime( + start_date).strftime('%Y') + parsed_end_date = pd.to_datetime(end_date).strftime( + '%Y') else: raise ValueError(f'Invalid data type: {data_type}') - df = self._query_macro_data( - query_func, data_type, parsed_start_date, parsed_end_date, **parsed_extra_kwargs - ) + df = self._query_macro_data(query_func, data_type, + parsed_start_date, + parsed_end_date, + **parsed_extra_kwargs) result[data_type] = df except Exception as e: @@ -353,16 +391,19 @@ def get_macro_data( continue if not result: - raise NoDataFoundError('No macro data found for the specified criteria') + raise NoDataFoundError( + 'No macro data found for the specified criteria') return result - def _query_macro_data(self, query_func, data_type: str, start_date: str, end_date: str, **kwargs) -> pd.DataFrame: + def _query_macro_data(self, query_func, data_type: str, start_date: str, + end_date: str, **kwargs) -> pd.DataFrame: """Query macro data using provided function (assumes session is already active)""" logger.info(f'Fetching {data_type} for {start_date} to {end_date}') try: rs = query_func(start_date=start_date, end_date=end_date, **kwargs) - return self._query_to_dataframe(rs, f'{data_type} for {start_date} to {end_date}') + return self._query_to_dataframe( + rs, f'{data_type} for {start_date} to {end_date}') except Exception as e: logger.warning(f'Failed to fetch {data_type} data: {e}') diff --git a/ms_agent/tools/findata/data_source_base.py b/ms_agent/tools/findata/data_source_base.py index ecd1cc64f..d287aebdb 100644 --- a/ms_agent/tools/findata/data_source_base.py +++ b/ms_agent/tools/findata/data_source_base.py @@ -7,13 +7,11 @@ class DataSourceError(Exception): """Base data source error class""" - pass class NoDataFoundError(DataSourceError): """Data not found exception""" - pass @@ -65,17 +63,22 @@ def get_stock_basic_info(self, code: str) -> pd.DataFrame: pass @abstractmethod - def get_dividend_data(self, code: str, year: str, year_type: str = 'report') -> pd.DataFrame: + def get_dividend_data(self, + code: str, + year: str, + year_type: str = 'report') -> pd.DataFrame: """Get dividend information""" pass @abstractmethod - def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> pd.DataFrame: + def get_adjust_factor_data(self, code: str, start_date: str, + end_date: str) -> pd.DataFrame: """Get adjustment factor data""" pass @abstractmethod - def get_financial_data(self, code: str, year: str, quarter: int, data_types: List[str]) -> Dict[str, pd.DataFrame]: + def get_financial_data(self, code: str, year: str, quarter: int, + data_types: List[str]) -> Dict[str, pd.DataFrame]: """Get financial data for multiple categories in one call Returns: @@ -84,9 +87,11 @@ def get_financial_data(self, code: str, year: str, quarter: int, data_types: Lis pass @abstractmethod - def get_report( - self, code: str, start_date: str, end_date: str, report_type: str = 'performance_express' - ) -> pd.DataFrame: + def get_report(self, + code: str, + start_date: str, + end_date: str, + report_type: str = 'performance_express') -> pd.DataFrame: """Get report data (performance express/forecast)""" pass @@ -111,8 +116,8 @@ def get_macro_data( start_date: str, end_date: str, data_types: Optional[List[str]] = None, - extra_kwargs: Optional[Dict[str, Any]] = None, - ) -> Dict[str, pd.DataFrame]: + extra_kwargs: Optional[Dict[str, + Any]] = None) -> Dict[str, pd.DataFrame]: """Get macroeconomic data for multiple categories in one call""" pass diff --git a/ms_agent/tools/findata/findata_fetcher.py b/ms_agent/tools/findata/findata_fetcher.py index d18c4cd23..996f28781 100644 --- a/ms_agent/tools/findata/findata_fetcher.py +++ b/ms_agent/tools/findata/findata_fetcher.py @@ -1,24 +1,25 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio -import json from concurrent.futures import ThreadPoolExecutor from datetime import date, datetime from functools import partial from pathlib import Path from typing import Any, Dict, Optional, Union +import json import numpy as np import pandas as pd -from omegaconf import DictConfig - from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.tools.findata.akshare_source import AKShareDataSource from ms_agent.tools.findata.baostock_source import BaoStockDataSource -from ms_agent.tools.findata.data_source_base import DataSourceError, FinancialDataSource, NoDataFoundError +from ms_agent.tools.findata.data_source_base import (DataSourceError, + FinancialDataSource, + NoDataFoundError) from ms_agent.tools.findata.hybrid_source import HybridDataSource from ms_agent.utils import get_logger from ms_agent.utils.rate_limiter import AdaptiveRateLimiter, RateLimiter +from omegaconf import DictConfig logger = get_logger() @@ -54,7 +55,8 @@ class FinancialDataFetcher(ToolBase): def __init__(self, config: Optional[DictConfig] = None): super().__init__(config) - tools_cfg = getattr(config, 'tools', None) if config is not None else None + tools_cfg = getattr(config, 'tools', + None) if config is not None else None self.exclude_func(getattr(tools_cfg, 'financial_data_fetcher', None)) self.save_dir = getattr(config, 'output_dir', './output') @@ -76,21 +78,25 @@ def __init__(self, config: Optional[DictConfig] = None): thread_name_prefix='financial_data_fetcher_', ) - logger.info(f'Initializing FinancialDataFetcher with source: {self.source_type}') - logger.info(f'Financial data will be saved to: {self.financial_data_dir}') + logger.info( + f'Initializing FinancialDataFetcher with source: {self.source_type}' + ) + logger.info( + f'Financial data will be saved to: {self.financial_data_dir}') def _get_source_type(self, config: Optional[DictConfig]) -> str: """Get data source type from config""" - if ( - isinstance(config, DictConfig) - and hasattr(config, 'tools') - and hasattr(config.tools, 'financial_data_fetcher') - ): - return getattr(config.tools.financial_data_fetcher, 'source_type', 'hybrid') + if isinstance(config, + DictConfig) and hasattr(config, 'tools') and hasattr( + config.tools, 'financial_data_fetcher'): + return getattr(config.tools.financial_data_fetcher, 'source_type', + 'hybrid') return 'hybrid' - def _create_rate_limiter(self, config: Optional[DictConfig]) -> Optional[Union[RateLimiter, AdaptiveRateLimiter]]: + def _create_rate_limiter( + self, config: Optional[DictConfig] + ) -> Optional[Union[RateLimiter, AdaptiveRateLimiter]]: """ Create rate limiter from config. @@ -116,17 +122,16 @@ def _create_rate_limiter(self, config: Optional[DictConfig]) -> Optional[Union[R ``` """ # Check if rate limiter is configured - if not ( - isinstance(config, DictConfig) - and hasattr(config, 'tools') - and hasattr(config.tools, 'financial_data_fetcher') - ): - logger.info('No rate limiter configured, running without rate limiting') + if not (isinstance(config, DictConfig) and hasattr(config, 'tools') + and hasattr(config.tools, 'financial_data_fetcher')): + logger.info( + 'No rate limiter configured, running without rate limiting') return None fetcher_config = config.tools.financial_data_fetcher if not hasattr(fetcher_config, 'rate_limiter'): - logger.info('No rate limiter configured, running without rate limiting') + logger.info( + 'No rate limiter configured, running without rate limiting') return None rl_config = fetcher_config.rate_limiter @@ -141,15 +146,24 @@ def _create_rate_limiter(self, config: Optional[DictConfig]) -> Optional[Union[R if limiter_type == 'adaptive': # Create AdaptiveRateLimiter params = { - 'initial_requests_per_second': getattr(rl_config, 'initial_requests_per_second', 2), - 'min_requests_per_second': getattr(rl_config, 'min_requests_per_second', 1), - 'max_requests_per_second': getattr(rl_config, 'max_requests_per_second', 5), - 'min_request_interval': getattr(rl_config, 'min_request_interval', 0.5), - 'max_concurrent': getattr(rl_config, 'max_concurrent', 3), - 'backoff_factor': getattr(rl_config, 'backoff_factor', 0.5), - 'recovery_factor': getattr(rl_config, 'recovery_factor', 1.2), - 'error_threshold': getattr(rl_config, 'error_threshold', 3), - 'success_threshold': getattr(rl_config, 'success_threshold', 10), + 'initial_requests_per_second': + getattr(rl_config, 'initial_requests_per_second', 2), + 'min_requests_per_second': + getattr(rl_config, 'min_requests_per_second', 1), + 'max_requests_per_second': + getattr(rl_config, 'max_requests_per_second', 5), + 'min_request_interval': + getattr(rl_config, 'min_request_interval', 0.5), + 'max_concurrent': + getattr(rl_config, 'max_concurrent', 3), + 'backoff_factor': + getattr(rl_config, 'backoff_factor', 0.5), + 'recovery_factor': + getattr(rl_config, 'recovery_factor', 1.2), + 'error_threshold': + getattr(rl_config, 'error_threshold', 3), + 'success_threshold': + getattr(rl_config, 'success_threshold', 10), } logger.info(f'Creating AdaptiveRateLimiter with params: {params}') return AdaptiveRateLimiter(**params) @@ -157,15 +171,20 @@ def _create_rate_limiter(self, config: Optional[DictConfig]) -> Optional[Union[R elif limiter_type == 'basic': # Create basic RateLimiter params = { - 'max_requests_per_second': getattr(rl_config, 'max_requests_per_second', 2), - 'min_request_interval': getattr(rl_config, 'min_request_interval', 0.5), - 'max_concurrent': getattr(rl_config, 'max_concurrent', 3), + 'max_requests_per_second': + getattr(rl_config, 'max_requests_per_second', 2), + 'min_request_interval': + getattr(rl_config, 'min_request_interval', 0.5), + 'max_concurrent': + getattr(rl_config, 'max_concurrent', 3), } logger.info(f'Creating RateLimiter with params: {params}') return RateLimiter(**params) else: - logger.warning(f'Unknown rate limiter type: {limiter_type}, running without rate limiting') + logger.warning( + f'Unknown rate limiter type: {limiter_type}, running without rate limiting' + ) return None def _create_data_source(self) -> FinancialDataSource: @@ -178,7 +197,8 @@ def _create_data_source(self) -> FinancialDataSource: source_class = source_map.get(self.source_type.lower()) if not source_class: - logger.warning(f'Unknown source type: {self.source_type}, using hybrid') + logger.warning( + f'Unknown source type: {self.source_type}, using hybrid') source_class = HybridDataSource return source_class() @@ -214,7 +234,8 @@ async def _execute_with_rate_limit(self, func, *args, **kwargs): func_with_args = partial(func, *args, **kwargs) async with self.rate_limiter: - result = await loop.run_in_executor(self.thread_pool, func_with_args) + result = await loop.run_in_executor(self.thread_pool, + func_with_args) # Record success if using adaptive rate limiter if isinstance(self.rate_limiter, AdaptiveRateLimiter): @@ -225,9 +246,9 @@ async def _execute_with_rate_limit(self, func, *args, **kwargs): except Exception as e: if isinstance(self.rate_limiter, AdaptiveRateLimiter): error_msg = str(e).lower() - is_rate_limit_error = any( - keyword in error_msg for keyword in ['rate limit', 'too many requests', 'quota exceeded', '429'] - ) + is_rate_limit_error = any(keyword in error_msg for keyword in [ + 'rate limit', 'too many requests', 'quota exceeded', '429' + ]) self.rate_limiter.record_error(is_rate_limit_error) raise @@ -249,10 +270,14 @@ def _save_dataframe(self, df, filename: str) -> str: logger.info(f'Data saved to: {filepath}') return str(filepath) except Exception as e: - logger.error(f'Failed to save data to {filename}: {e}', exc_info=True) + logger.error( + f'Failed to save data to {filename}: {e}', exc_info=True) return '' - def _create_success_response(self, df, saved_path: str, metadata: Optional[Dict] = None) -> str: + def _create_success_response(self, + df, + saved_path: str, + metadata: Optional[Dict] = None) -> str: """ Create success response with sample data. @@ -277,8 +302,8 @@ def _create_success_response(self, df, saved_path: str, metadata: Optional[Dict] response['example_data'] = sample_df.to_dict(orient='records') if len(df) > self.sample_rows: response['note'] = ( - f'Showing {self.sample_rows} sample rows out of {len(df)} total rows. Full data saved to file.' - ) + f'Showing {self.sample_rows} sample rows out of {len(df)} ' + f'total rows. Full data saved to file.') else: response['example_data'] = [] response['note'] = 'No data returned' @@ -287,9 +312,11 @@ def _create_success_response(self, df, saved_path: str, metadata: Optional[Dict] if metadata: response.update(metadata) - return json.dumps(response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) + return json.dumps( + response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) - def _create_error_response(self, error: Exception, operation: str, params: Dict) -> str: + def _create_error_response(self, error: Exception, operation: str, + params: Dict) -> str: """ Create standardized error response. @@ -309,7 +336,7 @@ def _create_error_response(self, error: Exception, operation: str, params: Dict) 'operation': operation, 'error_type': error_type, 'error': error_msg, - 'parameters': params, + 'parameters': params } # Only log with traceback for unexpected errors @@ -317,9 +344,12 @@ def _create_error_response(self, error: Exception, operation: str, params: Dict) if isinstance(error, (DataSourceError, NoDataFoundError)): logger.warning(f'{operation}: {error_msg}') else: - logger.error(f"Operation '{operation}' failed: {error_type} - {error_msg}", exc_info=True) + logger.error( + f"Operation '{operation}' failed: {error_type} - {error_msg}", + exc_info=True) - return json.dumps(response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) + return json.dumps( + response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) async def _get_tools_inner(self) -> Dict[str, Any]: """Return tool definitions""" @@ -328,56 +358,66 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='get_historical_k_data', server_name='financial_data_fetcher', - description='Get historical K-line data (daily, weekly, monthly, etc.)', + description= + 'Get historical K-line data (daily, weekly, monthly, etc.)', parameters={ 'type': 'object', 'properties': { 'code': { + 'type': + 'string', + 'description': + ('Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' + 'hk.03690 (Hong Kong), us.AAPL (US)') + }, + 'start_date': { 'type': 'string', - 'description': ( - 'Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' - 'hk.03690 (Hong Kong), us.AAPL (US)' - ), + 'description': 'Start date, format: YYYY-MM-DD' }, - 'start_date': {'type': 'string', 'description': 'Start date, format: YYYY-MM-DD'}, - 'end_date': {'type': 'string', 'description': 'End date, format: YYYY-MM-DD'}, - 'frequency': { + 'end_date': { 'type': 'string', - 'description': 'Data frequency: d(daily), w(weekly), m(monthly), 5/15/30/60(minutes)', - 'default': 'd', + 'description': 'End date, format: YYYY-MM-DD' }, - 'adjust_flag': { + 'frequency': { 'type': 'string', - 'description': ( - 'Adjustment flag for historical data.' - 'Adjust type: 1(backward adjusted), 2(forward adjusted), 3(non-adjusted)' - ), - 'default': '3', + 'description': + 'Data frequency: d(daily), w(weekly), m(monthly), 5/15/30/60(minutes)', + 'default': 'd' }, + 'adjust_flag': { + 'type': + 'string', + 'description': + ('Adjustment flag for historical data.' + 'Adjust type: 1(backward adjusted), 2(forward adjusted), 3(non-adjusted)' + ), + 'default': + '3' + } }, - 'required': ['code', 'start_date', 'end_date', 'frequency'], - 'additionalProperties': False, - }, - ), + 'required': + ['code', 'start_date', 'end_date', 'frequency'], + 'additionalProperties': False + }), Tool( tool_name='get_stock_basic_info', server_name='financial_data_fetcher', - description='Get stock basic information (name, industry, listing date, etc.)', + description= + 'Get stock basic information (name, industry, listing date, etc.)', parameters={ 'type': 'object', 'properties': { 'code': { - 'type': 'string', - 'description': ( - 'Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' - 'hk.03690 (Hong Kong), us.AAPL (US)' - ), + 'type': + 'string', + 'description': + ('Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' + 'hk.03690 (Hong Kong), us.AAPL (US)') } }, 'required': ['code'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='get_dividend_data', server_name='financial_data_fetcher', @@ -386,30 +426,33 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'object', 'properties': { 'code': { - 'type': 'string', - 'description': ( - 'Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' - 'hk.03690 (Hong Kong), us.AAPL (US)' - ), + 'type': + 'string', + 'description': + ('Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' + 'hk.03690 (Hong Kong), us.AAPL (US)') }, 'year': { - 'type': 'string', - 'description': 'Year, e.g. 2023. If not provided, the current year will be used', + 'type': + 'string', + 'description': + 'Year, e.g. 2023. If not provided, the current year will be used' }, 'year_type': { - 'type': 'string', - 'description': ( - 'Year category, default is "report": Year of the preliminary ' - 'announcement, optional "operate": Year of ex-dividend and ex-rights' - ), - 'default': 'report', - 'enum': ['report', 'operate'], - }, + 'type': + 'string', + 'description': + ('Year category, default is "report": Year of the preliminary ' + 'announcement, optional "operate": Year of ex-dividend and ex-rights' + ), + 'default': + 'report', + 'enum': ['report', 'operate'] + } }, 'required': ['code'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='get_adjust_factor_data', server_name='financial_data_fetcher', @@ -421,124 +464,156 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'string', 'description': 'Stock code', }, - 'start_date': {'type': 'string', 'description': 'Start date, format: YYYY-MM-DD.'}, - 'end_date': {'type': 'string', 'description': 'End date, format: YYYY-MM-DD.'}, + 'start_date': { + 'type': 'string', + 'description': + 'Start date, format: YYYY-MM-DD.' + }, + 'end_date': { + 'type': 'string', + 'description': 'End date, format: YYYY-MM-DD.' + } }, 'required': ['code', 'start_date', 'end_date'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='get_financial_data', server_name='financial_data_fetcher', - description=( - 'Get quarterly financial data for a given stock.' - 'Supported data types: profit, operation, growth, balance, cash_flow, dupont.' - 'You can specify one or multiple data types to get the corresponding data.' - ), + description= + ('Get quarterly financial data for a given stock.' + 'Supported data types: profit, operation, growth, balance, cash_flow, dupont.' + 'You can specify one or multiple data types to get the corresponding data.' + ), parameters={ 'type': 'object', 'properties': { 'code': { + 'type': + 'string', + 'description': + ('Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' + 'hk.03690 (Hong Kong), us.AAPL (US)') + }, + 'year': { 'type': 'string', - 'description': ( - 'Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' - 'hk.03690 (Hong Kong), us.AAPL (US)' - ), + 'description': 'Year, e.g. 2023' }, - 'year': {'type': 'string', 'description': 'Year, e.g. 2023'}, 'quarter': { - 'type': 'integer', - 'description': ( - 'Quarter, 1-4, e.g. 1 for first quarter, 2 for second ' - 'quarter, 3 for third quarter, 4 for fourth quarter' - ), + 'type': + 'integer', + 'description': + ('Quarter, 1-4, e.g. 1 for first quarter, 2 for second ' + 'quarter, 3 for third quarter, 4 for fourth quarter' + ) }, 'data_types': { 'type': 'array', 'description': 'Data types to get.', 'items': { - 'type': 'string', - 'enum': ['profit', 'operation', 'growth', 'balance', 'cash_flow', 'dupont'], - }, - }, + 'type': + 'string', + 'enum': [ + 'profit', 'operation', 'growth', + 'balance', 'cash_flow', 'dupont' + ] + } + } }, 'required': ['code', 'year', 'quarter'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='get_report', server_name='financial_data_fetcher', - description=( - 'Get report data for a given stock. Support for performance express ' - 'reports and performance forecast reports' - ), + description= + ('Get report data for a given stock. Support for performance express ' + 'reports and performance forecast reports'), parameters={ - 'type': 'object', + 'type': + 'object', 'properties': { 'code': { + 'type': + 'string', + 'description': + ('Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' + 'hk.03690 (Hong Kong), us.AAPL (US)') + }, + 'start_date': { 'type': 'string', - 'description': ( - 'Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' - 'hk.03690 (Hong Kong), us.AAPL (US)' - ), + 'description': 'Start date, format: YYYY-MM-DD' }, - 'start_date': {'type': 'string', 'description': 'Start date, format: YYYY-MM-DD'}, - 'end_date': {'type': 'string', 'description': 'End date, format: YYYY-MM-DD'}, - 'report_type': { + 'end_date': { 'type': 'string', - 'description': 'Report type', - 'default': 'performance_express', - 'enum': ['performance_express', 'performance_forecast'], + 'description': 'End date, format: YYYY-MM-DD' }, + 'report_type': { + 'type': + 'string', + 'description': + 'Report type', + 'default': + 'performance_express', + 'enum': [ + 'performance_express', + 'performance_forecast' + ] + } }, - 'required': ['code', 'start_date', 'end_date', 'report_type'], - 'additionalProperties': False, - }, - ), + 'required': + ['code', 'start_date', 'end_date', 'report_type'], + 'additionalProperties': + False + }), Tool( tool_name='get_stock_industry', server_name='financial_data_fetcher', - description='Get industry classification for a given stock and date', + description= + 'Get industry classification for a given stock and date', parameters={ 'type': 'object', 'properties': { 'code': { - 'type': 'string', - 'description': ( - 'Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' - 'hk.03690 (Hong Kong), us.AAPL (US)' - ), + 'type': + 'string', + 'description': + ('Stock code, e.g. sh.600000 (Shanghai), sz.000001 (Shenzhen)' + 'hk.03690 (Hong Kong), us.AAPL (US)') }, - 'date': {'type': 'string', 'description': 'Query date, format: YYYY-MM-DD'}, + 'date': { + 'type': 'string', + 'description': 'Query date, format: YYYY-MM-DD' + } }, 'required': ['code', 'date'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='get_stock_list', server_name='financial_data_fetcher', - description=( - 'Get stock list for a given date, support for SSE 50 index constituents (sse50), ' - 'CSI 300 index constituents (hs300), CSI 500 index constituents (zz500) ' - 'and all a-share stocks (all_a_share)' - ), + description= + ('Get stock list for a given date, support for SSE 50 index constituents (sse50), ' + 'CSI 300 index constituents (hs300), CSI 500 index constituents (zz500) ' + 'and all a-share stocks (all_a_share)'), parameters={ 'type': 'object', 'properties': { - 'date': {'type': 'string', 'description': 'Query date, format: YYYY-MM-DD'}, - 'data_type': { + 'date': { 'type': 'string', - 'description': 'Data type to get. Default is "all_a_share"', - 'enum': ['sse50', 'hs300', 'zz500', 'all_a_share'], + 'description': 'Query date, format: YYYY-MM-DD' }, + 'data_type': { + 'type': 'string', + 'description': + 'Data type to get. Default is "all_a_share"', + 'enum': + ['sse50', 'hs300', 'zz500', 'all_a_share'] + } }, 'required': ['date'], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='get_trade_dates', server_name='financial_data_fetcher', @@ -546,86 +621,104 @@ async def _get_tools_inner(self) -> Dict[str, Any]: parameters={ 'type': 'object', 'properties': { - 'start_date': {'type': 'string', 'description': 'Start date, format: YYYY-MM-DD'}, - 'end_date': {'type': 'string', 'description': 'End date, format: YYYY-MM-DD'}, + 'start_date': { + 'type': 'string', + 'description': 'Start date, format: YYYY-MM-DD' + }, + 'end_date': { + 'type': 'string', + 'description': 'End date, format: YYYY-MM-DD' + } }, 'required': [], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), Tool( tool_name='get_macro_data', server_name='financial_data_fetcher', - description=( - 'Get macro data for a given range of dates' - 'Supported data types: deposit_rate, loan_rate, required_reserve_ratio, money_supply_month, ' - 'money_supply_year' - ), + description= + ('Get macro data for a given range of dates' + 'Supported data types: deposit_rate, loan_rate, required_reserve_ratio, money_supply_month, ' + 'money_supply_year'), parameters={ 'type': 'object', 'properties': { - 'start_date': {'type': 'string', 'description': 'Start date, format: YYYY-MM-DD'}, - 'end_date': {'type': 'string', 'description': 'End date, format: YYYY-MM-DD'}, + 'start_date': { + 'type': 'string', + 'description': 'Start date, format: YYYY-MM-DD' + }, + 'end_date': { + 'type': 'string', + 'description': 'End date, format: YYYY-MM-DD' + }, 'data_types': { 'type': 'array', - 'description': 'Data types to get. Default is all data types', + 'description': + 'Data types to get. Default is all data types', 'items': { - 'type': 'string', + 'type': + 'string', 'enum': [ - 'deposit_rate', - 'loan_rate', + 'deposit_rate', 'loan_rate', 'required_reserve_ratio', 'money_supply_month', - 'money_supply_year', - ], + 'money_supply_year' + ] }, }, 'extra_kwargs': { 'type': 'object', - 'description': 'Extra keyword arguments for the macro data', + 'description': + 'Extra keyword arguments for the macro data', 'properties': { 'yearType': { - 'type': 'string', - 'description': ( - 'Year Type, default value 0 means "announcement date," ' - 'and 1 means "effective date".' - ), - 'default': '0', + 'type': + 'string', + 'description': + ('Year Type, default value 0 means "announcement date," ' + 'and 1 means "effective date".'), + 'default': + '0' } }, 'required': [], # yearType is optional - 'additionalProperties': False, - }, + 'additionalProperties': False + } }, 'required': ['start_date', 'end_date', 'data_types'], - 'additionalProperties': False, + 'additionalProperties': False }, - ), + ) ] } # Update tools by source type - if self.data_source is not None and hasattr(self.data_source, 'get_extra_tools'): + if self.data_source is not None and hasattr(self.data_source, + 'get_extra_tools'): tools.update(self.data_source.get_extra_tools()) return tools - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: """Call tool method""" if self.data_source is None: await self.connect() return await getattr(self, tool_name)(**tool_args) - async def get_historical_k_data( - self, code: str, start_date: str, end_date: str, frequency: str = 'd', adjust_flag: str = '3' - ) -> str: + async def get_historical_k_data(self, + code: str, + start_date: str, + end_date: str, + frequency: str = 'd', + adjust_flag: str = '3') -> str: """Get historical K-line data""" params = { 'code': code, 'start_date': start_date, 'end_date': end_date, 'frequency': frequency, - 'adjust_flag': adjust_flag, + 'adjust_flag': adjust_flag } try: @@ -635,8 +728,7 @@ async def get_historical_k_data( start_date=start_date, end_date=end_date, frequency=frequency, - adjust_flag=adjust_flag, - ) + adjust_flag=adjust_flag) # Generate filename with key parameters clean_code = code.replace('.', '_') @@ -650,19 +742,21 @@ async def get_historical_k_data( 'code': code, 'date_range': f'{start_date} to {end_date}', 'frequency': frequency, - 'adjust_flag': adjust_flag, + 'adjust_flag': adjust_flag } return self._create_success_response(df, saved_path, metadata) except Exception as e: - return self._create_error_response(e, 'get_historical_k_data', params) + return self._create_error_response(e, 'get_historical_k_data', + params) async def get_stock_basic_info(self, code: str) -> str: """Get stock basic information""" params = {'code': code} try: - df = await self._execute_with_rate_limit(self.data_source.get_stock_basic_info, code=code) + df = await self._execute_with_rate_limit( + self.data_source.get_stock_basic_info, code=code) # Generate filename clean_code = code.replace('.', '_') @@ -676,16 +770,22 @@ async def get_stock_basic_info(self, code: str) -> str: return self._create_success_response(df, saved_path, metadata) except Exception as e: - return self._create_error_response(e, 'get_stock_basic_info', params) + return self._create_error_response(e, 'get_stock_basic_info', + params) - async def get_dividend_data(self, code: str, year: Optional[str] = None, year_type: str = 'report') -> str: + async def get_dividend_data(self, + code: str, + year: Optional[str] = None, + year_type: str = 'report') -> str: """Get dividend information (BaoStock).""" params = {'code': code, 'year': year, 'year_type': year_type} try: df = await self._execute_with_rate_limit( - self.data_source.get_dividend_data, code=code, year=year, year_type=year_type - ) + self.data_source.get_dividend_data, + code=code, + year=year, + year_type=year_type) # Generate filename clean_code = code.replace('.', '_') @@ -702,14 +802,17 @@ async def get_dividend_data(self, code: str, year: Optional[str] = None, year_ty except Exception as e: return self._create_error_response(e, 'get_dividend_data', params) - async def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> str: + async def get_adjust_factor_data(self, code: str, start_date: str, + end_date: str) -> str: """Get adjustment factor data (BaoStock).""" params = {'code': code, 'start_date': start_date, 'end_date': end_date} try: df = await self._execute_with_rate_limit( - self.data_source.get_adjust_factor_data, code=code, start_date=start_date, end_date=end_date - ) + self.data_source.get_adjust_factor_data, + code=code, + start_date=start_date, + end_date=end_date) # Generate filename clean_code = code.replace('.', '_') @@ -719,21 +822,39 @@ async def get_adjust_factor_data(self, code: str, start_date: str, end_date: str saved_path = self._save_dataframe(df, filename) # Return response with sample data - metadata = {'code': code, 'date_range': f'{start_date} to {end_date}'} + metadata = { + 'code': code, + 'date_range': f'{start_date} to {end_date}' + } return self._create_success_response(df, saved_path, metadata) except Exception as e: - return self._create_error_response(e, 'get_adjust_factor_data', params) - - async def get_financial_data(self, code: str, year: str, quarter: int, data_types: Optional[list] = None) -> str: + return self._create_error_response(e, 'get_adjust_factor_data', + params) + + async def get_financial_data(self, + code: str, + year: str, + quarter: int, + data_types: Optional[list] = None) -> str: """Get multiple categories of financial data in one call.""" - data_types = data_types or ['profit', 'operation', 'growth', 'balance', 'cash_flow', 'dupont'] - params = {'code': code, 'year': year, 'quarter': quarter, 'data_types': data_types} + data_types = data_types or [ + 'profit', 'operation', 'growth', 'balance', 'cash_flow', 'dupont' + ] + params = { + 'code': code, + 'year': year, + 'quarter': quarter, + 'data_types': data_types + } try: result = await self._execute_with_rate_limit( - self.data_source.get_financial_data, code=code, year=year, quarter=quarter, data_types=data_types - ) + self.data_source.get_financial_data, + code=code, + year=year, + quarter=quarter, + data_types=data_types) # Save each data type and prepare response clean_code = code.replace('.', '_') @@ -754,26 +875,42 @@ async def get_financial_data(self, code: str, year: str, quarter: int, data_type example_data[key] = value response = { - 'success': True, - 'code': code, - 'year': year, - 'quarter': quarter, - 'data_types': list(result.keys()), - 'saved_files': saved_files, - 'example_data': example_data, - 'note': 'Financial data saved to separate files. Showing sample rows for each data type.', + 'success': + True, + 'code': + code, + 'year': + year, + 'quarter': + quarter, + 'data_types': + list(result.keys()), + 'saved_files': + saved_files, + 'example_data': + example_data, + 'note': + 'Financial data saved to separate files. Showing sample rows for each data type.' } - return json.dumps(response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) + return json.dumps( + response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) except Exception as e: return self._create_error_response(e, 'get_financial_data', params) - async def get_report( - self, code: str, start_date: str, end_date: str, report_type: str = 'performance_express' - ) -> str: + async def get_report(self, + code: str, + start_date: str, + end_date: str, + report_type: str = 'performance_express') -> str: """Get performance express/forecast reports (BaoStock).""" - params = {'code': code, 'start_date': start_date, 'end_date': end_date, 'report_type': report_type} + params = { + 'code': code, + 'start_date': start_date, + 'end_date': end_date, + 'report_type': report_type + } try: df = await self._execute_with_rate_limit( @@ -781,8 +918,7 @@ async def get_report( code=code, start_date=start_date, end_date=end_date, - report_type=report_type, - ) + report_type=report_type) # Generate filename clean_code = code.replace('.', '_') @@ -792,7 +928,11 @@ async def get_report( saved_path = self._save_dataframe(df, filename) # Return response with sample data - metadata = {'code': code, 'date_range': f'{start_date} to {end_date}', 'report_type': report_type} + metadata = { + 'code': code, + 'date_range': f'{start_date} to {end_date}', + 'report_type': report_type + } return self._create_success_response(df, saved_path, metadata) except Exception as e: @@ -804,8 +944,9 @@ async def get_trade_dates(self, start_date: str, end_date: str) -> str: try: df = await self._execute_with_rate_limit( - self.data_source.get_trade_dates, start_date=start_date, end_date=end_date - ) + self.data_source.get_trade_dates, + start_date=start_date, + end_date=end_date) # Generate filename filename = f'trade_dates_{start_date}_{end_date}' @@ -825,7 +966,8 @@ async def get_stock_industry(self, code: str, date: str) -> str: params = {'code': code, 'date': date} try: - df = await self._execute_with_rate_limit(self.data_source.get_stock_industry, code=code, date=date) + df = await self._execute_with_rate_limit( + self.data_source.get_stock_industry, code=code, date=date) # Generate filename clean_code = code.replace('.', '_') @@ -841,12 +983,17 @@ async def get_stock_industry(self, code: str, date: str) -> str: except Exception as e: return self._create_error_response(e, 'get_stock_industry', params) - async def get_stock_list(self, date: str, data_type: str = 'all_a_share') -> str: + async def get_stock_list(self, + date: str, + data_type: str = 'all_a_share') -> str: """Get index constituents or all stocks.""" params = {'date': date, 'data_type': data_type} try: - df = await self._execute_with_rate_limit(self.data_source.get_stock_list, date=date, data_type=data_type) + df = await self._execute_with_rate_limit( + self.data_source.get_stock_list, + date=date, + data_type=data_type) # Generate filename filename = f'stock_list_{data_type}_{date}' @@ -855,28 +1002,31 @@ async def get_stock_list(self, date: str, data_type: str = 'all_a_share') -> str saved_path = self._save_dataframe(df, filename) # Return response with sample data - metadata = {'date': date, 'data_type': data_type, 'total_stocks': len(df)} + metadata = { + 'date': date, + 'data_type': data_type, + 'total_stocks': len(df) + } return self._create_success_response(df, saved_path, metadata) except Exception as e: return self._create_error_response(e, 'get_stock_list', params) - async def get_macro_data( - self, start_date: str, end_date: str, data_types: Optional[list] = None, extra_kwargs: Optional[dict] = None - ) -> str: + async def get_macro_data(self, + start_date: str, + end_date: str, + data_types: Optional[list] = None, + extra_kwargs: Optional[dict] = None) -> str: """Get macroeconomic data (BaoStock).""" data_types = data_types or [ - 'deposit_rate', - 'loan_rate', - 'required_reserve_ratio', - 'money_supply_month', - 'money_supply_year', + 'deposit_rate', 'loan_rate', 'required_reserve_ratio', + 'money_supply_month', 'money_supply_year' ] params = { 'start_date': start_date, 'end_date': end_date, 'data_types': data_types, - 'extra_kwargs': extra_kwargs, + 'extra_kwargs': extra_kwargs } try: @@ -885,8 +1035,7 @@ async def get_macro_data( start_date=start_date, end_date=end_date, data_types=data_types, - extra_kwargs=extra_kwargs, - ) + extra_kwargs=extra_kwargs) # Save each data type and prepare response saved_files = {} @@ -906,15 +1055,22 @@ async def get_macro_data( example_data[key] = value response = { - 'success': True, - 'date_range': f'{start_date} to {end_date}', - 'data_types': list(result.keys()), - 'saved_files': saved_files, - 'example_data': example_data, - 'note': 'Macro data saved to separate files. Showing sample rows for each data type.', + 'success': + True, + 'date_range': + f'{start_date} to {end_date}', + 'data_types': + list(result.keys()), + 'saved_files': + saved_files, + 'example_data': + example_data, + 'note': + 'Macro data saved to separate files. Showing sample rows for each data type.' } - return json.dumps(response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) + return json.dumps( + response, ensure_ascii=False, indent=2, cls=DateTimeEncoder) except Exception as e: return self._create_error_response(e, 'get_macro_data', params) diff --git a/ms_agent/tools/findata/hybrid_source.py b/ms_agent/tools/findata/hybrid_source.py index 57ad3c374..79380fd14 100644 --- a/ms_agent/tools/findata/hybrid_source.py +++ b/ms_agent/tools/findata/hybrid_source.py @@ -3,10 +3,10 @@ from typing import Any, Callable, Dict, List, Optional import pandas as pd - from ms_agent.tools.findata.akshare_source import AKShareDataSource from ms_agent.tools.findata.baostock_source import BaoStockDataSource -from ms_agent.tools.findata.data_source_base import DataSourceError, FinancialDataSource +from ms_agent.tools.findata.data_source_base import (DataSourceError, + FinancialDataSource) from ms_agent.utils import get_logger logger = get_logger() @@ -27,7 +27,9 @@ def __init__(self): logger.info('Initializing Hybrid data source') self.baostock = BaoStockDataSource() self.akshare = AKShareDataSource() - logger.info('Hybrid data source initialized (A-shares: BaoStock, Others: AKShare)') + logger.info( + 'Hybrid data source initialized (A-shares: BaoStock, Others: AKShare)' + ) def _detect_market(self, code: str) -> str: """ @@ -54,7 +56,9 @@ def _detect_market(self, code: str) -> str: logger.warning(f'Unknown market type for code: {code}') return 'unknown' - def _get_source(self, code: str, market: str = None) -> List[FinancialDataSource]: + def _get_source(self, + code: str, + market: str = None) -> List[FinancialDataSource]: """Select data source based on stock code""" market = market if market else self._detect_market(code) @@ -65,7 +69,8 @@ def _get_source(self, code: str, market: str = None) -> List[FinancialDataSource logger.debug(f'Using AKShare for {market}: {code}') return [self.akshare] - def _call_sources(self, sources: List[FinancialDataSource], query_func: Callable) -> pd.DataFrame: + def _call_sources(self, sources: List[FinancialDataSource], + query_func: Callable) -> pd.DataFrame: """Call query function for multiple data sources""" for source in sources: try: @@ -75,7 +80,9 @@ def _call_sources(self, sources: List[FinancialDataSource], query_func: Callable if isinstance(result, dict) and result: return result except Exception as e: - logger.warning(f'Data source {source.__class__.__name__} failed, continue to next source: {e}') + logger.warning( + f'Data source {source.__class__.__name__} failed, continue to next source: {e}' + ) continue source_names = [s.__class__.__name__ for s in sources] @@ -93,58 +100,78 @@ def get_historical_k_data( """Get historical K-line data""" sources = self._get_source(code) return self._call_sources( - sources, lambda s: s.get_historical_k_data(code, start_date, end_date, frequency, adjust_flag, fields) - ) + sources, lambda s: s.get_historical_k_data( + code, start_date, end_date, frequency, adjust_flag, fields)) def get_stock_basic_info(self, code: str) -> pd.DataFrame: """Get stock basic information""" sources = self._get_source(code) - return self._call_sources(sources, lambda s: s.get_stock_basic_info(code)) + return self._call_sources(sources, + lambda s: s.get_stock_basic_info(code)) - def get_dividend_data(self, code: str, year: Optional[str] = None, year_type: str = 'report') -> pd.DataFrame: + def get_dividend_data(self, + code: str, + year: Optional[str] = None, + year_type: str = 'report') -> pd.DataFrame: """Get dividend data (BaoStock only)""" sources = self._get_source(code) - return self._call_sources(sources, lambda s: s.get_dividend_data(code, year, year_type)) + return self._call_sources( + sources, lambda s: s.get_dividend_data(code, year, year_type)) - def get_adjust_factor_data(self, code: str, start_date: str, end_date: str) -> pd.DataFrame: + def get_adjust_factor_data(self, code: str, start_date: str, + end_date: str) -> pd.DataFrame: """Get adjustment factor data (BaoStock only)""" sources = self._get_source(code) - return self._call_sources(sources, lambda s: s.get_adjust_factor_data(code, start_date, end_date)) + return self._call_sources( + sources, + lambda s: s.get_adjust_factor_data(code, start_date, end_date)) - def get_financial_data(self, code: str, year: str, quarter: int, data_types: List[str]) -> Dict[str, pd.DataFrame]: + def get_financial_data(self, code: str, year: str, quarter: int, + data_types: List[str]) -> Dict[str, pd.DataFrame]: """Get financial data for multiple categories in one call""" sources = self._get_source(code) - return self._call_sources(sources, lambda s: s.get_financial_data(code, year, quarter, data_types)) - - def get_report( - self, code: str, start_date: str, end_date: str, report_type: str = 'performance_express' - ) -> pd.DataFrame: + return self._call_sources( + sources, + lambda s: s.get_financial_data(code, year, quarter, data_types)) + + def get_report(self, + code: str, + start_date: str, + end_date: str, + report_type: str = 'performance_express') -> pd.DataFrame: """Get report data (BaoStock only)""" sources = self._get_source(code) - return self._call_sources(sources, lambda s: s.get_report(code, start_date, end_date, report_type)) + return self._call_sources( + sources, + lambda s: s.get_report(code, start_date, end_date, report_type)) def get_stock_industry(self, code: str, date: str) -> pd.DataFrame: """Get industry classification (BaoStock only)""" sources = self._get_source(code) - return self._call_sources(sources, lambda s: s.get_stock_industry(code, date)) + return self._call_sources(sources, + lambda s: s.get_stock_industry(code, date)) - def get_stock_list(self, date: str, data_type: str = 'all_a_share') -> pd.DataFrame: + def get_stock_list(self, + date: str, + data_type: str = 'all_a_share') -> pd.DataFrame: """Get stock list or index constituents (BaoStock only)""" sources = self._get_source('', market='a_share') - return self._call_sources(sources, lambda s: s.get_stock_list(date, data_type)) + return self._call_sources(sources, + lambda s: s.get_stock_list(date, data_type)) def get_trade_dates(self, start_date: str, end_date: str) -> pd.DataFrame: """Get trading calendar (BaoStock only)""" sources = self._get_source('', market='a_share') - return self._call_sources(sources, lambda s: s.get_trade_dates(start_date, end_date)) + return self._call_sources( + sources, lambda s: s.get_trade_dates(start_date, end_date)) def get_macro_data( self, start_date: str, end_date: str, data_types: Optional[List[str]] = None, - extra_kwargs: Optional[Dict[str, Any]] = None, - ) -> Dict[str, pd.DataFrame]: + extra_kwargs: Optional[Dict[str, + Any]] = None) -> Dict[str, pd.DataFrame]: """Get macroeconomic data for multiple categories in one call (BaoStock only)""" if data_types is None: data_types = [] @@ -152,4 +179,6 @@ def get_macro_data( extra_kwargs = {} sources = self._get_source('', market='a_share') - return self._call_sources(sources, lambda s: s.get_macro_data(start_date, end_date, data_types, extra_kwargs)) + return self._call_sources( + sources, lambda s: s.get_macro_data(start_date, end_date, + data_types, extra_kwargs)) diff --git a/ms_agent/tools/image_generator/ds_image_gen.py b/ms_agent/tools/image_generator/ds_image_gen.py index 0a3f482d9..b7286b218 100644 --- a/ms_agent/tools/image_generator/ds_image_gen.py +++ b/ms_agent/tools/image_generator/ds_image_gen.py @@ -6,18 +6,23 @@ class DSImageGenerator: + def __init__(self, config, temp_dir): self.config = config self.temp_dir = temp_dir os.makedirs(self.temp_dir, exist_ok=True) - async def generate_image(self, positive_prompt, negative_prompt=None, size=None, ratio=None, **kwargs): + async def generate_image(self, + positive_prompt, + negative_prompt=None, + size=None, + ratio=None, + **kwargs): import aiohttp - image_generator = self.config.tools.image_generator base_url = ( - getattr(image_generator, 'base_url', None) or 'https://dashscope.aliyuncs.com/compatible-mode' - ).strip('/') + getattr(image_generator, 'base_url', None) + or 'https://dashscope.aliyuncs.com/compatible-mode').strip('/') api_key = image_generator.api_key model_id = image_generator.model assert api_key is not None @@ -33,24 +38,34 @@ async def generate_image(self, positive_prompt, negative_prompt=None, size=None, request_body = { 'model': model_id, - 'dashscope_extend_params': {'provider': 'b', 'using_native_protocol': True}, + 'dashscope_extend_params': { + 'provider': 'b', + 'using_native_protocol': True + }, 'stream': False, - 'contents': {'role': 'USER', 'parts': {'text': positive_prompt}}, + 'contents': { + 'role': 'USER', + 'parts': { + 'text': positive_prompt + } + }, 'generationConfig': { 'responseModalities': ['TEXT', 'IMAGE'], 'image_config': { 'aspect_ratio': ratio, }, - }, + } } async with aiohttp.ClientSession() as session: - async with session.post(base_url, headers=headers, json=request_body) as resp: + async with session.post( + base_url, headers=headers, json=request_body) as resp: resp.raise_for_status() data = await resp.json() try: - image_url = data['candidates'][0]['content']['parts'][-1]['inlineData']['data'] + image_url = data['candidates'][0]['content']['parts'][-1][ + 'inlineData']['data'] async with session.get(image_url) as img_resp: img_content = await img_resp.read() image = Image.open(BytesIO(img_content)) diff --git a/ms_agent/tools/image_generator/google_image_gen.py b/ms_agent/tools/image_generator/google_image_gen.py index 103d52cc2..fe3f6f011 100644 --- a/ms_agent/tools/image_generator/google_image_gen.py +++ b/ms_agent/tools/image_generator/google_image_gen.py @@ -3,15 +3,18 @@ class GoogleImageGenerator: + def __init__(self, config, temp_dir): self.config = config self.temp_dir = temp_dir os.makedirs(self.temp_dir, exist_ok=True) - async def generate_image(self, positive_prompt, negative_prompt=None, **kwargs): + async def generate_image(self, + positive_prompt, + negative_prompt=None, + **kwargs): # TODO not tested from google import genai - image_generator = self.config.tools.image_generator api_key = image_generator.api_key model_id = image_generator.model diff --git a/ms_agent/tools/image_generator/image_gen.py b/ms_agent/tools/image_generator/image_gen.py index c9e5f12b4..3ab8a71df 100644 --- a/ms_agent/tools/image_generator/image_gen.py +++ b/ms_agent/tools/image_generator/image_gen.py @@ -6,22 +6,21 @@ class ImageGenerator(ToolBase): + def __init__(self, config): super().__init__(config) - self.temp_dir = os.path.join(self.output_dir, '.temp', 'image_generator') + self.temp_dir = os.path.join(self.output_dir, '.temp', + 'image_generator') os.makedirs(self.temp_dir, exist_ok=True) image_generator = self.config.image_generator if image_generator.type == 'modelscope': from .ms_image_gen import MSImageGenerator - self.generator = MSImageGenerator(self.config, self.temp_dir) elif image_generator.type == 'dashscope': from .ds_image_gen import DSImageGenerator - self.generator = DSImageGenerator(self.config, self.temp_dir) elif image_generator.type == 'google': from .google_image_gen import GoogleImageGenerator - self.generator = GoogleImageGenerator(self.config, self.temp_dir) else: raise NotImplementedError() @@ -35,21 +34,30 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='generate_image', server_name='image_generator', - description='Generate an image with a positive prompt, and return the image file path.', + description= + 'Generate an image with a positive prompt, and return the image file path.', parameters={ 'type': 'object', 'properties': { - 'positive_prompt': {'type': 'string', 'description': 'The prompt to generate the image.'} + 'positive_prompt': { + 'type': 'string', + 'description': + 'The prompt to generate the image.' + } }, 'required': ['positive_prompt'], - 'additionalProperties': False, - }, - ) + 'additionalProperties': False + }) ] } - async def generate_image(self, positive_prompt, negative_prompt=None, **kwargs): - return await self.generator.generate_image(positive_prompt, negative_prompt, **kwargs) + async def generate_image(self, + positive_prompt, + negative_prompt=None, + **kwargs): + return await self.generator.generate_image(positive_prompt, + negative_prompt, **kwargs) - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: return await self.generate_image(**tool_args) diff --git a/ms_agent/tools/image_generator/ms_image_gen.py b/ms_agent/tools/image_generator/ms_image_gen.py index 9000e26b4..b121458e6 100644 --- a/ms_agent/tools/image_generator/ms_image_gen.py +++ b/ms_agent/tools/image_generator/ms_image_gen.py @@ -1,27 +1,33 @@ import asyncio -import json import os import uuid from io import BytesIO +import json from PIL import Image class MSImageGenerator: + def __init__(self, config, temp_dir): self.config = config self.temp_dir = temp_dir os.makedirs(self.temp_dir, exist_ok=True) - async def generate_image(self, positive_prompt, negative_prompt=None, size=None, **kwargs): + async def generate_image(self, + positive_prompt, + negative_prompt=None, + size=None, + **kwargs): import aiohttp - image_generator = self.config.tools.image_generator - base_url = (getattr(image_generator, 'base_url', None) or 'https://api-inference.modelscope.cn').strip('/') + base_url = (getattr(image_generator, 'base_url', None) + or 'https://api-inference.modelscope.cn').strip('/') api_key = image_generator.api_key model_id = image_generator.model assert api_key is not None - output_file = os.path.join(self.temp_dir, f'{str(uuid.uuid4())[:8]}.png') + output_file = os.path.join(self.temp_dir, + f'{str(uuid.uuid4())[:8]}.png') headers = { 'Authorization': f'Bearer {api_key}', @@ -30,18 +36,18 @@ async def generate_image(self, positive_prompt, negative_prompt=None, size=None, async with aiohttp.ClientSession() as session: async with session.post( - f'{base_url}/v1/images/generations', - headers={**headers, 'X-ModelScope-Async-Mode': 'true'}, - data=json.dumps( - { - 'model': model_id, - 'prompt': positive_prompt, - 'negative_prompt': negative_prompt or '', - 'size': size or '', + f'{base_url}/v1/images/generations', + headers={ + **headers, 'X-ModelScope-Async-Mode': 'true' }, - ensure_ascii=False, - ), - ) as resp: + data=json.dumps( + { + 'model': model_id, + 'prompt': positive_prompt, + 'negative_prompt': negative_prompt or '', + 'size': size or '', + }, + ensure_ascii=False)) as resp: resp.raise_for_status() task_id = (await resp.json())['task_id'] @@ -55,8 +61,11 @@ async def generate_image(self, positive_prompt, negative_prompt=None, size=None, elapsed_time += poll_interval async with session.get( - f'{base_url}/v1/tasks/{task_id}', headers={**headers, 'X-ModelScope-Task-Type': 'image_generation'} - ) as result: + f'{base_url}/v1/tasks/{task_id}', + headers={ + **headers, 'X-ModelScope-Task-Type': + 'image_generation' + }) as result: result.raise_for_status() data = await result.json() @@ -73,5 +82,5 @@ async def generate_image(self, positive_prompt, negative_prompt=None, size=None, poll_interval = min(poll_interval * 1.5, max_poll_interval) return ( - f'Retrieval timeout, consider retry the task, or waiting for longer time(current is {max_wait_time}s).' - ) + f'Retrieval timeout, consider retry the task, or waiting for ' + f'longer time(current is {max_wait_time}s).') diff --git a/ms_agent/tools/jina_reader.py b/ms_agent/tools/jina_reader.py index acb9a7336..b3f663971 100644 --- a/ms_agent/tools/jina_reader.py +++ b/ms_agent/tools/jina_reader.py @@ -10,13 +10,15 @@ from urllib.parse import quote, urlparse from urllib.request import Request, urlopen -from ms_agent.tools.fetch_playwright_fallback import looks_like_spa_shell_html, try_playwright_inner_text +from ms_agent.tools.fetch_playwright_fallback import (looks_like_spa_shell_html, + try_playwright_inner_text) from ms_agent.utils.logger import get_logger logger = get_logger() DEFAULT_HEADERS: Dict[str, str] = { - 'User-Agent': 'Mozilla/5.0 (compatible; ms-agent/1.0; +https://example.com)', + 'User-Agent': + 'Mozilla/5.0 (compatible; ms-agent/1.0; +https://example.com)', 'Accept': 'text/plain; charset=utf-8', 'Accept-Language': 'en-US,en;q=0.9', } @@ -26,7 +28,8 @@ _DIRECT_FETCH_HEADERS: Dict[str, str] = { 'User-Agent': DEFAULT_HEADERS['User-Agent'], - 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,text/plain;q=0.8,*/*;q=0.7', + 'Accept': + 'text/html,application/xhtml+xml,application/xml;q=0.9,text/plain;q=0.8,*/*;q=0.7', 'Accept-Language': DEFAULT_HEADERS['Accept-Language'], } @@ -38,7 +41,8 @@ class JinaReaderConfig: retries: int = 3 backoff_base: float = 0.8 backoff_max: float = 8.0 - headers: Dict[str, str] = field(default_factory=lambda: DEFAULT_HEADERS.copy()) + headers: Dict[str, + str] = field(default_factory=lambda: DEFAULT_HEADERS.copy()) # When Jina Reader returns empty after retries, try HTTP GET on the target URL. direct_fetch_fallback: bool = True # Tier 2 (urllib): shorter than Jina timeout — fail fast on slow origins. @@ -53,7 +57,8 @@ class JinaReaderConfig: def _build_reader_url(target_url: str, base_endpoint: str) -> str: encoded_target = quote(target_url, safe=":/?&=%#@!$'*+,;[]()") - base = base_endpoint if base_endpoint.endswith('/') else f'{base_endpoint}/' + base = base_endpoint if base_endpoint.endswith( + '/') else f'{base_endpoint}/' return f'{base}{encoded_target}' @@ -127,11 +132,8 @@ def _fetch_direct_http_pair(url: str, timeout: float) -> Tuple[str, str]: content_type = (resp.headers.get('Content-Type') or '').lower() content_type_main = content_type.split(';')[0].strip() text = raw.decode(charset, errors='replace') - if ( - 'html' in content_type_main - or text.lstrip().lower().startswith(' Tuple[str, str]: return '', '' -def _should_try_playwright_after_direct(plain: str, raw_html: str, min_chars: int) -> bool: +def _should_try_playwright_after_direct(plain: str, raw_html: str, + min_chars: int) -> bool: """Whether tier-3 headless fetch is worth attempting.""" p = plain.strip() if raw_html: @@ -165,29 +168,34 @@ def _fetch_via_jina(url: str, config: JinaReaderConfig) -> str: return data.decode('utf-8', errors='replace') except HTTPError as e: status = getattr(e, 'code', None) - if status in (429, 500, 502, 503, 504) and attempt <= config.retries: - sleep_s = min(config.backoff_max, config.backoff_base * (2 ** (attempt - 1))) + if status in (429, 500, 502, 503, + 504) and attempt <= config.retries: + sleep_s = min(config.backoff_max, + config.backoff_base * (2**(attempt - 1))) sleep_s *= random.uniform(0.7, 1.4) time.sleep(sleep_s) continue return '' except URLError: if attempt <= config.retries: - sleep_s = min(config.backoff_max, config.backoff_base * (2 ** (attempt - 1))) + sleep_s = min(config.backoff_max, + config.backoff_base * (2**(attempt - 1))) sleep_s *= random.uniform(0.7, 1.4) time.sleep(sleep_s) continue return '' except Exception: if attempt <= config.retries: - sleep_s = min(config.backoff_max, config.backoff_base * (2 ** (attempt - 1))) + sleep_s = min(config.backoff_max, + config.backoff_base * (2**(attempt - 1))) sleep_s *= random.uniform(0.7, 1.4) time.sleep(sleep_s) continue return '' -def fetch_single_text_with_meta(url: str, config: JinaReaderConfig) -> Tuple[str, Dict[str, Any]]: +def fetch_single_text_with_meta(url: str, + config: JinaReaderConfig) -> Tuple[str, Dict[str, Any]]: """ Tiered fetch: Jina Reader → direct HTTP → optional Playwright (empty / short / SPA shell). @@ -201,17 +209,15 @@ def fetch_single_text_with_meta(url: str, config: JinaReaderConfig) -> Tuple[str return jina_text, {'content_source': 'jina_reader'} if not config.direct_fetch_fallback: return '', {'content_source': 'none'} - d_timeout = ( - float(config.timeout) if float(config.direct_fetch_timeout or 0) <= 0 else float(config.direct_fetch_timeout) - ) + d_timeout = (float(config.timeout) if float(config.direct_fetch_timeout or 0) + <= 0 else float(config.direct_fetch_timeout)) direct_plain, raw_html = _fetch_direct_http_pair(url, d_timeout) direct_text = _postprocess_text(direct_plain) try_playwright = ( - bool(config.playwright_fetch_fallback) - and _is_direct_http_allowed(url) - and _should_try_playwright_after_direct(direct_text, raw_html, config.playwright_retry_min_chars) - ) + bool(config.playwright_fetch_fallback) and _is_direct_http_allowed(url) + and _should_try_playwright_after_direct(direct_text, raw_html, + config.playwright_retry_min_chars)) if try_playwright: pw_text = _postprocess_text( @@ -219,14 +225,17 @@ def fetch_single_text_with_meta(url: str, config: JinaReaderConfig) -> Tuple[str url, int(config.playwright_timeout_ms), settle_ms=int(config.playwright_settle_ms), - ) - ) + )) if pw_text.strip(): - logger.info(f'Using headless Chromium fallback after Jina/direct HTTP (url prefix): {url[:80]}') + logger.info( + 'Using headless Chromium fallback after Jina/direct HTTP ' + f'(url prefix): {url[:80]}') return pw_text, {'content_source': 'playwright_fallback'} if direct_text: - logger.info(f'Jina Reader returned no body for URL; using direct HTTP fallback (url prefix): {url[:80]}') + logger.info( + 'Jina Reader returned no body for URL; using direct HTTP fallback ' + f'(url prefix): {url[:80]}') return direct_text, {'content_source': 'direct_http_fallback'} return '', {'content_source': 'none'} @@ -241,11 +250,10 @@ def fetch_single_text(url: str, config: JinaReaderConfig) -> str: async def fetch_texts_via_jina( - urls: List[str], - config: Optional[JinaReaderConfig] = None, - semaphore: Optional[asyncio.Semaphore] = None, - executor: Optional[ThreadPoolExecutor] = None, -) -> List[str]: + urls: List[str], + config: Optional[JinaReaderConfig] = None, + semaphore: Optional[asyncio.Semaphore] = None, + executor: Optional[ThreadPoolExecutor] = None) -> List[str]: """ Asynchronously fetch a list of URLs via Jina Reader. Allows caller-provided concurrency controls (semaphore/executor) to integrate with pipeline resource management. @@ -259,7 +267,8 @@ async def fetch_texts_via_jina( async def _bound(u: str) -> str: async with local_sem: - return await loop.run_in_executor(executor, fetch_single_text, u, cfg) + return await loop.run_in_executor(executor, fetch_single_text, u, + cfg) tasks = [_bound(u) for u in urls] results = await asyncio.gather(*tasks, return_exceptions=True) diff --git a/ms_agent/tools/mcp_client.py b/ms_agent/tools/mcp_client.py index e6074fe2a..d6b971fbb 100644 --- a/ms_agent/tools/mcp_client.py +++ b/ms_agent/tools/mcp_client.py @@ -8,13 +8,12 @@ from mcp import ClientSession, ListToolsResult, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client -from omegaconf import DictConfig - from ms_agent.config import Config from ms_agent.config.env import Env from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import enhance_error, get_logger +from omegaconf import DictConfig logger = get_logger() @@ -52,14 +51,18 @@ def __init__( self.mcp_config: Dict[str, Dict[str, Any]] = {'mcpServers': {}} if config is not None: config_from_file = Config.convert_mcp_servers_to_json(config) - self.mcp_config['mcpServers'].update(config_from_file.get('mcpServers', {})) + self.mcp_config['mcpServers'].update( + config_from_file.get('mcpServers', {})) self.exclude_functions = {} self.include_functions = {} if mcp_config is not None: - self.mcp_config['mcpServers'].update(mcp_config.get('mcpServers', {})) + self.mcp_config['mcpServers'].update( + mcp_config.get('mcpServers', {})) - async def call_tool(self, server_name: str, tool_name: str, tool_args: dict): - response = await self.sessions[server_name].call_tool(tool_name, tool_args) + async def call_tool(self, server_name: str, tool_name: str, + tool_args: dict): + response = await self.sessions[server_name].call_tool( + tool_name, tool_args) texts = [] resources = [] @@ -77,7 +80,6 @@ async def call_tool(self, server_name: str, tool_name: str, tool_args: dict): texts.append(content.text) elif content.type == 'resource': import json5 - json_str = content.resource.model_dump_json(by_alias=True) texts.append(json_str) resources.append(json5.loads(json_str)) @@ -94,7 +96,8 @@ async def get_tools(self) -> Dict: try: response = await session.list_tools() except Exception as e: - new_eg = enhance_error(e, f'MCP `{key}` list tool failed, details: ') + new_eg = enhance_error( + e, f'MCP `{key}` list tool failed, details: ') raise new_eg from e _session_tools = response.tools exclude = [] @@ -105,12 +108,19 @@ async def get_tools(self) -> Dict: elif self.exclude_functions: if key in self.exclude_functions: exclude = self.exclude_functions[key] - _session_tools = [t for t in _session_tools if t.name not in exclude] + _session_tools = [ + t for t in _session_tools if t.name not in exclude + ] if include: - _session_tools = [t for t in _session_tools if t.name in include] + _session_tools = [ + t for t in _session_tools if t.name in include + ] _session_tools = [ - Tool(tool_name=t.name, server_name=key, description=t.description, parameters=t.inputSchema) - for t in _session_tools + Tool( + tool_name=t.name, + server_name=key, + description=t.description, + parameters=t.inputSchema) for t in _session_tools ] tools[key].extend(_session_tools) return tools @@ -122,13 +132,18 @@ def print_tools(server_name: str, tools: ListToolsResult): if len(tools) > 10: tools = [tool.name for tool in tools][:10] logger.info( - f'\nConnected to server "{server_name}" with tools: \n{sep.join(tools)}\nOnly list first 10 of them.' + f'\nConnected to server "{server_name}" ' + f'with tools: \n{sep.join(tools)}\nOnly list first 10 of them.' ) else: tools = [tool.name for tool in tools] - logger.info(f'\nConnected to server "{server_name}" with tools: \n{sep.join(tools)}.') + logger.info(f'\nConnected to server "{server_name}" ' + f'with tools: \n{sep.join(tools)}.') - async def connect_to_server(self, server_name: str, timeout: int = CONNECTION_TIMEOUT, **kwargs): + async def connect_to_server(self, + server_name: str, + timeout: int = CONNECTION_TIMEOUT, + **kwargs): logger.info(f'connect to {server_name}') # transport: stdio, sse, streamable_http, websocket transport = kwargs.get('transport') or kwargs.get('type') @@ -137,19 +152,21 @@ async def connect_to_server(self, server_name: str, timeout: int = CONNECTION_TI session_kwargs = kwargs.get('session_kwargs') if url: if transport and transport.lower() == 'sse': - logger.info('`transport` or `type` is configured as "sse", using sse transport.') + logger.info( + '`transport` or `type` is configured as "sse", using sse transport.' + ) sse_transport = await self.exit_stack.enter_async_context( sse_client( - url, - kwargs.get('headers'), + url, kwargs.get('headers'), kwargs.get('timeout', DEFAULT_HTTP_TIMEOUT), - kwargs.get('sse_read_timeout', DEFAULT_SSE_READ_TIMEOUT), - ) - ) + kwargs.get('sse_read_timeout', + DEFAULT_SSE_READ_TIMEOUT))) read, write = sse_transport elif transport and transport.lower() == 'websocket': - logger.info('`transport` or `type` is configured as "websocket", using websocket transport.') + logger.info( + '`transport` or `type` is configured as "websocket", using websocket transport.' + ) try: from mcp.client.websocket import websocket_client except ImportError: @@ -158,22 +175,21 @@ async def connect_to_server(self, server_name: str, timeout: int = CONNECTION_TI 'To use Websocket connections, please install the required dependency with: ' "'pip install mcp[ws]' or 'pip install websockets'" ) from None - websocket_transport = await self.exit_stack.enter_async_context(websocket_client(url)) + websocket_transport = await self.exit_stack.enter_async_context( + websocket_client(url)) read, write = websocket_transport else: logger.info( 'Using streamable_http transport. To configure a different transport such as sse, please' - 'set the `type` or `transport` variable to "sse".' - ) + 'set the `type` or `transport` variable to "sse".') try: from mcp.client.streamable_http import streamablehttp_client except ImportError: raise ImportError( 'Could not import streamablehttp_client. ' 'To use streamable http connections, please upgrade to the latest version of mcp with: ' - "'pip install -U mcp'" - ) from None + "'pip install -U mcp'") from None httpx_client_factory = kwargs.get('httpx_client_factory') other_kwargs = {} if httpx_client_factory is not None: @@ -182,36 +198,46 @@ async def connect_to_server(self, server_name: str, timeout: int = CONNECTION_TI streamablehttp_client( url, headers=kwargs.get('headers'), - timeout=kwargs.get('timeout', DEFAULT_STREAMABLE_HTTP_TIMEOUT), - sse_read_timeout=kwargs.get('sse_read_timeout', DEFAULT_STREAMABLE_HTTP_SSE_READ_TIMEOUT), - **other_kwargs, - ) - ) + timeout=kwargs.get('timeout', + DEFAULT_STREAMABLE_HTTP_TIMEOUT), + sse_read_timeout=kwargs.get( + 'sse_read_timeout', + DEFAULT_STREAMABLE_HTTP_SSE_READ_TIMEOUT), + **other_kwargs)) read, write, _ = streamable_transport session_kwargs = session_kwargs or {} - timeout = max(session_kwargs.pop('read_timeout_seconds', timeout), 1) + timeout = max( + session_kwargs.pop('read_timeout_seconds', timeout), 1) session = await self.exit_stack.enter_async_context( - ClientSession(read, write, read_timeout_seconds=timedelta(seconds=timeout), **session_kwargs) - ) + ClientSession( + read, + write, + read_timeout_seconds=timedelta(seconds=timeout), + **session_kwargs)) elif command: # transport: 'stdio' args = kwargs.get('args') if not args: - raise ValueError("'args' parameter is required for stdio connection") + raise ValueError( + "'args' parameter is required for stdio connection") server_params = StdioServerParameters( command=command, args=args, env=kwargs.get('env'), encoding=kwargs.get('encoding', DEFAULT_ENCODING), - encoding_error_handler=kwargs.get('encoding_error_handler', DEFAULT_ENCODING_ERROR_HANDLER), + encoding_error_handler=kwargs.get( + 'encoding_error_handler', DEFAULT_ENCODING_ERROR_HANDLER), ) - stdio, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) - session = await self.exit_stack.enter_async_context(ClientSession(stdio, write)) + stdio, write = await self.exit_stack.enter_async_context( + stdio_client(server_params)) + session = await self.exit_stack.enter_async_context( + ClientSession(stdio, write)) else: - raise ValueError("'url' or 'command' parameter is required for connection") + raise ValueError( + "'url' or 'command' parameter is required for connection") await session.initialize() # Store session @@ -226,16 +252,20 @@ async def connect(self, timeout: int = CONNECTION_TIMEOUT): for name, server in mcp_config.items(): try: env_dict = server.pop('env', {}) - env_dict = {key: value if value else envs.get(key, '') for key, value in env_dict.items()} + env_dict = { + key: value if value else envs.get(key, '') + for key, value in env_dict.items() + } if 'exclude' in server: self.exclude_functions[name] = server.pop('exclude') if 'include' in server: self.include_functions[name] = server.pop('include') - assert (not self.include_functions.get(name)) or (not self.exclude_functions.get(name)), ( - 'Set either `include` or `exclude` in tools config.' - ) + assert (not self.include_functions.get(name)) or ( + not self.exclude_functions.get(name) + ), 'Set either `include` or `exclude` in tools config.' timeout = server.pop('timeout', timeout) - await self.connect_to_server(server_name=name, env=env_dict, timeout=timeout, **server) + await self.connect_to_server( + server_name=name, env=env_dict, timeout=timeout, **server) except Exception as e: new_eg = enhance_error(e, f'Connect `{name}` failed, details:') raise new_eg from e @@ -252,10 +282,14 @@ async def add_mcp_config(self, mcp_config: Dict[str, Dict[str, Any]]): else: servers[name] = server env_dict = server.pop('env', {}) - env_dict = {key: value if value else envs.get(key, '') for key, value in env_dict.items()} + env_dict = { + key: value if value else envs.get(key, '') + for key, value in env_dict.items() + } if 'exclude' in server: self.exclude_functions[name] = server.pop('exclude') - await self.connect_to_server(server_name=name, env=env_dict, **server) + await self.connect_to_server( + server_name=name, env=env_dict, **server) self.mcp_config['mcpServers'].update(new_mcp_config) async def cleanup(self): diff --git a/ms_agent/tools/mineru/pdf_parser.py b/ms_agent/tools/mineru/pdf_parser.py index a352fe396..d210fd9b1 100644 --- a/ms_agent/tools/mineru/pdf_parser.py +++ b/ms_agent/tools/mineru/pdf_parser.py @@ -1,13 +1,16 @@ import os from magic_pdf.config.enums import SupportedPdfParseMethod -from magic_pdf.data.data_reader_writer import FileBasedDataReader, FileBasedDataWriter +from magic_pdf.data.data_reader_writer import (FileBasedDataReader, + FileBasedDataWriter) from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze class PdfParser: + def __init__(self, parser_workdir: str): + # e.g. "your_workdir/resources/mineru" self._workdir = parser_workdir os.makedirs(self._workdir, exist_ok=True) @@ -15,7 +18,8 @@ def __init__(self, parser_workdir: str): self.relative_image_dir = 'images' self.markdown_dir = self._workdir - self.img_writer = FileBasedDataWriter(os.path.join(self._workdir, self.relative_image_dir)) + self.img_writer = FileBasedDataWriter( + os.path.join(self._workdir, self.relative_image_dir)) self.md_writer = FileBasedDataWriter(self.markdown_dir) self.data_reader = FileBasedDataReader('') @@ -39,7 +43,8 @@ def parse(self, f_path: str, reuse: bool = True) -> str: print(f'Processing file: {f_path}') file_name_no_suffix = os.path.splitext(os.path.basename(f_path))[0] - entry_md_file = os.path.join(self.markdown_dir, f'{file_name_no_suffix}.md') + entry_md_file = os.path.join(self.markdown_dir, + f'{file_name_no_suffix}.md') if reuse and os.path.exists(entry_md_file): print(f'File {entry_md_file} already exists. Skipping processing.') @@ -63,24 +68,32 @@ def parse(self, f_path: str, reuse: bool = True) -> str: pipe_result = infer_result.pipe_txt_mode(self.img_writer) # draw model result on each page - infer_result.draw_model(os.path.join(self.markdown_dir, f'{file_name_no_suffix}_model.pdf')) + infer_result.draw_model( + os.path.join(self.markdown_dir, + f'{file_name_no_suffix}_model.pdf')) # draw layout result on each page - pipe_result.draw_layout(os.path.join(self.markdown_dir, f'{file_name_no_suffix}_layout.pdf')) + pipe_result.draw_layout( + os.path.join(self.markdown_dir, + f'{file_name_no_suffix}_layout.pdf')) # draw spans result on each page - pipe_result.draw_span(os.path.join(self.markdown_dir, f'{file_name_no_suffix}_spans.pdf')) + pipe_result.draw_span( + os.path.join(self.markdown_dir, + f'{file_name_no_suffix}_spans.pdf')) # dump markdown - pipe_result.dump_md(self.md_writer, f'{file_name_no_suffix}.md', self.relative_image_dir) + pipe_result.dump_md(self.md_writer, f'{file_name_no_suffix}.md', + self.relative_image_dir) # dump content list pipe_result.dump_content_list( - self.md_writer, f'{file_name_no_suffix}_content_list.json', self.relative_image_dir - ) + self.md_writer, f'{file_name_no_suffix}_content_list.json', + self.relative_image_dir) # dump middle json - pipe_result.dump_middle_json(self.md_writer, f'{file_name_no_suffix}_middle.json') + pipe_result.dump_middle_json(self.md_writer, + f'{file_name_no_suffix}_middle.json') print(f'Finished processing file: {f_path}') diff --git a/ms_agent/tools/search/arxiv/__init__.py b/ms_agent/tools/search/arxiv/__init__.py index d4308d5f7..2cbf5d6fe 100644 --- a/ms_agent/tools/search/arxiv/__init__.py +++ b/ms_agent/tools/search/arxiv/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa -from ms_agent.tools.search.arxiv.schema import ArxivSearchRequest, ArxivSearchResult +from ms_agent.tools.search.arxiv.schema import (ArxivSearchRequest, + ArxivSearchResult) from ms_agent.tools.search.arxiv.search import ArxivSearch diff --git a/ms_agent/tools/search/arxiv/schema.py b/ms_agent/tools/search/arxiv/schema.py index f0bc2aed4..154c603ae 100644 --- a/ms_agent/tools/search/arxiv/schema.py +++ b/ms_agent/tools/search/arxiv/schema.py @@ -1,11 +1,12 @@ # flake8: noqa -import arxiv -import json -from arxiv import SortCriterion, SortOrder from dataclasses import dataclass, field from typing import Any, Dict, Generator, List, Optional -from ms_agent.tools.search.search_base import BaseResult, SearchRequest, SearchResponse, SearchResult +import arxiv +import json +from arxiv import SortCriterion, SortOrder +from ms_agent.tools.search.search_base import (BaseResult, SearchRequest, + SearchResponse, SearchResult) from ms_agent.utils.logger import get_logger logger = get_logger() @@ -16,17 +17,15 @@ class ArxivSearchRequest(SearchRequest): A class representing a search request to ArXiv. """ - def __init__( - self, - query: str = None, - num_results: Optional[int] = 10, - sort_strategy: SortCriterion = SortCriterion.Relevance, - sort_order: SortOrder = SortOrder.Descending, - categories: Optional[List[str]] = None, - date_from: Optional[str] = None, - date_to: Optional[str] = None, - **kwargs: Any, - ): + def __init__(self, + query: str = None, + num_results: Optional[int] = 10, + sort_strategy: SortCriterion = SortCriterion.Relevance, + sort_order: SortOrder = SortOrder.Descending, + categories: Optional[List[str]] = None, + date_from: Optional[str] = None, + date_to: Optional[str] = None, + **kwargs: Any): """ Initialize ArxivSearchRequest with search parameters. @@ -48,9 +47,12 @@ def __init__( self.sort_strategy_map = { 'relevance': SortCriterion.Relevance, 'lastUpdatedDate': SortCriterion.LastUpdatedDate, - 'submittedDate': SortCriterion.SubmittedDate, + 'submittedDate': SortCriterion.SubmittedDate + } + self.sort_order_map = { + 'descending': SortOrder.Descending, + 'ascending': SortOrder.Ascending } - self.sort_order_map = {'descending': SortOrder.Descending, 'ascending': SortOrder.Ascending} def to_dict(self) -> Dict[str, Any]: """ @@ -59,16 +61,18 @@ def to_dict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The parameters as a dictionary """ - if isinstance(self.sort_strategy, str) and self.sort_strategy_map.get(self.sort_strategy): + if isinstance(self.sort_strategy, str) and self.sort_strategy_map.get( + self.sort_strategy): self.sort_strategy = self.sort_strategy_map[self.sort_strategy] - if isinstance(self.sort_order, str) and self.sort_order_map.get(self.sort_order): + if isinstance(self.sort_order, str) and self.sort_order_map.get( + self.sort_order): self.sort_order = self.sort_order_map[self.sort_order] return { 'query': self.query, 'max_results': self.num_results, 'sort_by': self.sort_strategy, - 'sort_order': self.sort_order, + 'sort_order': self.sort_order } def to_json(self) -> Dict[str, Any]: @@ -83,16 +87,18 @@ def to_json(self) -> Dict[str, Any]: 'query': self.query, 'max_results': self.num_results, 'sort_strategy': self.sort_strategy.value, - 'sort_order': self.sort_order.value, + 'sort_order': self.sort_order.value }, - ensure_ascii=False, - ) + ensure_ascii=False) class ArxivSearchResult(SearchResult): """ArXiv search result implementation.""" - def __init__(self, query: str, arguments: Dict[str, Any] = None, response: List['arxiv.Result'] = None): + def __init__(self, + query: str, + arguments: Dict[str, Any] = None, + response: List['arxiv.Result'] = None): """ Initialize ArxivSearchResult. @@ -134,20 +140,21 @@ def _process_results(self) -> SearchResponse: processed = [] for res in self.raw_response: if not isinstance(res, arxiv.Result): - print(f'***Warning: Result {res} is not an instance of arxiv.Result.') + print( + f'***Warning: Result {res} is not an instance of arxiv.Result.' + ) continue processed.append( BaseResult( - url=getattr(res, 'pdf_url', None) or getattr(res, 'entry_id', None), + url=getattr(res, 'pdf_url', None) + or getattr(res, 'entry_id', None), id=getattr(res, 'entry_id', None), title=getattr(res, 'title', None), highlights=None, highlight_scores=None, summary=getattr(res, 'summary', None), - markdown=None, - ) - ) + markdown=None)) return SearchResponse(results=processed) @@ -155,7 +162,8 @@ def _process_arguments(self) -> Dict[str, Any]: """Process the search arguments to be JSON serializable.""" sort_strategy = self.arguments.get('sort_strategy', None) if sort_strategy is None: - sort_strategy = self.arguments.get('sort_by', SortCriterion.Relevance) + sort_strategy = self.arguments.get('sort_by', + SortCriterion.Relevance) sort_order = self.arguments.get('sort_order', SortOrder.Descending) if isinstance(sort_strategy, SortCriterion): @@ -220,21 +228,31 @@ def to_list(self) -> List[Dict[str, Any]]: categories = getattr(res, 'categories', None) or [] - res_list.append( - { - 'url': (getattr(res, 'pdf_url', None) or getattr(res, 'entry_id', None) or ''), - 'id': getattr(res, 'entry_id', None) or '', - 'title': getattr(res, 'title', None) or '', - 'published_date': published_date, - 'summary': getattr(res, 'summary', None) or '', - 'highlights': None, - 'highlight_scores': None, - 'markdown': None, - 'authors': authors, - 'categories': categories, - 'arxiv_id': short_id or '', - 'resource_uri': f'arxiv://{short_id}' if short_id else '', - } - ) + res_list.append({ + 'url': (getattr(res, 'pdf_url', None) + or getattr(res, 'entry_id', None) or ''), + 'id': + getattr(res, 'entry_id', None) or '', + 'title': + getattr(res, 'title', None) or '', + 'published_date': + published_date, + 'summary': + getattr(res, 'summary', None) or '', + 'highlights': + None, + 'highlight_scores': + None, + 'markdown': + None, + 'authors': + authors, + 'categories': + categories, + 'arxiv_id': + short_id or '', + 'resource_uri': + f'arxiv://{short_id}' if short_id else '', + }) return res_list diff --git a/ms_agent/tools/search/arxiv/search.py b/ms_agent/tools/search/arxiv/search.py index 0761b2509..8f407d454 100644 --- a/ms_agent/tools/search/arxiv/search.py +++ b/ms_agent/tools/search/arxiv/search.py @@ -1,11 +1,12 @@ # flake8: noqa -import arxiv import os -from arxiv import SortCriterion, SortOrder from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING -from ms_agent.tools.search.arxiv.schema import ArxivSearchRequest, ArxivSearchResult +import arxiv +from arxiv import SortCriterion, SortOrder +from ms_agent.tools.search.arxiv.schema import (ArxivSearchRequest, + ArxivSearchResult) from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType from ms_agent.utils.logger import get_logger @@ -131,16 +132,20 @@ def _parse_yyyy_mm_dd(s: str, *, end_of_day: bool) -> datetime: date_from_dt = None date_to_dt = None if getattr(search_request, 'date_from', None): - date_from_dt = _parse_yyyy_mm_dd(search_request.date_from, end_of_day=False) + date_from_dt = _parse_yyyy_mm_dd( + search_request.date_from, end_of_day=False) if getattr(search_request, 'date_to', None): - date_to_dt = _parse_yyyy_mm_dd(search_request.date_to, end_of_day=True) + date_to_dt = _parse_yyyy_mm_dd( + search_request.date_to, end_of_day=True) if date_from_dt or date_to_dt: desired = int(search_request.num_results or 10) - search_args['max_results'] = min(max(desired + 10, desired), 50) + search_args['max_results'] = min( + max(desired + 10, desired), 50) response = [] - for paper in self.client.results(search=arxiv.Search(**search_args)): + for paper in self.client.results( + search=arxiv.Search(**search_args)): if date_from_dt or date_to_dt: paper_date = getattr(paper, 'published', None) if paper_date is None: @@ -170,8 +175,7 @@ def _parse_yyyy_mm_dd(s: str, *, end_of_day: bool) -> datetime: **search_args, **extra_args, }, - response=response, - ) + response=response) except Exception as e: raise RuntimeError(f'Failed to perform search: {e}') from e @@ -181,7 +185,6 @@ def _parse_yyyy_mm_dd(s: str, *, end_of_day: bool) -> datetime: def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': """Return the tool definition for arXiv search engine.""" from ms_agent.llm.utils import Tool - return Tool( tool_name=cls.get_tool_name(), server_name=server_name, @@ -190,52 +193,62 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'type': 'object', 'properties': { 'query': { - 'type': 'string', - 'description': ( - 'Search query using quoted phrases for exact matches ' - '(e.g., \'"machine learning" OR "deep learning"\') or ' - 'specific technical terms. Avoid overly broad or generic terms.' - ), + 'type': + 'string', + 'description': + ('Search query using quoted phrases for exact matches ' + '(e.g., \'"machine learning" OR "deep learning"\') or ' + 'specific technical terms. Avoid overly broad or generic terms.' + ), }, 'num_results': { - 'type': 'integer', - 'minimum': 1, - 'maximum': 15, - 'description': ( - 'Maximum number of results to return. Default is 5.Use 5-15 for comprehensive searches.' - ), + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 15, + 'description': + ('Maximum number of results to return. Default is 5.' + 'Use 5-15 for comprehensive searches.'), }, 'date_from': { - 'type': 'string', - 'description': ( - 'Start date for papers (YYYY-MM-DD format). Use to find recent work, e.g., "2023-01-01".' - ), + 'type': + 'string', + 'description': + ('Start date for papers (YYYY-MM-DD format). ' + 'Use to find recent work, e.g., "2023-01-01".'), }, 'date_to': { - 'type': 'string', - 'description': ( - 'End date for papers (YYYY-MM-DD format). ' - 'Use with date_from for historical windows, e.g., "2020-12-31".' - ), + 'type': + 'string', + 'description': + ('End date for papers (YYYY-MM-DD format). ' + 'Use with date_from for historical windows, e.g., "2020-12-31".' + ), }, 'categories': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': ( - 'Strongly recommended: arXiv categories to focus search ' - '(e.g., ["cs.AI", "cs.MA"] for agent research, ["cs.LG"] for ML, ' - '["cs.CL"] for NLP, ["cs.CV"] for computer vision). ' - 'Greatly improves relevance.' - ), + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + ('Strongly recommended: arXiv categories to focus search ' + '(e.g., ["cs.AI", "cs.MA"] for agent research, ["cs.LG"] for ML, ' + '["cs.CL"] for NLP, ["cs.CV"] for computer vision). ' + 'Greatly improves relevance.'), }, 'sort_by': { - 'type': 'string', - 'enum': ['relevance', 'submittedDate', 'lastUpdatedDate'], - 'description': ( - 'How to sort results. "relevance" for best match, ' - '"submittedDate" for newest submissions, ' - '"lastUpdatedDate" for recently updated. Default is "relevance".' - ), + 'type': + 'string', + 'enum': + ['relevance', 'submittedDate', 'lastUpdatedDate'], + 'description': + ('How to sort results. "relevance" for best match, ' + '"submittedDate" for newest submissions, ' + '"lastUpdatedDate" for recently updated. Default is "relevance".' + ), }, 'sort_order': { 'type': 'string', @@ -257,8 +270,8 @@ def build_request_from_args(cls, **kwargs) -> ArxivSearchRequest: categories = [str(c).strip() for c in categories if str(c).strip()] if not _validate_categories(categories): logger.warning( - f"Invalid arXiv categories provided: {kwargs.get('categories')}. Ignoring categories filter." - ) + f"Invalid arXiv categories provided: {kwargs.get('categories')}. " + 'Ignoring categories filter.') categories = None # Build final query by AND-ing base query with category filter (OR across categories) diff --git a/ms_agent/tools/search/content_optimizer.py b/ms_agent/tools/search/content_optimizer.py index c316ac7e0..998b13b66 100644 --- a/ms_agent/tools/search/content_optimizer.py +++ b/ms_agent/tools/search/content_optimizer.py @@ -1,6 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio -import json import os import re from concurrent.futures import ThreadPoolExecutor @@ -9,12 +8,12 @@ from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse -from omegaconf import DictConfig, OmegaConf - +import json from ms_agent.llm.openai_llm import OpenAI from ms_agent.llm.utils import Message from ms_agent.utils.logger import get_logger from ms_agent.utils.thread_util import DaemonThreadPoolExecutor +from omegaconf import DictConfig, OmegaConf logger = get_logger() @@ -144,7 +143,6 @@ @dataclass class SearchResultMeta: """Metadata for a search result used in reranking.""" - url: str title: str snippet: str = '' @@ -157,7 +155,6 @@ class SearchResultMeta: @dataclass class SummaryResult: """Result of content summarization.""" - summary: str key_excerpts: str original_length: int @@ -181,7 +178,6 @@ def total_tokens(self) -> int: @dataclass class ContentOptimizerConfig: """Configuration for content optimization.""" - # Summarization settings summarizer_model: str = 'qwen-flash' summarizer_base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1' @@ -204,8 +200,7 @@ class ContentOptimizerConfig: 'blog': 0.6, # Technical blogs 'forum': 0.4, # Forums, Q&A sites 'unknown': 0.5, - } - ) + }) # Domain patterns for source classification @@ -311,7 +306,8 @@ def classify_source(url: str) -> str: return 'paper' # Check for documentation indicators - if any(doc_pattern in domain for doc_pattern in ['docs.', 'documentation.', 'developer.']): + if any(doc_pattern in domain + for doc_pattern in ['docs.', 'documentation.', 'developer.']): return 'official' for news_domain in NEWS_DOMAINS: @@ -379,7 +375,11 @@ def _build_llm_config(self) -> DictConfig: 'openai_base_url': self.config.summarizer_base_url, 'openai_api_key': self.config.summarizer_api_key, }, - 'generation_config': {'extra_body': {'enable_thinking': False}}, + 'generation_config': { + 'extra_body': { + 'enable_thinking': False + } + }, } return OmegaConf.create(config_dict) @@ -396,7 +396,9 @@ async def initialize(self) -> None: thread_name_prefix='content_summarizer_', ) self._initialized = True - logger.info(f'ContentSummarizer initialized with model: {self.config.summarizer_model}') + logger.info( + f'ContentSummarizer initialized with model: {self.config.summarizer_model}' + ) except Exception as e: logger.error(f'Failed to initialize ContentSummarizer: {e}') raise @@ -428,7 +430,8 @@ def _parse_summary_response(self, response_text: str) -> Tuple[str, str]: Tuple of (summary, key_excerpts) """ # Try to find JSON in the response - json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', response_text, re.DOTALL) + json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', + response_text, re.DOTALL) if json_match: try: data = json.loads(json_match.group(1)) @@ -442,7 +445,7 @@ def _parse_summary_response(self, response_text: str) -> Tuple[str, str]: start_idx = response_text.find('{') end_idx = response_text.rfind('}') if start_idx != -1 and end_idx != -1: - json_str = response_text[start_idx : end_idx + 1] + json_str = response_text[start_idx:end_idx + 1] data = json.loads(json_str) return data.get('summary', ''), data.get('key_excerpts', '') except json.JSONDecodeError: @@ -465,7 +468,10 @@ def _call_llm_sync(self, prompt: str) -> Message: response = self._llm.generate(messages) return response - async def summarize(self, content: str, task_context: str = '', language: str = 'auto') -> SummaryResult: + async def summarize(self, + content: str, + task_context: str = '', + language: str = 'auto') -> SummaryResult: """ Summarize webpage content using the configured LLM. @@ -494,12 +500,13 @@ async def summarize(self, content: str, task_context: str = '', language: str = ) # Truncate content if too long - content_to_summarize = content[: self.config.max_content_chars] + content_to_summarize = content[:self.config.max_content_chars] # Detect language and select prompt if language == 'auto': # Simple heuristic: check for Chinese characters - chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', content_to_summarize[:1000])) + chinese_chars = len( + re.findall(r'[\u4e00-\u9fff]', content_to_summarize[:1000])) language = 'zh' if chinese_chars > 30 else 'en' prompt_template = SUMMARIZE_WEBPAGE_PROMPT if language == 'zh' else SUMMARIZE_WEBPAGE_PROMPT_EN @@ -517,7 +524,8 @@ async def summarize(self, content: str, task_context: str = '', language: str = # Run synchronous LLM call in executor with timeout loop = asyncio.get_event_loop() response_msg: Message = await asyncio.wait_for( - loop.run_in_executor(self._executor, self._call_llm_sync, prompt), + loop.run_in_executor(self._executor, self._call_llm_sync, + prompt), timeout=self.config.summarization_timeout, ) @@ -532,8 +540,8 @@ async def summarize(self, content: str, task_context: str = '', language: str = compression_ratio = compressed_length / original_length if original_length > 0 else 1.0 logger.debug( - f'Content summarized: {original_length} -> {compressed_length} chars (ratio: {compression_ratio:.2%})' - ) + f'Content summarized: {original_length} -> {compressed_length} chars ' + f'(ratio: {compression_ratio:.2%})') return SummaryResult( summary=summary, @@ -542,19 +550,25 @@ async def summarize(self, content: str, task_context: str = '', language: str = compressed_length=compressed_length, compression_ratio=compression_ratio, success=True, - model=str(getattr(self._llm, 'model', '') or self.config.summarizer_model), - prompt_tokens=int(getattr(response_msg, 'prompt_tokens', 0) or 0), - completion_tokens=int(getattr(response_msg, 'completion_tokens', 0) or 0), - cached_tokens=int(getattr(response_msg, 'cached_tokens', 0) or 0), - cache_creation_input_tokens=int(getattr(response_msg, 'cache_creation_input_tokens', 0) or 0), + model=str( + getattr(self._llm, 'model', '') + or self.config.summarizer_model), + prompt_tokens=int( + getattr(response_msg, 'prompt_tokens', 0) or 0), + completion_tokens=int( + getattr(response_msg, 'completion_tokens', 0) or 0), + cached_tokens=int( + getattr(response_msg, 'cached_tokens', 0) or 0), + cache_creation_input_tokens=int( + getattr(response_msg, 'cache_creation_input_tokens', 0) + or 0), api_calls=int(getattr(response_msg, 'api_calls', 0) or 0), ) except asyncio.TimeoutError: logger.warning( f'Summarization timed out after {self.config.summarization_timeout}s, ' - 'returning truncated original content' - ) + 'returning truncated original content') # Return truncated original content truncated = content_to_summarize[:100000] return SummaryResult( @@ -569,7 +583,9 @@ async def summarize(self, content: str, task_context: str = '', language: str = ) except Exception as e: - logger.warning(f'Summarization failed: {e}, returning truncated original content') + logger.warning( + f'Summarization failed: {e}, returning truncated original content' + ) truncated = content_to_summarize[:100000] return SummaryResult( summary=truncated, @@ -604,7 +620,8 @@ async def summarize_batch( semaphore = asyncio.Semaphore(max_concurrent) - async def _bounded_summarize(url: str, content: str) -> Tuple[str, SummaryResult]: + async def _bounded_summarize( + url: str, content: str) -> Tuple[str, SummaryResult]: async with semaphore: result = await self.summarize(content, task_context) return url, result @@ -727,7 +744,8 @@ def _compute_recency_score(self, published_at: str) -> float: # Calculate months difference if month: - months_diff = (current_year - year) * 12 + (current_month - month) + months_diff = (current_year - year) * 12 + ( + current_month - month) else: months_diff = (current_year - year) * 12 @@ -768,7 +786,8 @@ def _build_result_meta( url = result.get('url', '') title = result.get('title', '') snippet = result.get('summary', '') or result.get('snippet', '') - published_at = result.get('published_date', '') or result.get('published_at', '') + published_at = result.get('published_date', '') or result.get( + 'published_at', '') source_type = classify_source(url) @@ -780,7 +799,9 @@ def _build_result_meta( # Weighted combination # Title relevance: 40%, Source type: 30%, Recency: 20%, Snippet: 10% - relevance_score = title_relevance * 0.4 + source_weight * 0.3 + recency_score * 0.2 + snippet_relevance * 0.1 + relevance_score = ( + title_relevance * 0.4 + source_weight * 0.3 + recency_score * 0.2 + + snippet_relevance * 0.1) return SearchResultMeta( url=url, @@ -818,7 +839,10 @@ def rerank( return results[:k] # Build metadata for all results - metas = [self._build_result_meta(result, idx, query) for idx, result in enumerate(results)] + metas = [ + self._build_result_meta(result, idx, query) + for idx, result in enumerate(results) + ] # Sort by relevance score (descending) sorted_pairs = sorted( @@ -843,7 +867,8 @@ def rerank( return top_results @staticmethod - def deduplicate_by_url(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def deduplicate_by_url( + results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Remove duplicate results based on URL. @@ -965,15 +990,16 @@ async def summarize_contents( if not self._initialized: await self.initialize() - results = await self.summarizer.summarize_batch(contents, task_context, max_concurrent) + results = await self.summarizer.summarize_batch( + contents, task_context, max_concurrent) # Convert SummaryResult to formatted strings formatted = {} for url, result in results.items(): if result.key_excerpts: formatted[url] = ( - f'\n{result.summary}\n\n\n\n{result.key_excerpts}\n' - ) + f'\n{result.summary}\n\n\n' + f'\n{result.key_excerpts}\n') else: formatted[url] = result.summary @@ -994,7 +1020,8 @@ async def summarize_contents_with_usage( if not self._initialized: await self.initialize() - results = await self.summarizer.summarize_batch(contents, task_context, max_concurrent) + results = await self.summarizer.summarize_batch( + contents, task_context, max_concurrent) formatted: Dict[str, str] = {} # Aggregate usage across results (best-effort; failures may have 0 usage) @@ -1008,8 +1035,8 @@ async def summarize_contents_with_usage( for url, result in results.items(): if result.key_excerpts: formatted[url] = ( - f'\n{result.summary}\n\n\n\n{result.key_excerpts}\n' - ) + f'\n{result.summary}\n\n\n' + f'\n{result.key_excerpts}\n') else: formatted[url] = result.summary @@ -1038,7 +1065,8 @@ async def summarize_contents_with_usage( def create_content_optimizer( summarizer_model: str = 'qwen-flash', - summarizer_base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1', + summarizer_base_url: + str = 'https://dashscope.aliyuncs.com/compatible-mode/v1', summarizer_api_key: Optional[str] = None, max_content_chars: int = 500000, enable_rerank: bool = False, diff --git a/ms_agent/tools/search/exa/schema.py b/ms_agent/tools/search/exa/schema.py index b306fb8e8..a80a1c401 100644 --- a/ms_agent/tools/search/exa/schema.py +++ b/ms_agent/tools/search/exa/schema.py @@ -1,12 +1,14 @@ # flake8: noqa -import json from dataclasses import dataclass, field -from exa_py.api import SearchResponse from typing import Any, Dict, List, Optional +import json +from exa_py.api import SearchResponse + @dataclass class ExaSearchRequest: + # The search query string query: str @@ -42,7 +44,7 @@ def to_dict(self) -> Dict[str, Any]: 'start_published_date': self.start_published_date, 'end_published_date': self.end_published_date, 'start_crawl_date': self.start_crawl_date, - 'end_crawl_date': self.end_crawl_date, + 'end_crawl_date': self.end_crawl_date } def to_json(self) -> str: @@ -54,6 +56,7 @@ def to_json(self) -> str: @dataclass class ExaSearchResult: + # The original search query string query: str @@ -78,19 +81,22 @@ def to_list(self): res_list: List[Any] = [] for res in self.response.results: - res_list.append( - { - 'url': getattr(res, 'url', ''), - 'id': getattr(res, 'id', ''), - 'title': getattr(res, 'title'), - 'published_date': getattr(res, 'published_date', ''), - 'summary': getattr(res, 'summary', ''), - # 'text': getattr(res, 'text', ''), - # 'highlights': getattr(res, 'highlights', ''), - # 'highlight_scores': getattr(res, 'highlight_scores', ''), - # 'markdown': getattr(res, 'markdown', ''), - } - ) + res_list.append({ + 'url': + getattr(res, 'url', ''), + 'id': + getattr(res, 'id', ''), + 'title': + getattr(res, 'title'), + 'published_date': + getattr(res, 'published_date', ''), + 'summary': + getattr(res, 'summary', ''), + # 'text': getattr(res, 'text', ''), + # 'highlights': getattr(res, 'highlights', ''), + # 'highlight_scores': getattr(res, 'highlight_scores', ''), + # 'markdown': getattr(res, 'markdown', ''), + }) return res_list @@ -134,19 +140,18 @@ def load_from_disk(file_path: str) -> List[Dict[str, Any]]: return data -def dump_batch_search_results(results: List[ExaSearchResult], file_path: str) -> None: +def dump_batch_search_results(results: List[ExaSearchResult], + file_path: str) -> None: """ Dump a batch of search results to a local file. """ out_list: List[Dict[str, Any]] = [] for res in results: - out_list.append( - { - 'query': res.query, - 'arguments': res.arguments, - 'results': res.to_list(), - } - ) + out_list.append({ + 'query': res.query, + 'arguments': res.arguments, + 'results': res.to_list(), + }) with open(file_path, 'w', encoding='utf-8') as f: json.dump(out_list, f, ensure_ascii=False, indent=2) diff --git a/ms_agent/tools/search/exa/search.py b/ms_agent/tools/search/exa/search.py index c3eba0bd7..08fa6fb71 100644 --- a/ms_agent/tools/search/exa/search.py +++ b/ms_agent/tools/search/exa/search.py @@ -1,9 +1,9 @@ # flake8: noqa import os import threading -from exa_py import Exa from typing import TYPE_CHECKING, List, Optional, Set, Union +from exa_py import Exa from ms_agent.tools.search.exa.schema import ExaSearchRequest, ExaSearchResult from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType from ms_agent.utils.logger import get_logger @@ -36,9 +36,13 @@ class ExaSearch(SearchEngine): _global_exhausted_keys: Set[str] = set() _global_lock = threading.Lock() - def __init__(self, api_key: Union[str, list, None] = None, api_keys: Union[str, list, None] = None): + def __init__(self, + api_key: Union[str, list, None] = None, + api_keys: Union[str, list, None] = None): all_keys = self._collect_keys(api_key, api_keys) - assert all_keys, 'EXA_API_KEY or EXA_API_KEYS must be set either as arguments or as environment variables' + assert all_keys, ( + 'EXA_API_KEY or EXA_API_KEYS must be set either as arguments ' + 'or as environment variables') self._api_keys: List[str] = all_keys self._lock = threading.Lock() @@ -56,12 +60,11 @@ def __init__(self, api_key: Union[str, list, None] = None, api_keys: Union[str, if len(all_keys) > 1: with ExaSearch._global_lock: - n_exhausted = sum(1 for k in all_keys if k in ExaSearch._global_exhausted_keys) - logger.info( - f'Exa key pool: {len(all_keys)} keys, ' - f'{n_exhausted} previously exhausted, ' - f'starting at key {start_idx + 1}/{len(all_keys)}' - ) + n_exhausted = sum(1 for k in all_keys + if k in ExaSearch._global_exhausted_keys) + logger.info(f'Exa key pool: {len(all_keys)} keys, ' + f'{n_exhausted} previously exhausted, ' + f'starting at key {start_idx + 1}/{len(all_keys)}') @staticmethod def _collect_keys( @@ -116,7 +119,8 @@ def _add_source(value): def _is_credits_exhausted(error: Exception) -> bool: """Detect Exa 402 / NO_MORE_CREDITS errors.""" msg = str(error) - return '402' in msg and ('credits' in msg.lower() or 'NO_MORE_CREDITS' in msg) + return ('402' in msg + and ('credits' in msg.lower() or 'NO_MORE_CREDITS' in msg)) @staticmethod def _mask_key(key: str) -> str: @@ -158,7 +162,8 @@ def search(self, search_request: ExaSearchRequest) -> ExaSearchResult: key_idx = self._current_key_idx try: - search_result.response = client.search_and_contents(**search_args) + search_result.response = client.search_and_contents( + **search_args) return search_result except Exception as e: if not self._is_credits_exhausted(e): @@ -176,30 +181,28 @@ def search(self, search_request: ExaSearchRequest) -> ExaSearchResult: ) rotated = False for i in range(len(self._api_keys)): - if i not in instance_exhausted and not self._is_key_exhausted(i): + if i not in instance_exhausted and not self._is_key_exhausted( + i): self._current_key_idx = i self.client = Exa(api_key=self._api_keys[i]) - logger.info( - f'Rotated to Exa API key ' - f'{self._mask_key(self._api_keys[i])} ' - f'({i + 1}/{len(self._api_keys)})' - ) + logger.info(f'Rotated to Exa API key ' + f'{self._mask_key(self._api_keys[i])} ' + f'({i + 1}/{len(self._api_keys)})') rotated = True break if not rotated: raise RuntimeError( - f'All {len(self._api_keys)} Exa API keys have been exhausted. Last error: {e}' - ) from e + f'All {len(self._api_keys)} Exa API keys have ' + f'been exhausted. Last error: {e}') from e raise RuntimeError( - f'All {len(self._api_keys)} Exa API keys have been exhausted. Last error: {last_error}' - ) from last_error + f'All {len(self._api_keys)} Exa API keys have been exhausted. ' + f'Last error: {last_error}') from last_error @classmethod def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': """Return the tool definition for Exa search engine.""" from ms_agent.llm.utils import Tool - return Tool( tool_name=cls.get_tool_name(), server_name=server_name, @@ -207,44 +210,50 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'Search the web using Exa neural search engine. ' 'Best for: semantic understanding, finding relevant content, ' 'recent web pages with date filtering. ' - 'Supports neural search (meaning-based) and keyword search.' - ), + 'Supports neural search (meaning-based) and keyword search.'), parameters={ 'type': 'object', 'properties': { 'query': { - 'type': 'string', - 'description': ( - 'The search query. For neural search, use natural language ' - 'descriptions. For keyword search, use Google-style queries.' - ), + 'type': + 'string', + 'description': + ('The search query. For neural search, use natural language ' + 'descriptions. For keyword search, use Google-style queries.' + ), }, 'num_results': { - 'type': 'integer', - 'minimum': 1, - 'maximum': 10, - 'description': 'Number of results to return. Default is 5.', + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 10, + 'description': + 'Number of results to return. Default is 5.', }, 'type': { - 'type': 'string', + 'type': + 'string', 'enum': ['auto', 'neural', 'keyword'], - 'description': ( - 'Search type. "neural" for semantic similarity, ' - '"keyword" for exact matching, "auto" to let Exa decide. ' - 'Default is "auto".' - ), + 'description': + ('Search type. "neural" for semantic similarity, ' + '"keyword" for exact matching, "auto" to let Exa decide. ' + 'Default is "auto".'), }, 'start_published_date': { - 'type': 'string', - 'description': ( - 'Filter results published on/after this date. Format: YYYY-MM-DD (e.g., "2024-01-01").' - ), + 'type': + 'string', + 'description': + ('Filter results published on/after this date. ' + 'Format: YYYY-MM-DD (e.g., "2024-01-01").'), }, 'end_published_date': { - 'type': 'string', - 'description': ( - 'Filter results published on/before this date. Format: YYYY-MM-DD (e.g., "2024-12-31").' - ), + 'type': + 'string', + 'description': + ('Filter results published on/before this date. ' + 'Format: YYYY-MM-DD (e.g., "2024-12-31").'), }, }, 'required': ['query'], diff --git a/ms_agent/tools/search/localsearch_tool.py b/ms_agent/tools/search/localsearch_tool.py index 7ec628dbf..d8ab34ef7 100644 --- a/ms_agent/tools/search/localsearch_tool.py +++ b/ms_agent/tools/search/localsearch_tool.py @@ -5,9 +5,12 @@ from pathlib import Path from typing import Any, Dict, List, Optional +from ms_agent.tools.search.sirchmunk_search import ( + SirchmunkSearch, + effective_localsearch_settings, +) from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase -from ms_agent.tools.search.sirchmunk_search import SirchmunkSearch, effective_localsearch_settings from ms_agent.utils.logger import get_logger logger = get_logger() @@ -57,7 +60,9 @@ def _resolved_localsearch_paths_from_config(config) -> List[str]: def _format_configured_roots(paths: List[str]) -> str: if not paths: - return '(none — set tools.localsearch.paths in agent config, or legacy knowledge_search.paths)' + return ( + '(none — set tools.localsearch.paths in agent config, ' + 'or legacy knowledge_search.paths)') return '\n'.join(f'- {p}' for p in paths) @@ -86,10 +91,12 @@ def __init__(self, config, **kwargs): if tool_cfg is not None: self.exclude_func(tool_cfg) self._searcher: Optional[SirchmunkSearch] = None - self._configured_roots: List[str] = _resolved_localsearch_paths_from_config(config) + self._configured_roots: List[str] = ( + _resolved_localsearch_paths_from_config(config)) def _tool_description(self) -> str: - return _LOCALSEARCH_DESCRIPTION.format(configured_roots=_format_configured_roots(self._configured_roots)) + return _LOCALSEARCH_DESCRIPTION.format( + configured_roots=_format_configured_roots(self._configured_roots)) def _paths_param_description(self) -> str: roots = _format_configured_roots(self._configured_roots) @@ -97,8 +104,7 @@ def _paths_param_description(self) -> str: 'Optional. Narrow search to specific files or directories under the ' 'configured roots below. Each path must exist on disk and lie under ' 'one of these roots (or be exactly one of them).\n' - f'Configured roots:\n{roots}' - ) + f'Configured roots:\n{roots}') def _ensure_searcher(self) -> SirchmunkSearch: if self._searcher is None: @@ -116,43 +122,60 @@ async def _get_tools_inner(self) -> Dict[str, List[Tool]]: server_name=_SERVER, description=self._tool_description(), parameters={ - 'type': 'object', + 'type': + 'object', 'properties': { 'query': { - 'type': 'string', - 'description': 'Search keywords or natural-language question about local content.', + 'type': + 'string', + 'description': + 'Search keywords or natural-language question about local content.', }, 'paths': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': self._paths_param_description(), + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + self._paths_param_description(), }, 'mode': { - 'type': 'string', + 'type': + 'string', 'enum': ['FAST', 'DEEP', 'FILENAME_ONLY'], - 'description': 'Search mode; omit to use agent default (usually FAST).', + 'description': + 'Search mode; omit to use agent default (usually FAST).', }, 'max_depth': { 'type': 'integer', 'minimum': 1, 'maximum': 20, - 'description': 'Max directory depth for filesystem search.', + 'description': + 'Max directory depth for filesystem search.', }, 'top_k_files': { 'type': 'integer', 'minimum': 1, 'maximum': 20, - 'description': 'Max files for evidence / filename hits.', + 'description': + 'Max files for evidence / filename hits.', }, 'include': { 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'Glob patterns to include (e.g. *.py, *.md).', + 'items': { + 'type': 'string' + }, + 'description': + 'Glob patterns to include (e.g. *.py, *.md).', }, 'exclude': { 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'Glob patterns to exclude (e.g. *.pyc).', + 'items': { + 'type': 'string' + }, + 'description': + 'Glob patterns to exclude (e.g. *.pyc).', }, }, 'required': ['query'], @@ -161,7 +184,8 @@ async def _get_tools_inner(self) -> Dict[str, List[Tool]]: ] } - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict): + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict): del server_name if tool_name != _TOOL: return f'Unknown tool: {tool_name}' @@ -195,11 +219,11 @@ async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict): if paths_arg: resolved_paths = searcher.resolve_tool_paths(paths_arg) if not resolved_paths: - roots = _format_configured_roots(self._configured_roots) + roots = _format_configured_roots( + self._configured_roots) return ( 'Error: `paths` are invalid. Each path must exist on disk and lie ' - 'under one of these configured roots:\n' + roots - ) + 'under one of these configured roots:\n' + roots) answer = await searcher.query( query, @@ -242,7 +266,8 @@ async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict): result_parts.append('\nSource paths:') for item in excerpts[:12]: meta = item.get('metadata') or {} - result_parts.append(f'- {meta.get("source", "?")}') + result_parts.append( + f'- {meta.get("source", "?")}') result_text = '\n'.join(result_parts) return { @@ -254,3 +279,4 @@ async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict): except Exception as exc: logger.warning(f'localsearch failed: {exc}') return f'Local search failed: {exc}' + diff --git a/ms_agent/tools/search/search_base.py b/ms_agent/tools/search/search_base.py index 443d4c79b..8a10c9d20 100644 --- a/ms_agent/tools/search/search_base.py +++ b/ms_agent/tools/search/search_base.py @@ -1,11 +1,12 @@ # flake8: noqa import enum -import json import os from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar +import json + if TYPE_CHECKING: from ms_agent.llm.utils import Tool @@ -62,7 +63,10 @@ class SearchResponse(Generic[T]): class SearchRequest(ABC): """Abstract base class for search requests.""" - def __init__(self, query: str, num_results: Optional[int] = 10, **kwargs: Any): + def __init__(self, + query: str, + num_results: Optional[int] = 10, + **kwargs: Any): """ Initialize SearchRequest with search parameters. @@ -92,7 +96,10 @@ def to_json(self) -> str: class SearchResult(ABC): """Base class for search results.""" - def __init__(self, query: str, arguments: Optional[Dict[str, Any]] = None, response: Any = None): + def __init__(self, + query: str, + arguments: Optional[Dict[str, Any]] = None, + response: Any = None): """ Initialize SearchResult. @@ -130,17 +137,15 @@ def to_list(self) -> List[Dict[str, Any]]: res_list: List[Dict[str, Any]] = [] for res in self.response.results: - res_list.append( - { - 'url': res.url, - 'id': res.id, - 'title': res.title, - 'highlights': res.highlights, - 'highlight_scores': res.highlight_scores, - 'summary': res.summary, - 'markdown': res.markdown, - } - ) + res_list.append({ + 'url': res.url, + 'id': res.id, + 'title': res.title, + 'highlights': res.highlights, + 'highlight_scores': res.highlight_scores, + 'summary': res.summary, + 'markdown': res.markdown, + }) return res_list @@ -196,7 +201,6 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': Tool definition dict """ from ms_agent.llm.utils import Tool - return Tool( tool_name=cls.get_tool_name(), server_name=server_name, @@ -231,4 +235,5 @@ def build_request_from_args(cls, **kwargs) -> SearchRequest: Returns: SearchRequest instance """ - raise NotImplementedError(f'{cls.__name__} must implement build_request_from_args') + raise NotImplementedError( + f'{cls.__name__} must implement build_request_from_args') diff --git a/ms_agent/tools/search/search_request.py b/ms_agent/tools/search/search_request.py index 81c8e13d7..e2aed6666 100644 --- a/ms_agent/tools/search/search_request.py +++ b/ms_agent/tools/search/search_request.py @@ -114,7 +114,7 @@ def get_rewrite_prompt(self) -> str: f'\n2. 其中,query参数的值直接使用用户原始输入,即:{self.user_prompt}' f'\n3. 参数需要符合搜索引擎的要求,num_results需要根据实际问题的复杂程度来估算,最大25,最小1,对于复杂的问题,num_results的值需要尽量大;' f'\n4. start_published_date和end_published_date需要根据实际问题的时间范围来估算,默认均为None。' - f'当前日期为:{datetime.now().strftime('%Y-%m-%d')}') + f'当前日期为:{datetime.now().strftime("%Y-%m-%d")}') def create_request(self, search_request_d: Dict[str, Any]) -> ExaSearchRequest: diff --git a/ms_agent/tools/search/serpapi/__init__.py b/ms_agent/tools/search/serpapi/__init__.py index cf16db488..8a46380ac 100644 --- a/ms_agent/tools/search/serpapi/__init__.py +++ b/ms_agent/tools/search/serpapi/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa -from ms_agent.tools.search.serpapi.schema import SerpApiSearchRequest, SerpApiSearchResult +from ms_agent.tools.search.serpapi.schema import (SerpApiSearchRequest, + SerpApiSearchResult) from ms_agent.tools.search.serpapi.search import SerpApiSearch diff --git a/ms_agent/tools/search/serpapi/schema.py b/ms_agent/tools/search/serpapi/schema.py index 5643ac4b7..633250473 100644 --- a/ms_agent/tools/search/serpapi/schema.py +++ b/ms_agent/tools/search/serpapi/schema.py @@ -2,7 +2,8 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional -from ms_agent.tools.search.search_base import BaseResult, SearchRequest, SearchResponse, SearchResult +from ms_agent.tools.search.search_base import (BaseResult, SearchRequest, + SearchResponse, SearchResult) class SerpApiSearchRequest(SearchRequest): @@ -10,7 +11,11 @@ class SerpApiSearchRequest(SearchRequest): A class representing a search request to SerpApi. """ - def __init__(self, query: str, num_results: Optional[int] = 5, location: Optional[str] = None, **kwargs: Any): + def __init__(self, + query: str, + num_results: Optional[int] = 5, + location: Optional[str] = None, + **kwargs: Any): """ Initialize SerpApiSearchRequest with search parameters. @@ -29,13 +34,21 @@ def to_dict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The parameters as a dictionary """ - return {'q': self.query, 'num': self.num_results, 'location': self.location} + return { + 'q': self.query, + 'num': self.num_results, + 'location': self.location + } class SerpApiSearchResult(SearchResult): """SerpApi search result implementation.""" - def __init__(self, provider: str, query: str, arguments: Dict[str, Any] = None, response: Dict[str, Any] = None): + def __init__(self, + provider: str, + query: str, + arguments: Dict[str, Any] = None, + response: Dict[str, Any] = None): """ Initialize SerpApiSearchResult. @@ -63,7 +76,8 @@ def _process_results(self) -> SearchResponse: processed = [] if self.provider.lower() in ['google', 'bing', 'baidu']: # Extract organic results - organic_results: List[Dict[str, Any]] = self.response.get('organic_results', []) + organic_results: List[Dict[str, Any]] = self.response.get( + 'organic_results', []) for res in organic_results: processed.append( BaseResult( @@ -73,10 +87,9 @@ def _process_results(self) -> SearchResponse: highlights=res.get('snippet_highlighted_words'), highlight_scores=None, summary=None, - markdown=None, - ) - ) + markdown=None)) else: - raise NotImplementedError(f"Provider '{self.provider}' is not supported yet.") + raise NotImplementedError( + f"Provider '{self.provider}' is not supported yet.") return SearchResponse(results=processed) diff --git a/ms_agent/tools/search/serpapi/search.py b/ms_agent/tools/search/serpapi/search.py index adebbe044..d42fef413 100644 --- a/ms_agent/tools/search/serpapi/search.py +++ b/ms_agent/tools/search/serpapi/search.py @@ -3,7 +3,8 @@ from typing import TYPE_CHECKING from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType -from ms_agent.tools.search.serpapi.schema import SerpApiSearchRequest, SerpApiSearchResult +from ms_agent.tools.search.serpapi.schema import (SerpApiSearchRequest, + SerpApiSearchResult) if TYPE_CHECKING: from ms_agent.llm.utils import Tool @@ -20,13 +21,16 @@ class SerpApiSearch(SearchEngine): engine_type = SearchEngineType.SERPAPI def __init__(self, api_key: str = None, provider: str = None): + api_key = api_key or os.getenv('SERPAPI_API_KEY') assert api_key, 'SERPAPI_API_KEY must be set either as an argument or as an environment variable' self.provider = (provider or 'google').lower() - self.client = self._get_search_client(provider=self.provider, api_key=api_key) + self.client = self._get_search_client( + provider=self.provider, api_key=api_key) - def search(self, search_request: SerpApiSearchRequest) -> SerpApiSearchResult: + def search(self, + search_request: SerpApiSearchRequest) -> SerpApiSearchResult: """ Perform a search using SerpApi and return the results. @@ -42,8 +46,10 @@ def search(self, search_request: SerpApiSearchRequest) -> SerpApiSearchResult: self.client.params_dict.update(search_args) response = self.client.get_dict() search_result = SerpApiSearchResult( - provider=self.provider, query=search_request.query, arguments=search_args, response=response - ) + provider=self.provider, + query=search_request.query, + arguments=search_args, + response=response) except Exception as e: raise RuntimeError(f'Failed to perform search: {e}') from e @@ -53,7 +59,6 @@ def search(self, search_request: SerpApiSearchRequest) -> SerpApiSearchResult: def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': """Return the tool definition for SerpApi search engine.""" from ms_agent.llm.utils import Tool - return Tool( tool_name=cls.get_tool_name(), server_name=server_name, @@ -62,28 +67,33 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'Default provider is Google. ' 'Best for: general web search, current events, news, ' 'real-time information, and location-specific results. ' - 'Supports Google search operators.' - ), + 'Supports Google search operators.'), parameters={ 'type': 'object', 'properties': { 'query': { - 'type': 'string', - 'description': ( - 'Google-style search query. Use operators as needed: ' - 'quotes for exact phrases ("..."), OR, -term to exclude. ' - 'Date limits: before:YYYY-MM-DD, after:YYYY-MM-DD.' - ), + 'type': + 'string', + 'description': + ('Google-style search query. Use operators as needed: ' + 'quotes for exact phrases ("..."), OR, -term to exclude. ' + 'Date limits: before:YYYY-MM-DD, after:YYYY-MM-DD.'), }, 'num_results': { - 'type': 'integer', - 'minimum': 1, - 'maximum': 10, - 'description': 'Number of results to return. Default is 5.', + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 10, + 'description': + 'Number of results to return. Default is 5.', }, 'location': { - 'type': 'string', - 'description': ('Geographic location filter. Default is null'), + 'type': + 'string', + 'description': + ('Geographic location filter. Default is null'), }, }, 'required': ['query'], diff --git a/ms_agent/tools/search/sirchmunk_search.py b/ms_agent/tools/search/sirchmunk_search.py index a49cae982..cd86f9d63 100644 --- a/ms_agent/tools/search/sirchmunk_search.py +++ b/ms_agent/tools/search/sirchmunk_search.py @@ -38,7 +38,8 @@ def effective_localsearch_settings(config: DictConfig) -> Optional[Any]: tools = getattr(config, 'tools', None) tl = None if tools is not None: - tl = tools.get('localsearch') if hasattr(tools, 'get') else getattr(tools, 'localsearch', None) + tl = tools.get('localsearch') if hasattr(tools, 'get') else getattr( + tools, 'localsearch', None) ks = getattr(config, 'knowledge_search', None) if tl is not None and _paths_from_block(tl): @@ -80,13 +81,16 @@ def __init__(self, config: DictConfig): paths = rag_config.get('paths', []) if isinstance(paths, str): paths = [paths] - self.search_paths: List[str] = [str(Path(p).expanduser().resolve()) for p in paths] + self.search_paths: List[str] = [ + str(Path(p).expanduser().resolve()) for p in paths + ] _work_path = rag_config.get('work_path', './.sirchmunk') self.work_path: Path = Path(_work_path).expanduser().resolve() self.reuse_knowledge = rag_config.get('reuse_knowledge', True) - self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', 0.85) + self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', + 0.85) self.cluster_sim_top_k = rag_config.get('cluster_sim_top_k', 3) self.search_mode = rag_config.get('mode', 'FAST') self.max_loops = rag_config.get('max_loops', 10) @@ -96,19 +100,23 @@ def __init__(self, config: DictConfig): self.llm_base_url = rag_config.get('llm_base_url', None) self.llm_model_name = rag_config.get('llm_model_name', None) - if self.llm_api_key is None or self.llm_base_url is None or self.llm_model_name is None: + if (self.llm_api_key is None or self.llm_base_url is None + or self.llm_model_name is None): llm_config = config.get('llm', {}) if llm_config: service = getattr(llm_config, 'service', 'dashscope') if self.llm_api_key is None: - self.llm_api_key = getattr(llm_config, f'{service}_api_key', None) + self.llm_api_key = getattr(llm_config, + f'{service}_api_key', None) if self.llm_base_url is None: - self.llm_base_url = getattr(llm_config, f'{service}_base_url', None) + self.llm_base_url = getattr(llm_config, + f'{service}_base_url', None) if self.llm_model_name is None: self.llm_model_name = getattr(llm_config, 'model', None) self.embedding_model_id = rag_config.get('embedding_model', None) - self.embedding_model_cache_dir = rag_config.get('embedding_model_cache_dir', None) + self.embedding_model_cache_dir = rag_config.get( + 'embedding_model_cache_dir', None) self._searcher = None self._initialized = False @@ -127,15 +135,15 @@ def _validate_config(self, config: DictConfig): raise ValueError( 'Missing localsearch configuration. Add ' '`tools.localsearch` with non-empty `paths` (or legacy ' - '`knowledge_search.paths`).' - ) + '`knowledge_search.paths`).') paths = _paths_from_block(block) if not paths: raise ValueError( - 'tools.localsearch.paths (or legacy knowledge_search.paths) must be specified and non-empty' - ) + 'tools.localsearch.paths (or legacy knowledge_search.paths) ' + 'must be specified and non-empty') - def resolve_tool_paths(self, paths: Optional[List[str]]) -> Optional[List[str]]: + def resolve_tool_paths( + self, paths: Optional[List[str]]) -> Optional[List[str]]: """Restrict per-call paths to configured search roots.""" if not paths: return None @@ -148,9 +156,12 @@ def resolve_tool_paths(self, paths: Optional[List[str]]) -> Optional[List[str]]: if not p.exists(): logger.warning(f'localsearch: path does not exist, skipped: {p}') continue - allowed = any(p == r or p.is_relative_to(r) for r in roots) + allowed = any( + p == r or p.is_relative_to(r) for r in roots) if not allowed: - logger.warning(f'localsearch: path outside configured search roots, skipped: {p}') + logger.warning( + f'localsearch: path outside configured search roots, ' + f'skipped: {p}') continue cleaned.append(str(p)) return cleaned or None @@ -173,9 +184,13 @@ def _initialize_searcher(self): log_callback=self._log_callback_wrapper(), ) - embedding_model_id = self.embedding_model_id if self.embedding_model_id else None - embedding_cache_dir = self.embedding_model_cache_dir if self.embedding_model_cache_dir else None - embedding = EmbeddingUtil(model_id=embedding_model_id, cache_dir=embedding_cache_dir) + embedding_model_id = ( + self.embedding_model_id if self.embedding_model_id else None) + embedding_cache_dir = ( + self.embedding_model_cache_dir + if self.embedding_model_cache_dir else None) + embedding = EmbeddingUtil( + model_id=embedding_model_id, cache_dir=embedding_cache_dir) self._searcher = AgenticSearch( llm=llm, @@ -190,10 +205,14 @@ def _initialize_searcher(self): ) self._initialized = True - logger.info(f'SirschmunkSearch initialized with paths: {self.search_paths}') + logger.info( + f'SirschmunkSearch initialized with paths: {self.search_paths}' + ) except ImportError as e: - raise ImportError(f'Failed to import sirchmunk: {e}. Please install sirchmunk: pip install sirchmunk') + raise ImportError( + f'Failed to import sirchmunk: {e}. ' + 'Please install sirchmunk: pip install sirchmunk') except Exception as e: raise RuntimeError(f'Failed to initialize SirchmunkSearch: {e}') @@ -260,16 +279,19 @@ async def add_documents_from_files(self, file_paths: List[str]) -> bool: try: for file_path in file_paths: if Path(file_path).exists(): - await self._searcher.scan_directory(str(Path(file_path).parent)) + await self._searcher.scan_directory( + str(Path(file_path).parent)) return True except Exception as e: logger.error(f'Failed to scan files: {e}') return False return True - async def retrieve( - self, query: str, limit: int = 5, score_threshold: float = 0.7, **filters - ) -> List[Dict[str, Any]]: + async def retrieve(self, + query: str, + limit: int = 5, + score_threshold: float = 0.7, + **filters) -> List[Dict[str, Any]]: """Retrieve relevant documents using sirchmunk. Args: @@ -288,7 +310,8 @@ async def retrieve( try: mode = filters.get('mode', self.search_mode) max_loops = filters.get('max_loops', self.max_loops) - max_token_budget = filters.get('max_token_budget', self.max_token_budget) + max_token_budget = filters.get('max_token_budget', + self.max_token_budget) result = await self._searcher.search( query=query, @@ -301,9 +324,11 @@ async def retrieve( self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + self._cluster_cache_hit = getattr(result.cluster, + '_reused_from_cache', False) if hasattr(result.cluster, 'updated_at'): - self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + self._cluster_cache_hit_time = getattr( + result.cluster, 'updated_at', None) return self._parse_search_result(result, score_threshold, limit) @@ -349,7 +374,8 @@ async def query( mode_eff = mode_eff.strip().upper() allowed_modes = ('FAST', 'DEEP', 'FILENAME_ONLY') if mode_eff not in allowed_modes: - return f'Invalid mode {mode_eff!r}; use one of {allowed_modes}.' + return ( + f'Invalid mode {mode_eff!r}; use one of {allowed_modes}.') kw: Dict[str, Any] = dict( query=query, @@ -376,29 +402,34 @@ async def query( self._last_search_result = [] for item in result[:20]: if isinstance(item, dict): - src = item.get('path') or item.get('file_path') or item.get('file') or '' - self._last_search_result.append( - { - 'text': json.dumps(item, ensure_ascii=False), - 'score': 1.0, - 'metadata': { - 'source': str(src), - 'type': 'filename_match', - }, - } - ) + src = (item.get('path') or item.get('file_path') + or item.get('file') or '') + self._last_search_result.append({ + 'text': + json.dumps(item, ensure_ascii=False), + 'score': + 1.0, + 'metadata': { + 'source': str(src), + 'type': 'filename_match', + }, + }) return json.dumps(result, ensure_ascii=False, indent=2) self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + self._cluster_cache_hit = getattr(result.cluster, + '_reused_from_cache', False) if hasattr(result.cluster, 'updated_at'): - self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + self._cluster_cache_hit_time = getattr( + result.cluster, 'updated_at', None) - self._last_search_result = self._parse_search_result(result, score_threshold=0.7, limit=5) + self._last_search_result = self._parse_search_result( + result, score_threshold=0.7, limit=5) - if hasattr(result, 'answer') and getattr(result, 'answer', None) is not None: + if hasattr(result, 'answer') and getattr(result, 'answer', + None) is not None: return result.answer if isinstance(result, str): @@ -410,7 +441,8 @@ async def query( logger.error(f'SirschmunkSearch query failed: {e}') return f'Query failed: {e}' - def _parse_search_result(self, result: Any, score_threshold: float, limit: int) -> List[Dict[str, Any]]: + def _parse_search_result(self, result: Any, score_threshold: float, + limit: int) -> List[Dict[str, Any]]: """Parse sirchmunk search result into standard format. Args: @@ -436,57 +468,62 @@ def _parse_search_result(self, result: Any, score_threshold: float, limit: int) else: text_parts.append(str(snippet)) - results.append( - { - 'text': '\n'.join(text_parts) if text_parts else getattr(unit, 'summary', ''), - 'score': score, - 'metadata': { - 'source': source, - 'type': getattr(unit, 'abstraction_level', 'text') - if hasattr(unit, 'abstraction_level') - else 'text', - }, - } - ) + results.append({ + 'text': + '\n'.join(text_parts) if text_parts else getattr( + unit, 'summary', ''), + 'score': + score, + 'metadata': { + 'source': + source, + 'type': + getattr(unit, 'abstraction_level', 'text') + if hasattr(unit, 'abstraction_level') else 'text', + }, + }) elif hasattr(result, 'evidence_units'): for unit in result.evidence_units: score = getattr(unit, 'confidence', 1.0) if score >= score_threshold: - results.append( - { - 'text': str(unit.content) if hasattr(unit, 'content') else str(unit), - 'score': score, - 'metadata': { - 'source': getattr(unit, 'source_file', 'unknown'), - 'type': getattr(unit, 'abstraction_level', 'text'), - }, - } - ) + results.append({ + 'text': + str(unit.content) + if hasattr(unit, 'content') else str(unit), + 'score': + score, + 'metadata': { + 'source': getattr(unit, 'source_file', 'unknown'), + 'type': getattr(unit, 'abstraction_level', 'text'), + }, + }) elif isinstance(result, list): for item in result: if isinstance(item, dict): score = item.get('score', item.get('confidence', 1.0)) if score >= score_threshold: - results.append( - { - 'text': item.get('content', item.get('text', str(item))), - 'score': score, - 'metadata': item.get('metadata', {}), - } - ) + results.append({ + 'text': + item.get('content', item.get('text', str(item))), + 'score': + score, + 'metadata': + item.get('metadata', {}), + }) elif isinstance(result, dict): score = result.get('score', result.get('confidence', 1.0)) if score >= score_threshold: - results.append( - { - 'text': result.get('content', result.get('text', str(result))), - 'score': score, - 'metadata': result.get('metadata', {}), - } - ) + results.append({ + 'text': + result.get('content', result.get('text', str(result))), + 'score': + score, + 'metadata': + result.get('metadata', {}), + }) results.sort(key=lambda x: x.get('score', 0), reverse=True) return results[:limit] diff --git a/ms_agent/tools/search/tavily/fetcher.py b/ms_agent/tools/search/tavily/fetcher.py index b38e2bf96..4082907a2 100644 --- a/ms_agent/tools/search/tavily/fetcher.py +++ b/ms_agent/tools/search/tavily/fetcher.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """Tavily Extract API as ContentFetcher (replaces Jina for fetch_page / URL fetch).""" - import os import time from typing import Any, Dict, Optional, Tuple @@ -35,7 +34,8 @@ def __init__( ): key = api_key or os.getenv('TAVILY_API_KEY') if not key: - raise ValueError('TAVILY_API_KEY required for tavily_extract fetcher') + raise ValueError( + 'TAVILY_API_KEY required for tavily_extract fetcher') self._api_key = key self._extract_depth = extract_depth self._format = format diff --git a/ms_agent/tools/search/tavily/http.py b/ms_agent/tools/search/tavily/http.py index 7c1d3981b..d4916d271 100644 --- a/ms_agent/tools/search/tavily/http.py +++ b/ms_agent/tools/search/tavily/http.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """Minimal HTTP JSON client for Tavily REST API (stdlib only).""" - import json from typing import Any, Dict from urllib.error import HTTPError, URLError @@ -45,6 +44,7 @@ def post_json( detail = json.loads(err_body) if err_body else {} except json.JSONDecodeError: detail = {'raw': err_body} - raise RuntimeError(f'Tavily HTTP {e.code}: {detail}') from e + raise RuntimeError( + f'Tavily HTTP {e.code}: {detail}') from e except URLError as e: raise RuntimeError(f'Tavily network error: {e}') from e diff --git a/ms_agent/tools/search/tavily/schema.py b/ms_agent/tools/search/tavily/schema.py index 9b92d812f..75f3f0aed 100644 --- a/ms_agent/tools/search/tavily/schema.py +++ b/ms_agent/tools/search/tavily/schema.py @@ -86,23 +86,21 @@ def to_list(self) -> List[Dict[str, Any]]: raw = (r.get('raw_content') or '').strip() # Prefer full page text for downstream summarization; fallback to snippets body = raw if raw else snippet - rows.append( - { - 'url': url, - 'id': url, - 'title': title, - 'highlights': None, - 'highlight_scores': None, - 'summary': snippet, - 'markdown': raw if raw else None, - # Pipeline uses these keys: - 'content': body, - 'fetch_success': bool(raw), - 'score': r.get('score'), - 'tavily_images': r.get('images') or [], - 'favicon': r.get('favicon'), - } - ) + rows.append({ + 'url': url, + 'id': url, + 'title': title, + 'highlights': None, + 'highlight_scores': None, + 'summary': snippet, + 'markdown': raw if raw else None, + # Pipeline uses these keys: + 'content': body, + 'fetch_success': bool(raw), + 'score': r.get('score'), + 'tavily_images': r.get('images') or [], + 'favicon': r.get('favicon'), + }) return rows def extra_response_fields(self) -> Dict[str, Any]: diff --git a/ms_agent/tools/search/tavily/search.py b/ms_agent/tools/search/tavily/search.py index aa42cb556..b4b7d3f3b 100644 --- a/ms_agent/tools/search/tavily/search.py +++ b/ms_agent/tools/search/tavily/search.py @@ -33,14 +33,17 @@ def __init__( ): key = api_key or os.getenv('TAVILY_API_KEY') if not key: - raise ValueError('TAVILY_API_KEY must be set in environment or web_search.tavily_api_key') + raise ValueError( + 'TAVILY_API_KEY must be set in environment or web_search.tavily_api_key' + ) self._api_key = key self._request_timeout = float(request_timeout) def search(self, search_request: TavilySearchRequest) -> TavilySearchResult: body = search_request.to_api_body(self._api_key) try: - data = post_json(TAVILY_SEARCH_URL, body, timeout=self._request_timeout) + data = post_json( + TAVILY_SEARCH_URL, body, timeout=self._request_timeout) except Exception as e: raise RuntimeError(f'Tavily search failed: {e}') from e safe_args = {k: v for k, v in body.items() if k != 'api_key'} @@ -54,7 +57,6 @@ def search(self, search_request: TavilySearchRequest) -> TavilySearchResult: @classmethod def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': from ms_agent.llm.utils import Tool - return Tool( tool_name=cls.get_tool_name(), server_name=server_name, @@ -62,8 +64,7 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'Search the web using Tavily (built for AI agents). ' 'Returns ranked results with optional full-page markdown via ' '`include_raw_content`. Use `search_depth` advanced for best ' - 'relevance and richer `content` chunks (higher API credit use).' - ), + 'relevance and richer `content` chunks (higher API credit use).'), parameters={ 'type': 'object', 'properties': { @@ -80,18 +81,20 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'search_depth': { 'type': 'string', 'enum': ['advanced', 'basic', 'fast', 'ultra-fast'], - 'description': ( - 'advanced: best quality, 2 credits; basic/fast/ultra-fast: 1 credit (see Tavily docs).' - ), + 'description': + ('advanced: best quality, 2 credits; ' + 'basic/fast/ultra-fast: 1 credit (see Tavily docs).'), }, 'topic': { 'type': 'string', 'enum': ['general', 'news', 'finance'], - 'description': 'Search category (`news` / `finance` for focused verticals).', + 'description': + 'Search category (`news` / `finance` for focused verticals).', }, 'time_range': { 'type': 'string', - 'description': ('Filter by recency: day, week, month, year or d,w,m,y.'), + 'description': + ('Filter by recency: day, week, month, year or d,w,m,y.'), }, 'start_date': { 'type': 'string', @@ -104,39 +107,50 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'include_answer': { 'type': 'string', 'enum': ['false', 'true', 'basic', 'advanced'], - 'description': ('LLM answer: true/basic for short, advanced for detailed. Use false to skip.'), + 'description': + ('LLM answer: true/basic for short, advanced for detailed. ' + 'Use false to skip.'), }, 'include_raw_content': { 'type': 'string', - 'enum': ['false', 'true', 'markdown', 'text'], - 'description': ('full page text: markdown (recommended) or text; false to skip raw content.'), + 'enum': + ['false', 'true', 'markdown', 'text'], + 'description': + ('full page text: markdown (recommended) or text; ' + 'false to skip raw content.'), }, 'chunks_per_source': { 'type': 'integer', 'minimum': 1, 'maximum': 3, - 'description': ( - 'Relevant chunks per URL when search_depth=advanced. ' - 'Each chunk up to ~500 chars in `content` field.' - ), + 'description': + ('Relevant chunks per URL when search_depth=advanced. ' + 'Each chunk up to ~500 chars in `content` field.'), }, 'include_domains': { 'type': 'array', - 'items': {'type': 'string'}, + 'items': { + 'type': 'string' + }, 'description': 'Only include these domains (max 300).', }, 'exclude_domains': { 'type': 'array', - 'items': {'type': 'string'}, + 'items': { + 'type': 'string' + }, 'description': 'Exclude domains (max 150).', }, 'country': { 'type': 'string', - 'description': ('Boost results from country (e.g. united states). See Tavily docs for enum.'), + 'description': + ('Boost results from country (e.g. united states). ' + 'See Tavily docs for enum.'), }, 'exact_match': { 'type': 'boolean', - 'description': 'Only results with exact quoted phrases in query.', + 'description': + 'Only results with exact quoted phrases in query.', }, }, 'required': ['query'], @@ -190,7 +204,8 @@ def _boolish(name: str, default: Any) -> Any: include_answer=inc_ans, include_raw_content=inc_raw, include_images=bool(_boolish('include_images', False)), - include_image_descriptions=bool(_boolish('include_image_descriptions', False)), + include_image_descriptions=bool( + _boolish('include_image_descriptions', False)), include_favicon=bool(_boolish('include_favicon', False)), include_domains=list(kwargs.get('include_domains') or []), exclude_domains=list(kwargs.get('exclude_domains') or []), diff --git a/ms_agent/tools/search/web_search_spill.py b/ms_agent/tools/search/web_search_spill.py index f1b92ebb1..9c8da73c3 100644 --- a/ms_agent/tools/search/web_search_spill.py +++ b/ms_agent/tools/search/web_search_spill.py @@ -27,7 +27,6 @@ JSON gains ``spill`` with ``digest`` (instructions + quick index) and paths relative to ``output_dir`` so ``read_file`` can open them. """ - from __future__ import annotations import copy @@ -112,18 +111,18 @@ def _build_spill_markdown(item: Dict[str, Any]) -> str: return ''.join(lines) -def _shrink_item_after_spill(item: Dict[str, Any], spill_preview_chars: int) -> Dict[str, Any]: +def _shrink_item_after_spill(item: Dict[str, Any], + spill_preview_chars: int) -> Dict[str, Any]: """Replace heavy fields with short previews + pointers.""" out = dict(item) note = ( 'Full text spilled to disk; see content_path / manifest_path in parent ' - 'JSON spill block. Use read_file on content_path for this row.' - ) + 'JSON spill block. Use read_file on content_path for this row.') sm = out.get('summary') if isinstance(sm, str) and sm.strip(): out['summary'] = _preview(sm, spill_preview_chars) out.setdefault('content_note', note) - main = out.get('content') or '' + main = (out.get('content') or '') if isinstance(main, str) and main.strip(): out['content'] = _preview(main, spill_preview_chars) out['content_note'] = note @@ -132,14 +131,12 @@ def _shrink_item_after_spill(item: Dict[str, Any], spill_preview_chars: int) -> out['abstract'] = _preview(ab, min(800, spill_preview_chars)) ch = out.get('chunks') if isinstance(ch, list) and ch: - out['chunks'] = [ - { - 'chunk_id': c.get('chunk_id', ''), - 'content': _preview(str(c.get('content', '')), min(400, spill_preview_chars)), - } - for c in ch - if isinstance(c, dict) - ] + out['chunks'] = [{ + 'chunk_id': + c.get('chunk_id', ''), + 'content': + _preview(str(c.get('content', '')), min(400, spill_preview_chars)), + } for c in ch if isinstance(c, dict)] out['chunks_note'] = 'Full chunk bodies are in the spilled markdown file.' return out @@ -192,10 +189,13 @@ def order_by_size() -> List[int]: if _item_inline_chars(item) == 0: break full_md = _build_spill_markdown(item) - rel_body = os.path.join(spill_subdir, run_key, 'bodies', f'{idx:03d}.md').replace('\\', '/') - abs_body = os.path.normpath(os.path.join(output_dir, rel_body.replace('/', os.sep))) + rel_body = os.path.join(spill_subdir, run_key, 'bodies', + f'{idx:03d}.md').replace('\\', '/') + abs_body = os.path.normpath( + os.path.join(output_dir, rel_body.replace('/', os.sep))) os.makedirs(os.path.dirname(abs_body), exist_ok=True) - header = f'\n' + header = ( + f'\n') with open(abs_body, 'w', encoding='utf-8') as bf: bf.write(header + full_md) @@ -206,37 +206,55 @@ def order_by_size() -> List[int]: work[idx]['content_path'] = rel_body work[idx]['content_chars_spilled'] = before_chars - preview_src = (item.get('content') or item.get('summary') or item.get('abstract') or '')[:4000] - manifest_rows.append( - { - 'index': idx, - 'url': item.get('url', ''), - 'title': item.get('title', ''), - 'body_file': f'bodies/{idx:03d}.md', - 'content_path': rel_body, - 'chars_spilled': before_chars, - 'preview': _preview(preview_src, min(500, spill_preview_chars)), - } - ) + preview_src = ( + item.get('content') or item.get('summary') or item.get('abstract') + or '')[:4000] + manifest_rows.append({ + 'index': + idx, + 'url': + item.get('url', ''), + 'title': + item.get('title', ''), + 'body_file': + f'bodies/{idx:03d}.md', + 'content_path': + rel_body, + 'chars_spilled': + before_chars, + 'preview': + _preview(preview_src, min(500, spill_preview_chars)), + }) manifest: Dict[str, Any] = { - 'version': 1, - 'created_at_utc': time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()), - 'query': query, - 'engine': engine, - 'run_key': run_key, - 'lifecycle': ( - 'Ephemeral: lives under this task output_dir; delete the task directory ' - 'to remove. ms-agent does not auto-prune.' - ), - 'inline_chars_before': total, - 'inline_chars_after': _total_inline_chars(work), - 'spill_threshold_chars': spill_max_inline_chars, - 'spilled_row_indices': spilled_indices, - 'rows': manifest_rows, + 'version': + 1, + 'created_at_utc': + time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()), + 'query': + query, + 'engine': + engine, + 'run_key': + run_key, + 'lifecycle': + ('Ephemeral: lives under this task output_dir; delete the task directory ' + 'to remove. ms-agent does not auto-prune.'), + 'inline_chars_before': + total, + 'inline_chars_after': + _total_inline_chars(work), + 'spill_threshold_chars': + spill_max_inline_chars, + 'spilled_row_indices': + spilled_indices, + 'rows': + manifest_rows, } - rel_manifest = os.path.join(spill_subdir, run_key, 'manifest.json').replace('\\', '/') - abs_manifest = os.path.normpath(os.path.join(output_dir, rel_manifest.replace('/', os.sep))) + rel_manifest = os.path.join(spill_subdir, run_key, 'manifest.json').replace( + '\\', '/') + abs_manifest = os.path.normpath( + os.path.join(output_dir, rel_manifest.replace('/', os.sep))) with open(abs_manifest, 'w', encoding='utf-8') as mf: json.dump(manifest, mf, ensure_ascii=False, indent=2) @@ -244,24 +262,31 @@ def order_by_size() -> List[int]: 'Large web_search payload was written to disk under this task output_dir.', f'- **Manifest (map of rows → files, sizes)**: `{rel_manifest}`', f'- **Bodies**: `{spill_subdir}/{run_key}/bodies/`', - 'Read **manifest.json** first, then **read_file** on specific `bodies/NNN.md` files as needed.', + 'Read **manifest.json** first, then **read_file** on specific ' + '`bodies/NNN.md` files as needed.', '', '**Quick index**', ] for row in manifest_rows: lines.append( f'{row["index"]}. {row.get("title") or "(no title)"} — ' - f'`{row["content_path"]}` ({row.get("chars_spilled", 0)} chars)' - ) + f'`{row["content_path"]}` ({row.get("chars_spilled", 0)} chars)') digest = '\n'.join(lines) spill_meta = { - 'spilled': True, - 'run_key': run_key, - 'artifact_dir': f'{spill_subdir}/{run_key}'.replace('\\', '/'), - 'manifest_path': rel_manifest, - 'digest': digest, - 'inline_chars_before_spill': total, - 'inline_chars_after_spill': _total_inline_chars(work), + 'spilled': + True, + 'run_key': + run_key, + 'artifact_dir': + f'{spill_subdir}/{run_key}'.replace('\\', '/'), + 'manifest_path': + rel_manifest, + 'digest': + digest, + 'inline_chars_before_spill': + total, + 'inline_chars_after_spill': + _total_inline_chars(work), } return work, spill_meta diff --git a/ms_agent/tools/search/websearch_tool.py b/ms_agent/tools/search/websearch_tool.py index 410ba0a60..16d6005d2 100644 --- a/ms_agent/tools/search/websearch_tool.py +++ b/ms_agent/tools/search/websearch_tool.py @@ -10,8 +10,11 @@ from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase -from ms_agent.tools.jina_reader import JinaReaderConfig, fetch_single_text_with_meta -from ms_agent.tools.search.content_optimizer import ContentOptimizer, ContentOptimizerConfig, SearchResultReranker +from ms_agent.tools.jina_reader import (JinaReaderConfig, + fetch_single_text_with_meta) +from ms_agent.tools.search.content_optimizer import (ContentOptimizer, + ContentOptimizerConfig, + SearchResultReranker) from ms_agent.tools.search.search_base import ENGINE_TOOL_NAMES, SearchEngine from ms_agent.tools.search.web_search_spill import maybe_spill_web_search_payload from ms_agent.utils.logger import get_logger @@ -39,14 +42,14 @@ def default_per_url_fetch_timeout_s( retries = max(0, int(fetch_retries)) # Up to (retries+1) attempts each up to ``ft``; 1.35 leaves slack for urllib backoff. jina_budget = ft * float(retries + 1) * 1.35 - tail = max(10.0, float(direct_fetch_timeout)) + (float(playwright_timeout_ms) / 1000.0) + 30.0 + tail = max(10.0, float(direct_fetch_timeout)) + ( + float(playwright_timeout_ms) / 1000.0) + 30.0 raw = jina_budget + tail return max(210.0, min(720.0, raw)) def _json_dumps(data: Any) -> str: import json - return json.dumps(data, ensure_ascii=False, indent=2) @@ -76,7 +79,10 @@ class TextChunk: end_pos: int -def chunk_text_simple(text: str, chunk_size: int = 1500, overlap: int = 200, prefix: str = '') -> List[TextChunk]: +def chunk_text_simple(text: str, + chunk_size: int = 1500, + overlap: int = 200, + prefix: str = '') -> List[TextChunk]: """ Simple text chunking by character count with overlap. Tries to break at paragraph or sentence boundaries when possible. @@ -95,7 +101,13 @@ def chunk_text_simple(text: str, chunk_size: int = 1500, overlap: int = 200, pre text = text.strip() if len(text) <= chunk_size: - return [TextChunk(chunk_id=f'{prefix}0' if prefix else '0', content=text, start_pos=0, end_pos=len(text))] + return [ + TextChunk( + chunk_id=f'{prefix}0' if prefix else '0', + content=text, + start_pos=0, + end_pos=len(text)) + ] chunks: List[TextChunk] = [] start = 0 @@ -122,12 +134,11 @@ def chunk_text_simple(text: str, chunk_size: int = 1500, overlap: int = 200, pre if chunk_content: chunks.append( TextChunk( - chunk_id=f'{prefix}{chunk_idx}' if prefix else str(chunk_idx), + chunk_id=f'{prefix}{chunk_idx}' + if prefix else str(chunk_idx), content=chunk_content, start_pos=start, - end_pos=end, - ) - ) + end_pos=end)) chunk_idx += 1 # Move start with overlap @@ -157,7 +168,11 @@ class JinaContentFetcher(ContentFetcher): def __init__(self, config: Optional[JinaReaderConfig] = None): self.config = config or JinaReaderConfig() - def fetch(self, url: str, max_chars: Optional[int] = MAX_FETCH_CHARS) -> Tuple[str, Dict[str, Any]]: + def fetch( + self, + url: str, + max_chars: Optional[int] = MAX_FETCH_CHARS + ) -> Tuple[str, Dict[str, Any]]: content, source_meta = fetch_single_text_with_meta(url, self.config) metadata: Dict[str, Any] = { 'fetcher': 'jina_reader', @@ -177,7 +192,8 @@ def fetch(self, url: str, max_chars: Optional[int] = MAX_FETCH_CHARS) -> Tuple[s # pass -def get_content_fetcher(fetcher_type: str = 'jina_reader', **kwargs) -> ContentFetcher: +def get_content_fetcher(fetcher_type: str = 'jina_reader', + **kwargs) -> ContentFetcher: """Factory function to get content fetcher by type.""" if fetcher_type == 'jina_reader': config = JinaReaderConfig( @@ -185,15 +201,17 @@ def get_content_fetcher(fetcher_type: str = 'jina_reader', **kwargs) -> ContentF retries=kwargs.get('retries', 3), direct_fetch_fallback=bool(kwargs.get('direct_fetch_fallback', True)), direct_fetch_timeout=float(kwargs.get('direct_fetch_timeout', 15.0)), - playwright_fetch_fallback=bool(kwargs.get('playwright_fetch_fallback', True)), - playwright_retry_min_chars=int(kwargs.get('playwright_retry_min_chars', 400) or 400), - playwright_timeout_ms=int(kwargs.get('playwright_timeout_ms', 30_000) or 30_000), + playwright_fetch_fallback=bool( + kwargs.get('playwright_fetch_fallback', True)), + playwright_retry_min_chars=int( + kwargs.get('playwright_retry_min_chars', 400) or 400), + playwright_timeout_ms=int( + kwargs.get('playwright_timeout_ms', 30_000) or 30_000), playwright_settle_ms=int(kwargs.get('playwright_settle_ms', 350)), ) return JinaContentFetcher(config) if fetcher_type == 'tavily_extract': from ms_agent.tools.search.tavily.fetcher import TavilyExtractFetcher - return TavilyExtractFetcher( api_key=kwargs.get('tavily_api_key'), extract_depth=str(kwargs.get('tavily_extract_depth', 'advanced')), @@ -208,19 +226,26 @@ def get_content_fetcher(fetcher_type: str = 'jina_reader', **kwargs) -> ContentF # elif fetcher_type == 'docling': # return DoclingContentFetcher(**kwargs) else: - logger.warning(f"Unknown fetcher type '{fetcher_type}', falling back to jina_reader") + logger.warning( + f"Unknown fetcher type '{fetcher_type}', falling back to jina_reader" + ) return JinaContentFetcher( JinaReaderConfig( timeout=kwargs.get('timeout', 45.0), retries=kwargs.get('retries', 3), - direct_fetch_fallback=bool(kwargs.get('direct_fetch_fallback', True)), - direct_fetch_timeout=float(kwargs.get('direct_fetch_timeout', 15.0)), - playwright_fetch_fallback=bool(kwargs.get('playwright_fetch_fallback', True)), - playwright_retry_min_chars=int(kwargs.get('playwright_retry_min_chars', 400) or 400), - playwright_timeout_ms=int(kwargs.get('playwright_timeout_ms', 30_000) or 30_000), - playwright_settle_ms=int(kwargs.get('playwright_settle_ms', 350)), - ) - ) + direct_fetch_fallback=bool(kwargs.get('direct_fetch_fallback', + True)), + direct_fetch_timeout=float( + kwargs.get('direct_fetch_timeout', 15.0)), + playwright_fetch_fallback=bool( + kwargs.get('playwright_fetch_fallback', True)), + playwright_retry_min_chars=int( + kwargs.get('playwright_retry_min_chars', 400) or 400), + playwright_timeout_ms=int( + kwargs.get('playwright_timeout_ms', 30_000) or 30_000), + playwright_settle_ms=int( + kwargs.get('playwright_settle_ms', 350)), + )) def get_search_engine_class(engine_type: str) -> Type[SearchEngine]: @@ -237,28 +262,26 @@ def get_search_engine_class(engine_type: str) -> Type[SearchEngine]: if engine_type == 'exa': from ms_agent.tools.search.exa import ExaSearch - return ExaSearch elif engine_type in ('serpapi', 'serp', 'google', 'bing', 'baidu'): from ms_agent.tools.search.serpapi import SerpApiSearch - return SerpApiSearch elif engine_type == 'arxiv': from ms_agent.tools.search.arxiv import ArxivSearch - return ArxivSearch elif engine_type == 'tavily': from ms_agent.tools.search.tavily import TavilySearch - return TavilySearch else: - logger.warning(f"Unknown search engine '{engine_type}', falling back to arxiv") + logger.warning( + f"Unknown search engine '{engine_type}', falling back to arxiv") from ms_agent.tools.search.arxiv import ArxivSearch - return ArxivSearch -def get_search_engine(engine_type: str, api_key: Optional[str] = None, **kwargs) -> SearchEngine: +def get_search_engine(engine_type: str, + api_key: Optional[str] = None, + **kwargs) -> SearchEngine: """ Get search engine instance by type. @@ -275,45 +298,46 @@ def get_search_engine(engine_type: str, api_key: Optional[str] = None, **kwargs) if engine_type == 'exa': from ms_agent.tools.search.exa import ExaSearch - return ExaSearch( api_key=api_key or os.getenv('EXA_API_KEY'), api_keys=kwargs.get('api_keys') or os.getenv('EXA_API_KEYS'), ) elif engine_type in ('serpapi', 'serp', 'google', 'bing', 'baidu'): from ms_agent.tools.search.serpapi import SerpApiSearch - - default_provider = 'google' if engine_type in ('serpapi', 'serp') else engine_type + default_provider = ('google' if engine_type in ('serpapi', 'serp') else + engine_type) return SerpApiSearch( api_key=api_key or os.getenv('SERPAPI_API_KEY'), provider=kwargs.get('provider', default_provider), ) elif engine_type == 'arxiv': from ms_agent.tools.search.arxiv import ArxivSearch - return ArxivSearch() elif engine_type == 'tavily': from ms_agent.tools.search.tavily import TavilySearch - return TavilySearch( api_key=api_key or os.getenv('TAVILY_API_KEY'), request_timeout=float(kwargs.get('request_timeout', 120.0)), ) else: - logger.warning(f"Unknown search engine '{engine_type}', falling back to arxiv") + logger.warning( + f"Unknown search engine '{engine_type}', falling back to arxiv") from ms_agent.tools.search.arxiv import ArxivSearch - return ArxivSearch() # Kept for backward compatibility -def build_search_request(engine_type: str, query: str, num_results: int = 5, **kwargs): +def build_search_request(engine_type: str, + query: str, + num_results: int = 5, + **kwargs): """Build a search request for the specified engine. DEPRECATED: Use SearchEngine.build_request_from_args() instead. """ engine_cls = get_search_engine_class(engine_type) - return engine_cls.build_request_from_args(query=query, num_results=num_results, **kwargs) + return engine_cls.build_request_from_args( + query=query, num_results=num_results, **kwargs) class WebSearchTool(ToolBase): @@ -372,8 +396,12 @@ def get_global_summarization_usage(cls) -> Dict[str, Any]: """Get process-wide summarization usage totals (best-effort).""" with cls._GLOBAL_SUMMARY_USAGE_LOCK: total = dict(cls._GLOBAL_SUMMARY_USAGE_TOTAL) - by_model = {k: dict(v) for k, v in cls._GLOBAL_SUMMARY_USAGE_BY_MODEL.items()} - total['total_tokens'] = total.get('prompt_tokens', 0) + total.get('completion_tokens', 0) + by_model = { + k: dict(v) + for k, v in cls._GLOBAL_SUMMARY_USAGE_BY_MODEL.items() + } + total['total_tokens'] = total.get('prompt_tokens', 0) + total.get( + 'completion_tokens', 0) return { 'total': total, 'by_model': by_model, @@ -384,7 +412,8 @@ def log_global_summarization_usage(cls) -> None: """Log process-wide summarization totals once at end-of-run.""" usage = cls.get_global_summarization_usage() total = usage.get('total', {}) or {} - if not (total.get('prompt_tokens', 0) or total.get('completion_tokens', 0) or total.get('api_calls', 0)): + if not (total.get('prompt_tokens', 0) or total.get( + 'completion_tokens', 0) or total.get('api_calls', 0)): return logger.info( '[web_search_summarization_usage_process_total] ' @@ -418,23 +447,31 @@ def __init__(self, config, **kwargs): self.exclude_func(tool_cfg) # Parse engine configuration - support both single and multi-engine modes - engines_config = getattr(tool_cfg, 'engines', None) if tool_cfg else None + engines_config = getattr(tool_cfg, 'engines', + None) if tool_cfg else None if engines_config: # Multi-engine mode: engines: [exa, arxiv] # Note: OmegaConf ListConfig is iterable but not isinstance of list/tuple - if hasattr(engines_config, '__iter__') and not isinstance(engines_config, str): - self._engine_types = [str(e).lower().strip() for e in engines_config] + if hasattr(engines_config, + '__iter__') and not isinstance(engines_config, str): + self._engine_types = [ + str(e).lower().strip() for e in engines_config + ] else: self._engine_types = [str(engines_config).lower().strip()] else: # Single engine mode (backward compatible): engine: exa - single_engine = getattr(tool_cfg, 'engine', 'arxiv') if tool_cfg else 'arxiv' + single_engine = getattr(tool_cfg, 'engine', + 'arxiv') if tool_cfg else 'arxiv' self._engine_types = [single_engine.lower().strip()] # Validate engine types - self._engine_types = [e for e in self._engine_types if e in self.SUPPORTED_ENGINES] + self._engine_types = [ + e for e in self._engine_types if e in self.SUPPORTED_ENGINES + ] if not self._engine_types: - logger.warning('No valid engines configured, falling back to arxiv') + logger.warning( + 'No valid engines configured, falling back to arxiv') self._engine_types = ['arxiv'] # API keys for each engine. @@ -446,17 +483,15 @@ def __init__(self, config, **kwargs): getattr(tool_cfg, 'exa_api_keys', None) or getattr(tool_cfg, 'exa_api_key', None) or getattr(tool_cfg, 'api_key', None) # backward compat - or os.getenv('EXA_API_KEYS') - or os.getenv('EXA_API_KEY') - ) - if tool_cfg - else (os.getenv('EXA_API_KEYS') or os.getenv('EXA_API_KEY')), - 'serpapi': (getattr(tool_cfg, 'serpapi_api_key', None) or os.getenv('SERPAPI_API_KEY')) - if tool_cfg - else os.getenv('SERPAPI_API_KEY'), - 'tavily': (getattr(tool_cfg, 'tavily_api_key', None) or os.getenv('TAVILY_API_KEY')) - if tool_cfg - else os.getenv('TAVILY_API_KEY'), + or os.getenv('EXA_API_KEYS') or os.getenv('EXA_API_KEY')) + if tool_cfg else + (os.getenv('EXA_API_KEYS') or os.getenv('EXA_API_KEY')), + 'serpapi': (getattr(tool_cfg, 'serpapi_api_key', None) + or os.getenv('SERPAPI_API_KEY')) + if tool_cfg else os.getenv('SERPAPI_API_KEY'), + 'tavily': (getattr(tool_cfg, 'tavily_api_key', None) + or os.getenv('TAVILY_API_KEY')) if tool_cfg else + os.getenv('TAVILY_API_KEY'), } # Tavily search defaults from optional `tavily:` sub-block in YAML @@ -466,64 +501,79 @@ def __init__(self, config, **kwargs): if tv is not None: try: from omegaconf import OmegaConf - if OmegaConf.is_config(tv): - self._tavily_defaults = dict(OmegaConf.to_container(tv, resolve=True)) + self._tavily_defaults = dict( + OmegaConf.to_container(tv, resolve=True)) elif isinstance(tv, dict): self._tavily_defaults = dict(tv) except Exception: if isinstance(tv, dict): self._tavily_defaults = dict(tv) - self._tavily_request_timeout = ( - float(getattr(tool_cfg, 'tavily_request_timeout', 120.0) or 120.0) if tool_cfg else 120.0 - ) + self._tavily_request_timeout = float( + getattr(tool_cfg, 'tavily_request_timeout', 120.0) + or 120.0) if tool_cfg else 120.0 # SerpApi provider (google, bing, baidu) - self._serpapi_provider = getattr(tool_cfg, 'serpapi_provider', 'google') if tool_cfg else 'google' + self._serpapi_provider = getattr(tool_cfg, 'serpapi_provider', + 'google') if tool_cfg else 'google' # Default result count - self._max_results = int(getattr(tool_cfg, 'max_results', 5) or 5) if tool_cfg else 5 + self._max_results = int(getattr(tool_cfg, 'max_results', 5) + or 5) if tool_cfg else 5 # Content fetcher config - self._fetcher_type = getattr(tool_cfg, 'fetcher', 'jina_reader') if tool_cfg else 'jina_reader' - self._fetch_timeout = float(getattr(tool_cfg, 'fetch_timeout', 45) or 45) if tool_cfg else 45.0 - self._fetch_retries = int(getattr(tool_cfg, 'fetch_retries', 3) or 3) if tool_cfg else 3 - self._jina_direct_fetch_fallback = ( - bool(getattr(tool_cfg, 'jina_direct_fetch_fallback', True)) if tool_cfg else True - ) + self._fetcher_type = getattr( + tool_cfg, 'fetcher', 'jina_reader') if tool_cfg else 'jina_reader' + self._fetch_timeout = float( + getattr(tool_cfg, 'fetch_timeout', 45) or 45) if tool_cfg else 45.0 + self._fetch_retries = int(getattr(tool_cfg, 'fetch_retries', 3) + or 3) if tool_cfg else 3 + self._jina_direct_fetch_fallback = bool( + getattr(tool_cfg, 'jina_direct_fetch_fallback', True) + ) if tool_cfg else True if tool_cfg is not None and hasattr(tool_cfg, 'jina_direct_fetch_timeout'): - self._jina_direct_fetch_timeout = float(tool_cfg.jina_direct_fetch_timeout) + self._jina_direct_fetch_timeout = float( + tool_cfg.jina_direct_fetch_timeout) else: self._jina_direct_fetch_timeout = 15.0 - self._jina_playwright_fetch_fallback = ( - bool(getattr(tool_cfg, 'jina_playwright_fetch_fallback', True)) if tool_cfg else True - ) - self._jina_playwright_retry_min_chars = ( - int(getattr(tool_cfg, 'jina_playwright_retry_min_chars', 400) or 400) if tool_cfg else 400 - ) - self._jina_playwright_timeout_ms = ( - int(getattr(tool_cfg, 'jina_playwright_timeout_ms', 30000) or 30000) if tool_cfg else 30000 - ) + self._jina_playwright_fetch_fallback = bool( + getattr(tool_cfg, 'jina_playwright_fetch_fallback', True) + ) if tool_cfg else True + self._jina_playwright_retry_min_chars = int( + getattr(tool_cfg, 'jina_playwright_retry_min_chars', 400) or 400 + ) if tool_cfg else 400 + self._jina_playwright_timeout_ms = int( + getattr(tool_cfg, 'jina_playwright_timeout_ms', 30000) or 30000 + ) if tool_cfg else 30000 if tool_cfg is not None and hasattr(tool_cfg, 'jina_playwright_settle_ms'): - self._jina_playwright_settle_ms = int(tool_cfg.jina_playwright_settle_ms) + self._jina_playwright_settle_ms = int( + tool_cfg.jina_playwright_settle_ms) else: self._jina_playwright_settle_ms = 350 - self._fetch_content_default = bool(getattr(tool_cfg, 'fetch_content', True)) if tool_cfg else True + self._fetch_content_default = bool( + getattr(tool_cfg, 'fetch_content', True)) if tool_cfg else True # Chunking config - self._enable_chunking = bool(getattr(tool_cfg, 'enable_chunking', False)) if tool_cfg else False - self._chunk_size = int(getattr(tool_cfg, 'chunk_size', 2000) or 2000) if tool_cfg else 2000 - self._chunk_overlap = int(getattr(tool_cfg, 'chunk_overlap', 200) or 200) if tool_cfg else 200 + self._enable_chunking = bool( + getattr(tool_cfg, 'enable_chunking', False)) if tool_cfg else False + self._chunk_size = int(getattr(tool_cfg, 'chunk_size', 2000) + or 2000) if tool_cfg else 2000 + self._chunk_overlap = int( + getattr(tool_cfg, 'chunk_overlap', 200) + or 200) if tool_cfg else 200 # Concurrency - self._max_concurrent_fetch = int(getattr(tool_cfg, 'max_concurrent_fetch', 3) or 3) if tool_cfg else 3 + self._max_concurrent_fetch = int( + getattr(tool_cfg, 'max_concurrent_fetch', 3) + or 3) if tool_cfg else 3 # Hard cap (seconds) per URL for asyncio.wait_for around run_in_executor. # When hit, this URL gets empty content + fetch_error; other URLs in the # same web_search call keep their already-fetched bodies. Set 0 to disable # the asyncio cap (underlying urllib/Jina timeouts still apply). if tool_cfg is not None and hasattr(tool_cfg, 'per_url_fetch_timeout'): - self._per_url_fetch_timeout_s = float(tool_cfg.per_url_fetch_timeout) + self._per_url_fetch_timeout_s = float( + tool_cfg.per_url_fetch_timeout) else: self._per_url_fetch_timeout_s = default_per_url_fetch_timeout_s( self._fetch_timeout, @@ -531,47 +581,61 @@ def __init__(self, config, **kwargs): self._jina_direct_fetch_timeout, self._jina_playwright_timeout_ms, ) - self._max_concurrent_summarization = ( - int(getattr(tool_cfg, 'max_concurrent_summarization', 5) or 5) if tool_cfg else 5 - ) + self._max_concurrent_summarization = int( + getattr(tool_cfg, 'max_concurrent_summarization', 5) + or 5) if tool_cfg else 5 # Content optimization config (summarization & reranking) - self._enable_summarization = bool(getattr(tool_cfg, 'enable_summarization', False)) if tool_cfg else False - self._summarizer_model = getattr(tool_cfg, 'summarizer_model', 'qwen-flash') if tool_cfg else 'qwen-flash' - self._summarizer_base_url = ( - getattr(tool_cfg, 'summarizer_base_url', 'https://dashscope.aliyuncs.com/compatible-mode/v1') - if tool_cfg - else 'https://dashscope.aliyuncs.com/compatible-mode/v1' - ) - self._summarizer_api_key = getattr(tool_cfg, 'summarizer_api_key', None) if tool_cfg else None - self._max_content_chars = int(getattr(tool_cfg, 'max_content_chars', 500000) or 500000) if tool_cfg else 500000 - self._summarizer_max_workers = int(getattr(tool_cfg, 'summarizer_max_workers', 5) or 5) if tool_cfg else 5 - self._summarization_timeout = ( - float(getattr(tool_cfg, 'summarization_timeout', 90.0) or 90.0) if tool_cfg else 90.0 - ) + self._enable_summarization = bool( + getattr(tool_cfg, 'enable_summarization', + False)) if tool_cfg else False + self._summarizer_model = getattr( + tool_cfg, 'summarizer_model', + 'qwen-flash') if tool_cfg else 'qwen-flash' + self._summarizer_base_url = getattr( + tool_cfg, 'summarizer_base_url', + 'https://dashscope.aliyuncs.com/compatible-mode/v1' + ) if tool_cfg else 'https://dashscope.aliyuncs.com/compatible-mode/v1' + self._summarizer_api_key = getattr(tool_cfg, 'summarizer_api_key', + None) if tool_cfg else None + self._max_content_chars = int( + getattr(tool_cfg, 'max_content_chars', 500000) + or 500000) if tool_cfg else 500000 + self._summarizer_max_workers = int( + getattr(tool_cfg, 'summarizer_max_workers', 5) + or 5) if tool_cfg else 5 + self._summarization_timeout = float( + getattr(tool_cfg, 'summarization_timeout', 90.0) + or 90.0) if tool_cfg else 90.0 # Large payload spill (write bodies to disk; keep JSON small) - self._spill_enabled = bool(getattr(tool_cfg, 'spill_large_results', True)) if tool_cfg else True - self._spill_max_inline_chars = ( - int(getattr(tool_cfg, 'spill_max_inline_chars', 120000) or 120000) if tool_cfg else 120000 - ) - self._spill_subdir = ( - str(getattr(tool_cfg, 'spill_subdir', 'web_search_artifacts') or 'web_search_artifacts') - if tool_cfg - else 'web_search_artifacts' - ) - self._spill_preview_chars = int(getattr(tool_cfg, 'spill_preview_chars', 600) or 600) if tool_cfg else 600 + self._spill_enabled = bool( + getattr(tool_cfg, 'spill_large_results', True)) if tool_cfg else True + self._spill_max_inline_chars = int( + getattr(tool_cfg, 'spill_max_inline_chars', 120000) + or 120000) if tool_cfg else 120000 + self._spill_subdir = str( + getattr(tool_cfg, 'spill_subdir', 'web_search_artifacts') + or 'web_search_artifacts') if tool_cfg else 'web_search_artifacts' + self._spill_preview_chars = int( + getattr(tool_cfg, 'spill_preview_chars', 600) + or 600) if tool_cfg else 600 # Reranking config - self._enable_rerank = bool(getattr(tool_cfg, 'enable_rerank', False)) if tool_cfg else False - self._rerank_top_k = int(getattr(tool_cfg, 'rerank_top_k', 3) or 3) if tool_cfg else 3 + self._enable_rerank = bool(getattr(tool_cfg, 'enable_rerank', + False)) if tool_cfg else False + self._rerank_top_k = int(getattr(tool_cfg, 'rerank_top_k', 3) + or 3) if tool_cfg else 3 # Task context for summarization (can be set dynamically) - self._task_context = getattr(tool_cfg, 'task_context', '') if tool_cfg else '' + self._task_context = getattr(tool_cfg, 'task_context', + '') if tool_cfg else '' # Runtime instances (lazy init) - self._engines: Dict[str, SearchEngine] = {} # engine_type -> engine instance - self._engine_classes: Dict[str, Type[SearchEngine]] = {} # engine_type -> engine class + self._engines: Dict[str, SearchEngine] = { + } # engine_type -> engine instance + self._engine_classes: Dict[str, Type[SearchEngine]] = { + } # engine_type -> engine class self._content_fetcher: Optional[ContentFetcher] = None self._content_optimizer: Optional[ContentOptimizer] = None self._executor: Optional[ThreadPoolExecutor] = None @@ -594,7 +658,8 @@ async def connect(self) -> None: # Create engine instance if engine_type == 'exa': - self._engines[engine_type] = engine_cls(api_key=self._api_keys.get('exa')) + self._engines[engine_type] = engine_cls( + api_key=self._api_keys.get('exa')) elif engine_type == 'serpapi': self._engines[engine_type] = engine_cls( api_key=self._api_keys.get('serpapi'), @@ -610,7 +675,8 @@ async def connect(self) -> None: logger.info(f'Initialized search engine: {engine_type}') except Exception as e: - logger.warning(f'Failed to initialize {engine_type} engine: {e}') + logger.warning( + f'Failed to initialize {engine_type} engine: {e}') if not self._engines: raise RuntimeError('No search engines could be initialized') @@ -628,16 +694,20 @@ async def connect(self) -> None: 'playwright_settle_ms': self._jina_playwright_settle_ms, } if wcfg is not None: - _fk.update( - { - 'tavily_extract_depth': getattr(wcfg, 'tavily_extract_depth', 'advanced'), - 'tavily_extract_format': getattr(wcfg, 'tavily_extract_format', 'markdown'), - 'tavily_extract_chunks_per_source': int(getattr(wcfg, 'tavily_extract_chunks_per_source', 3) or 3), - 'tavily_extract_include_images': bool(getattr(wcfg, 'tavily_extract_include_images', False)), - 'tavily_extract_include_favicon': bool(getattr(wcfg, 'tavily_extract_include_favicon', False)), - 'tavily_extract_include_usage': bool(getattr(wcfg, 'tavily_extract_include_usage', False)), - } - ) + _fk.update({ + 'tavily_extract_depth': + getattr(wcfg, 'tavily_extract_depth', 'advanced'), + 'tavily_extract_format': + getattr(wcfg, 'tavily_extract_format', 'markdown'), + 'tavily_extract_chunks_per_source': + int(getattr(wcfg, 'tavily_extract_chunks_per_source', 3) or 3), + 'tavily_extract_include_images': + bool(getattr(wcfg, 'tavily_extract_include_images', False)), + 'tavily_extract_include_favicon': + bool(getattr(wcfg, 'tavily_extract_include_favicon', False)), + 'tavily_extract_include_usage': + bool(getattr(wcfg, 'tavily_extract_include_usage', False)), + }) self._content_fetcher = get_content_fetcher(self._fetcher_type, **_fk) # Use daemon threads: tool-call timeouts can cancel the awaiting coroutine, # but not the underlying sync network calls running in executor threads. @@ -651,9 +721,9 @@ async def connect(self) -> None: optimizer_config = ContentOptimizerConfig( summarizer_model=self._summarizer_model, summarizer_base_url=self._summarizer_base_url, - summarizer_api_key=( - self._summarizer_api_key or os.getenv('DASHSCOPE_API_KEY') or os.getenv('OPENAI_API_KEY') - ), + summarizer_api_key=(self._summarizer_api_key + or os.getenv('DASHSCOPE_API_KEY') + or os.getenv('OPENAI_API_KEY')), max_content_chars=self._max_content_chars, summarizer_max_workers=self._summarizer_max_workers, summarization_timeout=self._summarization_timeout, @@ -663,9 +733,12 @@ async def connect(self) -> None: self._content_optimizer = ContentOptimizer(optimizer_config) if self._enable_summarization: await self._content_optimizer.initialize() - logger.info(f'Content optimizer initialized with model: {self._summarizer_model}') + logger.info( + f'Content optimizer initialized with model: {self._summarizer_model}' + ) else: - logger.info('Content reranking enabled (summarization disabled)') + logger.info( + 'Content reranking enabled (summarization disabled)') async def cleanup(self) -> None: """Cleanup resources.""" @@ -683,12 +756,11 @@ async def cleanup(self) -> None: self._engine_classes.clear() # Optional: instance-level totals can be noisy when multiple sub-agents # create their own WebSearchTool instances. Default off; use env var to enable. - if os.getenv('MS_AGENT_WEB_SEARCH_LOG_INSTANCE_SUMMARY_USAGE', '').lower() in ('1', 'true', 'yes'): - if ( - self._summary_usage_total.get('prompt_tokens', 0) - or self._summary_usage_total.get('completion_tokens', 0) - or self._summary_usage_total.get('api_calls', 0) - ): + if os.getenv('MS_AGENT_WEB_SEARCH_LOG_INSTANCE_SUMMARY_USAGE', + '').lower() in ('1', 'true', 'yes'): + if (self._summary_usage_total.get('prompt_tokens', 0) + or self._summary_usage_total.get('completion_tokens', 0) + or self._summary_usage_total.get('api_calls', 0)): model = self._summary_usage_model or self._summarizer_model logger.info( '[web_search_summarization_usage_total] ' @@ -723,19 +795,21 @@ async def _get_tools_inner(self) -> Dict[str, Any]: continue # Get engine's tool definition - tool_def = engine_cls.get_tool_definition(server_name=self.SERVER_NAME) + tool_def = engine_cls.get_tool_definition( + server_name=self.SERVER_NAME) # Add fetch_content parameter if content fetcher is available if self._content_fetcher: tool_params = dict(tool_def.get('parameters', {})) tool_props = dict(tool_params.get('properties', {})) tool_props['fetch_content'] = { - 'type': 'boolean', - 'description': ( - 'Whether to fetch and parse full page content. ' - 'Set to false for faster results with only titles/snippets. ' - f'Default is {self._fetch_content_default}. Suggested to set to True.' - ), + 'type': + 'boolean', + 'description': + ('Whether to fetch and parse full page content. ' + 'Set to false for faster results with only titles/snippets. ' + f'Default is {self._fetch_content_default}. Suggested to set to True.' + ), } tool_params['properties'] = tool_props tool_def['parameters'] = tool_params @@ -747,9 +821,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='fetch_page', server_name=self.SERVER_NAME, - description=( - 'Fetch and parse a single web page by URL. Use this when you have a specific URL to read.' - ), + description=('Fetch and parse a single web page by URL. ' + 'Use this when you have a specific URL to read.'), parameters={ 'type': 'object', 'properties': { @@ -760,12 +833,12 @@ async def _get_tools_inner(self) -> Dict[str, Any]: }, 'required': ['url'], }, - ) - ) + )) return {self.SERVER_NAME: tools} - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: """Route tool calls to appropriate handler.""" if tool_name == 'fetch_page': return await self.fetch_page(**(tool_args or {})) @@ -775,9 +848,12 @@ async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) engine_type = tool_to_engine.get(tool_name) if not engine_type or engine_type not in self._engines: - return _json_dumps( - {'status': 'error', 'message': f'Unknown tool: {tool_name}. Available: {list(tool_to_engine.keys())}'} - ) + return _json_dumps({ + 'status': + 'error', + 'message': + f'Unknown tool: {tool_name}. Available: {list(tool_to_engine.keys())}' + }) return await self._execute_search(engine_type, tool_args or {}) @@ -803,9 +879,11 @@ def _fetch_content_sync(self, url: str) -> Dict[str, Any]: content, chunk_size=self._chunk_size, overlap=self._chunk_overlap, - prefix=f'{hash(url) & 0xFFFFFF:06x}_', - ) - result['chunks'] = [{'chunk_id': c.chunk_id, 'content': c.content} for c in chunks] + prefix=f'{hash(url) & 0xFFFFFF:06x}_') + result['chunks'] = [{ + 'chunk_id': c.chunk_id, + 'content': c.content + } for c in chunks] return result except Exception as e: @@ -820,7 +898,8 @@ def _fetch_content_sync(self, url: str) -> Dict[str, Any]: async def _fetch_content_async(self, url: str) -> Dict[str, Any]: """Async wrapper for content fetching.""" loop = asyncio.get_event_loop() - return await loop.run_in_executor(self._executor, self._fetch_content_sync, url) + return await loop.run_in_executor(self._executor, + self._fetch_content_sync, url) def _url_log_preview(self, url: str, max_len: int = 220) -> str: u = (url or '').strip() @@ -853,10 +932,7 @@ async def _fetch_content_async_bounded(self, url: str) -> Dict[str, Any]: logger.warning( '[web_search] fetch TIMEOUT url=%s elapsed=%.1fs cap=%.1fs — ' 'this URL is dropped for this response; others are unchanged', - preview, - elapsed, - cap, - ) + preview, elapsed, cap) return { 'url': u, 'content': '', @@ -866,20 +942,20 @@ async def _fetch_content_async_bounded(self, url: str) -> Dict[str, Any]: 'fetch_timed_out': True, } elapsed = time.perf_counter() - t0 - src = (out or {}).get('content_source') or (out or {}).get('fetcher', '') or '' + src = (out or {}).get('content_source') or (out or {}).get( + 'fetcher', '') or '' ok = bool((out or {}).get('fetch_success')) - logger.info('[web_search] fetch done url=%s elapsed=%.2fs ok=%s source=%s', preview, elapsed, ok, src) - return ( - out - if out is not None - else { - 'url': u, - 'content': '', - 'fetch_success': False, - } - ) + logger.info( + '[web_search] fetch done url=%s elapsed=%.2fs ok=%s source=%s', + preview, elapsed, ok, src) + return out if out is not None else { + 'url': u, + 'content': '', + 'fetch_success': False, + } - async def _fetch_multiple_async(self, urls: List[str]) -> List[Dict[str, Any]]: + async def _fetch_multiple_async(self, + urls: List[str]) -> List[Dict[str, Any]]: """Fetch multiple URLs concurrently with semaphore.""" semaphore = asyncio.Semaphore(self._max_concurrent_fetch) @@ -891,30 +967,29 @@ async def _bounded_fetch(url: str) -> Dict[str, Any]: return await asyncio.gather(*tasks) def _do_search( - self, engine_type: str, engine: SearchEngine, engine_cls: Type[SearchEngine], tool_args: Dict[str, Any] + self, engine_type: str, engine: SearchEngine, + engine_cls: Type[SearchEngine], + tool_args: Dict[str, Any] ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """Perform search; returns (result rows, extra top-level metadata e.g. Tavily).""" try: merged = dict(tool_args) - if engine_type == 'tavily' and getattr(self, '_tavily_defaults', None): + if engine_type == 'tavily' and getattr(self, '_tavily_defaults', + None): merged = {**self._tavily_defaults, **merged} # Keys only for engine / fetcher YAML, not TavilySearchRequest - for _k in ( - 'request_timeout', - 'tavily_extract_depth', - 'tavily_extract_format', - 'tavily_extract_chunks_per_source', - 'tavily_extract_include_images', - 'tavily_extract_include_favicon', - 'tavily_extract_include_usage', - ): + for _k in ('request_timeout', 'tavily_extract_depth', + 'tavily_extract_format', + 'tavily_extract_chunks_per_source', + 'tavily_extract_include_images', + 'tavily_extract_include_favicon', + 'tavily_extract_include_usage'): merged.pop(_k, None) request = engine_cls.build_request_from_args(**merged) result = engine.search(request) rows = result.to_list() extra: Dict[str, Any] = {} from ms_agent.tools.search.tavily.schema import TavilySearchResult - if isinstance(result, TavilySearchResult): extra = result.extra_response_fields() return rows, extra @@ -922,7 +997,8 @@ def _do_search( logger.error(f'Search failed ({engine_type}): {e}') return [], {} - async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> str: + async def _execute_search(self, engine_type: str, + tool_args: Dict[str, Any]) -> str: """ Execute search using the specified engine. The search pipeline with optimization: @@ -938,30 +1014,40 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> """ query = tool_args.get('query', '').strip() if not query: - return _json_dumps({'status': 'error', 'message': 'Query is required.'}) + return _json_dumps({ + 'status': 'error', + 'message': 'Query is required.' + }) call_id_for_spill = str(tool_args.pop('__call_id', '') or '') # Get fetch_content preference, default to configured value - fetch_content = tool_args.pop('fetch_content', self._fetch_content_default) + fetch_content = tool_args.pop('fetch_content', + self._fetch_content_default) # Get task context for summarization (can be passed in tool_args) task_context = tool_args.pop('task_context', self._task_context) if 'num_results' not in tool_args or tool_args['num_results'] is None: - tool_args['num_results'] = 10 if engine_type == 'arxiv' else self._max_results + tool_args[ + 'num_results'] = 10 if engine_type == 'arxiv' else self._max_results engine = self._engines.get(engine_type) engine_cls = self._engine_classes.get(engine_type) if not engine or not engine_cls: - return _json_dumps({'status': 'error', 'message': f'Engine {engine_type} not initialized.'}) + return _json_dumps({ + 'status': + 'error', + 'message': + f'Engine {engine_type} not initialized.' + }) # Perform search loop = asyncio.get_event_loop() search_results, tavily_extra = await loop.run_in_executor( - self._executor, self._do_search, engine_type, engine, engine_cls, tool_args - ) + self._executor, self._do_search, engine_type, engine, engine_cls, + tool_args) if not search_results: out_empty: Dict[str, Any] = { @@ -986,13 +1072,16 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> query, top_k=self._rerank_top_k, ) - logger.info(f'Reranked {original_count} results to top {len(search_results)} for query: {query[:50]}...') + logger.info( + f'Reranked {original_count} results to top {len(search_results)} ' + f'for query: {query[:50]}...') # Step 3: Fetch content for (filtered) results (skip URLs already filled e.g. Tavily raw_content) fetch_attempts = 0 fetch_timeouts = 0 if fetch_content and self._content_fetcher: - search_results = SearchResultReranker.deduplicate_by_url(search_results) + search_results = SearchResultReranker.deduplicate_by_url( + search_results) urls: List[str] = [] for r in search_results: u = r.get('url') @@ -1004,7 +1093,8 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> if urls: fetch_attempts = len(urls) fetch_results = await self._fetch_multiple_async(urls) - fetch_timeouts = sum(1 for r in fetch_results if r.get('fetch_timed_out')) + fetch_timeouts = sum( + 1 for r in fetch_results if r.get('fetch_timed_out')) # Merge search metadata with fetched content url_to_fetch = {r['url']: r for r in fetch_results} @@ -1013,7 +1103,8 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> if url and url in url_to_fetch: fetched = url_to_fetch[url] sr['content'] = fetched.get('content', '') - sr['fetch_success'] = fetched.get('fetch_success', False) + sr['fetch_success'] = fetched.get( + 'fetch_success', False) if fetched.get('fetch_error'): sr['fetch_error'] = fetched['fetch_error'] else: @@ -1024,7 +1115,8 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> sr.pop('fetch_timed_out', None) if fetched.get('content_source'): sr['content_source'] = fetched['content_source'] - if fetched.get('published_at') and not sr.get('published_date'): + if fetched.get('published_at' + ) and not sr.get('published_date'): sr['published_at'] = fetched['published_at'] if self._enable_chunking and fetched.get('chunks'): sr['chunks'] = fetched['chunks'] @@ -1040,74 +1132,89 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> ] if contents_to_summarize: - logger.info(f'Summarizing {len(contents_to_summarize)} pages...') + logger.info( + f'Summarizing {len(contents_to_summarize)} pages...') # Summarize all contents in parallel + collect usage summaries, summarization_usage = await self._content_optimizer.summarize_contents_with_usage( contents_to_summarize, task_context=task_context, - max_concurrent=min(self._max_concurrent_summarization, len(contents_to_summarize)), + max_concurrent=min(self._max_concurrent_summarization, + len(contents_to_summarize)), ) # Update global usage totals for this tool instance (independent from LLMAgent) try: if summarization_usage: self._summary_usage_model = str( - summarization_usage.get('model') or self._summary_usage_model or '' - ) - self._summary_usage_total['api_calls'] += int(summarization_usage.get('api_calls', 0) or 0) + summarization_usage.get('model') + or self._summary_usage_model or '') + self._summary_usage_total['api_calls'] += int( + summarization_usage.get('api_calls', 0) or 0) self._summary_usage_total['prompt_tokens'] += int( - summarization_usage.get('prompt_tokens', 0) or 0 - ) + summarization_usage.get('prompt_tokens', 0) or 0) self._summary_usage_total['completion_tokens'] += int( - summarization_usage.get('completion_tokens', 0) or 0 - ) + summarization_usage.get('completion_tokens', 0) + or 0) self._summary_usage_total['cached_tokens'] += int( - summarization_usage.get('cached_tokens', 0) or 0 - ) - self._summary_usage_total['cache_creation_input_tokens'] += int( - summarization_usage.get('cache_creation_input_tokens', 0) or 0 - ) + summarization_usage.get('cached_tokens', 0) or 0) + self._summary_usage_total[ + 'cache_creation_input_tokens'] += int( + summarization_usage.get( + 'cache_creation_input_tokens', 0) or 0) # Process-wide totals (thread-safe; sub-agents may run in background threads) - model = str(summarization_usage.get('model') or self._summarizer_model) + model = str( + summarization_usage.get('model') + or self._summarizer_model) with WebSearchTool._GLOBAL_SUMMARY_USAGE_LOCK: - WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL['pages'] += int( - summarization_usage.get('pages', 0) or 0 - ) - WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL['api_calls'] += int( - summarization_usage.get('api_calls', 0) or 0 - ) - WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL['prompt_tokens'] += int( - summarization_usage.get('prompt_tokens', 0) or 0 - ) - WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL['completion_tokens'] += int( - summarization_usage.get('completion_tokens', 0) or 0 - ) - WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL['cached_tokens'] += int( - summarization_usage.get('cached_tokens', 0) or 0 - ) - WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL['cache_creation_input_tokens'] += int( - summarization_usage.get('cache_creation_input_tokens', 0) or 0 - ) + WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL[ + 'pages'] += int( + summarization_usage.get('pages', 0) or 0) + WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL[ + 'api_calls'] += int( + summarization_usage.get('api_calls', 0) + or 0) + WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL[ + 'prompt_tokens'] += int( + summarization_usage.get( + 'prompt_tokens', 0) or 0) + WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL[ + 'completion_tokens'] += int( + summarization_usage.get( + 'completion_tokens', 0) or 0) + WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL[ + 'cached_tokens'] += int( + summarization_usage.get( + 'cached_tokens', 0) or 0) + WebSearchTool._GLOBAL_SUMMARY_USAGE_TOTAL[ + 'cache_creation_input_tokens'] += int( + summarization_usage.get( + 'cache_creation_input_tokens', 0) or 0) m = WebSearchTool._GLOBAL_SUMMARY_USAGE_BY_MODEL.setdefault( - model, - { + model, { 'pages': 0, 'api_calls': 0, 'prompt_tokens': 0, 'completion_tokens': 0, 'cached_tokens': 0, 'cache_creation_input_tokens': 0, - }, - ) - m['pages'] += int(summarization_usage.get('pages', 0) or 0) - m['api_calls'] += int(summarization_usage.get('api_calls', 0) or 0) - m['prompt_tokens'] += int(summarization_usage.get('prompt_tokens', 0) or 0) - m['completion_tokens'] += int(summarization_usage.get('completion_tokens', 0) or 0) - m['cached_tokens'] += int(summarization_usage.get('cached_tokens', 0) or 0) + }) + m['pages'] += int( + summarization_usage.get('pages', 0) or 0) + m['api_calls'] += int( + summarization_usage.get('api_calls', 0) or 0) + m['prompt_tokens'] += int( + summarization_usage.get('prompt_tokens', 0) + or 0) + m['completion_tokens'] += int( + summarization_usage.get( + 'completion_tokens', 0) or 0) + m['cached_tokens'] += int( + summarization_usage.get('cached_tokens', 0) + or 0) m['cache_creation_input_tokens'] += int( - summarization_usage.get('cache_creation_input_tokens', 0) or 0 - ) + summarization_usage.get( + 'cache_creation_input_tokens', 0) or 0) logger.info( '[web_search_summarization_usage] ' f"model={summarization_usage.get('model', self._summarizer_model)} " @@ -1120,7 +1227,8 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> f"cache_creation_input_tokens={summarization_usage.get('cache_creation_input_tokens', 0)}" ) except Exception as e: - logger.warning(f'Failed to record summarization usage: {e}') + logger.warning( + f'Failed to record summarization usage: {e}') # Replace original content with summaries for sr in search_results: @@ -1130,31 +1238,42 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> sr['content'] = summaries[url] sr['content_summarized'] = True sr['original_content_length'] = original_len - logger.debug(f'Summarized content for {url[:50]}: {original_len} -> {len(sr["content"])} chars') + logger.debug( + f'Summarized content for {url[:50]}: ' + f"{original_len} -> {len(sr['content'])} chars") # Format output output_results = [] for sr in search_results: item = { - 'url': sr.get('url', ''), - 'title': sr.get('title', ''), - 'published_at': sr.get('published_date') or sr.get('published_at', ''), + 'url': + sr.get('url', ''), + 'title': + sr.get('title', ''), + 'published_at': + sr.get('published_date') or sr.get('published_at', ''), } # Preserve arXiv-specific metadata (aligned with arxiv-mcp-server) if engine_type == 'arxiv': - item.update( - { - 'id': sr.get('arxiv_id', '') or '', # arXiv short id - 'abs_url': sr.get('id', '') or '', # entry_id (abstract page) - 'pdf_url': sr.get('pdf_url', '') or '', - 'abstract': sr.get('summary', '') or '', - 'authors': sr.get('authors', []) or [], - 'categories': sr.get('categories', []) or [], - 'resource_uri': sr.get('resource_uri', '') or '', - 'published': sr.get('published_date') or sr.get('published_at', ''), - } - ) + item.update({ + 'id': + sr.get('arxiv_id', '') or '', # arXiv short id + 'abs_url': + sr.get('id', '') or '', # entry_id (abstract page) + 'pdf_url': + sr.get('pdf_url', '') or '', + 'abstract': + sr.get('summary', '') or '', + 'authors': + sr.get('authors', []) or [], + 'categories': + sr.get('categories', []) or [], + 'resource_uri': + sr.get('resource_uri', '') or '', + 'published': + sr.get('published_date') or sr.get('published_at', ''), + }) if fetch_content: item['content'] = sr.get('content', '') @@ -1168,7 +1287,8 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> # Add summarization metadata if applicable if sr.get('content_summarized'): item['content_summarized'] = True - item['original_content_length'] = sr.get('original_content_length', 0) + item['original_content_length'] = sr.get( + 'original_content_length', 0) if self._enable_chunking and sr.get('chunks'): item['chunks'] = sr['chunks'] @@ -1197,9 +1317,12 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> } if fetch_content and self._content_fetcher: response['fetch_stats'] = { - 'per_url_timeout_s': self._per_url_fetch_timeout_s, - 'urls_fetched_this_call': fetch_attempts, - 'urls_timed_out': fetch_timeouts, + 'per_url_timeout_s': + self._per_url_fetch_timeout_s, + 'urls_fetched_this_call': + fetch_attempts, + 'urls_timed_out': + fetch_timeouts, } if tavily_extra: response.update(tavily_extra) @@ -1211,25 +1334,34 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> 'summarization_enabled': self._enable_summarization, } if self._enable_rerank: - response['optimization']['original_result_count'] = original_count + response['optimization'][ + 'original_result_count'] = original_count response['optimization']['filtered_to'] = len(output_results) if self._enable_summarization: - summarized_count = sum(1 for r in output_results if r.get('content_summarized')) + summarized_count = sum(1 for r in output_results + if r.get('content_summarized')) response['optimization']['pages_summarized'] = summarized_count # Include per-call usage + cumulative totals (separate from LLMAgent usage) if summarization_usage: - response['optimization']['summarization_usage'] = summarization_usage + response['optimization'][ + 'summarization_usage'] = summarization_usage response['optimization']['summarization_usage_total'] = { - 'model': self._summary_usage_model or self._summarizer_model, - 'api_calls': self._summary_usage_total.get('api_calls', 0), - 'prompt_tokens': self._summary_usage_total.get('prompt_tokens', 0), - 'completion_tokens': self._summary_usage_total.get('completion_tokens', 0), - 'total_tokens': ( - self._summary_usage_total.get('prompt_tokens', 0) - + self._summary_usage_total.get('completion_tokens', 0) - ), - 'cached_tokens': self._summary_usage_total.get('cached_tokens', 0), - 'cache_creation_input_tokens': self._summary_usage_total.get('cache_creation_input_tokens', 0), + 'model': + self._summary_usage_model or self._summarizer_model, + 'api_calls': + self._summary_usage_total.get('api_calls', 0), + 'prompt_tokens': + self._summary_usage_total.get('prompt_tokens', 0), + 'completion_tokens': + self._summary_usage_total.get('completion_tokens', 0), + 'total_tokens': + (self._summary_usage_total.get('prompt_tokens', 0) + + self._summary_usage_total.get('completion_tokens', 0)), + 'cached_tokens': + self._summary_usage_total.get('cached_tokens', 0), + 'cache_creation_input_tokens': + self._summary_usage_total.get( + 'cache_creation_input_tokens', 0), } # Process-wide totals so far (across all WebSearchTool instances) response['optimization'][ @@ -1237,7 +1369,8 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> ] = WebSearchTool.get_global_summarization_usage() # yapf: disable if self._spill_enabled: - od = getattr(self, 'output_dir', None) or getattr(getattr(self, 'config', None), 'output_dir', '') or '' + od = getattr(self, 'output_dir', None) or getattr( + getattr(self, 'config', None), 'output_dir', '') or '' if od: try: new_results, spill_meta = maybe_spill_web_search_payload( @@ -1254,34 +1387,47 @@ async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> response['results'] = new_results response['spill'] = spill_meta except Exception as e: - logger.warning(f'web_search spill failed (returning full inline JSON): {e}') + logger.warning( + f'web_search spill failed (returning full inline JSON): {e}' + ) return _json_dumps(response) async def fetch_page(self, url: str) -> str: """Fetch and parse a single web page.""" if not url or not url.strip(): - return _json_dumps({'status': 'error', 'message': 'URL is required.'}) + return _json_dumps({ + 'status': 'error', + 'message': 'URL is required.' + }) result = await self._fetch_content_async_bounded(url.strip()) - return _json_dumps( - { - 'status': 'ok' if result.get('fetch_success') else 'error', - 'url': url, - 'content': result.get('content', ''), - 'published_at': result.get('published_at', ''), - 'fetch_success': result.get('fetch_success', False), - 'fetch_error': result.get('fetch_error', ''), - 'fetch_timed_out': bool(result.get('fetch_timed_out')), - 'chunks': result.get('chunks') if self._enable_chunking else None, - } - ) + return _json_dumps({ + 'status': + 'ok' if result.get('fetch_success') else 'error', + 'url': + url, + 'content': + result.get('content', ''), + 'published_at': + result.get('published_at', ''), + 'fetch_success': + result.get('fetch_success', False), + 'fetch_error': + result.get('fetch_error', ''), + 'fetch_timed_out': + bool(result.get('fetch_timed_out')), + 'chunks': + result.get('chunks') if self._enable_chunking else None, + }) # Backward compatibility aliases - async def web_search( - self, query: str, num_results: Optional[int] = None, fetch_content: bool = True, **kwargs - ) -> str: + async def web_search(self, + query: str, + num_results: Optional[int] = None, + fetch_content: bool = True, + **kwargs) -> str: """ Search the web and optionally fetch page content. @@ -1291,7 +1437,12 @@ async def web_search( # Use first engine as default engine_type = self._engine_types[0] if self._engine_types else 'arxiv' - tool_args = {'query': query, 'num_results': num_results, 'fetch_content': fetch_content, **kwargs} + tool_args = { + 'query': query, + 'num_results': num_results, + 'fetch_content': fetch_content, + **kwargs + } return await self._execute_search(engine_type, tool_args) diff --git a/ms_agent/tools/search_engine.py b/ms_agent/tools/search_engine.py index 20ad080d2..6bb0be4d6 100644 --- a/ms_agent/tools/search_engine.py +++ b/ms_agent/tools/search_engine.py @@ -3,7 +3,6 @@ from typing import Any, Dict, Optional from dotenv import load_dotenv - from ms_agent.config.env import Env from ms_agent.tools.search.arxiv import ArxivSearch from ms_agent.tools.search.exa import ExaSearch @@ -30,7 +29,10 @@ def set_search_env_overrides(env_overrides: Optional[Dict[str, str]]) -> None: if hasattr(_search_env_local, 'overrides'): delattr(_search_env_local, 'overrides') return - _search_env_local.overrides = {k: v for k, v in env_overrides.items() if v is not None} + _search_env_local.overrides = { + k: v + for k, v in env_overrides.items() if v is not None + } def get_search_env_overrides() -> Dict[str, str]: @@ -60,11 +62,12 @@ def load_base_config(file_path: str) -> Dict[str, Any]: Env.load_env() if not os.path.exists(file_path): - logger.warning(f'Config file {file_path} does not exist. Using default config (ArxivSearch).') + logger.warning( + f'Config file {file_path} does not exist. Using default config (ArxivSearch).' + ) return {} import yaml - with open(file_path, 'r') as file: config = yaml.safe_load(file) @@ -127,16 +130,12 @@ def get_web_search_tool(config_file: str): # Engine override precedence: # 1) Thread-local override (per-request, e.g. FinResearch UI) # 2) Global environment variable (shared default) - engine_override = ( - ((local_env.get(SEARCH_ENGINE_OVERRIDE_ENV, '') or '') or (os.getenv(SEARCH_ENGINE_OVERRIDE_ENV, '') or '')) - .strip() - .lower() - ) - if engine_override and engine_override in ( - SearchEngineType.EXA.value, - SearchEngineType.SERPAPI.value, - SearchEngineType.ARXIV.value, - ): + engine_override = ((local_env.get(SEARCH_ENGINE_OVERRIDE_ENV, '') or '') + or (os.getenv(SEARCH_ENGINE_OVERRIDE_ENV, '') + or '')).strip().lower() + if engine_override and engine_override in (SearchEngineType.EXA.value, + SearchEngineType.SERPAPI.value, + SearchEngineType.ARXIV.value): search_config['engine'] = engine_override engine_name = (search_config.get('engine', '') or '').lower() @@ -146,12 +145,14 @@ def get_web_search_tool(config_file: str): override_serp_key = local_env.get('SERPAPI_API_KEY') if engine_name == SearchEngineType.EXA.value: - return ExaSearch(api_key=override_exa_key or search_config.get('exa_api_key', os.getenv('EXA_API_KEY', None))) + return ExaSearch( + api_key=override_exa_key or search_config.get( + 'exa_api_key', os.getenv('EXA_API_KEY', None))) elif engine_name == SearchEngineType.SERPAPI.value: return SerpApiSearch( - api_key=override_serp_key or search_config.get('serpapi_api_key', os.getenv('SERPAPI_API_KEY', None)), - provider=search_config.get('provider', 'google').lower(), - ) + api_key=override_serp_key or search_config.get( + 'serpapi_api_key', os.getenv('SERPAPI_API_KEY', None)), + provider=search_config.get('provider', 'google').lower()) elif engine_name == SearchEngineType.ARXIV.value: return ArxivSearch() else: diff --git a/ms_agent/tools/task_control_tool.py b/ms_agent/tools/task_control_tool.py index 29ae693c8..01bd75f64 100644 --- a/ms_agent/tools/task_control_tool.py +++ b/ms_agent/tools/task_control_tool.py @@ -2,11 +2,10 @@ import json from typing import Any, Dict, Optional -from omegaconf import DictConfig - from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.logger import get_logger +from omegaconf import DictConfig logger = get_logger() @@ -48,8 +47,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: server_name=_SERVER, description=( 'List all background tasks and their current status. ' - 'Returns task_id, tool_name, description, status, and duration.' - ), + 'Returns task_id, tool_name, description, status, and duration.'), parameters={ 'type': 'object', 'properties': {}, @@ -76,7 +74,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: ] } - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: if self._task_manager is None: return 'TaskManager not available.' @@ -91,17 +90,14 @@ async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) duration = f'{t.ended_at - t.started_at:.1f}s' elif t.status == 'running': import time - duration = f'{time.monotonic() - t.started_at:.1f}s (running)' - rows.append( - { - 'task_id': t.task_id, - 'tool_name': t.tool_name, - 'description': t.description, - 'status': t.status, - 'duration': duration, - } - ) + rows.append({ + 'task_id': t.task_id, + 'tool_name': t.tool_name, + 'description': t.description, + 'status': t.status, + 'duration': duration, + }) return json.dumps(rows, ensure_ascii=False, indent=2) if tool_name == 'cancel_task': diff --git a/ms_agent/tools/todolist_tool.py b/ms_agent/tools/todolist_tool.py index 9ab5cbced..aee860134 100644 --- a/ms_agent/tools/todolist_tool.py +++ b/ms_agent/tools/todolist_tool.py @@ -1,9 +1,9 @@ -import json import os import time from dataclasses import dataclass from typing import Any, Dict, List, Optional +import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.utils import file_lock, render_markdown_todo @@ -40,14 +40,18 @@ def _write_text(path: str, content: str) -> None: def _validate_status(status: str) -> str: allowed = {'pending', 'in_progress', 'completed', 'cancelled'} if status not in allowed: - raise ValueError(f'Invalid todo status "{status}", must be one of {sorted(allowed)}.') + raise ValueError( + f'Invalid todo status "{status}", must be one of {sorted(allowed)}.' + ) return status def _validate_priority(priority: str) -> str: allowed = {'high', 'medium', 'low'} if priority not in allowed: - raise ValueError(f'Invalid todo priority "{priority}", must be one of {sorted(allowed)}.') + raise ValueError( + f'Invalid todo priority "{priority}", must be one of {sorted(allowed)}.' + ) return priority @@ -78,10 +82,14 @@ def __init__(self, config, **kwargs): tool_cfg = getattr(getattr(config, 'tools'), 'todo_list') self.exclude_func(tool_cfg) - self._plan_filename = getattr(tool_cfg, 'plan_filename', 'plan.json') if tool_cfg else 'plan.json' - self._plan_md_filename = getattr(tool_cfg, 'plan_md_filename', 'plan.md') if tool_cfg else 'plan.md' - self._lock_subdir = getattr(tool_cfg, 'lock_subdir', '.locks') if tool_cfg else '.locks' - self._auto_render_md = bool(getattr(tool_cfg, 'auto_render_md', True)) if tool_cfg else True + self._plan_filename = getattr(tool_cfg, 'plan_filename', + 'plan.json') if tool_cfg else 'plan.json' + self._plan_md_filename = getattr(tool_cfg, 'plan_md_filename', + 'plan.md') if tool_cfg else 'plan.md' + self._lock_subdir = getattr(tool_cfg, 'lock_subdir', + '.locks') if tool_cfg else '.locks' + self._auto_render_md = bool(getattr(tool_cfg, 'auto_render_md', + True)) if tool_cfg else True async def connect(self) -> None: # Nothing to connect; file-based tool. @@ -101,47 +109,57 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='todo_write', server_name=self.SERVER_NAME, - description=( - 'Create or update the structured todo list (plan.json) for this session/workdir. ' - 'Use merge=true to merge by id (partial updates allowed for existing ids); ' - 'merge=false replaces the list (full items required).' - ), + description= + ('Create or update the structured todo list (plan.json) for this session/workdir. ' + 'Use merge=true to merge by id (partial updates allowed for existing ids); ' + 'merge=false replaces the list (full items required).'), parameters={ 'type': 'object', 'properties': { 'merge': { - 'type': 'boolean', - 'description': ( - 'If true, merge todo items into existing list by id (partial updates allowed). ' - 'If false, replace the list entirely.' - ), - 'default': True, + 'type': + 'boolean', + 'description': + ('If true, merge todo items into existing list by id (partial updates allowed). ' + 'If false, replace the list entirely.'), + 'default': + True, }, 'todos': { 'type': 'array', - 'description': 'The updated/created todo list.', + 'description': + 'The updated/created todo list.', 'items': { 'type': 'object', 'properties': { 'id': { - 'type': 'string', - 'description': ( - 'Unique identifier for the todo item. e.g. "T_1", "T_2", ...' - ), + 'type': + 'string', + 'description': + ('Unique identifier for the todo item. ' + 'e.g. "T_1", "T_2", ...'), }, 'content': { - 'type': 'string', - 'description': 'Brief description of the task', + 'type': + 'string', + 'description': + 'Brief description of the task', }, 'status': { - 'type': 'string', - 'enum': ['pending', 'in_progress', 'completed', 'cancelled'], - 'description': 'Current status of the task', + 'type': + 'string', + 'enum': [ + 'pending', 'in_progress', + 'completed', 'cancelled' + ], + 'description': + 'Current status of the task', }, 'priority': { 'type': 'string', 'enum': ['high', 'medium', 'low'], - 'description': 'Priority level of the task', + 'description': + 'Priority level of the task', 'default': 'medium', }, }, @@ -160,7 +178,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='todo_read', server_name=self.SERVER_NAME, - description='Read the current todo list for this session/workdir.', + description= + 'Read the current todo list for this session/workdir.', parameters={ 'type': 'object', 'properties': {}, @@ -171,16 +190,17 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='todo_render_md', server_name=self.SERVER_NAME, - description='Render plan.md from plan.json (checkbox view).', + description= + 'Render plan.md from plan.json (checkbox view).', parameters={ 'type': 'object', 'properties': { 'path': { - 'type': 'string', - 'description': ( - 'Optional relative output path for the markdown file. ' - 'Defaults to plan.md in the workdir.' - ), + 'type': + 'string', + 'description': + ('Optional relative output path for the markdown file. ' + 'Defaults to plan.md in the workdir.'), } }, 'required': [], @@ -191,7 +211,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: } return tools - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: return await getattr(self, tool_name)(**(tool_args or {})) def _load_plan_locked(self, paths: _PlanPaths) -> Dict[str, Any]: @@ -223,13 +244,15 @@ def _load_plan_locked(self, paths: _PlanPaths) -> Dict[str, Any]: data['updated_at'] = _now_iso() return data - def _save_plan_locked(self, paths: _PlanPaths, plan: Dict[str, Any]) -> None: + def _save_plan_locked(self, paths: _PlanPaths, plan: Dict[str, + Any]) -> None: plan = dict(plan or {}) plan['schema_version'] = int(plan.get('schema_version', 1) or 1) plan['updated_at'] = _now_iso() _write_text(paths.plan_json, _json_dumps(plan)) - def _normalize_todos(self, todos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _normalize_todos(self, todos: List[Dict[str, + Any]]) -> List[Dict[str, Any]]: normalized: List[Dict[str, Any]] = [] for idx, item in enumerate(todos or []): if not isinstance(item, dict): @@ -239,9 +262,11 @@ def _normalize_todos(self, todos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: status = str(item.get('status', '')).strip() priority = str(item.get('priority', 'medium') or 'medium').strip() if not todo_id: - raise ValueError(f'todos[{idx}].id is required and must be non-empty.') + raise ValueError( + f'todos[{idx}].id is required and must be non-empty.') if not content: - raise ValueError(f'todos[{idx}].content is required and must be non-empty.') + raise ValueError( + f'todos[{idx}].content is required and must be non-empty.') _validate_status(status) _validate_priority(priority) # Keep extra fields as-is @@ -275,7 +300,8 @@ def _normalize_todo_updates( todo_id = str(item.get('id', '')).strip() if not todo_id: - raise ValueError(f'todos[{idx}].id is required and must be non-empty.') + raise ValueError( + f'todos[{idx}].id is required and must be non-empty.') is_new = todo_id not in existing_ids @@ -286,20 +312,27 @@ def _normalize_todo_updates( if 'content' in item: content = str(item.get('content', '')).strip() if not content: - raise ValueError(f'todos[{idx}].content is required and must be non-empty.') + raise ValueError( + f'todos[{idx}].content is required and must be non-empty.' + ) upd['content'] = content elif is_new: - raise ValueError(f'todos[{idx}] is a new id "{todo_id}" so content is required.') + raise ValueError( + f'todos[{idx}] is a new id "{todo_id}" so content is required.' + ) if 'status' in item: status = str(item.get('status', '')).strip() _validate_status(status) upd['status'] = status elif is_new: - raise ValueError(f'todos[{idx}] is a new id "{todo_id}" so status is required.') + raise ValueError( + f'todos[{idx}] is a new id "{todo_id}" so status is required.' + ) if 'priority' in item: - priority = str(item.get('priority', 'medium') or 'medium').strip() + priority = str(item.get('priority', 'medium') + or 'medium').strip() _validate_priority(priority) upd['priority'] = priority @@ -307,11 +340,16 @@ def _normalize_todo_updates( return normalized - def _merge_todos(self, base: List[Dict[str, Any]], updates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _merge_todos(self, base: List[Dict[str, Any]], + updates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: base_by_id: Dict[str, Dict[str, Any]] = { - str(t.get('id')): dict(t) for t in (base or []) if isinstance(t, dict) and t.get('id') + str(t.get('id')): dict(t) + for t in (base or []) if isinstance(t, dict) and t.get('id') } - order: List[str] = [str(t.get('id')) for t in (base or []) if isinstance(t, dict) and t.get('id')] + order: List[str] = [ + str(t.get('id')) for t in (base or []) + if isinstance(t, dict) and t.get('id') + ] for upd in updates or []: tid = str(upd.get('id')) if tid in base_by_id: @@ -348,7 +386,9 @@ def _render_plan_md_text(self, plan: Dict[str, Any]) -> str: lines.append('') return '\n'.join(lines) - async def todo_write(self, todos: List[Dict[str, Any]], merge: bool = True) -> str: + async def todo_write(self, + todos: List[Dict[str, Any]], + merge: bool = True) -> str: paths = self._paths() _ensure_dir(self.output_dir) _ensure_dir(paths.lock_dir) @@ -360,7 +400,8 @@ async def todo_write(self, todos: List[Dict[str, Any]], merge: bool = True) -> s # For merge=true, allow partial updates for existing ids. existing_full = self._normalize_todos(existing) existing_ids = {str(t.get('id')) for t in existing_full} - updates = self._normalize_todo_updates(todos, existing_ids=existing_ids) + updates = self._normalize_todo_updates( + todos, existing_ids=existing_ids) merged = self._merge_todos(existing_full, updates) plan['todos'] = self._normalize_todos(merged) else: @@ -372,16 +413,18 @@ async def todo_write(self, todos: List[Dict[str, Any]], merge: bool = True) -> s md_text = self._render_plan_md_text(plan) _write_text(paths.plan_md, md_text) - render_markdown_todo(paths.plan_md, title='CURRENT PLAN', use_pager=False) + render_markdown_todo( + paths.plan_md, title='CURRENT PLAN', use_pager=False) # Return a JSON list (opencode-style) so the model can easily read it. - return _json_dumps( - { - 'status': 'ok', - 'plan_path': os.path.relpath(paths.plan_json, self.output_dir), - 'todos': plan.get('todos', []), - } - ) + return _json_dumps({ + 'status': + 'ok', + 'plan_path': + os.path.relpath(paths.plan_json, self.output_dir), + 'todos': + plan.get('todos', []), + }) async def todo_read(self) -> str: paths = self._paths() @@ -390,7 +433,8 @@ async def todo_read(self) -> str: with file_lock(paths.lock_dir, self._plan_filename): plan = self._load_plan_locked(paths) if self._auto_render_md: - render_markdown_todo(paths.plan_md, title='CURRENT PLAN', use_pager=False) + render_markdown_todo( + paths.plan_md, title='CURRENT PLAN', use_pager=False) return _json_dumps(plan.get('todos', [])) @@ -398,7 +442,8 @@ async def todo_render_md(self, path: Optional[str] = None) -> str: paths = self._paths() _ensure_dir(self.output_dir) _ensure_dir(paths.lock_dir) - out_path = paths.plan_md if not path else os.path.join(self.output_dir, path) + out_path = paths.plan_md if not path else os.path.join( + self.output_dir, path) with file_lock(paths.lock_dir, self._plan_filename): plan = self._load_plan_locked(paths) diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index f108b6b87..70703fbfb 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -2,7 +2,6 @@ import asyncio import importlib import inspect -import json import os import sys import uuid @@ -10,6 +9,7 @@ from types import TracebackType from typing import Any, Dict, List, Optional +import json from ms_agent.llm.utils import Tool, ToolCall from ms_agent.tools.agent_tool import AgentTool from ms_agent.tools.base import ToolBase @@ -33,42 +33,55 @@ class ToolManager: - """Interacting with Agent class, hold all tools""" + """Interacting with Agent class, hold all tools + """ TOOL_SPLITER = '---' - def __init__( - self, config, mcp_config: Optional[Dict[str, Any]] = None, mcp_client: Optional[MCPClient] = None, **kwargs - ): + def __init__(self, + config, + mcp_config: Optional[Dict[str, Any]] = None, + mcp_client: Optional[MCPClient] = None, + **kwargs): self.config = config self.trust_remote_code = kwargs.get('trust_remote_code', False) self.extra_tools: List[ToolBase] = [] self.has_split_task_tool = False - if hasattr(config, 'tools') and hasattr(config.tools, 'image_generator'): + if hasattr(config, 'tools') and hasattr(config.tools, + 'image_generator'): self.extra_tools.append(ImageGenerator(config)) - if hasattr(config, 'tools') and hasattr(config.tools, 'video_generator'): + if hasattr(config, 'tools') and hasattr(config.tools, + 'video_generator'): self.extra_tools.append(VideoGenerator(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'file_system'): - self.extra_tools.append(FileSystemTool(config, trust_remote_code=self.trust_remote_code)) + self.extra_tools.append( + FileSystemTool( + config, trust_remote_code=self.trust_remote_code)) if hasattr(config, 'tools') and hasattr(config.tools, 'code_executor'): code_exec_cfg = getattr(config.tools, 'code_executor') - implementation = getattr(code_exec_cfg, 'implementation', 'sandbox') - if isinstance(implementation, str) and implementation.lower() == 'python_env': + implementation = getattr(code_exec_cfg, 'implementation', + 'sandbox') + if isinstance(implementation, + str) and implementation.lower() == 'python_env': self.extra_tools.append(LocalCodeExecutionTool(config)) - elif isinstance(implementation, str) and implementation.lower() == 'sandbox': + elif isinstance(implementation, + str) and implementation.lower() == 'sandbox': self.extra_tools.append(CodeExecutionTool(config)) else: - logger.warning(f'Unknown code execution implementation: {implementation},using sandbox instead.') + logger.warning( + f'Unknown code execution implementation: {implementation},' + f'using sandbox instead.') self.extra_tools.append(CodeExecutionTool(config)) - if hasattr(config, 'tools') and hasattr(config.tools, 'financial_data_fetcher'): + if hasattr(config, 'tools') and hasattr(config.tools, + 'financial_data_fetcher'): from ms_agent.tools.findata.findata_fetcher import FinancialDataFetcher - self.extra_tools.append(FinancialDataFetcher(config)) if hasattr(config, 'tools') and ( - getattr(config.tools, 'agent_tools', None) or hasattr(config.tools, 'split_task') - ): - agent_tool = AgentTool(config, trust_remote_code=self.trust_remote_code) + getattr(config.tools, 'agent_tools', None) + or hasattr(config.tools, 'split_task')): + agent_tool = AgentTool( + config, trust_remote_code=self.trust_remote_code) if agent_tool.enabled: self.extra_tools.append(agent_tool) if hasattr(config, 'tools') and hasattr(config.tools, 'todo_list'): @@ -79,11 +92,13 @@ def __init__( self.extra_tools.append(LocalSearchTool(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'task_control'): from ms_agent.tools.task_control_tool import TaskControlTool - self.extra_tools.append(TaskControlTool(config)) - self.tool_call_timeout = getattr(config, 'tool_call_timeout', TOOL_CALL_TIMEOUT) - local_dir = self.config.local_dir if hasattr(self.config, 'local_dir') else None - if hasattr(config, 'tools') and hasattr(config.tools, TOOL_PLUGIN_NAME): + self.tool_call_timeout = getattr(config, 'tool_call_timeout', + TOOL_CALL_TIMEOUT) + local_dir = self.config.local_dir if hasattr(self.config, + 'local_dir') else None + if hasattr(config, 'tools') and hasattr(config.tools, + TOOL_PLUGIN_NAME): plugins = getattr(config.tools, TOOL_PLUGIN_NAME) for plugin in plugins: subdir = os.path.dirname(plugin) @@ -104,7 +119,11 @@ def __init__( if _plugin.endswith('.py'): _plugin = _plugin[:-3] plugin_file = importlib.import_module(_plugin) - module_classes = {name: cls for name, cls in inspect.getmembers(plugin_file, inspect.isclass)} + module_classes = { + name: cls + for name, cls in inspect.getmembers( + plugin_file, inspect.isclass) + } for name, cls in module_classes.items(): # Find cls which base class is `ToolBase` if issubclass(cls, ToolBase) and cls.__module__ == _plugin: @@ -154,12 +173,15 @@ async def cleanup(self): pass async def reindex_tool(self): - def extend_tool(tool_ins: ToolBase, server_name: str, tool_list: List[Tool]): + + def extend_tool(tool_ins: ToolBase, server_name: str, + tool_list: List[Tool]): for tool in tool_list: # Subtract the length of the tool name splitter - max_server_len = MAX_TOOL_NAME_LEN - len(tool['tool_name']) - len(self.TOOL_SPLITER) + max_server_len = MAX_TOOL_NAME_LEN - len( + tool['tool_name']) - len(self.TOOL_SPLITER) if len(server_name) > max_server_len: - key = f"{server_name[: max(0, max_server_len)]}{self.TOOL_SPLITER}{tool['tool_name']}" + key = f"{server_name[:max(0, max_server_len)]}{self.TOOL_SPLITER}{tool['tool_name']}" else: key = f"{server_name}{self.TOOL_SPLITER}{tool['tool_name']}" assert key not in self._tool_index, f'Tool name duplicated {tool["tool_name"]}' @@ -179,7 +201,7 @@ async def get_tools(self): # Return tools in deterministic order to improve prompt/prefix cache hit rate # across process restarts and across different MCP tool listing orders. tools = [value[2] for value in self._tool_index.values()] - return sorted(tools, key=lambda t: (t.get('tool_name', ''),)) + return sorted(tools, key=lambda t: (t.get('tool_name', ''), )) async def single_call_tool(self, tool_info: ToolCall): if self._concurrent_limiter is None: @@ -187,7 +209,8 @@ async def single_call_tool(self, tool_info: ToolCall): self._init_lock = asyncio.Lock() async with self._init_lock: if self._concurrent_limiter is None: - self._concurrent_limiter = asyncio.Semaphore(MAX_CONCURRENT_TOOLS) + self._concurrent_limiter = asyncio.Semaphore( + MAX_CONCURRENT_TOOLS) async with self._concurrent_limiter: brief_info = json.dumps(tool_info, ensure_ascii=False) @@ -208,27 +231,27 @@ async def single_call_tool(self, tool_info: ToolCall): call_args = dict(tool_args or {}) call_id = tool_info.get('id') or str(uuid.uuid4()) call_args['__call_id'] = call_id - elif isinstance(tool_ins, LocalCodeExecutionTool) and tool_name.endswith( - f'{self.TOOL_SPLITER}shell_executor' - ): + elif isinstance( + tool_ins, + LocalCodeExecutionTool) and tool_name.endswith( + f'{self.TOOL_SPLITER}shell_executor'): call_args = dict(tool_args or {}) - call_args['__call_id'] = tool_info.get('id') or str(uuid.uuid4()) + call_args['__call_id'] = tool_info.get('id') or str( + uuid.uuid4()) response = await asyncio.wait_for( tool_ins.call_tool( - server_name, tool_name=tool_name.split(self.TOOL_SPLITER)[1], tool_args=call_args - ), - timeout=self.tool_call_timeout, - ) + server_name, + tool_name=tool_name.split(self.TOOL_SPLITER)[1], + tool_args=call_args), + timeout=self.tool_call_timeout) return response except asyncio.TimeoutError: import traceback - logger.warning(traceback.format_exc()) # TODO: How to get the information printed by the tool before hanging to return to the model? return f'Execute tool call timeout: {brief_info}' except Exception as e: import traceback - logger.warning(traceback.format_exc()) return f'Tool calling failed: {brief_info}, details: {str(e)}' @@ -238,6 +261,7 @@ async def parallel_call_tool(self, tool_list: List[ToolCall]): return result async def __aenter__(self) -> 'ToolManager': + return self async def __aexit__( diff --git a/ms_agent/tools/video_generator/ds_video_gen.py b/ms_agent/tools/video_generator/ds_video_gen.py index 48db015df..d757038d3 100644 --- a/ms_agent/tools/video_generator/ds_video_gen.py +++ b/ms_agent/tools/video_generator/ds_video_gen.py @@ -8,27 +8,32 @@ class DSVideoGenerator: + def __init__(self, config, temp_dir): self.config = config self.temp_dir = temp_dir os.makedirs(self.temp_dir, exist_ok=True) - async def generate_video(self, positive_prompt, size='1280x720', seconds=4): + async def generate_video(self, + positive_prompt, + size='1280x720', + seconds=4): video_generator = self.config.tools.video_generator - base_url = (getattr(video_generator, 'base_url', None) or 'https://dashscope.aliyuncs.com').strip('/') + base_url = (getattr(video_generator, 'base_url', None) + or 'https://dashscope.aliyuncs.com').strip('/') api_key = video_generator.api_key model_id = video_generator.model assert api_key is not None task_id = str(uuid.uuid4())[:8] output_file = os.path.join(self.temp_dir, f'{task_id}.mp4') - video_url = await self._generate_video(base_url, api_key, model_id, positive_prompt, size, seconds) + video_url = await self._generate_video(base_url, api_key, model_id, + positive_prompt, size, seconds) await self.download_video(video_url, output_file) return output_file @staticmethod async def download_video(video_url, output_file): import aiohttp - max_retries = 3 retry_count = 0 @@ -36,7 +41,8 @@ async def download_video(video_url, output_file): headers = {} while retry_count < max_retries: try: - async with session.get(video_url, headers=headers) as video_resp: + async with session.get( + video_url, headers=headers) as video_resp: video_resp.raise_for_status() video_content = await video_resp.read() with open(output_file, 'wb') as f: @@ -52,7 +58,6 @@ async def download_video(video_url, output_file): @staticmethod async def _generate_video(base_url, api_key, model, prompt, size, seconds): import aiohttp - base_url = base_url.strip('/') create_endpoint = '/api/v1/services/aigc/model-evaluation/async-inference/' @@ -68,19 +73,23 @@ async def _generate_video(base_url, api_key, model, prompt, size, seconds): 'size': size, 'seconds': seconds, }, - 'parameters': {}, + 'parameters': {} } async with aiohttp.ClientSession() as session: - async with session.post(f'{base_url}{create_endpoint}', headers=headers, json=payload) as resp: + async with session.post( + f'{base_url}{create_endpoint}', headers=headers, + json=payload) as resp: resp.raise_for_status() response_data = await resp.json() task_id = response_data['output']['task_id'] if not task_id: - raise RuntimeError(f'No task ID in response: {response_data}') + raise RuntimeError( + f'No task ID in response: {response_data}') - return await DSVideoGenerator._poll_video_task(session, base_url, task_id, headers) + return await DSVideoGenerator._poll_video_task( + session, base_url, task_id, headers) @staticmethod async def _poll_video_task(session, base_url, task_id, headers): @@ -97,21 +106,28 @@ async def _poll_video_task(session, base_url, task_id, headers): await asyncio.sleep(poll_interval) elapsed_time += poll_interval - async with session.get(f'{base_url}{poll_endpoint}', headers=headers) as result: + async with session.get( + f'{base_url}{poll_endpoint}', headers=headers) as result: result.raise_for_status() data = await result.json() status = data['output']['task_status'] - logger.info(f'Task {task_id} status: {status}, detailed message: {str(data)}') + logger.info( + f'Task {task_id} status: {status}, detailed message: {str(data)}' + ) if status in success_statuses: video_url = data['output']['video_url'] if not video_url: - raise RuntimeError(f'Video URL not found in response: {data}') + raise RuntimeError( + f'Video URL not found in response: {data}') return video_url elif status in failed_statuses: - error_msg = data['output'].get('message') or 'Unknown error' + error_msg = data['output'].get( + 'message') or 'Unknown error' raise RuntimeError(f'Video generation failed: {error_msg}') poll_interval = min(poll_interval * 1.2, max_poll_interval) - raise TimeoutError(f'Video generation task {task_id} timed out after {max_wait_time} seconds') + raise TimeoutError( + f'Video generation task {task_id} timed out after {max_wait_time} seconds' + ) diff --git a/ms_agent/tools/video_generator/video_gen.py b/ms_agent/tools/video_generator/video_gen.py index 9af5e2599..7578ebf59 100644 --- a/ms_agent/tools/video_generator/video_gen.py +++ b/ms_agent/tools/video_generator/video_gen.py @@ -6,14 +6,15 @@ class VideoGenerator(ToolBase): + def __init__(self, config): super().__init__(config) - self.temp_dir = os.path.join(self.output_dir, '.temp', 'video_generator') + self.temp_dir = os.path.join(self.output_dir, '.temp', + 'video_generator') os.makedirs(self.temp_dir, exist_ok=True) video_generator = self.config.video_generator if video_generator.type == 'dashscope': from .ds_video_gen import DSVideoGenerator - self.generator = DSVideoGenerator(self.config, self.temp_dir) else: raise NotImplementedError() @@ -27,25 +28,32 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='generate_video', server_name='video_generator', - description='Generate a video with a positive prompt, and return the video file path.', + description= + 'Generate a video with a positive prompt, and return the video file path.', parameters={ 'type': 'object', 'properties': { - 'positive_prompt': {'type': 'string', 'description': 'The prompt to generate the image.'}, - 'seconds': { - 'type': 'integer', - 'description': 'The generated video seconds, supported is 4/8/12', + 'positive_prompt': { + 'type': 'string', + 'description': + 'The prompt to generate the image.' }, + 'seconds': { + 'type': + 'integer', + 'description': + 'The generated video seconds, supported is 4/8/12' + } }, 'required': ['positive_prompt'], - 'additionalProperties': False, - }, - ) + 'additionalProperties': False + }) ] } async def generate_video(self, positive_prompt, **kwargs): return await self.generator.generate_video(positive_prompt, **kwargs) - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: return await self.generate_video(**tool_args) diff --git a/ms_agent/utils/__init__.py b/ms_agent/utils/__init__.py index 09061daab..32655bd93 100644 --- a/ms_agent/utils/__init__.py +++ b/ms_agent/utils/__init__.py @@ -2,6 +2,7 @@ from .llm_utils import async_retry, retry from .logger import get_logger from .prompt import get_fact_retrieval_prompt -from .utils import assert_package_exist, enhance_error, read_history, save_history, strtobool +from .utils import (assert_package_exist, enhance_error, read_history, + save_history, strtobool) MAX_CONTINUE_RUNS = 3 diff --git a/ms_agent/utils/artifact_manager.py b/ms_agent/utils/artifact_manager.py index 92ac60604..9953f25c0 100644 --- a/ms_agent/utils/artifact_manager.py +++ b/ms_agent/utils/artifact_manager.py @@ -49,7 +49,8 @@ def pack_text_result( out.update(extra) return out - safe_id = ''.join(c if c.isalnum() or c in '-_' else '_' for c in call_id)[:120] or 'call' + safe_id = ''.join(c if c.isalnum() or c in '-_' else '_' for c in call_id + )[:120] or 'call' rel_dir = Path(tool_name) / safe_id out_dir = self._artifact_root / rel_dir out_dir.mkdir(parents=True, exist_ok=True) @@ -59,14 +60,21 @@ def pack_text_result( fpath = out_dir / fname fpath.write_text(body, encoding='utf-8', errors='replace') rel = fpath.relative_to(self._root).as_posix() - preview = _make_preview(body, self.preview_head_chars, self.preview_tail_chars) + preview = _make_preview(body, self.preview_head_chars, + self.preview_tail_chars) result = { - 'output': stdout[: self.preview_head_chars] if len(stdout) > self.preview_head_chars else stdout, - 'error': (stderr[: self.preview_head_chars] if stderr else None), - 'truncated': True, - 'artifact_path': rel, - 'preview': preview, - 'artifact_bytes': len(enc), + 'output': stdout[:self.preview_head_chars] + if len(stdout) > self.preview_head_chars else stdout, + 'error': + (stderr[:self.preview_head_chars] if stderr else None), + 'truncated': + True, + 'artifact_path': + rel, + 'preview': + preview, + 'artifact_bytes': + len(enc), } if extra: result.update(extra) @@ -87,15 +95,23 @@ def pack_json_shell_result( call_id=call_id, stdout=stdout, stderr=stderr, - extra={k: v for k, v in payload.items() if k not in ('output', 'error')}, + extra={ + k: v + for k, v in payload.items() if k not in ('output', 'error') + }, ) # pack_text_result merged extra into top level; rebuild standard shell shape out = { - 'success': payload.get('success'), - 'output': packed.get('output'), - 'error': packed.get('error'), - 'return_code': payload.get('return_code'), - 'truncated': packed.get('truncated', False), + 'success': + payload.get('success'), + 'output': + packed.get('output'), + 'error': + packed.get('error'), + 'return_code': + payload.get('return_code'), + 'truncated': + packed.get('truncated', False), } if packed.get('artifact_path'): out['artifact_path'] = packed['artifact_path'] @@ -107,4 +123,4 @@ def pack_json_shell_result( def _make_preview(text: str, head: int, tail: int) -> str: if len(text) <= head + tail: return text - return text[:head] + '\n... [truncated] ...\n' + text[-tail:] + return (text[:head] + '\n... [truncated] ...\n' + text[-tail:]) diff --git a/ms_agent/utils/constants.py b/ms_agent/utils/constants.py index 9929a26e4..e068e654e 100644 --- a/ms_agent/utils/constants.py +++ b/ms_agent/utils/constants.py @@ -61,30 +61,36 @@ class ServiceConfig: @dataclass class ModelscopeConfig(ServiceConfig): + def __init__(self): super().__init__(base_url='https://api-inference.modelscope.cn/v1') @dataclass class DashscopeConfig(ServiceConfig): + def __init__(self): - super().__init__(base_url='https://dashscope.aliyuncs.com/compatible-mode/v1') + super().__init__( + base_url='https://dashscope.aliyuncs.com/compatible-mode/v1') @dataclass class DeepseekConfig(ServiceConfig): + def __init__(self): super().__init__(base_url='https://api.deepseek.com/v1') @dataclass class AnthropicConfig(ServiceConfig): + def __init__(self): # without /v1, using Anthropic API super().__init__(base_url='https://api.anthropic.com') class OpenaiConfig(ServiceConfig): + def __init__(self): super().__init__(base_url='https://api.openai.com/v1') diff --git a/ms_agent/utils/llm_utils.py b/ms_agent/utils/llm_utils.py index c9fc50036..4f800bd45 100644 --- a/ms_agent/utils/llm_utils.py +++ b/ms_agent/utils/llm_utils.py @@ -11,15 +11,15 @@ T = TypeVar('T') -def retry( - max_attempts: int = 3, - delay: float = 1.0, - backoff_factor: float = 2.0, - exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception, -): +def retry(max_attempts: int = 3, + delay: float = 1.0, + backoff_factor: float = 2.0, + exceptions: Union[Type[Exception], Tuple[Type[Exception], + ...]] = Exception): """Retry doing something""" def decorator(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) def wrapper(*args, **kwargs) -> T: current_delay = delay @@ -30,7 +30,6 @@ def wrapper(*args, **kwargs) -> T: return func(*args, **kwargs) except exceptions as e: import traceback - logger.warning(traceback.format_exc()) last_exception = e if attempt < max_attempts: @@ -43,8 +42,7 @@ def wrapper(*args, **kwargs) -> T: else: logger.error( f'Attempt to call {func.__name__} over {max_attempts} times. ' - f'The last exception message: {e}' - ) + f'The last exception message: {e}') raise last_exception return wrapper @@ -52,15 +50,15 @@ def wrapper(*args, **kwargs) -> T: return decorator -def async_retry( - max_attempts: int = 3, - delay: float = 1.0, - backoff_factor: float = 2.0, - exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception, -): +def async_retry(max_attempts: int = 3, + delay: float = 1.0, + backoff_factor: float = 2.0, + exceptions: Union[Type[Exception], Tuple[Type[Exception], + ...]] = Exception): """Retry doing something""" def decorator(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) async def wrapper(*args, **kwargs) -> AsyncGenerator[T, Any]: current_delay = delay @@ -73,7 +71,6 @@ async def wrapper(*args, **kwargs) -> AsyncGenerator[T, Any]: return except exceptions as e: import traceback - logger.warning(traceback.format_exc()) last_exception = e if attempt < max_attempts: @@ -86,8 +83,7 @@ async def wrapper(*args, **kwargs) -> AsyncGenerator[T, Any]: else: logger.error( f'Attempt to call {func.__name__} over {max_attempts} times. ' - f'The last exception message: {e}' - ) + f'The last exception message: {e}') raise last_exception return wrapper diff --git a/ms_agent/utils/logger.py b/ms_agent/utils/logger.py index b005d1ab7..0ac9d48a8 100644 --- a/ms_agent/utils/logger.py +++ b/ms_agent/utils/logger.py @@ -30,8 +30,10 @@ def warning_once(self, msg, *args, **kwargs): self.warning(msg) -def get_logger(log_file: Optional[str] = None, log_level: Optional[int] = None, file_mode: str = 'w'): - """Get logging logger +def get_logger(log_file: Optional[str] = None, + log_level: Optional[int] = None, + file_mode: str = 'w'): + """ Get logging logger Args: log_file: Log filename, if specified, file handler will be added to diff --git a/ms_agent/utils/parser_utils.py b/ms_agent/utils/parser_utils.py index ae2f6ba46..5455dae52 100644 --- a/ms_agent/utils/parser_utils.py +++ b/ms_agent/utils/parser_utils.py @@ -1,15 +1,15 @@ -import json import os import re from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Dict, List, Optional +import json + @dataclass class ImportInfo: """Detailed information about an import statement""" - # Source file path (resolved path) source_file: str # Original import statement @@ -24,7 +24,8 @@ class ImportInfo: is_type_only: bool = False def __repr__(self): - items_str = ', '.join(self.imported_items) if self.imported_items else 'all' + items_str = ', '.join( + self.imported_items) if self.imported_items else 'all' alias_str = f' as {self.alias}' if self.alias else '' return f"ImportInfo(file='{self.source_file}', items=[{items_str}]{alias_str})" @@ -63,7 +64,8 @@ def parse(self, code_content: str) -> List[ImportInfo]: # Pattern 1: from ... import ... from_pattern = r'^\s*from\s+([\w.]+)\s+import\s+(?:\(([^)]+)\)|([^\n]+))' - for match in re.finditer(from_pattern, code_content, re.MULTILINE | re.DOTALL): + for match in re.finditer(from_pattern, code_content, + re.MULTILINE | re.DOTALL): info = self._extract_from_import(match, code_content) if info: imports.append(info) @@ -76,7 +78,8 @@ def parse(self, code_content: str) -> List[ImportInfo]: return imports - def _extract_from_import(self, match, code_content) -> Optional[ImportInfo]: + def _extract_from_import(self, match, + code_content) -> Optional[ImportInfo]: """Extract 'from ... import ...' statement""" module_path = match.group(1) # Group 2 is parenthesized multi-line imports, group 3 is single-line imports @@ -87,7 +90,7 @@ def _extract_from_import(self, match, code_content) -> Optional[ImportInfo]: cleaned_items = [] for line in lines: if '#' in line: - line = line[: line.index('#')] + line = line[:line.index('#')] cleaned_items.append(line.strip()) imports_str = ','.join(cleaned_items) @@ -115,8 +118,7 @@ def _extract_from_import(self, match, code_content) -> Optional[ImportInfo]: source_file=file_path, raw_statement=match.group(0), imported_items=imported_items, - import_type='namespace' if '*' in imported_items else 'named', - ) + import_type='namespace' if '*' in imported_items else 'named') def _extract_simple_import(self, match) -> List[ImportInfo]: """Extract 'import ...' statement""" @@ -145,9 +147,7 @@ def _extract_simple_import(self, match) -> List[ImportInfo]: raw_statement=f'import {module}', imported_items=[module.split('.')[-1]], import_type='default', - alias=alias, - ) - ) + alias=alias)) return results @@ -206,7 +206,8 @@ def safe_relpath(path): target_dir = os.path.join(target_dir, module_file_path) # Try as package - package_init = os.path.normpath(os.path.join(target_dir, '__init__.py')) + package_init = os.path.normpath( + os.path.join(target_dir, '__init__.py')) if os.path.exists(package_init): return safe_relpath(package_init) @@ -226,22 +227,26 @@ def safe_relpath(path): module_file_path = module_path.replace('.', os.sep) # Try as package (relative to current file) - package_init = os.path.normpath(os.path.join(self.current_dir, module_file_path, '__init__.py')) + package_init = os.path.normpath( + os.path.join(self.current_dir, module_file_path, '__init__.py')) if os.path.exists(package_init): return safe_relpath(package_init) # Try as module (relative to current file) - module_file = os.path.normpath(os.path.join(self.current_dir, module_file_path + '.py')) + module_file = os.path.normpath( + os.path.join(self.current_dir, module_file_path + '.py')) if os.path.exists(module_file): return safe_relpath(module_file) # Try from output_dir (absolute import) if self.output_dir: - package_init_abs = os.path.normpath(os.path.join(self.output_dir, module_file_path, '__init__.py')) + package_init_abs = os.path.normpath( + os.path.join(self.output_dir, module_file_path, '__init__.py')) if os.path.exists(package_init_abs): return os.path.join(module_file_path, '__init__.py') - module_file_abs = os.path.normpath(os.path.join(self.output_dir, module_file_path + '.py')) + module_file_abs = os.path.normpath( + os.path.join(self.output_dir, module_file_path + '.py')) if os.path.exists(module_file_abs): return module_file_path + '.py' @@ -264,14 +269,16 @@ def parse(self, code_content: str) -> List[ImportInfo]: # Pattern 1: Mixed import - import Default, { Named } from 'path' # Must come BEFORE Pattern 2 and 3 to avoid partial matches mixed_pattern = r"^\s*import\s+(type\s+)?(\w+)\s*,\s*\{([^}]+)\}\s*from\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(mixed_pattern, code_content, re.MULTILINE | re.DOTALL): + for match in re.finditer(mixed_pattern, code_content, + re.MULTILINE | re.DOTALL): infos = self._extract_mixed_import(match) if infos: imports.extend(infos) # Pattern 2: Named import - import { A, B } from 'path' (supports multiline) named_pattern = r"^\s*import\s+(type\s+)?\{([^}]+)\}\s*from\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(named_pattern, code_content, re.MULTILINE | re.DOTALL): + for match in re.finditer(named_pattern, code_content, + re.MULTILINE | re.DOTALL): info = self._extract_named_import(match) if info: imports.append(info) @@ -285,35 +292,40 @@ def parse(self, code_content: str) -> List[ImportInfo]: # Pattern 4: Namespace import - import * as name from 'path' namespace_pattern = r"^\s*import\s+(type\s+)?\*\s+as\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(namespace_pattern, code_content, re.MULTILINE): + for match in re.finditer(namespace_pattern, code_content, + re.MULTILINE): info = self._extract_namespace_import(match) if info: imports.append(info) # Pattern 5: Side-effect import - import 'path' side_effect_pattern = r"^\s*import\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(side_effect_pattern, code_content, re.MULTILINE): + for match in re.finditer(side_effect_pattern, code_content, + re.MULTILINE): info = self._extract_side_effect_import(match) if info: imports.append(info) # Pattern 6: Named re-export - export { A, B } from 'path' (supports multiline) export_named_pattern = r"^\s*export\s+(type\s+)?\{([^}]+)\}\s+from\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(export_named_pattern, code_content, re.MULTILINE | re.DOTALL): + for match in re.finditer(export_named_pattern, code_content, + re.MULTILINE | re.DOTALL): info = self._extract_export_named(match) if info: imports.append(info) # Pattern 7: Wildcard re-export - export * from 'path' export_wildcard_pattern = r"^\s*export\s+(type\s+)?\*\s+from\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(export_wildcard_pattern, code_content, re.MULTILINE): + for match in re.finditer(export_wildcard_pattern, code_content, + re.MULTILINE): info = self._extract_export_wildcard(match) if info: imports.append(info) # Pattern 8: Named wildcard re-export - export * as name from 'path' export_named_wildcard_pattern = r"^\s*export\s+(type\s+)?\*\s+as\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"]" - for match in re.finditer(export_named_wildcard_pattern, code_content, re.MULTILINE): + for match in re.finditer(export_named_wildcard_pattern, code_content, + re.MULTILINE): info = self._extract_export_named_wildcard(match) if info: imports.append(info) @@ -359,9 +371,7 @@ def _extract_mixed_import(self, match) -> List[ImportInfo]: raw_statement=match.group(0), imported_items=[default_name], import_type='default', - is_type_only=is_type, - ) - ) + is_type_only=is_type)) # Create named import info results.append( @@ -370,9 +380,7 @@ def _extract_mixed_import(self, match) -> List[ImportInfo]: raw_statement=match.group(0), imported_items=named_items, import_type='named', - is_type_only=is_type, - ) - ) + is_type_only=is_type)) return results @@ -405,8 +413,7 @@ def _extract_named_import(self, match) -> Optional[ImportInfo]: raw_statement=match.group(0), imported_items=items, import_type='named', - is_type_only=is_type, - ) + is_type_only=is_type) def _extract_default_import(self, match) -> Optional[ImportInfo]: """Extract: import React from 'path'""" @@ -424,8 +431,7 @@ def _extract_default_import(self, match) -> Optional[ImportInfo]: raw_statement=match.group(0), imported_items=[name], import_type='default', - is_type_only=is_type, - ) + is_type_only=is_type) def _extract_namespace_import(self, match) -> Optional[ImportInfo]: """Extract: import * as name from 'path'""" @@ -444,8 +450,7 @@ def _extract_namespace_import(self, match) -> Optional[ImportInfo]: imported_items=['*'], import_type='namespace', alias=name, - is_type_only=is_type, - ) + is_type_only=is_type) def _extract_side_effect_import(self, match) -> Optional[ImportInfo]: """Extract: import 'path'""" @@ -456,8 +461,10 @@ def _extract_side_effect_import(self, match) -> Optional[ImportInfo]: resolved_path = import_path return ImportInfo( - source_file=resolved_path, raw_statement=match.group(0), imported_items=[], import_type='side-effect' - ) + source_file=resolved_path, + raw_statement=match.group(0), + imported_items=[], + import_type='side-effect') def _extract_export_named(self, match) -> Optional[ImportInfo]: """Extract: export { A, B } from 'path'""" @@ -488,8 +495,7 @@ def _extract_export_named(self, match) -> Optional[ImportInfo]: raw_statement=match.group(0), imported_items=items, import_type='named', - is_type_only=is_type, - ) + is_type_only=is_type) def _extract_export_wildcard(self, match) -> Optional[ImportInfo]: """Extract: export * from 'path'""" @@ -506,8 +512,7 @@ def _extract_export_wildcard(self, match) -> Optional[ImportInfo]: raw_statement=match.group(0), imported_items=['*'], import_type='namespace', - is_type_only=is_type, - ) + is_type_only=is_type) def _extract_export_named_wildcard(self, match) -> Optional[ImportInfo]: """Extract: export * as name from 'path'""" @@ -526,8 +531,7 @@ def _extract_export_named_wildcard(self, match) -> Optional[ImportInfo]: imported_items=['*'], import_type='namespace', alias=name, - is_type_only=is_type, - ) + is_type_only=is_type) def _resolve_js_path(self, import_path: str) -> Optional[str]: """Resolve JavaScript/TypeScript import path to file @@ -536,7 +540,8 @@ def _resolve_js_path(self, import_path: str) -> Optional[str]: Returns None for external packages. """ # Check if it's an external package (doesn't start with . or /) - is_external = not import_path.startswith('.') and not import_path.startswith('/') + is_external = not import_path.startswith( + '.') and not import_path.startswith('/') # External packages return None early if is_external: @@ -581,11 +586,14 @@ def to_relative(path): abs_resolved = os.path.join(self.output_dir, resolved) else: # Both are relative, make absolute from current working directory - abs_resolved = os.path.abspath(os.path.join(self.output_dir, resolved)) + abs_resolved = os.path.abspath( + os.path.join(self.output_dir, resolved)) # Try as directory with index file first if os.path.isdir(abs_resolved): - for index_file in ['index.ts', 'index.tsx', 'index.js', 'index.jsx']: + for index_file in [ + 'index.ts', 'index.tsx', 'index.js', 'index.jsx' + ]: index_path = os.path.join(abs_resolved, index_file) if os.path.exists(index_path): # Return relative path with index file @@ -595,19 +603,8 @@ def to_relative(path): # Try different extensions extensions = [ - '.ts', - '.tsx', - '.js', - '.jsx', - '.mjs', - '.cjs', - '.json', - '.css', - '.scss', - '.sass', - '.less', - '.module.css', - '.module.scss', + '.ts', '.tsx', '.js', '.jsx', '.mjs', '.cjs', '.json', '.css', + '.scss', '.sass', '.less', '.module.css', '.module.scss' ] for ext in extensions: @@ -651,7 +648,9 @@ def to_relative(path): def _load_path_aliases(self) -> Dict[str, str]: """Load path aliases from tsconfig.json and vite.config""" aliases = {} - excluded_dirs = {'node_modules', 'dist', 'build', '.git', '__pycache__'} + excluded_dirs = { + 'node_modules', 'dist', 'build', '.git', '__pycache__' + } # Search for config files for root, dirs, files in os.walk(self.output_dir): @@ -659,12 +658,16 @@ def _load_path_aliases(self) -> Dict[str, str]: # tsconfig.json if 'tsconfig.json' in files: - self._parse_tsconfig_aliases(os.path.join(root, 'tsconfig.json'), root, aliases) + self._parse_tsconfig_aliases( + os.path.join(root, 'tsconfig.json'), root, aliases) # vite.config.* - for config_file in ['vite.config.js', 'vite.config.ts', 'vite.config.mjs']: + for config_file in [ + 'vite.config.js', 'vite.config.ts', 'vite.config.mjs' + ]: if config_file in files: - self._parse_vite_config_aliases(os.path.join(root, config_file), root, aliases) + self._parse_vite_config_aliases( + os.path.join(root, config_file), root, aliases) # Default aliases if not aliases: @@ -677,36 +680,41 @@ def _load_path_aliases(self) -> Dict[str, str]: return aliases - def _parse_tsconfig_aliases(self, tsconfig_path: str, base_dir: str, aliases: Dict[str, str]): + def _parse_tsconfig_aliases(self, tsconfig_path: str, base_dir: str, + aliases: Dict[str, str]): """Parse tsconfig.json and extract path aliases""" try: with open(tsconfig_path, 'r', encoding='utf-8') as f: content = f.read() # Remove comments - content = re.sub(r'//.*?\n|/\*.*?\*/', '', content, flags=re.DOTALL) + content = re.sub( + r'//.*?\n|/\*.*?\*/', '', content, flags=re.DOTALL) tsconfig = json.loads(content) - if 'compilerOptions' in tsconfig and 'paths' in tsconfig['compilerOptions']: + if 'compilerOptions' in tsconfig and 'paths' in tsconfig[ + 'compilerOptions']: base_url = tsconfig['compilerOptions'].get('baseUrl', '.') - for alias, paths in tsconfig['compilerOptions']['paths'].items(): + for alias, paths in tsconfig['compilerOptions'][ + 'paths'].items(): clean_alias = alias.rstrip('/*') if paths and len(paths) > 0: target = paths[0].rstrip('/*') - resolved_target = os.path.normpath(os.path.join(base_dir, base_url, target)) + resolved_target = os.path.normpath( + os.path.join(base_dir, base_url, target)) if clean_alias not in aliases: aliases[clean_alias] = resolved_target except (json.JSONDecodeError, IOError, KeyError): pass - def _parse_vite_config_aliases(self, config_path: str, base_dir: str, aliases: Dict[str, str]): + def _parse_vite_config_aliases(self, config_path: str, base_dir: str, + aliases: Dict[str, str]): """Parse vite.config and extract path aliases""" try: with open(config_path, 'r', encoding='utf-8') as f: content = f.read() alias_pattern = ( r"['\"]([^'\"]+)['\"]\s*:\s*(?:path\.resolve\([^,]+,\s*['\"]" - r"([^'\"]+)['\"]\)|['\"]([^'\"]+)['\"])" - ) + r"([^'\"]+)['\"]\)|['\"]([^'\"]+)['\"])") for match in re.finditer(alias_pattern, content): alias_key = match.group(1) target = match.group(2) or match.group(3) @@ -724,7 +732,7 @@ def _resolve_alias_path(self, import_path: str) -> Optional[str]: if import_path == alias: return target elif import_path.startswith(alias + '/'): - remainder = import_path[len(alias) + 1 :] + remainder = import_path[len(alias) + 1:] return os.path.join(target, remainder) return None @@ -766,8 +774,10 @@ def _extract_java_import(self, match) -> Optional[ImportInfo]: items = [import_path.split('.')[-1]] return ImportInfo( - source_file=file_path, raw_statement=match.group(0), imported_items=items, import_type=import_type - ) + source_file=file_path, + raw_statement=match.group(0), + imported_items=items, + import_type=import_type) def _resolve_java_path(self, import_path: str) -> Optional[str]: """Resolve Java import to file path""" @@ -789,7 +799,8 @@ class ImportParserFactory: """Factory to get appropriate parser for file type""" @staticmethod - def get_parser(file_ext: str, output_dir: str, current_file: str, current_dir: str) -> Optional[BaseImportParser]: + def get_parser(file_ext: str, output_dir: str, current_file: str, + current_dir: str) -> Optional[BaseImportParser]: """Get parser instance for given file extension""" parsers = [ PythonImportParser, @@ -805,7 +816,8 @@ def get_parser(file_ext: str, output_dir: str, current_file: str, current_dir: s return None -def parse_imports(current_file: str, code_content: str, output_dir: str) -> List[ImportInfo]: +def parse_imports(current_file: str, code_content: str, + output_dir: str) -> List[ImportInfo]: """ Parse imports from code content (main entry point for backward compatibility) @@ -821,11 +833,13 @@ def parse_imports(current_file: str, code_content: str, output_dir: str) -> List List of ImportInfo objects for project files only (external packages are excluded) """ # Detect file extension - file_ext = os.path.splitext(current_file)[1].lstrip('.').lower() if current_file else '' + file_ext = os.path.splitext(current_file)[1].lstrip( + '.').lower() if current_file else '' current_dir = os.path.dirname(current_file) if current_file else '.' # Get appropriate parser - parser = ImportParserFactory.get_parser(file_ext, output_dir, current_file, current_dir) + parser = ImportParserFactory.get_parser(file_ext, output_dir, current_file, + current_dir) if not parser: return [] @@ -858,41 +872,28 @@ def parse_imports(current_file: str, code_content: str, output_dir: str) -> List # They don't start with '.', '/', or contain path separators (except scoped packages) # Check if it's a scoped package (starts with @ but file doesn't exist) - is_scoped_package = source.startswith('@') and not os.path.exists(os.path.join(output_dir, source)) + is_scoped_package = source.startswith('@') and not os.path.exists( + os.path.join(output_dir, source)) # Check if it's a project file (exists in output_dir) - full_path = os.path.join(output_dir, source) if not os.path.isabs(source) else source + full_path = os.path.join( + output_dir, source) if not os.path.isabs(source) else source is_project_file = os.path.exists(full_path) # Check if source has common code file extension # This helps identify resolved file paths vs package names - common_extensions = ( - '.js', - '.jsx', - '.ts', - '.tsx', - '.mjs', - '.cjs', - '.java', - '.py', - '.pyw', - '.css', - '.scss', - '.json', - ) + common_extensions = ('.js', '.jsx', '.ts', '.tsx', '.mjs', '.cjs', + '.java', '.py', '.pyw', '.css', '.scss', '.json') has_code_extension = source.endswith(common_extensions) # Check if it's an external package (package name without path separators) # For Java: java.util.List has dots but no file extension, so it's external # For JS: utils.js has extension, so it's a file - is_external = is_scoped_package or ( - not is_project_file - and not has_code_extension - and not source.startswith('.') - and not source.startswith('/') - and '/' not in source - and os.sep not in source - ) + is_external = ( + is_scoped_package + or (not is_project_file and not has_code_extension + and not source.startswith('.') and not source.startswith('/') + and '/' not in source and os.sep not in source)) if not is_external: project_imports.append(imp) diff --git a/ms_agent/utils/patcher.py b/ms_agent/utils/patcher.py index b0602d435..4d721f266 100644 --- a/ms_agent/utils/patcher.py +++ b/ms_agent/utils/patcher.py @@ -4,7 +4,8 @@ T = TypeVar('T') -def patch(target_object: Any, attribute_name: str, patch_value: Any) -> Callable[[Callable[..., T]], Callable[..., T]]: +def patch(target_object: Any, attribute_name: str, + patch_value: Any) -> Callable[[Callable[..., T]], Callable[..., T]]: """ A decorator factory that patches an attribute of an object for the duration of a function's execution. @@ -29,7 +30,9 @@ def wrapper(*args: Any, **kwargs: Any) -> T: """ # Check if the target attribute exists if not hasattr(target_object, attribute_name): - raise AttributeError(f'{target_object} does not have attribute {attribute_name}') + raise AttributeError( + f'{target_object} does not have attribute {attribute_name}' + ) # 1. Save the original value (similar to __enter__) original_value = getattr(target_object, attribute_name) diff --git a/ms_agent/utils/push_to_hub.py b/ms_agent/utils/push_to_hub.py index 3df0f9e56..a164603dd 100644 --- a/ms_agent/utils/push_to_hub.py +++ b/ms_agent/utils/push_to_hub.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 -import json import mimetypes import os import re @@ -9,10 +8,11 @@ from pathlib import Path from typing import List, Optional, Tuple +import json import requests - from ms_agent.utils.logger import get_logger -from ms_agent.utils.utils import get_files_from_dir, is_package_installed, text_hash +from ms_agent.utils.utils import (get_files_from_dir, is_package_installed, + text_hash) logger = get_logger() @@ -22,7 +22,8 @@ class PushToHub(ABC): The abstract base class for pushing files to a remote hub (e.g., GitHub). """ - def __init__(self, *args, **kwargs): ... + def __init__(self, *args, **kwargs): + ... @abstractmethod def push(self, *args, **kwargs): @@ -31,16 +32,15 @@ def push(self, *args, **kwargs): class PushToGitHub(PushToHub): + GITHUB_API_URL = 'https://api.github.com' - def __init__( - self, - user_name: str, - repo_name: str, - token: str, - visibility: Optional[str] = 'public', - description: Optional[str] = None, - ): + def __init__(self, + user_name: str, + repo_name: str, + token: str, + visibility: Optional[str] = 'public', + description: Optional[str] = None): """ Initialize the `PushToGitHub` class with authentication. @@ -72,7 +72,9 @@ def __init__( super().__init__() if not all([user_name, repo_name, token]): - raise ValueError('GitHub username, repository name, and token must be provided.') + raise ValueError( + 'GitHub username, repository name, and token must be provided.' + ) self.user_name = user_name self.repo_name = repo_name @@ -82,12 +84,10 @@ def __init__( # Create a session and set authentication headers self.session = requests.Session() - self.session.headers.update( - { - 'Authorization': f'token {self.token}', - 'Accept': 'application/vnd.github.v3+json', - } - ) + self.session.headers.update({ + 'Authorization': f'token {self.token}', + 'Accept': 'application/vnd.github.v3+json', + }) self._check_auth() self._create_github_repo( @@ -103,8 +103,11 @@ def _check_auth(self): RuntimeError: If authentication fails. """ user_data_resp = self.session.get(f'{self.GITHUB_API_URL}/user') - if user_data_resp.status_code != 200 or user_data_resp.json()['login'] != self.user_name: - raise RuntimeError('Authentication failed! Please check your username and Personal Access Token.') + if user_data_resp.status_code != 200 or user_data_resp.json( + )['login'] != self.user_name: + raise RuntimeError( + 'Authentication failed! Please check your username and Personal Access Token.' + ) def _create_github_repo( self, @@ -129,37 +132,49 @@ def _create_github_repo( raise ValueError('Repository name cannot be empty.') if visibility not in ['public', 'private']: - raise ValueError("Visibility must be either 'public' or 'private'.") + raise ValueError( + "Visibility must be either 'public' or 'private'.") if description is None: description = f'Repository - `{repo_name}` created by MS-Agent.' # Create the first commit with README url = f'{self.GITHUB_API_URL}/user/repos' - payload = {'name': repo_name, 'description': description, 'private': visibility == 'private', 'auto_init': True} + payload = { + 'name': repo_name, + 'description': description, + 'private': visibility == 'private', + 'auto_init': True + } response = self.session.post(url, json=payload) if response.status_code == 201: - logger.info(f"Successfully created and initialized repository: {response.json()['html_url']}") + logger.info( + f"Successfully created and initialized repository: {response.json()['html_url']}" + ) return response.json() elif response.status_code == 422: - error_message = response.json().get('errors', [{}])[0].get('message', '') + error_message = response.json().get('errors', + [{}])[0].get('message', '') if 'name already exists' in error_message: - logger.info(f"Repository '{repo_name}' already exists. Will attempt to upload files to it.") + logger.info( + f"Repository '{repo_name}' already exists. Will attempt to upload files to it." + ) return None else: - raise ValueError(f'Validation error (422) while creating repository: {response.json()}') + raise ValueError( + f'Validation error (422) while creating repository: {response.json()}' + ) else: logger.error(response.json()) - raise RuntimeError(f'Failed to create repository: {response.status_code}') - - def _upload_files( - self, - files_to_upload: List[Path], - work_dir: Path, - path_in_repo: Optional[str] = None, - branch: Optional[str] = 'main', - commit_message: Optional[str] = None, - ) -> None: + raise RuntimeError( + f'Failed to create repository: {response.status_code}') + + def _upload_files(self, + files_to_upload: List[Path], + work_dir: Path, + path_in_repo: Optional[str] = None, + branch: Optional[str] = 'main', + commit_message: Optional[str] = None) -> None: """ Upload multiple files to a GitHub repository in a single commit. @@ -187,7 +202,8 @@ def _upload_files( commit_response.raise_for_status() base_tree_sha = commit_response.json()['tree']['sha'] - logger.info(f"Found '{branch}' branch, latest commit: {latest_commit_sha[:7]}") + logger.info( + f"Found '{branch}' branch, latest commit: {latest_commit_sha[:7]}") # 2. Create a blob for each file blobs = [] @@ -195,10 +211,13 @@ def _upload_files( repo_base_path = Path(path_in_repo or '') for full_path in files_to_upload: - file_relative_path: str = str(full_path.relative_to(work_dir)).replace('\\', '/') + + file_relative_path: str = str( + full_path.relative_to(work_dir)).replace('\\', '/') mime_type, _ = mimetypes.guess_type(full_path) - is_binary = not (mime_type and mime_type.startswith('text/')) if mime_type else False + is_binary = not (mime_type and mime_type.startswith('text/') + ) if mime_type else False with open(full_path, 'rb') as f: content_bytes = f.read() @@ -217,14 +236,22 @@ def _upload_files( blob_url = f'{self.GITHUB_API_URL}/repos/{self.user_name}/{self.repo_name}/git/blobs' blob_payload = {'content': content, 'encoding': encoding} - response = self.session.post(blob_url, data=json.dumps(blob_payload)) + response = self.session.post( + blob_url, data=json.dumps(blob_payload)) response.raise_for_status() remote_path = repo_base_path / file_relative_path remote_path_str = str(remote_path).replace('\\', '/') - blobs.append({'path': remote_path_str, 'mode': '100644', 'type': 'blob', 'sha': response.json()['sha']}) - logger.info(f" - Local: '{str(full_path)}' -> Remote: '{remote_path_str}'") + blobs.append({ + 'path': remote_path_str, + 'mode': '100644', + 'type': 'blob', + 'sha': response.json()['sha'] + }) + logger.info( + f" - Local: '{str(full_path)}' -> Remote: '{remote_path_str}'" + ) # 3. Create a tree object tree_url = f'{self.GITHUB_API_URL}/repos/{self.user_name}/{self.repo_name}/git/trees' @@ -237,11 +264,13 @@ def _upload_files( # 4. Create a commit commit_url = f'{self.GITHUB_API_URL}/repos/{self.user_name}/{self.repo_name}/git/commits' commit_payload = { - 'message': commit_message or f"Upload files to '{path_in_repo or '/'}'", + 'message': commit_message + or f"Upload files to '{path_in_repo or '/'}'", 'tree': tree_sha, - 'parents': [latest_commit_sha], + 'parents': [latest_commit_sha] } - response = self.session.post(commit_url, data=json.dumps(commit_payload)) + response = self.session.post( + commit_url, data=json.dumps(commit_payload)) response.raise_for_status() new_commit_sha = response.json()['sha'] logger.info(f'Commit created: {new_commit_sha[:7]}') @@ -253,15 +282,13 @@ def _upload_files( logger.info(f"Branch '{branch}' successfully points to the new commit") - def push( - self, - folder_path: str, - path_in_repo: Optional[str] = None, - branch: Optional[str] = 'main', - commit_message: Optional[str] = None, - exclude: Optional[List[str]] = None, - **kwargs, - ) -> None: + def push(self, + folder_path: str, + path_in_repo: Optional[str] = None, + branch: Optional[str] = 'main', + commit_message: Optional[str] = None, + exclude: Optional[List[str]] = None, + **kwargs) -> None: """ Push files from a local directory to the GitHub repository. @@ -322,9 +349,12 @@ def __init__( """ if not is_package_installed('modelscope'): - raise ImportError('ModelScope package is not installed. Please install it with `pip install modelscope`.') + raise ImportError( + 'ModelScope package is not installed. Please install it with `pip install modelscope`.' + ) - from modelscope.hub.api import HubApi, get_endpoint + from modelscope.hub.api import HubApi + from modelscope.hub.api import get_endpoint self.api = HubApi() self.token = token @@ -333,7 +363,10 @@ def __init__( super().__init__() @staticmethod - def _preprocess(folder_path: str, path_in_repo_url: str, add_powered_by: bool = True) -> 'Tuple[str, str]': + def _preprocess(folder_path: str, + path_in_repo_url: str, + add_powered_by: bool = True) -> 'Tuple[str, str]': + report_filename = 'report.md' file_path = os.path.join(folder_path, report_filename) file_path_hash: str = text_hash(text=file_path, keep_n_chars=8) @@ -343,7 +376,9 @@ def _preprocess(folder_path: str, path_in_repo_url: str, add_powered_by: bool = new_file_path = os.path.join(current_cache_path, new_report_filename) if not os.path.exists(file_path): - logger.warning(f'The report file: {file_path} does not exist. Skipping preprocessing.') + logger.warning( + f'The report file: {file_path} does not exist. Skipping preprocessing.' + ) return '', '' try: @@ -355,15 +390,12 @@ def _preprocess(folder_path: str, path_in_repo_url: str, add_powered_by: bool = try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() - if add_powered_by and not content.lstrip().startswith('""" + '\n\n' + content pattern = r'!\[(.*?)\]\((resources/.*?)\)' replacement = rf'![\1]({path_in_repo_url}\2)' @@ -372,26 +404,35 @@ def _preprocess(folder_path: str, path_in_repo_url: str, add_powered_by: bool = if count > 0: with open(file_path, 'w', encoding='utf-8') as f: f.write(new_content) - logger.info(f"Preprocessed {count} 'resources/' links in {file_path}.") + logger.info( + f"Preprocessed {count} 'resources/' links in {file_path}.") else: - logger.info(f'No "resources/" links found in {file_path}. No changes made.') + logger.info( + f'No "resources/" links found in {file_path}. No changes made.' + ) except IOError as e: logger.error(f'Error reading or writing the report file: {e}') return '', '' except Exception as e: - logger.error(f'Unexpected error during preprocessing of PushToModelScope: {e}') + logger.error( + f'Unexpected error during preprocessing of PushToModelScope: {e}' + ) return '', '' return file_path, new_file_path @staticmethod def _postprocess(report_file_path: str, report_file_path_in_cache: str): + try: shutil.move(report_file_path_in_cache, report_file_path) - shutil.rmtree(os.path.dirname(report_file_path_in_cache), ignore_errors=True) + shutil.rmtree( + os.path.dirname(report_file_path_in_cache), ignore_errors=True) except FileNotFoundError: - logger.warning(f'The backup file of report: {report_file_path_in_cache} does not exist.') + logger.warning( + f'The backup file of report: {report_file_path_in_cache} does not exist.' + ) def push( self, @@ -436,7 +477,8 @@ def push( revision='master', ) target_url: str = f'{self.endpoint}/{repo_type}s/{repo_id}/files' - logger.info(f'Successfully pushed files to ModelScope: {target_url}') + logger.info( + f'Successfully pushed files to ModelScope: {target_url}') except Exception as e: logger.error(f'Failed to push files to ModelScope: {e}') finally: @@ -448,6 +490,7 @@ def push( class PushToHuggingFace(PushToHub): + def __init__(self, token: str): """ Initialize the `PushToHuggingFace` with authentication. @@ -498,7 +541,8 @@ def push( raise ValueError('Repository ID cannot be empty.') if repo_type not in ['model', 'dataset']: - raise ValueError("Repository type must be either 'model' or 'dataset'.") + raise ValueError( + "Repository type must be either 'model' or 'dataset'.") try: self.api.upload_folder( @@ -518,5 +562,6 @@ def push( f'https://huggingface.co/{repo_type_in_url}{repo_id}/tree/main/{path_in_repo or ""}' ) except Exception as e: - logger.error(f'Failed to push files to {repo_id} on HuggingFace: {e}') + logger.error( + f'Failed to push files to {repo_id} on HuggingFace: {e}') raise e diff --git a/ms_agent/utils/rate_limiter.py b/ms_agent/utils/rate_limiter.py index 5273dbf75..3bb94ec52 100644 --- a/ms_agent/utils/rate_limiter.py +++ b/ms_agent/utils/rate_limiter.py @@ -73,7 +73,8 @@ async def _wait_if_needed(self): elapsed = now - self._last_request_time if elapsed < self.min_request_interval: wait_time = self.min_request_interval - elapsed - logger.debug(f'Enforcing min interval: waiting {wait_time:.3f}s') + logger.debug( + f'Enforcing min interval: waiting {wait_time:.3f}s') await asyncio.sleep(wait_time) now = time.time() @@ -87,14 +88,18 @@ async def _wait_if_needed(self): # If rate limit reached, wait until oldest request expires if len(self._request_times) >= self.max_requests_per_second: oldest_request = self._request_times[0] - wait_time = 1.0 - (now - oldest_request) + 0.01 # Add 10ms margin + wait_time = 1.0 - (now + - oldest_request) + 0.01 # Add 10ms margin if wait_time > 0: - logger.debug(f'Rate limit reached ({self.max_requests_per_second} req/s): waiting {wait_time:.3f}s') + logger.debug( + f'Rate limit reached ({self.max_requests_per_second} req/s): ' + f'waiting {wait_time:.3f}s') await asyncio.sleep(wait_time) now = time.time() # Clean up expired records cutoff_time = now - 1.0 - while self._request_times and self._request_times[0] < cutoff_time: + while self._request_times and self._request_times[ + 0] < cutoff_time: self._request_times.popleft() # Record this request time @@ -135,15 +140,23 @@ def get_stats(self) -> dict: with self._lock: now = time.time() cutoff_time = now - 1.0 - recent_requests = sum(1 for t in self._request_times if t >= cutoff_time) + recent_requests = sum(1 for t in self._request_times + if t >= cutoff_time) return { - 'max_requests_per_second': self.max_requests_per_second, - 'min_request_interval': self.min_request_interval, - 'max_concurrent': self.max_concurrent, - 'recent_requests_count': recent_requests, - 'available_concurrent_slots': self._semaphore._value, - 'last_request_ago': now - self._last_request_time if self._last_request_time > 0 else None, + 'max_requests_per_second': + self.max_requests_per_second, + 'min_request_interval': + self.min_request_interval, + 'max_concurrent': + self.max_concurrent, + 'recent_requests_count': + recent_requests, + 'available_concurrent_slots': + self._semaphore._value, + 'last_request_ago': + now - self._last_request_time + if self._last_request_time > 0 else None, } def reset(self): @@ -202,8 +215,7 @@ def __init__( logger.info( f'AdaptiveRateLimiter initialized: {initial_requests_per_second} req/s ' - f'(range: {min_requests_per_second}-{max_requests_per_second})' - ) + f'(range: {min_requests_per_second}-{max_requests_per_second})') def record_success(self): """Record successful request""" @@ -215,13 +227,13 @@ def record_success(self): # Consecutive successes reached threshold, attempt to increase rate if self._consecutive_successes >= self._success_threshold: old_rps = self.max_requests_per_second - new_rps = min(round(old_rps * self._recovery_factor), self._max_rps) + new_rps = min( + round(old_rps * self._recovery_factor), self._max_rps) if new_rps > old_rps: self.max_requests_per_second = new_rps logger.info( f'Rate limit increased: {old_rps} → {new_rps} req/s ' - f'(after {self._consecutive_successes} successes)' - ) + f'(after {self._consecutive_successes} successes)') self._consecutive_successes = 0 def record_error(self, is_rate_limit_error: bool = False): @@ -240,41 +252,46 @@ def record_error(self, is_rate_limit_error: bool = False): # If rate limit error, immediately reduce rate if is_rate_limit_error: old_rps = self.max_requests_per_second - new_rps = max(int(old_rps * self._backoff_factor), self._min_rps) + new_rps = max( + int(old_rps * self._backoff_factor), self._min_rps) if new_rps < old_rps: self.max_requests_per_second = new_rps # Also increase minimum request interval - self.min_request_interval = min(self.min_request_interval * 1.5, 2.0) + self.min_request_interval = min( + self.min_request_interval * 1.5, 2.0) logger.warning( f'Rate limit error detected! Reducing rate: {old_rps} → {new_rps} req/s, ' - f'min_interval → {self.min_request_interval:.2f}s' - ) + f'min_interval → {self.min_request_interval:.2f}s') self._consecutive_errors = 0 # Consecutive errors reached threshold, reduce rate elif self._consecutive_errors >= self._error_threshold: old_rps = self.max_requests_per_second - new_rps = max(int(old_rps * self._backoff_factor), self._min_rps) + new_rps = max( + int(old_rps * self._backoff_factor), self._min_rps) if new_rps < old_rps: self.max_requests_per_second = new_rps logger.warning( f'Multiple errors detected! Reducing rate: {old_rps} → {new_rps} req/s ' - f'(after {self._consecutive_errors} errors)' - ) + f'(after {self._consecutive_errors} errors)') self._consecutive_errors = 0 def get_stats(self) -> dict: """Get extended statistics""" stats = super().get_stats() with self._lock: - stats.update( - { - 'total_requests': self._total_requests, - 'total_errors': self._total_errors, - 'error_rate': self._total_errors / max(self._total_requests, 1), - 'consecutive_successes': self._consecutive_successes, - 'consecutive_errors': self._consecutive_errors, - 'current_requests_per_second': self.max_requests_per_second, - } - ) + stats.update({ + 'total_requests': + self._total_requests, + 'total_errors': + self._total_errors, + 'error_rate': + self._total_errors / max(self._total_requests, 1), + 'consecutive_successes': + self._consecutive_successes, + 'consecutive_errors': + self._consecutive_errors, + 'current_requests_per_second': + self.max_requests_per_second, + }) return stats diff --git a/ms_agent/utils/snapshot.py b/ms_agent/utils/snapshot.py index 649acab26..2014b49da 100644 --- a/ms_agent/utils/snapshot.py +++ b/ms_agent/utils/snapshot.py @@ -8,9 +8,8 @@ All git commands are run with GIT_DIR and GIT_WORK_TREE explicitly set, so the snapshot repo is fully isolated from any surrounding repository. """ - -import json import os +import json import subprocess from typing import Optional @@ -22,7 +21,8 @@ _META_FILE = 'snapshot_meta.json' -def _git(args: list[str], work_tree: str, git_dir: str, check: bool = True) -> subprocess.CompletedProcess: +def _git(args: list[str], work_tree: str, git_dir: str, + check: bool = True) -> subprocess.CompletedProcess: env = os.environ.copy() env['GIT_DIR'] = git_dir env['GIT_WORK_TREE'] = work_tree @@ -51,7 +51,10 @@ def _configure_snapshot_repo_for_automation(work_tree: str, git_dir: str) -> Non Git-supported way to disable hooks (POSIX ``/dev/null``, Windows ``nul``). """ try: - _git(['config', 'core.hooksPath', os.devnull], work_tree=work_tree, git_dir=git_dir, check=False) + _git(['config', 'core.hooksPath', os.devnull], + work_tree=work_tree, + git_dir=git_dir, + check=False) except Exception: pass @@ -64,8 +67,10 @@ def _ensure_repo(output_dir: str) -> str: # Use non-bare init with explicit GIT_DIR — no --bare so work tree is supported. # Do NOT pass a path argument; GIT_DIR env var points git at our custom dir. _git(['init'], work_tree=output_dir, git_dir=git_dir) - _git(['config', 'user.email', 'ms-agent@snapshot'], work_tree=output_dir, git_dir=git_dir) - _git(['config', 'user.name', 'ms-agent'], work_tree=output_dir, git_dir=git_dir) + _git(['config', 'user.email', 'ms-agent@snapshot'], + work_tree=output_dir, git_dir=git_dir) + _git(['config', 'user.name', 'ms-agent'], + work_tree=output_dir, git_dir=git_dir) # Exclude the snapshot dir itself from tracking info_dir = os.path.join(git_dir, 'info') os.makedirs(info_dir, exist_ok=True) @@ -98,7 +103,8 @@ def _save_meta(output_dir: str, meta: dict) -> None: json.dump(meta, f, indent=2) -def take_snapshot(output_dir: str, message: str, message_count: int = 0) -> Optional[str]: +def take_snapshot(output_dir: str, message: str, + message_count: int = 0) -> Optional[str]: """ Stage all changes in output_dir and create a snapshot commit. @@ -121,13 +127,15 @@ def take_snapshot(output_dir: str, message: str, message_count: int = 0) -> Opti _git(['add', '-A'], work_tree=output_dir, git_dir=git_dir) # Check if there's anything to commit - status = _git(['status', '--porcelain'], work_tree=output_dir, git_dir=git_dir) + status = _git(['status', '--porcelain'], + work_tree=output_dir, git_dir=git_dir) if not status.stdout.strip(): return None # Nothing changed # Truncate message to keep commit subject readable subject = message.strip().replace('\n', ' ')[:120] - result = _git(['commit', '--no-verify', '-m', subject], work_tree=output_dir, git_dir=git_dir) + result = _git(['commit', '--no-verify', '-m', subject], + work_tree=output_dir, git_dir=git_dir) commit_hash = None for line in result.stdout.splitlines(): @@ -146,7 +154,8 @@ def take_snapshot(output_dir: str, message: str, message_count: int = 0) -> Opti return commit_hash except FileNotFoundError: - logger.warning_once('[snapshot] git not found — snapshots disabled.') + logger.warning_once( + '[snapshot] git not found — snapshots disabled.') return None except subprocess.CalledProcessError as e: logger.warning(f'[snapshot] git error: {e.stderr.strip()}') @@ -179,20 +188,19 @@ def list_snapshots(output_dir: str) -> list[dict]: parts = line.split('\t', 2) if len(parts) == 3: h = parts[0] - snapshots.append( - { - 'hash': h, - 'date': parts[1], - 'message': parts[2], - 'message_count': meta.get(h, {}).get('message_count', 0), - } - ) + snapshots.append({ + 'hash': h, + 'date': parts[1], + 'message': parts[2], + 'message_count': meta.get(h, {}).get('message_count', 0), + }) return snapshots except Exception: return [] -def restore_snapshot(output_dir: str, commit_hash: str) -> tuple[bool, int]: +def restore_snapshot(output_dir: str, + commit_hash: str) -> tuple[bool, int]: """ Restore output_dir to the state at commit_hash. @@ -204,7 +212,8 @@ def restore_snapshot(output_dir: str, commit_hash: str) -> tuple[bool, int]: logger.warning('[snapshot] No snapshot repo found.') return False, 0 try: - _git(['checkout', commit_hash, '--', '.'], work_tree=output_dir, git_dir=git_dir) + _git(['checkout', commit_hash, '--', '.'], + work_tree=output_dir, git_dir=git_dir) logger.info(f'[snapshot] Restored to {commit_hash}') meta = _load_meta(output_dir) message_count = meta.get(commit_hash, {}).get('message_count', 0) diff --git a/ms_agent/utils/stats.py b/ms_agent/utils/stats.py index def45af31..7ed705a1d 100644 --- a/ms_agent/utils/stats.py +++ b/ms_agent/utils/stats.py @@ -1,11 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio -import json import os import time from datetime import datetime from typing import Any, Dict, Iterable, Optional +import json from ms_agent.llm.utils import Message from .logger import get_logger @@ -23,7 +23,8 @@ def _get_lock(path: str) -> asyncio.Lock: return lock -def get_stats_path(config: Any, default_filename: str = 'workflow_stats.json') -> str: +def get_stats_path(config: Any, + default_filename: str = 'workflow_stats.json') -> str: stats_file = getattr(config, 'stats_file', None) output_dir = getattr(config, 'output_dir', './output') if stats_file: @@ -46,7 +47,8 @@ def summarize_usage(messages: Optional[Iterable[Message]]) -> Dict[str, int]: prompt_tokens += int(getattr(msg, 'prompt_tokens', 0) or 0) completion_tokens += int(getattr(msg, 'completion_tokens', 0) or 0) cached_tokens += int(getattr(msg, 'cached_tokens', 0) or 0) - cache_creation_input_tokens += int(getattr(msg, 'cache_creation_input_tokens', 0) or 0) + cache_creation_input_tokens += int( + getattr(msg, 'cache_creation_input_tokens', 0) or 0) api_calls += int(getattr(msg, 'api_calls', 0) or 0) return { 'prompt_tokens': prompt_tokens, @@ -70,7 +72,8 @@ async def append_stats(path: str, record: Dict[str, Any]) -> None: with open(path, 'r', encoding='utf-8') as f: data = json.load(f) or [] except Exception as exc: - logger.warning(f'Failed to read stats file {path}, resetting: {exc}') + logger.warning( + f'Failed to read stats file {path}, resetting: {exc}') data = [] if not isinstance(data, list): data = [] @@ -82,17 +85,16 @@ async def append_stats(path: str, record: Dict[str, Any]) -> None: def build_timing_record( - *, - event: str, - agent_tag: Optional[str], - agent_type: Optional[str], - started_at: str, - ended_at: str, - duration_s: float, - status: str, - usage: Optional[Dict[str, int]] = None, - extra: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: + *, + event: str, + agent_tag: Optional[str], + agent_type: Optional[str], + started_at: str, + ended_at: str, + duration_s: float, + status: str, + usage: Optional[Dict[str, int]] = None, + extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: record = { 'event': event, 'agent_tag': agent_tag, diff --git a/ms_agent/utils/stream_writer.py b/ms_agent/utils/stream_writer.py index 045c32e1c..46c5cf95b 100644 --- a/ms_agent/utils/stream_writer.py +++ b/ms_agent/utils/stream_writer.py @@ -21,7 +21,6 @@ File path: ``{output_dir}/subagents/{call_id}.stream.jsonl`` """ - import json import os import threading @@ -77,17 +76,16 @@ def on_start(self, agent_tag: Optional[str]) -> None: self._agent_tag = agent_tag try: self._file = open(self._path, 'w', encoding='utf-8') - self._write_line( - { - 'type': 'header', - 'call_id': self._call_id, - 'tool_name': self._tool_name, - 'agent_tag': agent_tag or '', - 'ts': _now_iso(), - } - ) + self._write_line({ + 'type': 'header', + 'call_id': self._call_id, + 'tool_name': self._tool_name, + 'agent_tag': agent_tag or '', + 'ts': _now_iso(), + }) except Exception as exc: - logger.warning('SubAgentStreamWriter: failed to open %s: %s', self._path, exc) + logger.warning( + 'SubAgentStreamWriter: failed to open %s: %s', self._path, exc) self._file = None def on_chunk(self, history: Any) -> None: @@ -104,15 +102,13 @@ def on_chunk(self, history: Any) -> None: with self._lock: if self._closed or self._file is None: return - for msg in messages[self._last_written_count :]: - self._write_line( - { - 'type': 'message', - 'index': self._last_written_count, - 'message': _msg_to_dict(msg), - 'ts': _now_iso(), - } - ) + for msg in messages[self._last_written_count:]: + self._write_line({ + 'type': 'message', + 'index': self._last_written_count, + 'message': _msg_to_dict(msg), + 'ts': _now_iso(), + }) self._last_written_count += 1 def on_end(self, history: Any) -> None: @@ -129,20 +125,19 @@ def on_end(self, history: Any) -> None: self._closed = True if self._file is not None: try: - self._write_line( - { - 'type': 'footer', - 'call_id': self._call_id, - 'agent_tag': self._agent_tag or '', - 'status': 'complete', - 'total_messages': self._last_written_count, - 'ts': _now_iso(), - } - ) + self._write_line({ + 'type': 'footer', + 'call_id': self._call_id, + 'agent_tag': self._agent_tag or '', + 'status': 'complete', + 'total_messages': self._last_written_count, + 'ts': _now_iso(), + }) self._file.flush() self._file.close() except Exception as exc: - logger.warning('SubAgentStreamWriter: close error on %s: %s', self._path, exc) + logger.warning( + 'SubAgentStreamWriter: close error on %s: %s', self._path, exc) finally: self._file = None @@ -158,17 +153,15 @@ def on_error(self, error: str) -> None: self._closed = True if self._file is not None: try: - self._write_line( - { - 'type': 'footer', - 'call_id': self._call_id, - 'agent_tag': self._agent_tag or '', - 'status': 'error', - 'error': error, - 'total_messages': self._last_written_count, - 'ts': _now_iso(), - } - ) + self._write_line({ + 'type': 'footer', + 'call_id': self._call_id, + 'agent_tag': self._agent_tag or '', + 'status': 'error', + 'error': error, + 'total_messages': self._last_written_count, + 'ts': _now_iso(), + }) self._file.flush() self._file.close() except Exception: diff --git a/ms_agent/utils/task_manager.py b/ms_agent/utils/task_manager.py index cff5b5691..a897e2fa9 100644 --- a/ms_agent/utils/task_manager.py +++ b/ms_agent/utils/task_manager.py @@ -14,8 +14,8 @@ @dataclass class BackgroundTask: task_id: str - task_type: str # 'agent' | 'shell' - tool_name: str # which tool spawned this + task_type: str # 'agent' | 'shell' + tool_name: str # which tool spawned this description: str status: str = 'running' # 'running' | 'completed' | 'failed' | 'killed' proc: Optional[Any] = field(default=None, repr=False) # mp.Process or asyncio.Task diff --git a/ms_agent/utils/thread_util.py b/ms_agent/utils/thread_util.py index 815407b48..16e46eba9 100644 --- a/ms_agent/utils/thread_util.py +++ b/ms_agent/utils/thread_util.py @@ -5,16 +5,19 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from functools import wraps -from tqdm.auto import tqdm - from ms_agent.utils.logger import get_logger +from tqdm.auto import tqdm logger = get_logger() -DEFAULT_MAX_WORKERS = int(os.getenv('DEFAULT_MAX_WORKERS', min(8, os.cpu_count() + 4))) +DEFAULT_MAX_WORKERS = int( + os.getenv('DEFAULT_MAX_WORKERS', min(8, + os.cpu_count() + 4))) -def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS, disable_tqdm: bool = False, tqdm_desc: str = None): +def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS, + disable_tqdm: bool = False, + tqdm_desc: str = None): """ A decorator to execute a function in a threaded manner using ThreadPoolExecutor. @@ -40,22 +43,26 @@ def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS, disable_tqdm: bool = """ def decorator(func): + @wraps(func) def wrapper(iterable, *args, **kwargs): results = [] # Create a tqdm progress bar with the total number of items to process with tqdm( - unit_scale=True, - unit_divisor=1024, - initial=0, - total=len(iterable), - desc=tqdm_desc or f'Processing {len(iterable)} items', - disable=disable_tqdm, + unit_scale=True, + unit_divisor=1024, + initial=0, + total=len(iterable), + desc=tqdm_desc or f'Processing {len(iterable)} items', + disable=disable_tqdm, ) as pbar: # Define a wrapper function to update the progress bar with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all tasks - futures = {executor.submit(func, item, *args, **kwargs): item for item in iterable} + futures = { + executor.submit(func, item, *args, **kwargs): item + for item in iterable + } # Update the progress bar as tasks complete for future in as_completed(futures): @@ -92,14 +99,16 @@ def weakref_cb(_, q=self._work_queue): num_threads = len(self._threads) if num_threads < self._max_workers: - thread_name = '%s_%d' % (self._thread_name_prefix or self, num_threads) + thread_name = '%s_%d' % (self._thread_name_prefix + or self, num_threads) # Import internal helpers from stdlib to keep behavior consistent. - from concurrent.futures.thread import _threads_queues, _worker # type: ignore + from concurrent.futures.thread import _worker, _threads_queues # type: ignore t = threading.Thread( name=thread_name, target=_worker, - args=(weakref.ref(self, weakref_cb), self._work_queue, self._initializer, self._initargs), + args=(weakref.ref(self, weakref_cb), self._work_queue, + self._initializer, self._initargs), ) t.daemon = True t.start() diff --git a/ms_agent/utils/tokenizer_util.py b/ms_agent/utils/tokenizer_util.py index 423b4084f..b26e5228f 100644 --- a/ms_agent/utils/tokenizer_util.py +++ b/ms_agent/utils/tokenizer_util.py @@ -59,10 +59,14 @@ def segment(self, content: str) -> List[str]: token_ids = self.encode(content) # Decode each token ID individually to get its string representation - token_strings = [self.tokenizer.decode([tid], skip_special_tokens=True) for tid in token_ids] + token_strings = [ + self.tokenizer.decode([tid], skip_special_tokens=True) + for tid in token_ids + ] return token_strings - def count_tokens(self, contents: Union[str, List[str]]) -> Union[int, List[int]]: + def count_tokens(self, + contents: Union[str, List[str]]) -> Union[int, List[int]]: """ Batch count tokens for multiple texts. diff --git a/ms_agent/utils/utils.py b/ms_agent/utils/utils.py index 0a2e1e1a7..a6d0bc87d 100644 --- a/ms_agent/utils/utils.py +++ b/ms_agent/utils/utils.py @@ -5,7 +5,6 @@ import html import importlib import importlib.util -import json import os.path import re import subprocess @@ -16,6 +15,7 @@ from pathlib import Path from typing import List, Optional, Tuple, Union +import json import requests import yaml from omegaconf import DictConfig, OmegaConf @@ -30,10 +30,11 @@ else: # Define a placeholder class for type-checking compatibility class BuiltInExceptionGroup(BaseException): + def __init__(self, message, exceptions): self.message = message self.exceptions = exceptions - self.args = (message,) + self.args = (message, ) def __str__(self): return f'{self.message}: {self.exceptions}' @@ -168,7 +169,8 @@ def escape_yaml_string(text: str) -> str: return text -def save_history(output_dir: str, task: str, config: DictConfig, messages: List['Message']): +def save_history(output_dir: str, task: str, config: DictConfig, + messages: List['Message']): """ Saves the specified configuration and conversation history to a cache directory for later retrieval or restoration. @@ -203,7 +205,10 @@ def save_history(output_dir: str, task: str, config: DictConfig, messages: List[ with open(config_file, 'w') as f: OmegaConf.save(config, f) with open(message_file, 'w') as f: - json.dump([message.to_dict() for message in messages], f, indent=4, ensure_ascii=False) + json.dump([message.to_dict() for message in messages], + f, + indent=4, + ensure_ascii=False) def read_history(output_dir: str, task: str): @@ -235,9 +240,8 @@ def read_history(output_dir: str, task: str): TypeError / AttributeError: If the deserialized JSON data lacks expected keys or structure for Message objects. """ - from ms_agent.config import Config from ms_agent.llm import Message - + from ms_agent.config import Config cache_dir = os.path.join(output_dir, DEFAULT_MEMORY_DIR) os.makedirs(cache_dir, exist_ok=True) config_file = os.path.join(cache_dir, f'{task}.yaml') @@ -305,7 +309,6 @@ def json_loads(text: str) -> dict: JSON decoding error is raised. """ import json5 - text = text.strip('\n') if text.startswith('```') and text.endswith('\n```'): text = '\n'.join(text.split('\n')[1:-1]) @@ -329,12 +332,14 @@ def download_pdf(url: str, out_file_path: str, reuse: bool = True): """ if reuse and os.path.exists(out_file_path): - logger.info(f"File '{out_file_path}' already exists. Skipping download.") + logger.info( + f"File '{out_file_path}' already exists. Skipping download.") return try: response = requests.get(url, stream=True) - response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) + response.raise_for_status( + ) # Raise an exception for bad status codes (4xx or 5xx) with open(out_file_path, 'wb') as pdf_file: for chunk in response.iter_content(chunk_size=8192): @@ -372,7 +377,6 @@ def load_image_from_url_to_pil(url: str) -> 'Image.Image': A PIL Image object if successful, None otherwise. """ from PIL import Image - try: response = requests.get(url, timeout=(10, 25)) # Raise an HTTPError for bad responses (4xx or 5xx) @@ -400,7 +404,6 @@ def load_image_from_uri_to_pil(uri: str) -> 'Image.Image': tuple: (PIL Image object, file extension string) or None if failed """ from PIL import Image - try: header, encoded = uri.split(',', 1) if ';base64' in header: @@ -419,11 +422,14 @@ def load_image_from_uri_to_pil(uri: str) -> 'Image.Image': logger.error(f'Error opening image with PIL for uri_to_pil: {e}') return None except Exception as e: - logger.error(f'Unexpected error loading image from URI for uri_to_pil: {e}') + logger.error( + f'Unexpected error loading image from URI for uri_to_pil: {e}') return None -def validate_url(img_url: str, backend: 'docling.backend.html_backend.HTMLDocumentBackend') -> str: +def validate_url( + img_url: str, + backend: 'docling.backend.html_backend.HTMLDocumentBackend') -> str: """ Validates and resolves a relative image URL using the base URL from the HTML document's metadata. @@ -442,23 +448,23 @@ def validate_url(img_url: str, backend: 'docling.backend.html_backend.HTMLDocume from urllib.parse import urljoin, urlparse # Check if we have a valid soup object in the backend - if not backend or not hasattr(backend, 'soup') or not backend.soup or not backend.soup.head: + if not backend or not hasattr( + backend, 'soup') or not backend.soup or not backend.soup.head: return None # Potential sources of base URLs to try sources = [ # Try base tag lambda: backend.soup.head.find('base', href=True)['href'] - if backend.soup.head.find('base', href=True) - else None, + if backend.soup.head.find('base', href=True) else None, # Try canonical link - lambda: backend.soup.head.find('link', rel='canonical', href=True)['href'] - if backend.soup.head.find('link', rel='canonical', href=True) - else None, + lambda: backend.soup.head.find('link', rel='canonical', href=True)[ + 'href'] if backend.soup.head.find( + 'link', rel='canonical', href=True) else None, # Try OG URL meta tag - lambda: backend.soup.head.find('meta', property='og:url', content=True)['content'] - if backend.soup.head.find('meta', property='og:url', content=True) - else None, + lambda: backend.soup.head.find( + 'meta', property='og:url', content=True)['content'] if backend.soup + .head.find('meta', property='og:url', content=True) else None ] # Try each source until we find a valid base URL @@ -502,9 +508,7 @@ def get_default_config(): os.path.dirname(__file__), # ms_agent/utils/ '..', # ↑ up to ms_agent/ 'agent', # → agent/ - 'agent.yaml', - ) - ) + 'agent.yaml')) with open(config_path, 'r', encoding='utf-8') as file: return yaml.safe_load(file) @@ -574,7 +578,8 @@ def txt_to_html(txt_path: str, html_path: Optional[str] = None) -> str: return html_path -def get_files_from_dir(folder_path: Union[str, Path], exclude: Optional[List[str]] = None) -> List[Path]: +def get_files_from_dir(folder_path: Union[str, Path], + exclude: Optional[List[str]] = None) -> List[Path]: """ Get all files in the target directory recursively, excluding files that match any of the given regex patterns. @@ -602,12 +607,11 @@ def get_files_from_dir(folder_path: Union[str, Path], exclude: Optional[List[str # Filter files based on exclusion patterns file_list = [ - file_path - for file_path in files - if not any( - pattern.search(str(file_path.resolve().relative_to(folder_path.resolve())).replace('\\', '/')) - for pattern in exclude_patterns - ) + file_path for file_path in files if not any( + pattern.search( + str(file_path.resolve().relative_to( + folder_path.resolve())).replace('\\', '/')) + for pattern in exclude_patterns) ] return file_list @@ -626,7 +630,9 @@ def is_package_installed(package_or_import_name: str) -> bool: return importlib.util.find_spec(package_or_import_name) is not None -def install_package(package_name: str, import_name: Optional[str] = None, extend_module: str = None): +def install_package(package_name: str, + import_name: Optional[str] = None, + extend_module: str = None): """ Check and install a package using pip. @@ -646,7 +652,8 @@ def install_package(package_name: str, import_name: Optional[str] = None, extend package_name = f'{package_name}[{extend_module}]' if not is_package_installed(import_name): - subprocess.check_call([sys.executable, '-m', 'pip', 'install', package_name]) + subprocess.check_call( + [sys.executable, '-m', 'pip', 'install', package_name]) logger.info(f'Package {package_name} installed successfully.') else: logger.info(f'Package {import_name} is already installed.') @@ -663,7 +670,7 @@ def extract_by_tag(text: str, tag: str) -> str: Returns: str: The content found between the specified tags, or an empty string if not found. """ - pattern = rf'<{tag}>(.*?)' + pattern = fr'<{tag}>(.*?)' match = re.search(pattern, text, re.DOTALL) if match: return match.group(1).strip() @@ -746,12 +753,15 @@ def file_lock(lock_dir: str, filename: str, timeout: float = 120.0): while True: try: - lock_fd = os.open(lock_file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + lock_fd = os.open(lock_file_path, + os.O_CREAT | os.O_EXCL | os.O_WRONLY) os.write(lock_fd, f'{os.getpid()}'.encode()) break except FileExistsError: if time.time() - start_time >= timeout: - raise TimeoutError(f'Failed to acquire lock for {filename} after {timeout} seconds') + raise TimeoutError( + f'Failed to acquire lock for {filename} after {timeout} seconds' + ) time.sleep(0.1) # Wait 100ms before retry try: @@ -766,7 +776,10 @@ def file_lock(lock_dir: str, filename: str, timeout: float = 120.0): pass -def render_markdown_todo(md_path: str, *, title: str = ' CURRENT PLAN', use_pager: bool = False) -> None: +def render_markdown_todo(md_path: str, + *, + title: str = ' CURRENT PLAN', + use_pager: bool = False) -> None: """ Render a Markdown todo list nicely in terminal using Rich. - Cross-platform (Windows/Linux/macOS) @@ -777,17 +790,15 @@ def render_markdown_todo(md_path: str, *, title: str = ' CURRENT PLAN', use_page from rich.panel import Panel from rich.theme import Theme - theme = Theme( - { - 'markdown.code': 'bold', - 'markdown.code_block': 'dim', - 'markdown.h1': 'bold', - 'markdown.h2': 'bold', - 'markdown.h3': 'bold', - 'markdown.link': 'underline', - 'markdown.list': '', - } - ) + theme = Theme({ + 'markdown.code': 'bold', + 'markdown.code_block': 'dim', + 'markdown.h1': 'bold', + 'markdown.h2': 'bold', + 'markdown.h3': 'bold', + 'markdown.link': 'underline', + 'markdown.list': '', + }) console = Console(theme=theme, soft_wrap=True, highlight=False) try: diff --git a/ms_agent/utils/workspace_policy.py b/ms_agent/utils/workspace_policy.py index dbc485321..a8380b710 100644 --- a/ms_agent/utils/workspace_policy.py +++ b/ms_agent/utils/workspace_policy.py @@ -68,9 +68,11 @@ def resolve_under_roots(self, user_path: str | Path) -> Path: except ValueError: continue else: - raise WorkspacePolicyError(f'Path is outside allowed workspace roots: {resolved}') + raise WorkspacePolicyError( + f'Path is outside allowed workspace roots: {resolved}') if self._is_denied(resolved): - raise WorkspacePolicyError(f'Path matches a deny_globs pattern: {resolved}') + raise WorkspacePolicyError( + f'Path matches a deny_globs pattern: {resolved}') return resolved def _is_denied(self, path: Path) -> bool: @@ -106,12 +108,16 @@ def assert_shell_command_allowed(self, command: str) -> None: if not command or not command.strip(): raise WorkspacePolicyError('Empty shell command') if len(command) > self.max_command_chars: - raise WorkspacePolicyError(f'Shell command exceeds max length ({self.max_command_chars})') + raise WorkspacePolicyError( + f'Shell command exceeds max length ({self.max_command_chars})') mode = self.shell_default_mode if mode == 'read_only': - if _shell_looks_mutating_or_network(command, allow_network=False): - raise WorkspacePolicyError('Shell is in read_only mode: mutating or network commands are not allowed') + if _shell_looks_mutating_or_network(command, + allow_network=False): + raise WorkspacePolicyError( + 'Shell is in read_only mode: mutating or network commands are not allowed' + ) elif mode == 'workspace_write': if not self.shell_network_enabled and _shell_looks_network(command): raise WorkspacePolicyError( @@ -140,13 +146,15 @@ def _shell_looks_network(command: str) -> bool: return any(t in lowered for t in tokens) -def _shell_looks_mutating_or_network(command: str, *, allow_network: bool) -> bool: +def _shell_looks_mutating_or_network(command: str, *, + allow_network: bool) -> bool: if not allow_network and _shell_looks_network(command): return True # redirection that creates/overwrites files if re.search(r'[>]{1,2}\s*[^\s]', command): return True - if re.search(r'\b(rm|rmdir|mv|cp|chmod|chown|chgrp|mkdir|touch|tee)\b', command): + if re.search(r'\b(rm|rmdir|mv|cp|chmod|chown|chgrp|mkdir|touch|tee)\b', + command): return True return False @@ -172,12 +180,14 @@ def dir_skipped(dirpath: Path) -> bool: return True parts = rel.split('/') for i in range(len(parts)): - sub = '/'.join(parts[: i + 1]) - if fnmatch.fnmatch(sub, pat.rstrip('/')) or fnmatch.fnmatch(sub + '/', pat): + sub = '/'.join(parts[:i + 1]) + if fnmatch.fnmatch(sub, pat.rstrip('/')) or fnmatch.fnmatch( + sub + '/', pat): return True return False - for dirpath, dirnames, filenames in os.walk(root, topdown=True, followlinks=False): + for dirpath, dirnames, filenames in os.walk( + root, topdown=True, followlinks=False): dp = Path(dirpath) if dir_skipped(dp): dirnames[:] = [] diff --git a/ms_agent/workflow/base.py b/ms_agent/workflow/base.py index 2c21d7bcb..9d484118c 100644 --- a/ms_agent/workflow/base.py +++ b/ms_agent/workflow/base.py @@ -2,9 +2,8 @@ from abc import ABC, abstractmethod from typing import Dict, Optional -from omegaconf import DictConfig - from ms_agent.config import Config +from omegaconf import DictConfig class Workflow(ABC): @@ -23,14 +22,12 @@ class Workflow(ABC): - mcp_server_file (Optional[str]): Path to an MCP server file if needed. Default is None. """ - def __init__( - self, - config_dir_or_id: Optional[str] = None, - config: Optional[DictConfig] = None, - env: Optional[Dict[str, str]] = None, - trust_remote_code: bool = False, - **kwargs, - ): + def __init__(self, + config_dir_or_id: Optional[str] = None, + config: Optional[DictConfig] = None, + env: Optional[Dict[str, str]] = None, + trust_remote_code: bool = False, + **kwargs): if config_dir_or_id is None: self.config = config else: diff --git a/ms_agent/workflow/chain_workflow.py b/ms_agent/workflow/chain_workflow.py index 7ea8cd55c..2b2b739e1 100644 --- a/ms_agent/workflow/chain_workflow.py +++ b/ms_agent/workflow/chain_workflow.py @@ -1,11 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os -from omegaconf import DictConfig - from ms_agent.agent.loader import AgentLoader from ms_agent.utils import get_logger from ms_agent.workflow.base import Workflow +from omegaconf import DictConfig logger = get_logger() @@ -27,7 +26,9 @@ def build_workflow(self): if isinstance(next_tasks, str): has_next.add(next_tasks) else: - assert len(next_tasks) == 1, 'ChainWorkflow only supports one next task' + assert len( + next_tasks + ) == 1, 'ChainWorkflow only supports one next task' has_next.update(next_tasks) for task_name in self.config.keys(): @@ -88,7 +89,8 @@ async def run(self, inputs, **kwargs): init_args['task'] = task init_args['load_cache'] = self.load_cache if isinstance(config, str): - init_args['config_dir_or_id'] = os.path.join(self.config.local_dir, config) + init_args['config_dir_or_id'] = os.path.join( + self.config.local_dir, config) else: init_args['config'] = config init_args['env'] = self.env diff --git a/ms_agent/workflow/dag_workflow.py b/ms_agent/workflow/dag_workflow.py index 97e3ac65c..1419ccf46 100644 --- a/ms_agent/workflow/dag_workflow.py +++ b/ms_agent/workflow/dag_workflow.py @@ -3,11 +3,10 @@ from collections import defaultdict, deque from typing import Any, Dict, List, Set -from omegaconf import DictConfig - from ms_agent.agent.loader import AgentLoader from ms_agent.utils import get_logger from ms_agent.workflow.base import Workflow +from omegaconf import DictConfig logger = get_logger() @@ -38,7 +37,9 @@ def build_workflow(self): self.nodes = set(list(self.graph.keys()) + list(indegree.keys())) # Find root tasks (indegree==0) - self.roots = [t for t in tasks if 'next' in self.config[t] and indegree[t] == 0] + self.roots = [ + t for t in tasks if 'next' in self.config[t] and indegree[t] == 0 + ] if not self.roots: raise ValueError('No root task found for DagWorkflow') @@ -75,10 +76,12 @@ async def run(self, inputs: Any, **kwargs): task_input = inputs else: parent_outs = [outputs[p] for p in self.parents[task]] - task_input = parent_outs if len(parent_outs) > 1 else parent_outs[0] + task_input = parent_outs if len( + parent_outs) > 1 else parent_outs[0] task_info: DictConfig = getattr(self.config, task) - agent_cfg_path = os.path.join(self.config.local_dir, task_info.agent_config) + agent_cfg_path = os.path.join(self.config.local_dir, + task_info.agent_config) if not hasattr(task_info, 'agent'): task_info.agent = DictConfig({}) init_args = getattr(task_info.agent, 'kwargs', {}) @@ -95,5 +98,8 @@ async def run(self, inputs: Any, **kwargs): outputs[task] = result # Return results of terminal nodes (no outgoing edges) - terminals = [t for t in self.config.keys() if t not in self.graph and t in self.nodes] + terminals = [ + t for t in self.config.keys() + if t not in self.graph and t in self.nodes + ] return {t: outputs[t] for t in terminals} diff --git a/ms_agent/workflow/deep_research/__init__.py b/ms_agent/workflow/deep_research/__init__.py index f79565769..df00c463b 100644 --- a/ms_agent/workflow/deep_research/__init__.py +++ b/ms_agent/workflow/deep_research/__init__.py @@ -1,12 +1,6 @@ -from .principle import ( - BSGMatrixPrinciple, - MECEPrinciple, - ParetoPrinciple, - Principle, - PyramidPrinciple, - SWOTPrinciple, - ValueChainPrinciple, -) +from .principle import (BSGMatrixPrinciple, MECEPrinciple, ParetoPrinciple, + Principle, PyramidPrinciple, SWOTPrinciple, + ValueChainPrinciple) from .research_workflow import ResearchWorkflow from .research_workflow_beta import ResearchWorkflowBeta diff --git a/ms_agent/workflow/deep_research/principle.py b/ms_agent/workflow/deep_research/principle.py index 89dd929cb..818a89092 100644 --- a/ms_agent/workflow/deep_research/principle.py +++ b/ms_agent/workflow/deep_research/principle.py @@ -2,12 +2,11 @@ class Principle: + def __init__(self, breakdown_prompt: str = None): - self.breakdown_prompt: str = ( - breakdown_prompt - or """\n首先生成一份系统性的分析方案,自上而下breakdown,输出markdown格式:\n + + self.breakdown_prompt: str = breakdown_prompt or """\n首先生成一份系统性的分析方案,自上而下breakdown,输出markdown格式:\n """ - ) self.todo_prompt: str = """"\n基于上述breakdown,生成todo list,输出markdown格式,形式必须遵循:\n # Title @@ -57,56 +56,48 @@ def __init__(self, breakdown_prompt: str = None): class BSGMatrixPrinciple(Principle): + def __init__(self, breakdown_prompt: str = None): super().__init__(breakdown_prompt=breakdown_prompt) - self.breakdown_prompt = ( - breakdown_prompt - or '\n首先使用Boston Matrix Analysis Principle(Boston Consulting Group matrix analysis)来拆解和分析上述问题,输出markdown格式:' - ) + self.breakdown_prompt = breakdown_prompt or '\n首先使用Boston Matrix Analysis Principle(Boston Consulting Group matrix analysis)来拆解和分析上述问题,输出markdown格式:' class ParetoPrinciple(Principle): + def __init__(self, breakdown_prompt: str = None): super().__init__(breakdown_prompt=breakdown_prompt) - self.breakdown_prompt = ( - breakdown_prompt or '\n首先使用Pareto Principle(80/20 Rule)来拆解和分析上述问题,输出markdown格式:' - ) + self.breakdown_prompt = breakdown_prompt or '\n首先使用Pareto Principle(80/20 Rule)来拆解和分析上述问题,输出markdown格式:' class MECEPrinciple(Principle): + def __init__(self, breakdown_prompt: str = None): super().__init__(breakdown_prompt=breakdown_prompt) - self.breakdown_prompt = ( - breakdown_prompt - or '\n首先使用MECE原则(Mutually Exclusive and Collectively Exhaustive)来拆解和分析上述问题,输出markdown格式:' - ) + self.breakdown_prompt = breakdown_prompt or '\n首先使用MECE原则(Mutually Exclusive and Collectively Exhaustive)来拆解和分析上述问题,输出markdown格式:' class PyramidPrinciple(Principle): + def __init__(self, breakdown_prompt: str = None): super().__init__(breakdown_prompt=breakdown_prompt) - self.breakdown_prompt = ( - breakdown_prompt or '\n首先使用金字塔原理(Pyramid Principle)来拆解和分析上述问题,输出markdown格式:' - ) + self.breakdown_prompt = breakdown_prompt or '\n首先使用金字塔原理(Pyramid Principle)来拆解和分析上述问题,输出markdown格式:' class SWOTPrinciple(Principle): + def __init__(self, breakdown_prompt: str = None): super().__init__(breakdown_prompt=breakdown_prompt) - self.breakdown_prompt = ( - breakdown_prompt or '\n首先使用SWOT分析法(SWOT Analysis)来拆解和分析上述问题,输出markdown格式:' - ) + self.breakdown_prompt = breakdown_prompt or '\n首先使用SWOT分析法(SWOT Analysis)来拆解和分析上述问题,输出markdown格式:' class ValueChainPrinciple(Principle): + def __init__(self, breakdown_prompt: str = None): super().__init__(breakdown_prompt=breakdown_prompt) - self.breakdown_prompt = ( - breakdown_prompt or '\n首先使用价值链分析(Value Chain Analysis)来拆解和分析上述问题,输出markdown格式:' - ) + self.breakdown_prompt = breakdown_prompt or '\n首先使用价值链分析(Value Chain Analysis)来拆解和分析上述问题,输出markdown格式:' diff --git a/ms_agent/workflow/deep_research/research_utils.py b/ms_agent/workflow/deep_research/research_utils.py index 56dcb7952..7d5ccdb69 100644 --- a/ms_agent/workflow/deep_research/research_utils.py +++ b/ms_agent/workflow/deep_research/research_utils.py @@ -32,7 +32,8 @@ class ResearchRequest(BaseModel): query: str = Field(..., description='Research query') depth: int = Field(default=2, ge=1, le=5, description='Research depth') - breadth: int = Field(default=4, ge=1, le=10, description='Research breadth') + breadth: int = Field( + default=4, ge=1, le=10, description='Research breadth') class ResearchResponse(BaseModel): @@ -47,8 +48,10 @@ class ResearchResponse(BaseModel): class LearningsResponse(BaseModel): """Response model for processed search results.""" - learnings: List[str] = Field(..., description='List of learnings extracted from search results') - follow_up_questions: List[str] = Field(..., description='List of follow-up questions for further research') + learnings: List[str] = Field( + ..., description='List of learnings extracted from search results') + follow_up_questions: List[str] = Field( + ..., description='List of follow-up questions for further research') class ProgressTracker: @@ -60,10 +63,12 @@ def __init__(self): def __enter__(self): self.progress = Progress( - SpinnerColumn(), TextColumn('[progress.description]{task.description}'), console=console - ) + SpinnerColumn(), + TextColumn('[progress.description]{task.description}'), + console=console) self.progress.__enter__() - self.task_id = self.progress.add_task('Starting research...', total=None) + self.task_id = self.progress.add_task( + 'Starting research...', total=None) return self def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/ms_agent/workflow/deep_research/research_workflow.py b/ms_agent/workflow/deep_research/research_workflow.py index d3dcc48a4..03d64bf80 100644 --- a/ms_agent/workflow/deep_research/research_workflow.py +++ b/ms_agent/workflow/deep_research/research_workflow.py @@ -1,11 +1,11 @@ # flake8: noqa # yapf: disable import copy -import json import os import re from typing import Any, Dict, List, Optional, Union +import json from ms_agent.llm.openai import OpenAIChat from ms_agent.utils import get_logger @@ -484,4 +484,4 @@ def run(self, # Dump report to markdown file with open(self.workdir_structure['report_md'], 'w', encoding='utf-8') as f_report: f_report.write(resp_content) - logger.info(f'Report saved to {self.workdir_structure['report_md']}') + logger.info(f'Report saved to {self.workdir_structure["report_md"]}') diff --git a/ms_agent/workflow/deep_research/research_workflow_beta.py b/ms_agent/workflow/deep_research/research_workflow_beta.py index 6dd4003a2..24816ae52 100644 --- a/ms_agent/workflow/deep_research/research_workflow_beta.py +++ b/ms_agent/workflow/deep_research/research_workflow_beta.py @@ -8,8 +8,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import click -from rich.prompt import Confirm, Prompt - from ms_agent.llm.openai import OpenAIChat from ms_agent.rag.extraction_manager import extract_key_information from ms_agent.tools.search.exa.schema import dump_batch_search_results @@ -18,13 +16,12 @@ from ms_agent.utils.logger import get_logger from ms_agent.utils.utils import remove_resource_info, text_hash from ms_agent.workflow.deep_research.principle import MECEPrinciple, Principle -from ms_agent.workflow.deep_research.research_utils import ( - LearningsResponse, - ProgressTracker, - ResearchProgress, - ResearchResult, -) +from ms_agent.workflow.deep_research.research_utils import (LearningsResponse, + ProgressTracker, + ResearchProgress, + ResearchResult) from ms_agent.workflow.deep_research.research_workflow import ResearchWorkflow +from rich.prompt import Confirm, Prompt logger = get_logger() @@ -346,7 +343,7 @@ async def generate_search_queries( if learnings: learnings_prompt = ( f'\n\nHere are some learnings from previous research, ' - f'use them to generate more specific queries: {', '.join(learnings)}' + f'use them to generate more specific queries: {", ".join(learnings)}' ) rewrite_prompt = ( @@ -595,7 +592,7 @@ async def process_search_results( f'information dense as possible.\n' f'- Make sure to include any entities like people, places, companies, products, ' f'things, etc in the learnings, as well as any exact metrics, numbers, or dates.\n' - f'{multimodal_prompt if self._enable_multimodal else ''}' + f'{multimodal_prompt if self._enable_multimodal else ""}' f'- The learnings will be used to research the topic further.\n' f'- Do NOT repeat the query verbatim as a learning. ' f'Do NOT invent facts not present in .\n' @@ -623,7 +620,7 @@ async def process_search_results( response_data = response_data.get('learnings_extraction', {}) or response_data except Exception as e: logger.error(f'Error parsing JSON response: {e}') - logger.error(f'Raw response content: {response.get('content', '')}') + logger.error(f'Raw response content: {response.get("content", "")}') return LearningsResponse(learnings=[], follow_up_questions=[]) learnings = response_data.get('learnings', [])[:num_learnings] @@ -673,7 +670,7 @@ async def _process_single_query( if new_depth > 0 and len(processed_results.follow_up_questions) > 0: logger.info( f'Researching deeper, breadth: {new_breadth}, ' - f'depth: {progress_manager.get_current().current_depth if progress_manager else 'N/A'}' + f'depth: {progress_manager.get_current().current_depth if progress_manager else "N/A"}' ) # Use atomic increment to avoid race conditions if progress_manager is not None: @@ -682,8 +679,8 @@ async def _process_single_query( # Create next query from follow-up questions next_query = ( f'Previous Query: {search_request.query}\n' - f'Previous research goal: {getattr(search_request, 'research_goal', '')}\n' - f'Follow-up research directions: {', '.join(processed_results.follow_up_questions)}' + f'Previous research goal: {getattr(search_request, "research_goal", "")}\n' + f'Follow-up research directions: {", ".join(processed_results.follow_up_questions)}' ).strip() # Continue with deeper research, passing through the progress manager @@ -872,7 +869,7 @@ async def write_final_report(self, prompt: str, f'\n{learnings_text}\n' f'\n\nPlease respond with valid JSON that matches provided schema:\n{json_schema}\n' f'Please respond in the language of the . ' - f'{multimodal_prompt if self._enable_multimodal else ''}' + f'{multimodal_prompt if self._enable_multimodal else ""}' ) response = await self._chat_async( @@ -891,7 +888,7 @@ async def write_final_report(self, prompt: str, fix_prompt = ( f'The response is not valid JSON. Please fix it. ' f'You can only return the fixed JSON, no other text. ' - f'The response is: {response.get('content', '')}' + f'The response is: {response.get("content", "")}' ) response = await self._chat_async( messages=[ @@ -1085,7 +1082,7 @@ async def _run(self, encoding='utf-8') as f_report: f_report.write(report) logger.info( - f'Report saved to {self.workdir_structure['report_md']}') + f'Report saved to {self.workdir_structure["report_md"]}') else: # Generate and save answer answer = await self.write_final_answer( @@ -1100,7 +1097,7 @@ async def _run(self, encoding='utf-8') as f_answer: f_answer.write(answer) logger.info( - f'Answer saved to {self.workdir_structure['report_md']}') + f'Answer saved to {self.workdir_structure["report_md"]}') return self.workdir_structure['report_md'] diff --git a/ms_agent/workflow/loader.py b/ms_agent/workflow/loader.py index 3df2d6b57..7e0589cfa 100644 --- a/ms_agent/workflow/loader.py +++ b/ms_agent/workflow/loader.py @@ -1,21 +1,19 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import Dict, Optional -from omegaconf import DictConfig, OmegaConf - from ms_agent.config.config import Config +from omegaconf import DictConfig, OmegaConf class WorkflowLoader: + @classmethod - def build( - cls, - config_dir_or_id: Optional[str] = None, - config: Optional[DictConfig] = None, - env: Optional[Dict[str, str]] = None, - trust_remote_code: bool = False, - **kwargs, - ): + def build(cls, + config_dir_or_id: Optional[str] = None, + config: Optional[DictConfig] = None, + env: Optional[Dict[str, str]] = None, + trust_remote_code: bool = False, + **kwargs): wf_config: Optional[DictConfig] = None if config_dir_or_id is not None: wf_config: DictConfig = Config.from_task(config_dir_or_id, env) @@ -27,7 +25,6 @@ def build( from ms_agent.workflow.chain_workflow import ChainWorkflow from ms_agent.workflow.dag_workflow import DagWorkflow - wf_type = ChainWorkflow.WORKFLOW_NAME.lower() wf_type = getattr(wf_config, 'type', '').lower() or wf_type @@ -38,8 +35,7 @@ def build( env=env, mcp_server_file=kwargs.get('mcp_server_file'), load_cache=kwargs.get('load_cache', False), - trust_remote_code=trust_remote_code, - ) + trust_remote_code=trust_remote_code) elif wf_type == DagWorkflow.WORKFLOW_NAME.lower(): wf_instance = DagWorkflow( config_dir_or_id=config_dir_or_id, @@ -47,8 +43,7 @@ def build( env=env, mcp_server_file=kwargs.get('mcp_server_file'), load_cache=kwargs.get('load_cache', False), - trust_remote_code=trust_remote_code, - ) + trust_remote_code=trust_remote_code) elif wf_type == 'ResearchWorkflow'.lower(): # TODO raise NotImplementedError() diff --git a/projects/code_genesis/tools/build_sandbox_image.py b/projects/code_genesis/tools/build_sandbox_image.py deleted file mode 100644 index 8995df386..000000000 --- a/projects/code_genesis/tools/build_sandbox_image.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 -"""Build code_genesis sandbox image via Docker API (Colima / API-compatible daemon). - -Avoids requiring the standalone `docker` CLI binary; uses the PyPI `docker` package -(``pip install docker`` / requirements/code.txt) like the rest of ms-agent. -""" -from __future__ import annotations - -import sys -from pathlib import Path - -IMAGE_NAME = "code-genesis-sandbox" -IMAGE_TAG = "version1" - -DOCKERFILE = r"""FROM python:3.12-slim - -# Install system dependencies and Node.js -RUN sed -i 's|deb.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \ - && apt-get update -o Acquire::Retries=5 \ - && apt-get install -y --no-install-recommends \ - curl \ - git \ - build-essential \ - && curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \ - && apt-get install -y --no-install-recommends nodejs \ - && apt-get clean && rm -rf /var/lib/apt/lists/* - -# Configure npm to use a Chinese mirror. Comment out this line if not needed. -RUN npm config set registry https://registry.npmmirror.com/ - -# Install Jupyter kernel gateway (required by sandbox) -RUN pip install --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple --trusted-host mirrors.aliyun.com \ - jupyter_kernel_gateway \ - jupyter_client \ - ipykernel - -# Install Python kernel -RUN python -m ipykernel install --sys-prefix --name python3 --display-name "Python 3" - -WORKDIR /data - -EXPOSE 8888 -CMD ["jupyter", "kernelgateway", "--KernelGatewayApp.ip=0.0.0.0", "--KernelGatewayApp.port=8888", "--KernelGatewayApp.allow_origin=*"] -""" - - -def _repo_root() -> Path: - # projects/code_genesis/tools/thisfile -> parents[3] == repo root - return Path(__file__).resolve().parents[3] - - -def main() -> int: - try: - import docker - except ImportError: - print( - "Missing Python package 'docker'. Run: pip install docker\n" - "or: pip install -r requirements/code.txt", - file=sys.stderr, - ) - return 1 - - root = _repo_root() - dockerfile_path = root / "Dockerfile.sandbox" - dockerfile_path.write_text(DOCKERFILE, encoding="utf-8") - try: - client = docker.from_env() - client.ping() - print("Pulling python:3.12-slim ...") - client.images.pull("python:3.12-slim") - tag = f"{IMAGE_NAME}:{IMAGE_TAG}" - print(f"Building {tag} (context: {root}) ...") - stream = client.api.build( - path=str(root), - dockerfile="Dockerfile.sandbox", - tag=tag, - rm=True, - forcerm=True, - decode=True, - ) - for chunk in stream: - if not chunk: - continue - if "stream" in chunk and chunk["stream"]: - print(chunk["stream"], end="") - if "errorDetail" in chunk: - print(chunk.get("error", chunk["errorDetail"]), file=sys.stderr) - return 1 - print(f"Done: {tag}") - finally: - if dockerfile_path.is_file(): - dockerfile_path.unlink() - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/projects/code_genesis/tools/build_sandbox_image.sh b/projects/code_genesis/tools/build_sandbox_image.sh index 9c368011d..8aae6859a 100755 --- a/projects/code_genesis/tools/build_sandbox_image.sh +++ b/projects/code_genesis/tools/build_sandbox_image.sh @@ -1,10 +1,53 @@ -#!/usr/bin/env bash -# Build sandbox image using Docker HTTP API (PyPI `docker`); Colima supplies the daemon. -# No Docker Desktop and no standalone `docker` CLI required. - -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -if [[ -n "${VIRTUAL_ENV:-}" && -x "${VIRTUAL_ENV}/bin/python" ]]; then - exec "${VIRTUAL_ENV}/bin/python" "${SCRIPT_DIR}/build_sandbox_image.py" "$@" -fi -exec python3 "${SCRIPT_DIR}/build_sandbox_image.py" "$@" +#!/bin/bash + +# Build Docker sandbox image for code_genesis +# Includes Python + Node.js for full-stack project support + +set -e + +IMAGE_NAME="code-genesis-sandbox" +IMAGE_TAG="version1" + +echo "Building code-genesis sandbox Docker image..." + +docker pull python:3.12-slim + +cat > Dockerfile.sandbox << 'EOF' +FROM python:3.12-slim + +# Install system dependencies and Node.js +RUN sed -i 's|deb.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \ + && apt-get update -o Acquire::Retries=5 \ + && apt-get install -y --no-install-recommends \ + curl \ + git \ + build-essential \ + && curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \ + && apt-get install -y --no-install-recommends nodejs \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +# Configure npm to use a Chinese mirror. Comment out this line if not needed. +RUN npm config set registry https://registry.npmmirror.com/ + +# Install Jupyter kernel gateway (required by sandbox) +RUN pip install --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple --trusted-host mirrors.aliyun.com \ + jupyter_kernel_gateway \ + jupyter_client \ + ipykernel + +# Install Python kernel +RUN python -m ipykernel install --sys-prefix --name python3 --display-name "Python 3" + +WORKDIR /data + +EXPOSE 8888 +CMD ["jupyter", "kernelgateway", "--KernelGatewayApp.ip=0.0.0.0", "--KernelGatewayApp.port=8888", "--KernelGatewayApp.allow_origin=*"] +EOF + +echo "Building Docker image: ${IMAGE_NAME}:${IMAGE_TAG}" +docker build -f Dockerfile.sandbox -t "${IMAGE_NAME}:${IMAGE_TAG}" . + +rm Dockerfile.sandbox + +echo "Done: ${IMAGE_NAME}:${IMAGE_TAG}" +echo "Contains: Python 3.12, Node.js 20, npm, git, curl" diff --git a/projects/code_genesis/workflow/api_search.py b/projects/code_genesis/workflow/api_search.py index 60a635530..4e380d17a 100644 --- a/projects/code_genesis/workflow/api_search.py +++ b/projects/code_genesis/workflow/api_search.py @@ -1,15 +1,16 @@ -import json import os import re from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict +import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.constants import DEFAULT_INDEX_DIR class ApiSearch(ToolBase): + def __init__(self, config): super().__init__(config) index_dir = getattr(config, 'index_cache_dir', DEFAULT_INDEX_DIR) @@ -42,19 +43,21 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'object', 'properties': { 'keywords': { - 'type': 'string', - 'description': 'The keywords/regex in the url to search api of.', + 'type': + 'string', + 'description': + 'The keywords/regex in the url to search api of.', } }, 'required': [], - 'additionalProperties': False, - }, - ), + 'additionalProperties': False + }), ] } return tools - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: return await self.url_search(**tool_args) async def url_search(self, keywords: str = None): @@ -81,7 +84,9 @@ async def url_search(self, keywords: str = None): use_regex = True except re.error: # Not a valid regex, treat as comma-separated keywords - keyword_list = [kw.strip() for kw in keywords.split(',') if kw.strip()] + keyword_list = [ + kw.strip() for kw in keywords.split(',') if kw.strip() + ] use_regex = False def search_in_file(file_path): @@ -102,10 +107,12 @@ def search_in_file(file_path): is_match = regex_pattern.search(url) is not None else: # Substring matching (any keyword matches) - is_match = any(keyword in url for keyword in keyword_list) + is_match = any(keyword in url + for keyword in keyword_list) if is_match: - matches.append(json.dumps(protocol, ensure_ascii=False)) + matches.append( + json.dumps(protocol, ensure_ascii=False)) except Exception: # noqa return [] if matches: @@ -113,8 +120,7 @@ def search_in_file(file_path): matches.insert( 0, f'API{" with keywords: " + str(keywords) + " " + match_mode if keywords else ""} defined ' - f'in {file_path}:', - ) + f'in {file_path}:') matches.append('\n') return matches @@ -126,7 +132,10 @@ def search_in_file(file_path): # Use thread pool to search files in parallel all_matches = [] with ThreadPoolExecutor(max_workers=8) as executor: - future_to_file = {executor.submit(search_in_file, f): f for f in files_to_search} + future_to_file = { + executor.submit(search_in_file, f): f + for f in files_to_search + } for future in as_completed(future_to_file): matches = future.result() all_matches.extend(matches) diff --git a/projects/code_genesis/workflow/architect.py b/projects/code_genesis/workflow/architect.py index 889b34384..b340ab78a 100644 --- a/projects/code_genesis/workflow/architect.py +++ b/projects/code_genesis/workflow/architect.py @@ -6,6 +6,7 @@ class ArchitectureAgent(LLMAgent): + async def run(self, messages, **kwargs): with open(os.path.join(self.output_dir, 'topic.txt'), 'r') as f: topic = f.read() diff --git a/projects/code_genesis/workflow/coding.py b/projects/code_genesis/workflow/coding.py index 70e6df08e..c6cfc398f 100644 --- a/projects/code_genesis/workflow/coding.py +++ b/projects/code_genesis/workflow/coding.py @@ -1,6 +1,5 @@ import asyncio import dataclasses -import json import os import re import shutil @@ -9,17 +8,18 @@ from pathlib import Path from typing import List, Optional, Set -from omegaconf import DictConfig - +import json from ms_agent import LLMAgent from ms_agent.agent import CodeAgent from ms_agent.llm import Message from ms_agent.memory.condenser.code_condenser import CodeCondenser from ms_agent.tools.code_server import LSPCodeServer from ms_agent.utils import get_logger -from ms_agent.utils.constants import DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, DEFAULT_TAG +from ms_agent.utils.constants import (DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, + DEFAULT_TAG) from ms_agent.utils.parser_utils import ImportInfo, parse_imports from ms_agent.utils.utils import extract_code_blocks, file_lock +from omegaconf import DictConfig logger = get_logger() @@ -54,14 +54,13 @@ class Programmer(LLMAgent): - def __init__( - self, - config: DictConfig = DictConfig({}), - tag: str = DEFAULT_TAG, - trust_remote_code: bool = False, - code_file: str = None, - **kwargs, - ): + + def __init__(self, + config: DictConfig = DictConfig({}), + tag: str = DEFAULT_TAG, + trust_remote_code: bool = False, + code_file: str = None, + **kwargs): # Validate and adjust config before passing to parent config = self._validate_config(config) super().__init__(config, tag, trust_remote_code, **kwargs) @@ -91,19 +90,22 @@ def _validate_config(self, config: DictConfig) -> DictConfig: # Check edit_file_config.api_key edit_file_api_key = None try: - edit_file_api_key = ( - config.get('tools', {}).get('file_system', {}).get('edit_file_config', {}).get('api_key') - ) + edit_file_api_key = config.get('tools', {}).get( + 'file_system', {}).get('edit_file_config', {}).get('api_key') except Exception: pass if not edit_file_api_key: # Remove edit_file from include list try: - include_list = config.get('tools', {}).get('file_system', {}).get('include', []) + include_list = config.get('tools', + {}).get('file_system', + {}).get('include', []) if include_list and 'edit_file' in include_list: include_list.remove('edit_file') - logger.warning('[coding] edit_file_config.api_key not set, removing edit_file from tools') + logger.warning( + '[coding] edit_file_config.api_key not set, removing edit_file from tools' + ) except Exception: pass else: @@ -131,7 +133,8 @@ def stop_nothing(self): self.llm.args['extra_body']['stop_sequences'] = self.stop_words[1] def is_stop_imports(self): - return self.llm.args['extra_body']['stop_sequences'] == self.stop_words[0] + return self.llm.args['extra_body'][ + 'stop_sequences'] == self.stop_words[0] def find_all_files(self): self.all_code_files = [] @@ -180,28 +183,36 @@ def read_file(path): contents = content.split('\n') comments = ['*', '#', '-', '%', '/'] - contents = [c for c in contents if not any(c.strip().startswith(cm) for cm in comments)] - all_files = parse_imports(code_file, '\n'.join(contents), self.output_dir) or [] + contents = [ + c for c in contents + if not any(c.strip().startswith(cm) for cm in comments) + ] + all_files = parse_imports(code_file, '\n'.join(contents), + self.output_dir) or [] all_read_files = find_all_read_files() all_notes = [] for file in all_files: if 'react' in file.source_file or 'vue' in file.source_file: continue if file.source_file == code_file: - all_notes.append(f'You should not import the file itself: {code_file}') + all_notes.append( + f'You should not import the file itself: {code_file}') continue - file.imported_items = [item for item in file.imported_items if item not in ('*', 'default')] + file.imported_items = [ + item for item in file.imported_items + if item not in ('*', 'default') + ] filename = os.path.join(self.output_dir, file.source_file) if not os.path.exists(filename): if file.source_file in self.all_code_files: all_notes.append( - f'The dependency you import: {file.source_file} does not exist, the order may be incorrect.' - ) + f'The dependency you import: {file.source_file} does not exist, ' + f'the order may be incorrect.') else: all_notes.append( - f'The dependency you import: {file.source_file} is not in the code plan, stop importing it.' - ) + f'The dependency you import: {file.source_file} is not in the code plan, ' + f'stop importing it.') elif os.path.isfile(filename): if file.source_file not in all_read_files: all_notes.append( @@ -210,7 +221,8 @@ def read_file(path): elif os.path.isdir(filename): index_file_path = self.find_index_file(filename) if index_file_path: - index_file_path = str(Path(index_file_path).relative_to(self.output_dir)) + index_file_path = str( + Path(index_file_path).relative_to(self.output_dir)) if index_file_path not in all_read_files: all_notes.append( f'Extra file {index_file_path} content in imports:\n{read_file(index_file_path)}' @@ -218,9 +230,9 @@ def read_file(path): if all_notes: all_notes = '\n'.join(all_notes) - user_content = ( - f'Problems found in your imports:\n\n{all_notes}\nCorrect the errors and regenerate the code:\n' - ) + user_content = (f'Problems found in your imports:\n' + f'\n{all_notes}\n' + f'Correct the errors and regenerate the code:\n') messages.append(Message(role='user', content=user_content)) else: messages.pop(-1) @@ -230,12 +242,14 @@ def read_file(path): async def _incremental_check(self, code_file: str, partial_code: str): if self.lsp_check: - lsp_result = await self._incremental_lsp_check(code_file, partial_code) + lsp_result = await self._incremental_lsp_check( + code_file, partial_code) else: lsp_result = None if self.post_import_check: - import_result = await self._after_import_check(code_file, partial_code) + import_result = await self._after_import_check( + code_file, partial_code) else: import_result = None return (lsp_result or '') + '\n' + (import_result or '') @@ -246,27 +260,38 @@ def find_index_file(full_path): return None else: result = None - for index_file in ['index.ts', 'index.tsx', 'index.js', 'index.jsx', 'index.vue', '__init__.py']: + for index_file in [ + 'index.ts', 'index.tsx', 'index.js', 'index.jsx', + 'index.vue', '__init__.py' + ]: index_path = os.path.join(full_path, index_file) if os.path.exists(index_path): result = index_path break return result - async def _after_import_check(self, code_file: str, partial_code: str) -> Optional[str]: + async def _after_import_check(self, code_file: str, + partial_code: str) -> Optional[str]: errors = [] partial_code = partial_code.split('\n') comments = ['*', '#', '-', '%', '/'] - contents = [c for c in partial_code if not any(c.strip().startswith(cm) for cm in comments)] + contents = [ + c for c in partial_code + if not any(c.strip().startswith(cm) for cm in comments) + ] partial_code = '\n'.join(contents) - all_imports: List[ImportInfo] = parse_imports(code_file, partial_code, self.output_dir) + all_imports: List[ImportInfo] = parse_imports(code_file, partial_code, + self.output_dir) for info in all_imports: source_file = info.source_file if not source_file or 'react' in source_file or 'vue' in source_file: continue - info.imported_items = [item for item in info.imported_items if item not in ('*', 'default')] + info.imported_items = [ + item for item in info.imported_items + if item not in ('*', 'default') + ] if not os.path.isabs(source_file): full_path = os.path.join(self.output_dir, source_file) @@ -283,17 +308,14 @@ async def _after_import_check(self, code_file: str, partial_code: str) -> Option errors.append( f'Import error in {code_file}:\n' f" Directory '{source_file}' exists but has no index file (__init__.py, index.ts, etc.)\n" - f' Statement: {info.raw_statement}\n' - ) + f' Statement: {info.raw_statement}\n') continue else: full_path = index_file_path else: - errors.append( - f'Import error in {code_file}:\n' - f" File '{source_file}' does not exist\n" - f' Statement: {info.raw_statement}\n' - ) + errors.append(f'Import error in {code_file}:\n' + f" File '{source_file}' does not exist\n" + f' Statement: {info.raw_statement}\n') continue # 2. Check if imported symbols exist in the file @@ -315,12 +337,12 @@ async def _after_import_check(self, code_file: str, partial_code: str) -> Option errors.append( f'Import error in {code_file}:\n' f" Items {missing_items} not found in '{source_file}'\n" - f' Statement: {info.raw_statement}\n' - ) + f' Statement: {info.raw_statement}\n') return '\n'.join(errors) if errors else None - async def _incremental_lsp_check(self, code_file: str, partial_code: str) -> Optional[str]: + async def _incremental_lsp_check(self, code_file: str, + partial_code: str) -> Optional[str]: lsp_servers = self.shared_lsp_context.get('lsp_servers', {}) if not lsp_servers: return None @@ -357,8 +379,11 @@ async def _incremental_lsp_check(self, code_file: str, partial_code: str) -> Opt return await lsp_server.call_tool( 'lsp_code_server', tool_name='update_and_check', - tool_args={'file_path': code_file, 'content': partial_code, 'language': lang}, - ) + tool_args={ + 'file_path': code_file, + 'content': partial_code, + 'language': lang + }) def filter_code_files(self): code_files = [] @@ -378,20 +403,20 @@ def increment_unchecked_file(self): self.unchecked_files.pop(key) logger.error( f"Unchecked file {key} still have problem:\n{self.unchecked_issues.get('key')}\n" - f'But the checking limit has reached.' - ) + f'But the checking limit has reached.') async def after_tool_call(self, messages: List[Message]): - is_prepare = len(messages[-1].tool_calls or []) > 0 or messages[-1].role != 'assistant' - is_code_finish = '' in messages[-1].content and '' in messages[-1].content and not is_prepare + is_prepare = len(messages[-1].tool_calls + or []) > 0 or messages[-1].role != 'assistant' + is_code_finish = '' in messages[ + -1].content and '' in messages[ + -1].content and not is_prepare is_import = ( - self.is_stop_imports() - and not is_code_finish - and not is_prepare + self.is_stop_imports() and not is_code_finish and not is_prepare and '' in messages[-1].content - and '' not in messages[-1].content - ) - is_check = messages[-1].role == 'assistant' and len(messages[-1].tool_calls or []) == 0 and not is_import + and '' not in messages[-1].content) + is_check = messages[-1].role == 'assistant' and len( + messages[-1].tool_calls or []) == 0 and not is_import message = messages[-1] all_issues = [] @@ -399,6 +424,7 @@ async def after_tool_call(self, messages: List[Message]): self._before_import_check(messages) if is_code_finish: + # Saving code result, remaining_text = extract_code_blocks(message.content) if result: @@ -430,12 +456,15 @@ async def after_tool_call(self, messages: List[Message]): if is_check: # After checking when fix ended or write ended for uncheck_file in list(self.unchecked_files.keys()): - with open(os.path.join(self.output_dir, uncheck_file), 'r') as f: + with open(os.path.join(self.output_dir, uncheck_file), + 'r') as f: _code = f.read() - lsp_feedback = await self._incremental_check(uncheck_file, _code) + lsp_feedback = await self._incremental_check( + uncheck_file, _code) lsp_feedback = lsp_feedback.strip() if lsp_feedback: - all_issues.append(f'❎Issues in {uncheck_file}:' + lsp_feedback) + all_issues.append(f'❎Issues in {uncheck_file}:' + + lsp_feedback) self.unchecked_issues[uncheck_file] = lsp_feedback else: logger.info(f'✅No issues found in {uncheck_file}.') @@ -477,15 +506,16 @@ async def after_tool_call(self, messages: List[Message]): if self.error_counter > 2: raise RuntimeError('The model does not output any response!') - new_task = is_code_finish and self.code_files and (not self.unchecked_files) + new_task = is_code_finish and self.code_files and ( + not self.unchecked_files) if new_task: last_file = self.code_files[-1] messages.append( Message( role='user', - content=f'\nA code file in your imports not found, you should write it first: {last_file}\n', - ) - ) + content= + f'\nA code file in your imports not found, you should write it first: {last_file}\n' + )) # Condense code block and prepare index files # await self.code_condenser.run(messages) @@ -493,6 +523,7 @@ async def after_tool_call(self, messages: List[Message]): @dataclasses.dataclass class FileRelation: + name: str description: str done: bool = False @@ -500,6 +531,7 @@ class FileRelation: class CodingAgent(CodeAgent): + def __init__(self, config, tag, trust_remote_code, **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) # Shared LSP context across all Programmers @@ -517,23 +549,31 @@ async def _init_lsp_servers(self): # Detect all languages in the project detected_languages = set() - if any(kw in framework for kw in ['typescript', 'javascript', 'react', 'node', 'npm', 'html']): + if any(kw in framework for kw in + ['typescript', 'javascript', 'react', 'node', 'npm', 'html']): detected_languages.add('typescript') - if any(kw in framework for kw in ['python', 'django', 'flask', 'fastapi']): + if any(kw in framework + for kw in ['python', 'django', 'flask', 'fastapi']): detected_languages.add('python') - if any(kw in framework for kw in ['java ', 'java\n', 'spring', 'maven', 'gradle']): + if any(kw in framework + for kw in ['java ', 'java\n', 'spring', 'maven', 'gradle']): detected_languages.add('java') if not detected_languages: logger.info('No supported languages detected in framework.txt') return - logger.info(f"Initializing LSP servers for languages: {', '.join(detected_languages)}") + logger.info( + f"Initializing LSP servers for languages: {', '.join(detected_languages)}" + ) # Initialize LSP server for each detected language - lsp_config = DictConfig({'workspace_dir': self.output_dir, 'output_dir': self.output_dir}) + lsp_config = DictConfig({ + 'workspace_dir': self.output_dir, + 'output_dir': self.output_dir + }) lsp_servers = {} for lang in detected_languages: @@ -545,8 +585,12 @@ async def _init_lsp_servers(self): for lang, lsp_server in lsp_servers.items(): logger.info(f'Building LSP index for {lang}...') await lsp_server.call_tool( - 'lsp_code_server', tool_name='check_directory', tool_args={'directory': '', 'language': lang} - ) + 'lsp_code_server', + tool_name='check_directory', + tool_args={ + 'directory': '', + 'language': lang + }) logger.info(f'LSP index built for {lang}') self.shared_lsp_context['lsp_servers'] = lsp_servers @@ -563,20 +607,9 @@ async def _cleanup_lsp_servers(self): except Exception: # noqa pass - async def write_code( - self, - topic, - user_story, - framework, - protocol, - file_order, - name, - description, - index, - last_batch, - siblings, - next_batch, - ): + async def write_code(self, topic, user_story, framework, protocol, + file_order, name, description, index, last_batch, + siblings, next_batch): logger.info(f'Writing {name}') _config = deepcopy(self.config) messages = [ @@ -593,8 +626,7 @@ async def write_code( f'File description: {description}\n' f'Previous batch output:\n{last_batch}\n' f'Other workers writing in parallel:\n{siblings}\n' - f'Next batch planned:\n{next_batch}\n', - ), + f'Next batch planned:\n{next_batch}\n'), ] _config = deepcopy(self.config) @@ -605,8 +637,7 @@ async def write_code( tag=f'programmer-{name.replace(os.sep, "-")}', trust_remote_code=True, code_file=name, - shared_lsp_context=self.shared_lsp_context, - ) # Pass shared context + shared_lsp_context=self.shared_lsp_context) # Pass shared context await programmer.run(messages) async def execute_code(self, inputs, **kwargs): @@ -625,7 +656,8 @@ async def execute_code(self, inputs, **kwargs): file_orders = self.construct_file_orders() file_relation = OrderedDict() self.refresh_file_status(file_relation) - shutil.rmtree(os.path.join(self.output_dir, 'locks'), ignore_errors=True) + shutil.rmtree( + os.path.join(self.output_dir, 'locks'), ignore_errors=True) for idx, files in enumerate(file_orders): while True: @@ -657,8 +689,7 @@ async def execute_code(self, inputs, **kwargs): index=idx, last_batch=last_batch, siblings='\n'.join(set(files) - {name}), - next_batch=next_batch, - ) + next_batch=next_batch) for name, description in files.items() ] @@ -718,7 +749,8 @@ def refresh_file_status(self, file_relation): description = file['description'] file_path = os.path.join(self.output_dir, file_name) if file_name not in file_relation: - file_relation[file_name] = FileRelation(name=file_name, description=description) + file_relation[file_name] = FileRelation( + name=file_name, description=description) file_relation[file_name].done = os.path.exists(file_path) def construct_file_information(self, file_relation, add_output_dir=False): diff --git a/projects/code_genesis/workflow/file_design.py b/projects/code_genesis/workflow/file_design.py index d1dae4756..b33f977a5 100644 --- a/projects/code_genesis/workflow/file_design.py +++ b/projects/code_genesis/workflow/file_design.py @@ -1,12 +1,13 @@ -import json import os from typing import List +import json from ms_agent import LLMAgent from ms_agent.llm import Message class FileDesignAgent(LLMAgent): + async def run(self, messages, **kwargs): with open(os.path.join(self.output_dir, 'topic.txt'), 'r') as f: topic = f.read() @@ -31,11 +32,15 @@ async def after_tool_call(self, messages: List[Message]): if self.runtime.should_stop: query = None - if os.path.isfile(os.path.join(self.output_dir, 'file_design.txt')): - with open(os.path.join(self.output_dir, 'file_design.txt'), 'r') as f: + if os.path.isfile( + os.path.join(self.output_dir, 'file_design.txt')): + with open( + os.path.join(self.output_dir, 'file_design.txt'), + 'r') as f: file_design = json.load(f) - with open(os.path.join(self.output_dir, 'modules.txt'), 'r') as f: + with open(os.path.join(self.output_dir, 'modules.txt'), + 'r') as f: modules = f.readlines() files1 = set() @@ -58,7 +63,8 @@ async def after_tool_call(self, messages: List[Message]): f'please provide the correct file order without these files.' ) else: - query = 'The file design you provided is missing, please provide the complete file design.' + query = ('The file design you provided is missing, ' + 'please provide the complete file design.') if query: messages.append(Message(role='user', content=query)) diff --git a/projects/code_genesis/workflow/file_order.py b/projects/code_genesis/workflow/file_order.py index c380e926f..592a68025 100644 --- a/projects/code_genesis/workflow/file_order.py +++ b/projects/code_genesis/workflow/file_order.py @@ -1,12 +1,13 @@ -import json import os from typing import List +import json from ms_agent import LLMAgent from ms_agent.llm import Message class FileOrderAgent(LLMAgent): + async def run(self, messages, **kwargs): with open(os.path.join(self.output_dir, 'topic.txt'), 'r') as f: topic = f.read() @@ -32,10 +33,14 @@ async def after_tool_call(self, messages: List[Message]): query = None if os.path.isfile(os.path.join(self.output_dir, 'file_order.txt')): - with open(os.path.join(self.output_dir, 'file_order.txt'), 'r') as f: + with open( + os.path.join(self.output_dir, 'file_order.txt'), + 'r') as f: file_order = json.load(f) - with open(os.path.join(self.output_dir, 'file_design.txt'), 'r') as f: + with open( + os.path.join(self.output_dir, 'file_design.txt'), + 'r') as f: file_design = json.load(f) files1 = set() @@ -58,7 +63,8 @@ async def after_tool_call(self, messages: List[Message]): f'please provide the correct file order without these files.' ) else: - query = 'The file order you provided is missing, please provide the complete file order.' + query = ('The file order you provided is missing, ' + 'please provide the complete file order.') if query: messages.append(Message(role='user', content=query)) diff --git a/projects/code_genesis/workflow/install.py b/projects/code_genesis/workflow/install.py index 69ea4ee13..82f091147 100644 --- a/projects/code_genesis/workflow/install.py +++ b/projects/code_genesis/workflow/install.py @@ -5,6 +5,7 @@ class InstallAgent(LLMAgent): + async def run(self, messages, **kwargs): with open(os.path.join(self.output_dir, 'topic.txt'), 'r') as f: topic = f.read() @@ -18,8 +19,7 @@ async def run(self, messages, **kwargs): query = ( f'Topic: {topic}\nFramework: {framework}\nFile Design: {file_design}\n' f'Your `workflow_dir` is "./", ' - 'Please write dependency files and install dependencies.' - ) + 'Please write dependency files and install dependencies.') messages = [ Message(role='system', content=self.config.prompt.system), diff --git a/projects/code_genesis/workflow/refine.py b/projects/code_genesis/workflow/refine.py index 6247f5369..17c9ffba4 100644 --- a/projects/code_genesis/workflow/refine.py +++ b/projects/code_genesis/workflow/refine.py @@ -1,24 +1,26 @@ -import json import os import sys from typing import List, OrderedDict +import json from coding import CodingAgent -from omegaconf import DictConfig - from ms_agent import LLMAgent from ms_agent.llm import Message from ms_agent.memory.condenser.refine_condenser import RefineCondenser from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG +from omegaconf import DictConfig logger = get_logger() class RefineAgent(LLMAgent): - def __init__( - self, config: DictConfig = DictConfig({}), tag: str = DEFAULT_TAG, trust_remote_code: bool = False, **kwargs - ): + + def __init__(self, + config: DictConfig = DictConfig({}), + tag: str = DEFAULT_TAG, + trust_remote_code: bool = False, + **kwargs): # Validate and adjust config before passing to parent config = self._validate_config(config) super().__init__(config, tag, trust_remote_code, **kwargs) @@ -34,19 +36,22 @@ def _validate_config(self, config: DictConfig) -> DictConfig: # Check edit_file_config.api_key edit_file_api_key = None try: - edit_file_api_key = ( - config.get('tools', {}).get('file_system', {}).get('edit_file_config', {}).get('api_key') - ) + edit_file_api_key = config.get('tools', {}).get( + 'file_system', {}).get('edit_file_config', {}).get('api_key') except Exception: pass if not edit_file_api_key: # Remove edit_file from include list try: - include_list = config.get('tools', {}).get('file_system', {}).get('include', []) + include_list = config.get('tools', + {}).get('file_system', + {}).get('include', []) if 'edit_file' in include_list: include_list.remove('edit_file') - logger.warning('[refine] edit_file_config.api_key not set, removing edit_file from tools') + logger.warning( + '[refine] edit_file_config.api_key not set, removing edit_file from tools' + ) except Exception: pass else: @@ -55,9 +60,9 @@ def _validate_config(self, config: DictConfig) -> DictConfig: # Check EDGEONE_PAGES_API_TOKEN edgeone_token = None try: - edgeone_token = ( - config.get('tools', {}).get('edgeone-pages-mcp', {}).get('env', {}).get('EDGEONE_PAGES_API_TOKEN') - ) + edgeone_token = config.get('tools', {}).get( + 'edgeone-pages-mcp', {}).get('env', + {}).get('EDGEONE_PAGES_API_TOKEN') except Exception: pass @@ -66,11 +71,15 @@ def _validate_config(self, config: DictConfig) -> DictConfig: try: if 'edgeone-pages-mcp' in config.get('tools', {}): del config['tools']['edgeone-pages-mcp'] - logger.warning('[refine] EDGEONE_PAGES_API_TOKEN not set, removing edgeone-pages-mcp from tools') + logger.warning( + '[refine] EDGEONE_PAGES_API_TOKEN not set, removing edgeone-pages-mcp from tools' + ) except Exception: pass else: - logger.info(f'[refine] EDGEONE_PAGES_API_TOKEN is configured: {edgeone_token[:10]}...') + logger.info( + f'[refine] EDGEONE_PAGES_API_TOKEN is configured: {edgeone_token[:10]}...' + ) return OmegaConf.create(config) @@ -102,8 +111,7 @@ async def run(self, messages, **kwargs): f'Project files are at the current working directory (/data). ' f'All relative paths work directly.\n' f'When creating the deployment zip file, name it workspace.zip.\n' - f'Please refine the project and deploy it to EdgeOne Pages:', - ), + f'Please refine the project and deploy it to EdgeOne Pages:'), ] return await super().run(messages, **kwargs) @@ -113,7 +121,9 @@ async def after_tool_call(self, messages: List[Message]): if self.runtime.should_stop: if not sys.stdin.isatty(): # Running in WebUI - notify user that agent is waiting for input - logger.info('[refine] Agent completed initial refinement. Waiting for user feedback.') + logger.info( + '[refine] Agent completed initial refinement. Waiting for user feedback.' + ) # # Add a system message to notify the user # messages.append( @@ -127,12 +137,15 @@ async def after_tool_call(self, messages: List[Message]): try: query = sys.stdin.readline().strip() if query: - logger.info(f'[refine] Received input from WebUI: {query}') + logger.info( + f'[refine] Received input from WebUI: {query}') messages.append(Message(role='user', content=query)) self.runtime.should_stop = False return else: - logger.warning('[refine] Received empty input, continuing to wait...') + logger.warning( + '[refine] Received empty input, continuing to wait...' + ) return except (EOFError, OSError, ValueError) as e: logger.error(f'[refine] Error reading from stdin: {e}') diff --git a/projects/code_genesis/workflow/user_story.py b/projects/code_genesis/workflow/user_story.py index f92ee0380..28c62f984 100644 --- a/projects/code_genesis/workflow/user_story.py +++ b/projects/code_genesis/workflow/user_story.py @@ -6,6 +6,7 @@ class SplitModuleAgent(LLMAgent): + async def on_task_end(self, messages: List[Message]): assert os.path.isfile(os.path.join(self.output_dir, 'user_story.txt')) topic = '' diff --git a/projects/deep_research/run.py b/projects/deep_research/run.py index 2cf8ef18c..d2094fdc2 100644 --- a/projects/deep_research/run.py +++ b/projects/deep_research/run.py @@ -6,17 +6,16 @@ from ms_agent.tools.search_engine import get_web_search_tool from ms_agent.workflow.deep_research.principle import MECEPrinciple from ms_agent.workflow.deep_research.research_workflow import ResearchWorkflow -from ms_agent.workflow.deep_research.research_workflow_beta import ResearchWorkflowBeta +from ms_agent.workflow.deep_research.research_workflow_beta import \ + ResearchWorkflowBeta -def run_workflow( - user_prompt: str, - task_dir: str, - chat_client: OpenAIChat, - search_engine: SearchEngine, - reuse: bool, - use_ray: bool = False, -): +def run_workflow(user_prompt: str, + task_dir: str, + chat_client: OpenAIChat, + search_engine: SearchEngine, + reuse: bool, + use_ray: bool = False): """ Run the deep research workflow, which follows a lightweight and efficient pipeline: 1. Receive a user prompt and generate search queries. @@ -44,17 +43,15 @@ def run_workflow( research_workflow.run(user_prompt=user_prompt) -def run_deep_workflow( - user_prompt: str, - task_dir: str, - chat_client: OpenAIChat, - search_engine: SearchEngine, - breadth: int = 4, - depth: int = 2, - is_report: bool = True, - show_progress: bool = True, - use_ray: bool = False, -): +def run_deep_workflow(user_prompt: str, + task_dir: str, + chat_client: OpenAIChat, + search_engine: SearchEngine, + breadth: int = 4, + depth: int = 2, + is_report: bool = True, + show_progress: bool = True, + use_ray: bool = False): """ Run the expandable deep research workflow (beta version). This version is more flexible and scalable than the original deep research workflow. @@ -81,17 +78,23 @@ def run_deep_workflow( """ research_workflow = ResearchWorkflowBeta( - client=chat_client, search_engine=search_engine, workdir=task_dir, use_ray=use_ray, enable_multimodal=True - ) + client=chat_client, + search_engine=search_engine, + workdir=task_dir, + use_ray=use_ray, + enable_multimodal=True) asyncio.run( research_workflow.run( - user_prompt=user_prompt, breadth=breadth, depth=depth, is_report=is_report, show_progress=show_progress - ) - ) + user_prompt=user_prompt, + breadth=breadth, + depth=depth, + is_report=is_report, + show_progress=show_progress)) if __name__ == '__main__': + query: str = 'Survey of the AI Agent within the recent 3 month, including the latest research papers, open-source projects, and industry applications.' # noqa task_workdir: str = '/path/to/your_workdir' # Specify your task work directory here reuse: bool = False @@ -107,8 +110,9 @@ def run_deep_workflow( api_key='xxx-xxx', base_url='https://api-inference.modelscope.cn/v1/', model='Qwen/Qwen3-235B-A22B-Instruct-2507', - generation_config={'extra_body': {'enable_thinking': False}}, - ) + generation_config={'extra_body': { + 'enable_thinking': False + }}) # Get web-search engine client # For the ExaSearch, you can get your API key from https://exa.ai diff --git a/projects/deep_research/v2/callbacks/quality_checker.py b/projects/deep_research/v2/callbacks/quality_checker.py index a406face4..36fadf902 100644 --- a/projects/deep_research/v2/callbacks/quality_checker.py +++ b/projects/deep_research/v2/callbacks/quality_checker.py @@ -1,13 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import json from abc import ABC, abstractmethod from typing import List, Optional -from omegaconf import DictConfig, OmegaConf - +import json from ms_agent.llm.openai_llm import OpenAI as OpenAILLM from ms_agent.llm.utils import Message from ms_agent.utils import get_logger +from omegaconf import DictConfig, OmegaConf logger = get_logger() @@ -46,49 +45,47 @@ class ModelQualityChecker(ReportQualityChecker): """ _SYSTEM_PROMPTS = { - 'en': ( - 'You are a strict report quality auditor. Your ONLY job is to detect whether a research report violates any of the rules listed below.\n' - 'You MUST check ONLY against these rules — do NOT invent additional criteria or penalize anything not explicitly listed here.\n' - 'If a problem is NOT described by rules below, you MUST ignore it and return {"pass": true}. ' - 'Specifically: duplicate/repeated content, heading numbering gaps, structural ordering issues, stylistic choices, ' - 'and the density of inline citations within otherwise substantive paragraphs are all OUT OF SCOPE and must NOT cause a failure.\n\n' - 'RULES — flag the report ONLY if ANY of the following are clearly found:\n' - '1. Sections where detailed content has been replaced by ellipsis or brevity markers such as "...for brevity", ' - '"Content truncated for brevity", "omitted for brevity", "(remaining content follows the same pattern)", etc.\n' - '2. Sections that refer the reader to an external file instead of containing actual content, e.g. "This section ' - 'is stored in xxx file", "See full analysis in evidence/xxx".\n' - '3. Sections that guide the reader to view the reference source instead of writing substantive content, e.g. "See [1]", "Reference [2]".\n' - '4. Multiple reference/bibliography sections appear in the report (e.g., per-chapter reference lists), or any ' - 'variant heading such as "## References (Merged)", "## 参考文献(合并版)", "## 参考资料", etc. ' - 'Only one unified reference section at the very end is allowed.\n\n' - 'OUTPUT FORMAT:\n' - 'Respond with EXACTLY one JSON object. No markdown fences, no explanation outside the JSON.\n' - '{"pass": true} or {"pass": false, "reason": ""}\n' - 'Do NOT output anything else.' - ), - 'zh': ( - '你是一个严格的研究报告质量审核员,你唯一的任务是判断报告是否违反了下方列出的规则。\n' - '你只能依据以下规则进行检查,不得自行发明额外标准,也不得基于规则未涉及的内容判定不通过。如果某个问题不属于下方规则的任何一条,你必须忽略它并返回 {"pass": true}。\n' - '特别说明:重复/相似内容、标题编号跳跃、章节结构顺序问题、文体风格选择、以及在有实质论述的段落中密集使用行内引注,都不在检查范围内,不得因此判定不通过。\n\n' - '规则 — 仅当明确发现以下任一问题时才判定不通过:\n' - '1. 正文被省略号或缩略标记替代,如"此处省略"、"篇幅所限不再展开"、"……以下类似"、"内容已截断"、"...for brevity"、"omitted for brevity"等。\n' - '2. 正文引导读者查看外部文件而非包含实际内容,如"该部分内容保存在xxx文件中"、"详见附件"、"See full analysis in evidence/xxx"。\n' - '3. 正文引导读者查看引用来源而没有撰写实质性内容,如"详见[1]"、"参考[2]"。\n' - '4. 报告中出现多个参考文献/引用列表章节(如各章节末尾的独立引用列表),或使用变体标题如"## 参考文献(合并版)"、"## 参考资料"、"## References (Merged)"等。' - '报告仅允许在末尾保留唯一一个统一的参考文献章节。\n\n' - '输出格式:\n' - '只返回一个JSON对象,不要使用markdown代码块,不要在JSON之外输出任何文字。\n' - '{"pass": true} 或者 {"reason": "<不得超过三句话;引用具体违反的规则编号>", "pass": false}\n' - '不要输出任何其他内容。' - ), + 'en': + ('You are a strict report quality auditor. Your ONLY job is to detect whether a research report violates any of the rules listed below.\n' + 'You MUST check ONLY against these rules — do NOT invent additional criteria or penalize anything not explicitly listed here.\n' + 'If a problem is NOT described by rules below, you MUST ignore it and return {"pass": true}. ' + 'Specifically: duplicate/repeated content, heading numbering gaps, structural ordering issues, stylistic choices, ' + 'and the density of inline citations within otherwise substantive paragraphs are all OUT OF SCOPE and must NOT cause a failure.\n\n' + 'RULES — flag the report ONLY if ANY of the following are clearly found:\n' + '1. Sections where detailed content has been replaced by ellipsis or brevity markers such as "...for brevity", ' + '"Content truncated for brevity", "omitted for brevity", "(remaining content follows the same pattern)", etc.\n' + '2. Sections that refer the reader to an external file instead of containing actual content, e.g. "This section ' + 'is stored in xxx file", "See full analysis in evidence/xxx".\n' + '3. Sections that guide the reader to view the reference source instead of writing substantive content, e.g. "See [1]", "Reference [2]".\n' + '4. Multiple reference/bibliography sections appear in the report (e.g., per-chapter reference lists), or any ' + 'variant heading such as "## References (Merged)", "## 参考文献(合并版)", "## 参考资料", etc. ' + 'Only one unified reference section at the very end is allowed.\n\n' + 'OUTPUT FORMAT:\n' + 'Respond with EXACTLY one JSON object. No markdown fences, no explanation outside the JSON.\n' + '{"pass": true} or {"pass": false, "reason": ""}\n' + 'Do NOT output anything else.'), + 'zh': + ('你是一个严格的研究报告质量审核员,你唯一的任务是判断报告是否违反了下方列出的规则。\n' + '你只能依据以下规则进行检查,不得自行发明额外标准,也不得基于规则未涉及的内容判定不通过。如果某个问题不属于下方规则的任何一条,你必须忽略它并返回 {"pass": true}。\n' + '特别说明:重复/相似内容、标题编号跳跃、章节结构顺序问题、文体风格选择、以及在有实质论述的段落中密集使用行内引注,都不在检查范围内,不得因此判定不通过。\n\n' + '规则 — 仅当明确发现以下任一问题时才判定不通过:\n' + '1. 正文被省略号或缩略标记替代,如"此处省略"、"篇幅所限不再展开"、"……以下类似"、"内容已截断"、"...for brevity"、"omitted for brevity"等。\n' + '2. 正文引导读者查看外部文件而非包含实际内容,如"该部分内容保存在xxx文件中"、"详见附件"、"See full analysis in evidence/xxx"。\n' + '3. 正文引导读者查看引用来源而没有撰写实质性内容,如"详见[1]"、"参考[2]"。\n' + '4. 报告中出现多个参考文献/引用列表章节(如各章节末尾的独立引用列表),或使用变体标题如"## 参考文献(合并版)"、"## 参考资料"、"## References (Merged)"等。' + '报告仅允许在末尾保留唯一一个统一的参考文献章节。\n\n' + '输出格式:\n' + '只返回一个JSON对象,不要使用markdown代码块,不要在JSON之外输出任何文字。\n' + '{"pass": true} 或者 {"reason": "<不得超过三句话;引用具体违反的规则编号>", "pass": false}\n' + '不要输出任何其他内容。'), } _USER_TEMPLATES = { - 'en': ( - 'Please audit the following research report against the rules provided in the system instruction.\n\n' - '---BEGIN REPORT---\n{report}\n---END REPORT---' - ), - 'zh': ('请依据系统指令中提供的规则审核以下研究报告。\n\n---报告开始---\n{report}\n---报告结束---'), + 'en': + ('Please audit the following research report against the rules provided in the system instruction.\n\n' + '---BEGIN REPORT---\n{report}\n---END REPORT---'), + 'zh': ('请依据系统指令中提供的规则审核以下研究报告。\n\n' + '---报告开始---\n{report}\n---报告结束---'), } _MAX_REPORT_CHARS = 80000 @@ -99,27 +96,25 @@ def __init__(self, config: DictConfig): qc_cfg = getattr(qc_cfg, 'quality_check', DictConfig({})) self._model: str = str(getattr(qc_cfg, 'model', 'qwen3.5-plus')) - self._api_key: Optional[str] = getattr(qc_cfg, 'openai_api_key', None) or getattr( - config.llm, 'openai_api_key', None - ) - self._base_url: Optional[str] = getattr(qc_cfg, 'openai_base_url', None) or getattr( - config.llm, 'openai_base_url', None - ) + self._api_key: Optional[str] = getattr( + qc_cfg, 'openai_api_key', None) or getattr(config.llm, + 'openai_api_key', None) + self._base_url: Optional[str] = getattr( + qc_cfg, 'openai_base_url', None) or getattr( + config.llm, 'openai_base_url', None) self._client: Optional[OpenAILLM] = None def _build_llm_config(self) -> DictConfig: """Build lightweight llm config for quality checker.""" - return OmegaConf.create( - { - 'llm': { - 'model': self._model, - 'openai_api_key': self._api_key, - 'openai_base_url': self._base_url, - }, - 'generation_config': {}, - } - ) + return OmegaConf.create({ + 'llm': { + 'model': self._model, + 'openai_api_key': self._api_key, + 'openai_base_url': self._base_url, + }, + 'generation_config': {}, + }) def _ensure_client(self): if self._client is not None: @@ -131,23 +126,23 @@ async def check(self, content: str, lang: str) -> Optional[str]: report_text = content if len(report_text) > self._MAX_REPORT_CHARS: - report_text = report_text[: self._MAX_REPORT_CHARS] + report_text = report_text[:self._MAX_REPORT_CHARS] sys_prompt = self._SYSTEM_PROMPTS.get(lang, self._SYSTEM_PROMPTS['en']) - usr_template = self._USER_TEMPLATES.get(lang, self._USER_TEMPLATES['en']) + usr_template = self._USER_TEMPLATES.get(lang, + self._USER_TEMPLATES['en']) try: - response = self._client.generate( - messages=[ - Message(role='system', content=sys_prompt), - Message( - role='user', - content=usr_template.format(report=report_text), - ), - ] - ) + response = self._client.generate(messages=[ + Message(role='system', content=sys_prompt), + Message( + role='user', + content=usr_template.format(report=report_text), + ), + ]) raw = (response.content or '').strip() - logger.info(f'ModelQualityChecker ({self._model}): raw response: {raw}') + logger.info( + f'ModelQualityChecker ({self._model}): raw response: {raw}') verdict = json.loads(raw) if verdict.get('pass', True): @@ -155,7 +150,8 @@ async def check(self, content: str, lang: str) -> Optional[str]: return verdict.get('reason', 'placeholder_content') except json.JSONDecodeError: - logger.warning(f'ModelQualityChecker: failed to parse JSON from model response: {raw!r}') + logger.warning(f'ModelQualityChecker: failed to parse JSON from ' + f'model response: {raw!r}') return None except Exception as exc: logger.warning(f'ModelQualityChecker: model call failed: {exc}') @@ -179,5 +175,6 @@ def build_quality_checkers(config: DictConfig) -> List[ReportQualityChecker]: checkers: List[ReportQualityChecker] = [] checkers.append(ModelQualityChecker(config)) - logger.info(f'Quality checker chain initialised with {len(checkers)} checker(s).') + logger.info( + f'Quality checker chain initialised with {len(checkers)} checker(s).') return checkers diff --git a/projects/deep_research/v2/callbacks/reporter_callback.py b/projects/deep_research/v2/callbacks/reporter_callback.py index a3f4bac8f..477623a74 100644 --- a/projects/deep_research/v2/callbacks/reporter_callback.py +++ b/projects/deep_research/v2/callbacks/reporter_callback.py @@ -1,19 +1,19 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # yapf: disable -import json import os import re import shutil from typing import Any, Dict, List, Optional, Set -from omegaconf import DictConfig - -from callbacks.quality_checker import ReportQualityChecker, build_quality_checkers +import json +from callbacks.quality_checker import (ReportQualityChecker, + build_quality_checkers) from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_MEMORY_DIR +from omegaconf import DictConfig logger = get_logger() @@ -446,7 +446,7 @@ def _format_trajectory(self, messages: List[Dict[str, Any]]) -> str: lines.append('') elif role == 'tool': - lines.append(f'{labels['tool_result']} ({tool_name})') + lines.append(f'{labels["tool_result"]} ({tool_name})') # Truncate very long tool results if content and len(content) > 20000: content = content[:20000] + '\n...(truncated)' @@ -485,7 +485,7 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): labels = self._TRAJECTORY_LABELS.get( self.lang, self._TRAJECTORY_LABELS['en']) - trajectory_str = (f'{labels['trajectory_intro']}\n\n' + trajectory_str = (f'{labels["trajectory_intro"]}\n\n' f'{trajectory_text}') if messages[insert_pos].role == 'user': diff --git a/projects/deep_research/v2/callbacks/researcher_callback.py b/projects/deep_research/v2/callbacks/researcher_callback.py index 2569638fb..8306a6ba7 100644 --- a/projects/deep_research/v2/callbacks/researcher_callback.py +++ b/projects/deep_research/v2/callbacks/researcher_callback.py @@ -5,14 +5,14 @@ import shutil from typing import List, Optional -from omegaconf import DictConfig, OmegaConf - -from callbacks.quality_checker import ReportQualityChecker, build_quality_checkers +from callbacks.quality_checker import (ReportQualityChecker, + build_quality_checkers) from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.openai_llm import OpenAI as OpenAILLM from ms_agent.llm.utils import Message from ms_agent.utils import get_logger +from omegaconf import DictConfig, OmegaConf logger = get_logger() diff --git a/projects/deep_research/v2/callbacks/searcher_callback.py b/projects/deep_research/v2/callbacks/searcher_callback.py index 093f7e0a7..735a2d47a 100644 --- a/projects/deep_research/v2/callbacks/searcher_callback.py +++ b/projects/deep_research/v2/callbacks/searcher_callback.py @@ -1,16 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import json import os import re import uuid from typing import Any, List, Optional -from omegaconf import DictConfig - +import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() @@ -62,25 +61,26 @@ class SearcherCallback(Callback): # Bilingual round-reminder templates keyed by language code. _ROUND_REMINDER_TEMPLATES = { - 'zh': ( - '你已接近最大允许的对话轮数上限,请立刻开始收敛准备最终交付。\n' - '- 从现在开始:优先总结已有证据与进度、补齐关键缺口、减少发散探索。\n' - '- 在接下来的极少数轮次内,立刻准备并输出最终的 JSON 回复。\n' - '- 当前轮次信息:round=,max_chat_round=,剩余≈ 轮。' - ), - 'en': ( - 'You are approaching the maximum allowed conversation round limit. Begin converging immediately and prepare the final delivery.\n' - '- From now on: Prioritize summarizing existing evidence and progress, fill critical gaps, and reduce exploratory divergence.\n' - '- Within the very few remaining rounds, immediately prepare and output the final JSON response.\n' - '- Current round info: round=, max_chat_round=, remaining ≈ rounds.' - ), + 'zh': + ('你已接近最大允许的对话轮数上限,请立刻开始收敛准备最终交付。\n' + '- 从现在开始:优先总结已有证据与进度、补齐关键缺口、减少发散探索。\n' + '- 在接下来的极少数轮次内,立刻准备并输出最终的 JSON 回复。\n' + '- 当前轮次信息:round=,max_chat_round=,剩余≈ 轮。' + ), + 'en': + ('You are approaching the maximum allowed conversation round limit. Begin converging immediately and prepare the final delivery.\n' + '- From now on: Prioritize summarizing existing evidence and progress, fill critical gaps, and reduce exploratory divergence.\n' + '- Within the very few remaining rounds, immediately prepare and output the final JSON response.\n' + '- Current round info: round=, max_chat_round=, remaining ≈ rounds.' + ), } def __init__(self, config: DictConfig): super().__init__(config) self.output_dir = getattr(config, 'output_dir', './output') self.search_task_id: Optional[str] = None - self.search_result_path = os.path.join(self.output_dir, f'search_result_{uuid.uuid4().hex[:4]}.json') + self.search_result_path = os.path.join( + self.output_dir, f'search_result_{uuid.uuid4().hex[:4]}.json') # Resolve language from config for bilingual prompt selection. self.lang = self._resolve_lang(config) self._ensure_output_dir() @@ -103,7 +103,8 @@ def _ensure_output_dir(self) -> None: try: os.makedirs(self.output_dir, exist_ok=True) except Exception as e: - logger.warning(f'Failed to create output_dir {self.output_dir!r}: {e}') + logger.warning( + f'Failed to create output_dir {self.output_dir!r}: {e}') @staticmethod def _sanitize_task_id(task_id: Any, max_len: int = 10) -> Optional[str]: @@ -137,21 +138,27 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): if not isinstance(message.content, str): continue search_task_description = json.loads(message.content) - raw_task_id = search_task_description.get('task_id') or search_task_description.get('任务ID') + raw_task_id = search_task_description.get( + 'task_id') or search_task_description.get('任务ID') safe_task_id = self._sanitize_task_id(raw_task_id) self.search_task_id = safe_task_id if safe_task_id: - self.search_result_path = os.path.join(self.output_dir, f'search_result_{safe_task_id}.json') + self.search_result_path = os.path.join( + self.output_dir, + f'search_result_{safe_task_id}.json') except json.JSONDecodeError: - logger.warning(f'Failed to parse search task description: {message.content}') + logger.warning( + f'Failed to parse search task description: {message.content}' + ) continue except Exception as e: logger.warning( - f'Unexpected error when parsing search task description: {message.content}, with error: {e}' - ) + f'Unexpected error when parsing search task description: {message.content}, ' + f'with error: {e}') continue - async def on_generate_response(self, runtime: Runtime, messages: List[Message]): + async def on_generate_response(self, runtime: Runtime, + messages: List[Message]): """ Inject a round-aware reminder into the system prompt near max rounds. @@ -181,8 +188,10 @@ async def on_generate_response(self, runtime: Runtime, messages: List[Message]): custom_message = None if round_reminder_cfg is not None: enabled = bool(getattr(round_reminder_cfg, 'enabled', False)) - remind_before = getattr(round_reminder_cfg, 'remind_before_max_round', remind_before) - remind_at_round = getattr(round_reminder_cfg, 'remind_at_round', None) + remind_before = getattr(round_reminder_cfg, + 'remind_before_max_round', remind_before) + remind_at_round = getattr(round_reminder_cfg, 'remind_at_round', + None) custom_message = getattr(round_reminder_cfg, 'message', None) if not enabled: @@ -208,18 +217,21 @@ async def on_generate_response(self, runtime: Runtime, messages: List[Message]): reminder_mark = '\n[ROUND_REMINDER]\n' # Avoid injecting duplicates (e.g. if resumed from history at the same round). for m in reversed(messages[-10:]): - if m.role == 'user' and isinstance(m.content, str) and '[ROUND_REMINDER]' in m.content: + if m.role == 'user' and isinstance( + m.content, str) and '[ROUND_REMINDER]' in m.content: return remaining = max_chat_round - runtime.round if not custom_message or not isinstance(custom_message, str): - custom_message = self._ROUND_REMINDER_TEMPLATES.get(self.lang, self._ROUND_REMINDER_TEMPLATES['en']) + custom_message = self._ROUND_REMINDER_TEMPLATES.get( + self.lang, self._ROUND_REMINDER_TEMPLATES['en']) injected = custom_message injected = injected.replace('', str(runtime.round)) injected = injected.replace('', str(max_chat_round)) injected = injected.replace('', str(remaining)) - messages.append(Message(role='user', content=reminder_mark + injected + '\n')) + messages.append( + Message(role='user', content=reminder_mark + injected + '\n')) async def on_task_end(self, runtime: Runtime, messages: List[Message]): """ @@ -228,9 +240,12 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): """ self._ensure_output_dir() json_path = self.search_result_path - md_path = (json_path[:-5] + '.md') if json_path.endswith('.json') else (json_path.split('.')[0] + '.md') + md_path = (json_path[:-5] + + '.md') if json_path.endswith('.json') else ( + json_path.split('.')[0] + '.md') if os.path.exists(json_path) or os.path.exists(md_path): - logger.info(f'Search result already exists at {json_path} or {md_path}') + logger.info( + f'Search result already exists at {json_path} or {md_path}') return # Find the last assistant message without tool calls @@ -250,28 +265,39 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): except (json.JSONDecodeError, TypeError): parsed_json = _parse_search_result_json(content) if parsed_json is not None: - logger.info('Searcher: parsed JSON from fenced or embedded payload') + logger.info( + 'Searcher: parsed JSON from fenced or embedded payload' + ) else: parsed_json = _parse_search_result_json(str(content)) if parsed_json is not None: try: with open(json_path, 'x', encoding='utf-8') as f: - json.dump(parsed_json, f, ensure_ascii=False, indent=2) - logger.info(f'Searcher: Search result saved to {json_path}') + json.dump( + parsed_json, f, ensure_ascii=False, indent=2) + logger.info( + f'Searcher: Search result saved to {json_path}') except FileExistsError: - logger.info(f'Search result already exists at {json_path}') + logger.info( + f'Search result already exists at {json_path}') else: - logger.warning('Failed to parse search result as JSON, saving as markdown') - text = content if isinstance(content, str) else str(content) + logger.warning( + 'Failed to parse search result as JSON, saving as markdown' + ) + text = content if isinstance(content, + str) else str(content) try: with open(md_path, 'x', encoding='utf-8') as f: f.write(text) - logger.info(f'Searcher: Search result saved to {md_path}') + logger.info( + f'Searcher: Search result saved to {md_path}') except FileExistsError: - logger.info(f'Search result already exists at {md_path}') + logger.info( + f'Search result already exists at {md_path}') except Exception as e: - logger.warning(f'Unexpected error when saving search result: {e}') + logger.warning( + f'Unexpected error when saving search result: {e}') return logger.warning('Searcher: No final search result found in messages') diff --git a/projects/deep_research/v2/eval/dr_bench_runner.py b/projects/deep_research/v2/eval/dr_bench_runner.py index ce3978c6b..1917564bf 100644 --- a/projects/deep_research/v2/eval/dr_bench_runner.py +++ b/projects/deep_research/v2/eval/dr_bench_runner.py @@ -17,9 +17,7 @@ """ from __future__ import annotations - import argparse -import json import os import subprocess import sys @@ -30,6 +28,8 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple +import json + try: # Auto-load environment variables from a nearby `.env` (if present). from dotenv import find_dotenv, load_dotenv @@ -55,7 +55,10 @@ def _read_jsonl(path: str) -> List[Dict]: return items -def _append_jsonl(path: str, obj: Dict, *, lock: Optional[threading.Lock] = None) -> None: +def _append_jsonl(path: str, + obj: Dict, + *, + lock: Optional[threading.Lock] = None) -> None: os.makedirs(os.path.dirname(path) or '.', exist_ok=True) if lock is None: with open(path, 'a', encoding='utf-8') as f: @@ -264,7 +267,8 @@ def _report_is_stable( stable_since = now_s return False, sig, stable_since - return (now_s - stable_since) >= max(0.0, stable_window_s), sig, stable_since + return (now_s - stable_since) >= max(0.0, + stable_window_s), sig, stable_since def _run_one_task( @@ -320,12 +324,18 @@ def _run_one_task( # (e.g. process hung at shutdown). Force-reap to unblock # the batch runner. # - post_finish_grace_s = float(os.getenv('DR_BENCH_POST_FINISH_GRACE_S', '180') or 180.0) - post_report_exit_grace_s = float(os.getenv('DR_BENCH_POST_REPORT_EXIT_GRACE_S', '3600') or 3600.0) - report_stable_window_s = float(os.getenv('DR_BENCH_REPORT_STABLE_WINDOW_S', '2') or 2.0) - poll_interval_s = float(os.getenv('DR_BENCH_SUBPROCESS_POLL_INTERVAL_S', '0.5') or 0.5) - terminate_timeout_s = float(os.getenv('DR_BENCH_SUBPROCESS_TERMINATE_TIMEOUT_S', '5') or 5.0) - kill_timeout_s = float(os.getenv('DR_BENCH_SUBPROCESS_KILL_TIMEOUT_S', '2') or 2.0) + post_finish_grace_s = float( + os.getenv('DR_BENCH_POST_FINISH_GRACE_S', '180') or 180.0) + post_report_exit_grace_s = float( + os.getenv('DR_BENCH_POST_REPORT_EXIT_GRACE_S', '3600') or 3600.0) + report_stable_window_s = float( + os.getenv('DR_BENCH_REPORT_STABLE_WINDOW_S', '2') or 2.0) + poll_interval_s = float( + os.getenv('DR_BENCH_SUBPROCESS_POLL_INTERVAL_S', '0.5') or 0.5) + terminate_timeout_s = float( + os.getenv('DR_BENCH_SUBPROCESS_TERMINATE_TIMEOUT_S', '5') or 5.0) + kill_timeout_s = float( + os.getenv('DR_BENCH_SUBPROCESS_KILL_TIMEOUT_S', '2') or 2.0) report_seen_stable_at: Optional[float] = None report_last_sig: Optional[Tuple[float, int]] = None @@ -354,11 +364,9 @@ def _run_one_task( # --- Condition 1: .researcher_task_finished marker --- if marker_seen_at is None and os.path.exists(marker_path): marker_seen_at = now_s - if ( - marker_seen_at is not None - and proc.poll() is None - and (now_s - marker_seen_at) >= max(0.0, post_finish_grace_s) - ): + if (marker_seen_at is not None and proc.poll() is None + and (now_s - marker_seen_at) >= max( + 0.0, post_finish_grace_s)): _terminate_process( proc, terminate_timeout_s=terminate_timeout_s, @@ -369,7 +377,8 @@ def _run_one_task( # --- Condition 2: report stable for a long time (fallback) --- report_path_hint = _find_report_md(workdir) - if report_path_hint and _is_direct_final_report_path(workdir, report_path_hint): + if report_path_hint and _is_direct_final_report_path( + workdir, report_path_hint): stable, report_last_sig, report_stable_since = _report_is_stable( report_path_hint, stable_window_s=report_stable_window_s, @@ -382,11 +391,10 @@ def _run_one_task( report_seen_stable_at = now_s else: report_seen_stable_at = None - if ( - report_seen_stable_at is not None - and proc.poll() is None - and (now_s - report_seen_stable_at) >= max(0.0, post_report_exit_grace_s) - ): + if (report_seen_stable_at is not None + and proc.poll() is None + and (now_s - report_seen_stable_at) >= max( + 0.0, post_report_exit_grace_s)): _terminate_process( proc, terminate_timeout_s=terminate_timeout_s, @@ -398,7 +406,8 @@ def _run_one_task( # Drain available stdout without blocking. if select is not None: try: - r, _, _ = select.select([proc.stdout], [], [], poll_interval_s) + r, _, _ = select.select([proc.stdout], [], [], + poll_interval_s) except Exception: r = [] if r: @@ -413,7 +422,8 @@ def _run_one_task( print(f'[{task.task_id}] {line}', end='') else: with print_lock: - print(f'[{task.task_id}] {line}', end='') + print( + f'[{task.task_id}] {line}', end='') continue else: # No select available; degrade to polling only. @@ -439,11 +449,9 @@ def _run_one_task( returncode = 0 if returncode != 0: tail = ''.join(tail_lines)[-20000:] - return ( - task.task_id, - None, - (f'ms-agent exited with code={returncode}. log={log_path}. output tail:\n{tail}'), - ) + return task.task_id, None, ( + f'ms-agent exited with code={returncode}. ' + f'log={log_path}. output tail:\n{tail}') else: with open(log_path, 'w', encoding='utf-8') as logf: # Use Popen+poll so we can force-reap hung-at-exit children once @@ -462,11 +470,9 @@ def _run_one_task( # --- Condition 1: .researcher_task_finished marker --- if marker_seen_at is None and os.path.exists(marker_path): marker_seen_at = now_s - if ( - marker_seen_at is not None - and proc2.poll() is None - and (now_s - marker_seen_at) >= max(0.0, post_finish_grace_s) - ): + if (marker_seen_at is not None and proc2.poll() is None + and (now_s - marker_seen_at) >= max( + 0.0, post_finish_grace_s)): _terminate_process( proc2, terminate_timeout_s=terminate_timeout_s, @@ -477,7 +483,8 @@ def _run_one_task( # --- Condition 2: report stable for a long time (fallback) --- report_path_hint = _find_report_md(workdir) - if report_path_hint and _is_direct_final_report_path(workdir, report_path_hint): + if report_path_hint and _is_direct_final_report_path( + workdir, report_path_hint): stable, report_last_sig, report_stable_since = _report_is_stable( report_path_hint, stable_window_s=report_stable_window_s, @@ -490,11 +497,10 @@ def _run_one_task( report_seen_stable_at = now_s else: report_seen_stable_at = None - if ( - report_seen_stable_at is not None - and proc2.poll() is None - and (now_s - report_seen_stable_at) >= max(0.0, post_report_exit_grace_s) - ): + if (report_seen_stable_at is not None + and proc2.poll() is None + and (now_s - report_seen_stable_at) >= max( + 0.0, post_report_exit_grace_s)): _terminate_process( proc2, terminate_timeout_s=terminate_timeout_s, @@ -512,23 +518,17 @@ def _run_one_task( returncode = 0 if returncode != 0: tail = _tail_text_from_file(log_path, max_chars=20000) - return ( - task.task_id, - None, - (f'ms-agent exited with code={returncode}. log={log_path}. output tail:\n{tail}'), - ) + return task.task_id, None, ( + f'ms-agent exited with code={returncode}. ' + f'log={log_path}. output tail:\n{tail}') except Exception as e: return task.task_id, None, f'subprocess failed: {e}' report_path = _find_report_md(workdir) if not report_path: - return ( - task.task_id, - None, - ( - f'final_report.md not found in workdir={workdir}. ' - f'log={log_path}. ms-agent output tail:\n{_tail_text_from_file(log_path, max_chars=20000)}' - ), + return task.task_id, None, ( + f'final_report.md not found in workdir={workdir}. ' + f'log={log_path}. ms-agent output tail:\n{_tail_text_from_file(log_path, max_chars=20000)}' ) try: @@ -538,23 +538,28 @@ def _run_one_task( return task.task_id, None, f'failed to read report: {e} (path={report_path})' if not article.strip(): - return ( - task.task_id, - None, - ( - f'empty report content (path={report_path}). log={log_path}. ' - f'ms-agent output tail:\n{_tail_text_from_file(log_path, max_chars=20000)}' - ), + return task.task_id, None, ( + f'empty report content (path={report_path}). log={log_path}. ' + f'ms-agent output tail:\n{_tail_text_from_file(log_path, max_chars=20000)}' ) return task.task_id, article, None def main() -> None: - parser = argparse.ArgumentParser(description='Run ms-agent v2 on dr_bench queries and dump raw_data jsonl.') - parser.add_argument('--query_file', required=True, help='Path to dr_bench query.jsonl') - parser.add_argument('--output_jsonl', required=True, help='Output path for dr_bench raw_data/.jsonl') - parser.add_argument('--model_name', default='ms_deepresearch', help='Model/agent name used in output file naming') + parser = argparse.ArgumentParser( + description= + 'Run ms-agent v2 on dr_bench queries and dump raw_data jsonl.') + parser.add_argument( + '--query_file', required=True, help='Path to dr_bench query.jsonl') + parser.add_argument( + '--output_jsonl', + required=True, + help='Output path for dr_bench raw_data/.jsonl') + parser.add_argument( + '--model_name', + default='ms_deepresearch', + help='Model/agent name used in output file naming') parser.add_argument( '--config', default='projects/deep_research/v2/researcher.yaml', @@ -563,31 +568,47 @@ def main() -> None: parser.add_argument( '--work_root', default='eval/dr_bench/results/runs', - help='Root dir to store per-task workdirs. Will create ///', + help= + 'Root dir to store per-task workdirs. Will create ///', ) - parser.add_argument('--limit', type=int, default=0, help='Limit number of tasks (0 means all)') - parser.add_argument('--workers', type=int, default=1, help='Concurrency level (subprocess-based)') + parser.add_argument( + '--limit', + type=int, + default=0, + help='Limit number of tasks (0 means all)') + parser.add_argument( + '--workers', + type=int, + default=1, + help='Concurrency level (subprocess-based)') parser.add_argument( '--python', default=sys.executable, - help='Python executable to run ms-agent (defaults to current interpreter)', + help= + 'Python executable to run ms-agent (defaults to current interpreter)', ) - parser.add_argument('--trust_remote_code', action='store_true', help='Pass --trust_remote_code true to ms-agent') + parser.add_argument( + '--trust_remote_code', + action='store_true', + help='Pass --trust_remote_code true to ms-agent') parser.add_argument( '--ms_agent_root', default='.', - help='Path to ms-agent repo root (contains ms_agent/). Defaults to current working directory.', + help= + 'Path to ms-agent repo root (contains ms_agent/). Defaults to current working directory.', ) parser.add_argument( '--stream_subprocess_output', action='store_true', - help='Stream ms-agent stdout/stderr to console (also written to /ms_agent.log).', + help= + 'Stream ms-agent stdout/stderr to console (also written to /ms_agent.log).', ) parser.add_argument( '--extra', nargs=argparse.REMAINDER, default=[], - help='Extra args passed through to ms-agent (e.g. --llm.model xxx --generation_config.stream false)', + help= + 'Extra args passed through to ms-agent (e.g. --llm.model xxx --generation_config.stream false)', ) args = parser.parse_args() @@ -621,7 +642,7 @@ def main() -> None: tasks.append(Task(task_id=task_id, prompt=prompt)) if args.limit and args.limit > 0: - tasks = tasks[: args.limit] + tasks = tasks[:args.limit] done_ids = _load_existing_ids(output_jsonl) # Backfill: if a workdir already has a top-level final report file but the @@ -656,12 +677,15 @@ def main() -> None: print(msg) return - print(f'Will run {len(tasks)} tasks (workers={args.workers}). Output: {output_jsonl}') + print( + f'Will run {len(tasks)} tasks (workers={args.workers}). Output: {output_jsonl}' + ) os.makedirs(os.path.dirname(output_jsonl) or '.', exist_ok=True) # Ensure ms-agent is importable at runtime for subprocess (best-effort check) if not os.path.exists(os.path.join(ms_agent_root, 'ms_agent')): - raise FileNotFoundError(f'ms_agent_root seems wrong: {ms_agent_root} (missing ms_agent/)') + raise FileNotFoundError( + f'ms_agent_root seems wrong: {ms_agent_root} (missing ms_agent/)') extra_args = args.extra or [] print_lock = threading.Lock() @@ -683,7 +707,13 @@ def main() -> None: if err: print(f'[{tid}] ERROR: {err}', file=sys.stderr) continue - _append_jsonl(output_jsonl, {'id': tid, 'prompt': t.prompt, 'article': article}, lock=write_lock) + _append_jsonl( + output_jsonl, { + 'id': tid, + 'prompt': t.prompt, + 'article': article + }, + lock=write_lock) print(f'[{tid}] OK') return @@ -710,12 +740,20 @@ def main() -> None: try: tid, article, err = fut.result() except Exception as e: - print(f'[{t.task_id}] ERROR: future failed: {e}', file=sys.stderr) + print( + f'[{t.task_id}] ERROR: future failed: {e}', + file=sys.stderr) continue if err: print(f'[{tid}] ERROR: {err}', file=sys.stderr) continue - _append_jsonl(output_jsonl, {'id': tid, 'prompt': t.prompt, 'article': article}, lock=write_lock) + _append_jsonl( + output_jsonl, { + 'id': tid, + 'prompt': t.prompt, + 'article': article + }, + lock=write_lock) print(f'[{tid}] OK') diff --git a/projects/deep_research/v2/reporter.py b/projects/deep_research/v2/reporter.py index 2ce320951..d9c6f2507 100644 --- a/projects/deep_research/v2/reporter.py +++ b/projects/deep_research/v2/reporter.py @@ -2,12 +2,11 @@ import os from typing import Any, AsyncGenerator, List, Union -from omegaconf import DictConfig - from ms_agent.agent.llm_agent import LLMAgent from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG +from omegaconf import DictConfig logger = get_logger() @@ -24,19 +23,23 @@ class ReporterAgent(LLMAgent): 5. Assemble final reports """ - def __init__( - self, config: DictConfig = DictConfig({}), tag: str = DEFAULT_TAG, trust_remote_code: bool = False, **kwargs - ): + def __init__(self, + config: DictConfig = DictConfig({}), + tag: str = DEFAULT_TAG, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) # Reporter-specific configuration self._reports_dir = 'reports' - if hasattr(config, 'tools') and hasattr(config.tools, 'report_generator'): + if hasattr(config, 'tools') and hasattr(config.tools, + 'report_generator'): report_cfg = config.tools.report_generator self._reports_dir = getattr(report_cfg, 'reports_dir', 'reports') async def run( - self, inputs: Union[str, List[str], List[Message], List[List[Message]]], **kwargs + self, inputs: Union[str, List[str], List[Message], + List[List[Message]]], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: # Add context about the reporter's role if isinstance(inputs, str): @@ -48,7 +51,9 @@ async def run( if os.path.exists(evidence_dir): evidence_index = os.path.join(evidence_dir, 'index.json') if os.path.exists(evidence_index): - logger.info(f'ReporterAgent: Evidence index found at {evidence_index}') + logger.info( + f'ReporterAgent: Evidence index found at {evidence_index}' + ) inputs = enhanced_input diff --git a/projects/deep_research/v2/researcher.py b/projects/deep_research/v2/researcher.py index c8ace978c..6af015717 100644 --- a/projects/deep_research/v2/researcher.py +++ b/projects/deep_research/v2/researcher.py @@ -1,13 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, AsyncGenerator, List, Union -from omegaconf import DictConfig - from ms_agent.agent.llm_agent import LLMAgent from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG -from ms_agent.utils.stats import append_stats, build_timing_record, get_stats_path, monotonic, now_iso, summarize_usage +from ms_agent.utils.stats import (append_stats, build_timing_record, + get_stats_path, monotonic, now_iso, + summarize_usage) +from omegaconf import DictConfig logger = get_logger() @@ -17,12 +18,15 @@ class ResearcherAgent(LLMAgent): Researcher Agent that conducts deep research tasks using LLMs and various tools. """ - def __init__( - self, config: DictConfig = DictConfig({}), tag: str = DEFAULT_TAG, trust_remote_code: bool = False, **kwargs - ): + def __init__(self, + config: DictConfig = DictConfig({}), + tag: str = DEFAULT_TAG, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) - async def run_loop(self, messages: Union[List[Message], str], **kwargs) -> AsyncGenerator[Any, Any]: + async def run_loop(self, messages: Union[List[Message], str], + **kwargs) -> AsyncGenerator[Any, Any]: start_ts = now_iso() start_time = monotonic() last_messages: List[Message] = [] @@ -62,7 +66,7 @@ async def on_task_end(self, messages: List[Message]): await super().on_task_end(messages) try: from ms_agent.tools.search.websearch_tool import WebSearchTool - WebSearchTool.log_global_summarization_usage() except Exception as exc: - logger.warning(f'Failed to log web search summarization usage: {exc}') + logger.warning( + f'Failed to log web search summarization usage: {exc}') diff --git a/projects/deep_research/v2/time_handler.py b/projects/deep_research/v2/time_handler.py index 90514f0a8..a36282b5f 100644 --- a/projects/deep_research/v2/time_handler.py +++ b/projects/deep_research/v2/time_handler.py @@ -2,9 +2,8 @@ from datetime import datetime from typing import Any -from omegaconf import DictConfig - from ms_agent.config.config import ConfigLifecycleHandler +from omegaconf import DictConfig class TimeHandler(ConfigLifecycleHandler): @@ -30,7 +29,8 @@ def task_begin(self, config: DictConfig, tag: str) -> DictConfig: def traverse_and_replace(_config: Any): if isinstance(_config, DictConfig): for name, value in _config.items(): - if isinstance(value, DictConfig) or isinstance(value, list): + if isinstance(value, DictConfig) or isinstance( + value, list): traverse_and_replace(value) elif isinstance(value, str): new_value = value @@ -38,7 +38,8 @@ def traverse_and_replace(_config: Any): for var_name, var_value in time_vars.items(): placeholder = f'<{var_name}>' if placeholder in new_value: - new_value = new_value.replace(placeholder, var_value) + new_value = new_value.replace( + placeholder, var_value) setattr(_config, name, new_value) elif isinstance(_config, list): @@ -51,7 +52,8 @@ def traverse_and_replace(_config: Any): for var_name, var_value in time_vars.items(): placeholder = f'<{var_name}>' if placeholder in new_value: - new_value = new_value.replace(placeholder, var_value) + new_value = new_value.replace( + placeholder, var_value) _config[i] = new_value traverse_and_replace(config) diff --git a/projects/deep_research/v2/tools/evidence_tool.py b/projects/deep_research/v2/tools/evidence_tool.py index 7ef061b4c..1379ce511 100644 --- a/projects/deep_research/v2/tools/evidence_tool.py +++ b/projects/deep_research/v2/tools/evidence_tool.py @@ -1,11 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import json import os import re import time import uuid from typing import Any, Dict, List, Optional +import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.utils import file_lock @@ -161,7 +161,8 @@ def _render_analysis_card(analysis: Dict[str, Any]) -> str: if analysis.get('task_id'): lines.append(f"- **Task ID**: `{analysis['task_id']}`") if analysis.get('based_on_note_ids'): - ids_str = ', '.join(f'`{nid}`' for nid in analysis.get('based_on_note_ids', [])) + ids_str = ', '.join(f'`{nid}`' + for nid in analysis.get('based_on_note_ids', [])) lines.append(f'- **Based on Notes**: {ids_str}') if analysis.get('tags'): tags_str = ', '.join(f'`{t}`' for t in analysis['tags']) @@ -284,15 +285,14 @@ def _parse_note_from_md(content: str, note_id: str) -> Dict[str, Any]: elif header == 'Sources': sources = [] for line in body.split('\n'): - match = re.search(r'- \[(\w+)\] (.+?)(?:\s+\(published: ([^)]+)\))?$', line) + match = re.search( + r'- \[(\w+)\] (.+?)(?:\s+\(published: ([^)]+)\))?$', line) if match: - sources.append( - { - 'url': match.group(2).strip(), - 'source_tier': match.group(1), - 'published_at': match.group(3) or '', - } - ) + sources.append({ + 'url': match.group(2).strip(), + 'source_tier': match.group(1), + 'published_at': match.group(3) or '' + }) note['sources'] = sources return note @@ -321,31 +321,44 @@ def __init__(self, config, **kwargs): self.exclude_func(tool_cfg) # Configurable paths - self._evidence_dir = getattr(tool_cfg, 'evidence_dir', 'evidence') if tool_cfg else 'evidence' - self._chunks_dir = getattr(tool_cfg, 'chunks_dir', 'chunks') if tool_cfg else 'chunks' - self._lock_subdir = getattr(tool_cfg, 'lock_subdir', '.locks') if tool_cfg else '.locks' + self._evidence_dir = getattr(tool_cfg, 'evidence_dir', + 'evidence') if tool_cfg else 'evidence' + self._chunks_dir = getattr(tool_cfg, 'chunks_dir', + 'chunks') if tool_cfg else 'chunks' + self._lock_subdir = getattr(tool_cfg, 'lock_subdir', + '.locks') if tool_cfg else '.locks' # Feature flags - self._enable_chunk_storage = bool(getattr(tool_cfg, 'enable_chunk_storage', False)) if tool_cfg else False + self._enable_chunk_storage = bool( + getattr(tool_cfg, 'enable_chunk_storage', + False)) if tool_cfg else False async def connect(self) -> None: """Initialize directory structure.""" _ensure_dir(self.output_dir) _ensure_dir(os.path.join(self.output_dir, self._evidence_dir, 'notes')) - _ensure_dir(os.path.join(self.output_dir, self._evidence_dir, 'analyses')) + _ensure_dir( + os.path.join(self.output_dir, self._evidence_dir, 'analyses')) # Backward-compat: older runs may have used evidence/conclusions/ - _ensure_dir(os.path.join(self.output_dir, self._evidence_dir, 'conclusions')) + _ensure_dir( + os.path.join(self.output_dir, self._evidence_dir, 'conclusions')) _ensure_dir(os.path.join(self.output_dir, self._chunks_dir)) _ensure_dir(os.path.join(self.output_dir, self._lock_subdir)) def _paths(self) -> Dict[str, str]: return { - 'index': os.path.join(self.output_dir, self._evidence_dir, 'index.json'), - 'notes_dir': os.path.join(self.output_dir, self._evidence_dir, 'notes'), - 'analyses_dir': os.path.join(self.output_dir, self._evidence_dir, 'analyses'), - 'legacy_conclusions_dir': os.path.join(self.output_dir, self._evidence_dir, 'conclusions'), - 'chunks_dir': os.path.join(self.output_dir, self._chunks_dir), - 'lock_dir': os.path.join(self.output_dir, self._lock_subdir), + 'index': + os.path.join(self.output_dir, self._evidence_dir, 'index.json'), + 'notes_dir': + os.path.join(self.output_dir, self._evidence_dir, 'notes'), + 'analyses_dir': + os.path.join(self.output_dir, self._evidence_dir, 'analyses'), + 'legacy_conclusions_dir': + os.path.join(self.output_dir, self._evidence_dir, 'conclusions'), + 'chunks_dir': + os.path.join(self.output_dir, self._chunks_dir), + 'lock_dir': + os.path.join(self.output_dir, self._lock_subdir), } async def _get_tools_inner(self) -> Dict[str, Any]: @@ -354,85 +367,107 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='write_note', server_name=self.SERVER_NAME, - description=( - 'Write a new evidence note (card) to the evidence store. ' - 'Each note represents ONE piece of evidence: a claim/observation with supporting text. ' - 'Returns the generated note_id.' - ), + description= + ('Write a new evidence note (card) to the evidence store. ' + 'Each note represents ONE piece of evidence: a claim/observation with supporting text. ' + 'Returns the generated note_id.'), parameters={ - 'type': 'object', + 'type': + 'object', 'properties': { 'title': { - 'type': 'string', - 'description': ( - 'Brief title describing this evidence (e.g., "Tesla Q3 revenue growth"). ' - 'Optional: if omitted, a title is derived from the first line of `content`.' - ), + 'type': + 'string', + 'description': + ('Brief title describing this evidence (e.g., "Tesla Q3 revenue growth"). ' + 'Optional: if omitted, a title is derived from the first line of `content`.'), }, 'content': { - 'type': 'string', - 'description': ( - 'The full evidence text for this note. ' - 'State the core finding or observation, then provide all ' - 'supporting details: specific data points, statistics, quotes, ' - 'case studies, reasoning, and any other substantive information. ' - 'Be thorough — preserve all valuable details from the source material. ' - 'Multi-paragraph allowed.' - ), + 'type': + 'string', + 'description': + ('The full evidence text for this note. ' + 'State the core finding or observation, then provide all ' + 'supporting details: specific data points, statistics, quotes, ' + 'case studies, reasoning, and any other substantive information. ' + 'Be thorough — preserve all valuable details from the source material. ' + 'Multi-paragraph allowed.'), }, 'contradicts': { - 'type': 'string', - 'description': ( - 'Optional: Evidence text that contradicts this finding. ' - 'Include if there are conflicting sources or caveats.' - ), + 'type': + 'string', + 'description': + ('Optional: Evidence text that contradicts this finding. ' + 'Include if there are conflicting sources or caveats.' + ), }, 'sources': { 'type': 'array', - 'description': 'List of source references for this evidence.', + 'description': + 'List of source references for this evidence.', 'items': { 'type': 'object', 'properties': { - 'url': {'type': 'string', 'description': 'Source URL'}, - 'published_at': { + 'url': { 'type': 'string', - 'description': 'Publication date (YYYY-MM-DD)', + 'description': 'Source URL' + }, + 'published_at': { + 'type': + 'string', + 'description': + 'Publication date (YYYY-MM-DD)' }, 'source_tier': { - 'type': 'string', - 'enum': ['official', 'primary', 'secondary', 'unknown'], - 'description': ( - 'Source credibility tier (for example, Official ' - 'Documents/Papers/Standards > ' - 'Primary News/Announcements > Secondary Blogs)' - ), + 'type': + 'string', + 'enum': [ + 'official', 'primary', + 'secondary', 'unknown' + ], + 'description': + ('Source credibility tier (for example, Official ' + 'Documents/Papers/Standards > ' + 'Primary News/Announcements > Secondary Blogs)' + ), }, }, 'required': ['url'], }, }, 'summary': { - 'type': 'string', - 'description': 'One-sentence summary of this evidence.', + 'type': + 'string', + 'description': + 'One-sentence summary of this evidence.', }, 'task_id': { - 'type': 'string', - 'description': 'The plan task this evidence relates to.', + 'type': + 'string', + 'description': + 'The plan task this evidence relates to.', }, 'tags': { 'type': 'array', - 'items': {'type': 'string'}, + 'items': { + 'type': 'string' + }, 'description': 'Tags for categorization.', }, 'quality_score': { - 'type': 'integer', - 'minimum': 0, - 'maximum': 100, - 'description': 'Optional: Confidence/quality score (0-100).', + 'type': + 'integer', + 'minimum': + 0, + 'maximum': + 100, + 'description': + 'Optional: Confidence/quality score (0-100).', }, }, 'required': ['content'], - 'additionalProperties': False, + 'additionalProperties': + False, }, ), Tool( @@ -444,7 +479,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'properties': { 'note_id': { 'type': 'string', - 'description': 'The ID of the note to retrieve.', + 'description': + 'The ID of the note to retrieve.', }, }, 'required': ['note_id'], @@ -454,10 +490,9 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='list_notes', server_name=self.SERVER_NAME, - description=( - 'List all evidence notes, optionally filtered by task_id or tags. ' - 'Returns a summary list (not full content).' - ), + description= + ('List all evidence notes, optionally filtered by task_id or tags. ' + 'Returns a summary list (not full content).'), parameters={ 'type': 'object', 'properties': { @@ -466,9 +501,13 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'description': 'Optional: Filter by task ID.', }, 'tags': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'Optional: Filter by tags (notes must have ALL specified tags).', + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Optional: Filter by tags (notes must have ALL specified tags).', }, # 'min_quality': { # 'type': 'integer', @@ -482,7 +521,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='search_notes', server_name=self.SERVER_NAME, - description='Search notes by keyword in title, claim, or summary.', + description= + 'Search notes by keyword in title, claim, or summary.', parameters={ 'type': 'object', 'properties': { @@ -514,60 +554,77 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='write_analysis', server_name=self.SERVER_NAME, - description=( - 'Write an interim **analysis** record to the evidence store. ' - 'Use this tool whenever you need to turn multiple evidence notes into reusable reasoning artifacts, e.g.: ' - '(1) synthesis / interim summaries; ' - '(2) comparisons and trade-off decisions (A vs B, pros/cons, why choose X); ' - '(3) framework building (typologies, evaluation rubrics, scoring criteria, checklists); ' - '(4) mapping & reconciliation (align competing definitions/metrics, resolve conflicts, record assumptions); ' - '(5) scenario framing and uncertainty tracking (what-if branches, key sensitivities/risks, open questions); ' - '(6) rankings/recommendations that require rationale (e.g., pick top 2–3 options and justify). ' - '(7) Structured / visual intermediate artifacts (e.g., mind-map-style hierarchical outlines, and ' - 'text-based flow/relationship diagrams—prefer Mermaid syntax when possible).' - '(8) other intermediate analysis that requires reasoning, justification and recording.' - 'This is **not** the final report; it is an intermediate analysis that should cite supporting evidence via ' - 'based_on_note_ids when possible so downstream writing can reuse it. ' - 'Returns the generated analysis_id.' - ), + description= + ('Write an interim **analysis** record to the evidence store. ' + 'Use this tool whenever you need to turn multiple evidence notes into reusable reasoning artifacts, e.g.: ' + '(1) synthesis / interim summaries; ' + '(2) comparisons and trade-off decisions (A vs B, pros/cons, why choose X); ' + '(3) framework building (typologies, evaluation rubrics, scoring criteria, checklists); ' + '(4) mapping & reconciliation (align competing definitions/metrics, resolve conflicts, record assumptions); ' + '(5) scenario framing and uncertainty tracking (what-if branches, key sensitivities/risks, open questions); ' + '(6) rankings/recommendations that require rationale (e.g., pick top 2–3 options and justify). ' + '(7) Structured / visual intermediate artifacts (e.g., mind-map-style hierarchical outlines, and ' + 'text-based flow/relationship diagrams—prefer Mermaid syntax when possible).' + '(8) other intermediate analysis that requires reasoning, justification and recording.' + 'This is **not** the final report; it is an intermediate analysis that should cite supporting evidence via ' + 'based_on_note_ids when possible so downstream writing can reuse it. ' + 'Returns the generated analysis_id.'), parameters={ 'type': 'object', 'properties': { 'title': { - 'type': 'string', - 'description': 'Brief title describing this analysis (e.g., "Interim comparison: Framework A vs B").', + 'type': + 'string', + 'description': + 'Brief title describing this analysis (e.g., "Interim comparison: Framework A vs B").', }, 'content': { - 'type': 'string', - 'description': ( - 'The analysis content in Markdown. ' - 'This should capture synthesis/comparison, constraints, assumptions, and reasoning. ' - 'Multi-paragraph allowed.' - ), + 'type': + 'string', + 'description': + ('The analysis content in Markdown. ' + 'This should capture synthesis/comparison, constraints, assumptions, and reasoning. ' + 'Multi-paragraph allowed.'), }, 'summary': { - 'type': 'string', - 'description': 'Optional: One-sentence summary of this analysis.', + 'type': + 'string', + 'description': + 'Optional: One-sentence summary of this analysis.', }, 'task_id': { - 'type': 'string', - 'description': 'Optional: The plan task this analysis relates to.', + 'type': + 'string', + 'description': + 'Optional: The plan task this analysis relates to.', }, 'based_on_note_ids': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'Optional: List of note_ids this analysis is based on.', + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Optional: List of note_ids this analysis is based on.', }, 'tags': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'Optional: Tags for categorization.', + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Optional: Tags for categorization.', }, 'quality_score': { - 'type': 'integer', - 'minimum': 0, - 'maximum': 100, - 'description': 'Optional: Confidence/quality score (0-100).', + 'type': + 'integer', + 'minimum': + 0, + 'maximum': + 100, + 'description': + 'Optional: Confidence/quality score (0-100).', }, }, 'required': ['title', 'content', 'summary', 'tags'], @@ -582,12 +639,16 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'object', 'properties': { 'analysis_id': { - 'type': 'string', - 'description': 'The ID of the analysis to retrieve.', + 'type': + 'string', + 'description': + 'The ID of the analysis to retrieve.', }, 'parse_analysis': { - 'type': 'boolean', - 'description': 'Optional: Whether to parse stored markdown back to structured dict.', + 'type': + 'boolean', + 'description': + 'Optional: Whether to parse stored markdown back to structured dict.', }, }, 'required': ['analysis_id'], @@ -597,10 +658,9 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='list_analyses', server_name=self.SERVER_NAME, - description=( - 'List all analyses, optionally filtered by task_id or tags. ' - 'Returns a summary list (not full content).' - ), + description= + ('List all analyses, optionally filtered by task_id or tags. ' + 'Returns a summary list (not full content).'), parameters={ 'type': 'object', 'properties': { @@ -609,9 +669,13 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'description': 'Optional: Filter by task ID.', }, 'tags': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'Optional: Filter by tags (analyses must have ALL specified tags).', + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Optional: Filter by tags (analyses must have ALL specified tags).', }, }, 'required': [], @@ -621,7 +685,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='search_analyses', server_name=self.SERVER_NAME, - description='Search analyses by keyword in title, summary, or tags.', + description= + 'Search analyses by keyword in title, summary, or tags.', parameters={ 'type': 'object', 'properties': { @@ -643,7 +708,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'properties': { 'analysis_id': { 'type': 'string', - 'description': 'The ID of the analysis to delete.', + 'description': + 'The ID of the analysis to delete.', }, }, 'required': ['analysis_id'], @@ -660,12 +726,13 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'required': [], 'additionalProperties': False, }, - ), + ) ] } return tools - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: return await getattr(self, tool_name)(**(tool_args or {})) def _load_index_locked(self, paths: Dict[str, str]) -> Dict[str, Any]: @@ -675,13 +742,16 @@ def _load_index_locked(self, paths: Dict[str, str]) -> Dict[str, Any]: return { 'schema_version': 2, 'updated_at': _now_iso(), - 'notes': {}, # note_id -> {title, task_id, summary, sources, tags, quality_score, created_at} - 'analyses': {}, # analysis_id -> {title, task_id, summary, based_on_note_ids, tags, quality_score, created_at, path} + 'notes': + {}, # note_id -> {title, task_id, summary, sources, tags, quality_score, created_at} + 'analyses': + {}, # analysis_id -> {title, task_id, summary, based_on_note_ids, tags, quality_score, created_at, path} } # Backward/forward compatible defaults if 'notes' not in data or not isinstance(data.get('notes'), dict): data['notes'] = {} - if 'analyses' not in data or not isinstance(data.get('analyses'), dict): + if 'analyses' not in data or not isinstance( + data.get('analyses'), dict): data['analyses'] = {} # Backward-compat: older schema used "conclusions" key. @@ -690,12 +760,14 @@ def _load_index_locked(self, paths: Dict[str, str]) -> Dict[str, Any]: data['analyses'] = legacy return data - def _save_index_locked(self, paths: Dict[str, str], index: Dict[str, Any]) -> None: + def _save_index_locked(self, paths: Dict[str, str], + index: Dict[str, Any]) -> None: """Save index.json.""" index['updated_at'] = _now_iso() _write_text(paths['index'], _json_dumps(index)) - def _add_to_index(self, index: Dict[str, Any], note: Dict[str, Any]) -> None: + def _add_to_index(self, index: Dict[str, Any], note: Dict[str, + Any]) -> None: """Add a note's metadata to the index.""" note_id = note['note_id'] index['notes'][note_id] = { @@ -708,7 +780,9 @@ def _add_to_index(self, index: Dict[str, Any], note: Dict[str, Any]) -> None: 'created_at': note.get('created_at', ''), } - def _add_analysis_to_index(self, index: Dict[str, Any], analysis: Dict[str, Any], analysis_path: str) -> None: + def _add_analysis_to_index(self, index: Dict[str, Any], + analysis: Dict[str, Any], + analysis_path: str) -> None: """Add an analysis' metadata to the index.""" aid = analysis['analysis_id'] index['analyses'][aid] = { @@ -729,14 +803,16 @@ def _remove_from_index(self, index: Dict[str, Any], note_id: str) -> bool: return True return False - def _remove_analysis_from_index(self, index: Dict[str, Any], analysis_id: str) -> bool: + def _remove_analysis_from_index(self, index: Dict[str, Any], + analysis_id: str) -> bool: """Remove an analysis from the index. Returns True if found and removed.""" if analysis_id in index.get('analyses', {}): del index['analyses'][analysis_id] return True return False - def _store_chunk(self, chunk_id: str, content: str, metadata: Dict[str, Any]) -> str: + def _store_chunk(self, chunk_id: str, content: str, + metadata: Dict[str, Any]) -> str: """ Store a text chunk. Reserved for future implementation. @@ -787,12 +863,10 @@ async def write_note( content = (content or '').strip() if not content: - return _json_dumps( - { - 'status': 'error', - 'message': 'write_note requires non-empty content.', - } - ) + return _json_dumps({ + 'status': 'error', + 'message': 'write_note requires non-empty content.', + }) if title is None or not str(title).strip(): first_line = content.split('\n', 1)[0].strip() @@ -816,7 +890,8 @@ async def write_note( if sources: # Validate source tiers for src in sources: - src['source_tier'] = _validate_source_tier(src.get('source_tier', 'unknown')) + src['source_tier'] = _validate_source_tier( + src.get('source_tier', 'unknown')) note['sources'] = sources if summary: note['summary'] = summary.strip() @@ -839,13 +914,11 @@ async def write_note( self._add_to_index(index, note) self._save_index_locked(paths, index) - return _json_dumps( - { - 'status': 'ok', - 'note_id': note_id, - 'path': os.path.relpath(note_path, self.output_dir), - } - ) + return _json_dumps({ + 'status': 'ok', + 'note_id': note_id, + 'path': os.path.relpath(note_path, self.output_dir), + }) async def write_analysis( self, @@ -874,13 +947,16 @@ async def write_analysis( if task_id: analysis['task_id'] = task_id.strip() if based_on_note_ids: - analysis['based_on_note_ids'] = [nid.strip() for nid in based_on_note_ids if nid.strip()] + analysis['based_on_note_ids'] = [ + nid.strip() for nid in based_on_note_ids if nid.strip() + ] if tags: analysis['tags'] = [t.strip() for t in tags if t.strip()] if quality_score is not None: analysis['quality_score'] = max(0, min(100, quality_score)) - analysis_path = os.path.join(paths['analyses_dir'], f'analysis_{analysis_id}.md') + analysis_path = os.path.join(paths['analyses_dir'], + f'analysis_{analysis_id}.md') analysis_content = _render_analysis_card(analysis) _write_text(analysis_path, analysis_content) @@ -889,25 +965,33 @@ async def write_analysis( self._add_analysis_to_index(index, analysis, analysis_path) self._save_index_locked(paths, index) - return _json_dumps( - { - 'status': 'ok', - 'analysis_id': analysis_id, - 'path': os.path.relpath(analysis_path, self.output_dir), - } - ) - - async def get_analysis(self, analysis_id: str, parse_analysis: Optional[bool] = False) -> str: + return _json_dumps({ + 'status': + 'ok', + 'analysis_id': + analysis_id, + 'path': + os.path.relpath(analysis_path, self.output_dir), + }) + + async def get_analysis(self, + analysis_id: str, + parse_analysis: Optional[bool] = False) -> str: """Retrieve an analysis by ID.""" paths = self._paths() - analysis_path = os.path.join(paths['analyses_dir'], f'analysis_{analysis_id}.md') - legacy_path = os.path.join(paths['legacy_conclusions_dir'], f'conclusion_{analysis_id}.md') + analysis_path = os.path.join(paths['analyses_dir'], + f'analysis_{analysis_id}.md') + legacy_path = os.path.join(paths['legacy_conclusions_dir'], + f'conclusion_{analysis_id}.md') if not os.path.exists(analysis_path) and os.path.exists(legacy_path): analysis_path = legacy_path if not os.path.exists(analysis_path): - return _json_dumps({'status': 'error', 'message': f'Analysis {analysis_id} not found.'}) + return _json_dumps({ + 'status': 'error', + 'message': f'Analysis {analysis_id} not found.' + }) with open(analysis_path, 'r', encoding='utf-8') as f: content = f.read() @@ -915,16 +999,16 @@ async def get_analysis(self, analysis_id: str, parse_analysis: Optional[bool] = if not parse_analysis: return _json_dumps({'status': 'ok', 'raw_content': content}) analysis = _parse_analysis_from_md(content, analysis_id) - return _json_dumps( - { - 'status': 'ok', - 'analysis_id': analysis_id, - 'analysis': analysis, - 'raw_content': content, - } - ) + return _json_dumps({ + 'status': 'ok', + 'analysis_id': analysis_id, + 'analysis': analysis, + 'raw_content': content, + }) - async def list_analyses(self, task_id: Optional[str] = None, tags: Optional[List[str]] = None) -> str: + async def list_analyses(self, + task_id: Optional[str] = None, + tags: Optional[List[str]] = None) -> str: """List analyses with optional filters.""" paths = self._paths() _ensure_dir(paths['lock_dir']) @@ -941,28 +1025,33 @@ async def list_analyses(self, task_id: Optional[str] = None, tags: Optional[List a_tags = set(meta.get('tags', [])) if not all(t in a_tags for t in tags): continue - results.append( - { - 'analysis_id': aid, - 'title': meta.get('title', ''), - 'task_id': meta.get('task_id', ''), - 'summary': meta.get('summary', ''), - 'based_on_note_ids': meta.get('based_on_note_ids', []), - 'tags': meta.get('tags', []), - 'quality_score': meta.get('quality_score'), - 'created_at': meta.get('created_at', ''), - 'path': meta.get('path', ''), - } - ) + results.append({ + 'analysis_id': + aid, + 'title': + meta.get('title', ''), + 'task_id': + meta.get('task_id', ''), + 'summary': + meta.get('summary', ''), + 'based_on_note_ids': + meta.get('based_on_note_ids', []), + 'tags': + meta.get('tags', []), + 'quality_score': + meta.get('quality_score'), + 'created_at': + meta.get('created_at', ''), + 'path': + meta.get('path', ''), + }) results.sort(key=lambda x: x.get('created_at', ''), reverse=True) - return _json_dumps( - { - 'status': 'ok', - 'count': len(results), - 'analyses': results, - } - ) + return _json_dumps({ + 'status': 'ok', + 'count': len(results), + 'analyses': results, + }) async def search_analyses(self, keyword: str) -> str: """Search analyses by keyword.""" @@ -971,7 +1060,10 @@ async def search_analyses(self, keyword: str) -> str: keyword_lower = keyword.lower().strip() if not keyword_lower: - return _json_dumps({'status': 'error', 'message': 'Keyword is required.'}) + return _json_dumps({ + 'status': 'error', + 'message': 'Keyword is required.' + }) with file_lock(paths['lock_dir'], 'evidence_index'): index = self._load_index_locked(paths) @@ -979,48 +1071,50 @@ async def search_analyses(self, keyword: str) -> str: analyses_meta = index.get('analyses', {}) results = [] for aid, meta in analyses_meta.items(): - searchable = ' '.join( - [ - meta.get('title', ''), - meta.get('summary', ''), - ] - ).lower() + searchable = ' '.join([ + meta.get('title', ''), + meta.get('summary', ''), + ]).lower() a_tags = meta.get('tags', []) searchable += ' ' + ' '.join(a_tags).lower() if keyword_lower in searchable: - results.append( - { - 'analysis_id': aid, - 'title': meta.get('title', ''), - 'summary': meta.get('summary', ''), - 'task_id': meta.get('task_id', ''), - 'quality_score': meta.get('quality_score'), - } - ) + results.append({ + 'analysis_id': aid, + 'title': meta.get('title', ''), + 'summary': meta.get('summary', ''), + 'task_id': meta.get('task_id', ''), + 'quality_score': meta.get('quality_score'), + }) - return _json_dumps( - { - 'status': 'ok', - 'keyword': keyword, - 'count': len(results), - 'analyses': results, - } - ) + return _json_dumps({ + 'status': 'ok', + 'keyword': keyword, + 'count': len(results), + 'analyses': results, + }) async def delete_analysis(self, analysis_id: str) -> str: """Delete an analysis by ID.""" paths = self._paths() _ensure_dir(paths['lock_dir']) - analysis_path = os.path.join(paths['analyses_dir'], f'analysis_{analysis_id}.md') - legacy_path = os.path.join(paths['legacy_conclusions_dir'], f'conclusion_{analysis_id}.md') + analysis_path = os.path.join(paths['analyses_dir'], + f'analysis_{analysis_id}.md') + legacy_path = os.path.join(paths['legacy_conclusions_dir'], + f'conclusion_{analysis_id}.md') with file_lock(paths['lock_dir'], 'evidence_index'): index = self._load_index_locked(paths) removed = self._remove_analysis_from_index(index, analysis_id) - if not removed and not os.path.exists(analysis_path) and not os.path.exists(legacy_path): - return _json_dumps({'status': 'error', 'message': f'Analysis {analysis_id} not found.'}) + if not removed and not os.path.exists( + analysis_path) and not os.path.exists(legacy_path): + return _json_dumps({ + 'status': + 'error', + 'message': + f'Analysis {analysis_id} not found.' + }) self._save_index_locked(paths, index) @@ -1031,13 +1125,18 @@ async def delete_analysis(self, analysis_id: str) -> str: return _json_dumps({'status': 'ok', 'deleted': analysis_id}) - async def get_note(self, note_id: str, parse_note: Optional[bool] = False) -> str: + async def get_note(self, + note_id: str, + parse_note: Optional[bool] = False) -> str: """Retrieve a note by ID.""" paths = self._paths() note_path = os.path.join(paths['notes_dir'], f'note_{note_id}.md') if not os.path.exists(note_path): - return _json_dumps({'status': 'error', 'message': f'Note {note_id} not found.'}) + return _json_dumps({ + 'status': 'error', + 'message': f'Note {note_id} not found.' + }) with open(note_path, 'r', encoding='utf-8') as f: content = f.read() @@ -1046,11 +1145,17 @@ async def get_note(self, note_id: str, parse_note: Optional[bool] = False) -> st return _json_dumps({'status': 'ok', 'raw_content': content}) else: note = _parse_note_from_md(content, note_id) - return _json_dumps({'status': 'ok', 'note_id': note_id, 'note': note, 'raw_content': content}) - - async def list_notes( - self, task_id: Optional[str] = None, tags: Optional[List[str]] = None, min_quality: Optional[int] = None - ) -> str: + return _json_dumps({ + 'status': 'ok', + 'note_id': note_id, + 'note': note, + 'raw_content': content + }) + + async def list_notes(self, + task_id: Optional[str] = None, + tags: Optional[List[str]] = None, + min_quality: Optional[int] = None) -> str: """List notes with optional filters. Args: @@ -1080,29 +1185,25 @@ async def list_notes( if score is None or score < min_quality: continue - results.append( - { - 'note_id': nid, - 'title': meta.get('title', ''), - 'task_id': meta.get('task_id', ''), - 'summary': meta.get('summary', ''), - 'sources': meta.get('sources', []), - 'tags': meta.get('tags', []), - 'quality_score': meta.get('quality_score'), - 'created_at': meta.get('created_at', ''), - } - ) + results.append({ + 'note_id': nid, + 'title': meta.get('title', ''), + 'task_id': meta.get('task_id', ''), + 'summary': meta.get('summary', ''), + 'sources': meta.get('sources', []), + 'tags': meta.get('tags', []), + 'quality_score': meta.get('quality_score'), + 'created_at': meta.get('created_at', ''), + }) # Sort by created_at descending results.sort(key=lambda x: x.get('created_at', ''), reverse=True) - return _json_dumps( - { - 'status': 'ok', - 'count': len(results), - 'notes': results, - } - ) + return _json_dumps({ + 'status': 'ok', + 'count': len(results), + 'notes': results, + }) async def search_notes(self, keyword: str) -> str: """Search notes by keyword.""" @@ -1111,7 +1212,10 @@ async def search_notes(self, keyword: str) -> str: keyword_lower = keyword.lower().strip() if not keyword_lower: - return _json_dumps({'status': 'error', 'message': 'Keyword is required.'}) + return _json_dumps({ + 'status': 'error', + 'message': 'Keyword is required.' + }) with file_lock(paths['lock_dir'], 'evidence_index'): index = self._load_index_locked(paths) @@ -1121,34 +1225,28 @@ async def search_notes(self, keyword: str) -> str: for nid, meta in notes_meta.items(): # Search in title, summary - searchable = ' '.join( - [ - meta.get('title', ''), - meta.get('summary', ''), - ] - ).lower() + searchable = ' '.join([ + meta.get('title', ''), + meta.get('summary', ''), + ]).lower() tags = meta.get('tags', []) searchable += ' ' + ' '.join(tags).lower() if keyword_lower in searchable: - results.append( - { - 'note_id': nid, - 'title': meta.get('title', ''), - 'summary': meta.get('summary', ''), - 'task_id': meta.get('task_id', ''), - 'quality_score': meta.get('quality_score'), - } - ) + results.append({ + 'note_id': nid, + 'title': meta.get('title', ''), + 'summary': meta.get('summary', ''), + 'task_id': meta.get('task_id', ''), + 'quality_score': meta.get('quality_score'), + }) - return _json_dumps( - { - 'status': 'ok', - 'keyword': keyword, - 'count': len(results), - 'notes': results, - } - ) + return _json_dumps({ + 'status': 'ok', + 'keyword': keyword, + 'count': len(results), + 'notes': results, + }) async def delete_note(self, note_id: str) -> str: """Delete a note by ID.""" @@ -1163,7 +1261,10 @@ async def delete_note(self, note_id: str) -> str: removed = self._remove_from_index(index, note_id) if not removed and not os.path.exists(note_path): - return _json_dumps({'status': 'error', 'message': f'Note {note_id} not found.'}) + return _json_dumps({ + 'status': 'error', + 'message': f'Note {note_id} not found.' + }) self._save_index_locked(paths, index) @@ -1188,13 +1289,11 @@ async def load_index(self) -> str: notes = index.get('notes', {}) analyses = index.get('analyses', {}) - return _json_dumps( - { - 'status': 'ok', - 'updated_at': index.get('updated_at', ''), - 'total_notes': len(notes), - 'total_analyses': len(analyses), - 'notes': notes, - 'analyses': analyses, - } - ) + return _json_dumps({ + 'status': 'ok', + 'updated_at': index.get('updated_at', ''), + 'total_notes': len(notes), + 'total_analyses': len(analyses), + 'notes': notes, + 'analyses': analyses, + }) diff --git a/projects/deep_research/v2/tools/report_tool.py b/projects/deep_research/v2/tools/report_tool.py index 5d0bcfbdd..819da108c 100644 --- a/projects/deep_research/v2/tools/report_tool.py +++ b/projects/deep_research/v2/tools/report_tool.py @@ -1,11 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import json import os import re import time import uuid from typing import Any, Dict, List, Optional +import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.utils import file_lock, render_markdown_todo @@ -46,17 +46,20 @@ def _write_text(path: str, content: str) -> None: def _coerce_chapters_argument(chapters: Any) -> tuple[List[Dict[str, Any]], Optional[str]]: """Normalize `chapters` from the model (list, JSON string, or nested strings).""" if chapters is None: - return [], ('commit_outline requires `chapters` (array of chapter objects, or a JSON string of that array).') + return [], ( + 'commit_outline requires `chapters` (array of chapter objects, ' + 'or a JSON string of that array).') raw: Any = chapters if isinstance(raw, str): try: raw = json.loads(raw.strip()) except json.JSONDecodeError as e: return [], ( - f'commit_outline `chapters` must be a JSON array of objects, or a JSON string of that array: {e}' - ) + 'commit_outline `chapters` must be a JSON array of objects, ' + f'or a JSON string of that array: {e}') if not isinstance(raw, list): - return [], (f'commit_outline `chapters` must be a list, got {type(chapters).__name__}.') + return [], ( + f'commit_outline `chapters` must be a list, got {type(chapters).__name__}.') out: List[Dict[str, Any]] = [] for i, ch in enumerate(raw): if isinstance(ch, str): @@ -64,10 +67,11 @@ def _coerce_chapters_argument(chapters: Any) -> tuple[List[Dict[str, Any]], Opti ch = json.loads(ch.strip()) except json.JSONDecodeError: return [], ( - f'commit_outline chapters[{i}] must be an object; string entry is not valid JSON for an object.' - ) + f'commit_outline chapters[{i}] must be an object; ' + 'string entry is not valid JSON for an object.') if not isinstance(ch, dict): - return [], (f'commit_outline chapters[{i}] must be an object, got {type(ch).__name__}.') + return [], ( + f'commit_outline chapters[{i}] must be an object, got {type(ch).__name__}.') out.append(ch) return out, None @@ -77,9 +81,14 @@ def _render_outline_md(outline: Dict[str, Any]) -> str: lines = [f"# {outline.get('title', 'Report Outline')}", ''] for ch in outline.get('chapters', []): - status_icon = {'pending': '⏳', 'in_progress': '🔄', 'completed': '✅'}.get(ch.get('status', 'pending'), '⏳') + status_icon = { + 'pending': '⏳', + 'in_progress': '🔄', + 'completed': '✅' + }.get(ch.get('status', 'pending'), '⏳') - lines.append(f"## Chapter {ch['chapter_id']}: {ch['title']} {status_icon}") + lines.append( + f"## Chapter {ch['chapter_id']}: {ch['title']} {status_icon}") if ch.get('goals'): lines.append('') @@ -94,7 +103,8 @@ def _render_outline_md(outline: Dict[str, Any]) -> str: if ch.get('candidate_evidence'): lines.append('') - lines.append(f"**Related evidence:** {', '.join(ch['candidate_evidence'])}") + lines.append( + f"**Related evidence:** {', '.join(ch['candidate_evidence'])}") lines.append('') @@ -106,19 +116,27 @@ def _render_outline_progress_md(outline: Dict[str, Any]) -> str: chapters = outline.get('chapters', []) total = len(chapters) completed = sum(1 for ch in chapters if ch.get('status') == 'completed') - in_progress = sum(1 for ch in chapters if ch.get('status') == 'in_progress') + in_progress = sum(1 for ch in chapters + if ch.get('status') == 'in_progress') pending = total - completed - in_progress lines = [f"# {outline.get('title', 'Report Outline')}", ''] - lines.append(f'Progress: {completed}/{total} completed | {in_progress} in progress | {pending} pending') + lines.append( + f'Progress: {completed}/{total} completed | {in_progress} in progress | {pending} pending' + ) lines.append('') lines.append('## Chapters') lines.append('') for ch in chapters: status = ch.get('status', 'pending') - status_icon = {'pending': '⏳', 'in_progress': '🔄', 'completed': '✅'}.get(status, '⏳') - lines.append(f"- {status_icon} Chapter {ch['chapter_id']}: {ch['title']}") + status_icon = { + 'pending': '⏳', + 'in_progress': '🔄', + 'completed': '✅' + }.get(status, '⏳') + lines.append( + f"- {status_icon} Chapter {ch['chapter_id']}: {ch['title']}") lines.append('') return '\n'.join(lines) @@ -154,28 +172,43 @@ def __init__(self, config, **kwargs): self.exclude_func(tool_cfg) # Configurable paths - self._reports_dir = getattr(tool_cfg, 'reports_dir', 'reports') if tool_cfg else 'reports' - self._evidence_dir = getattr(tool_cfg, 'evidence_dir', 'evidence') if tool_cfg else 'evidence' - self._lock_subdir = getattr(tool_cfg, 'lock_subdir', '.locks') if tool_cfg else '.locks' + self._reports_dir = getattr(tool_cfg, 'reports_dir', + 'reports') if tool_cfg else 'reports' + self._evidence_dir = getattr(tool_cfg, 'evidence_dir', + 'evidence') if tool_cfg else 'evidence' + self._lock_subdir = getattr(tool_cfg, 'lock_subdir', + '.locks') if tool_cfg else '.locks' async def connect(self) -> None: """Initialize directory structure.""" _ensure_dir(self.output_dir) - _ensure_dir(os.path.join(self.output_dir, self._reports_dir, 'chapters')) + _ensure_dir( + os.path.join(self.output_dir, self._reports_dir, 'chapters')) _ensure_dir(os.path.join(self.output_dir, self._lock_subdir)) def _paths(self) -> Dict[str, str]: return { - 'outline_json': os.path.join(self.output_dir, self._reports_dir, 'outline.json'), - 'outline_md': os.path.join(self.output_dir, self._reports_dir, 'outline.md'), - 'outline_progress_md': os.path.join(self.output_dir, self._reports_dir, 'outline_progress.md'), - 'chapters_dir': os.path.join(self.output_dir, self._reports_dir, 'chapters'), - 'conflict_json': os.path.join(self.output_dir, self._reports_dir, 'conflict.json'), - 'draft_md': os.path.join(self.output_dir, self._reports_dir, 'draft.md'), - 'report_md': os.path.join(self.output_dir, self._reports_dir, 'report.md'), - 'evidence_index': os.path.join(self.output_dir, self._evidence_dir, 'index.json'), - 'evidence_notes_dir': os.path.join(self.output_dir, self._evidence_dir, 'notes'), - 'lock_dir': os.path.join(self.output_dir, self._lock_subdir), + 'outline_json': + os.path.join(self.output_dir, self._reports_dir, 'outline.json'), + 'outline_md': + os.path.join(self.output_dir, self._reports_dir, 'outline.md'), + 'outline_progress_md': + os.path.join(self.output_dir, self._reports_dir, + 'outline_progress.md'), + 'chapters_dir': + os.path.join(self.output_dir, self._reports_dir, 'chapters'), + 'conflict_json': + os.path.join(self.output_dir, self._reports_dir, 'conflict.json'), + 'draft_md': + os.path.join(self.output_dir, self._reports_dir, 'draft.md'), + 'report_md': + os.path.join(self.output_dir, self._reports_dir, 'report.md'), + 'evidence_index': + os.path.join(self.output_dir, self._evidence_dir, 'index.json'), + 'evidence_notes_dir': + os.path.join(self.output_dir, self._evidence_dir, 'notes'), + 'lock_dir': + os.path.join(self.output_dir, self._lock_subdir), } def _filter_candidate_evidence( @@ -215,11 +248,11 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='commit_outline', server_name=self.SERVER_NAME, - description=( - 'Generate the report outline with chapter structure. ' - 'Each chapter must be bound to relevant evidence (note_ids). ' - 'Ensures all evidence is covered by at least one chapter.' - ), + description= + ('Generate the report outline with chapter structure. ' + 'Each chapter must be bound to relevant evidence (note_ids). ' + 'Ensures all evidence is covered by at least one chapter.' + ), parameters={ 'type': 'object', 'properties': { @@ -231,38 +264,52 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'array', 'description': 'List of chapter definitions.', 'items': { - 'type': 'object', + 'type': + 'object', 'properties': { 'title': { 'type': 'string', 'description': 'Chapter title.', }, 'goals': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'Main objectives of this chapter.', + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Main objectives of this chapter.', }, 'sections_description': { - 'type': 'string', - 'description': ( - 'Detailed section-by-section plan for this chapter ' - '(NOT a single-sentence summary). ' - 'Write subsections as a numbered list in markdown. ' - 'For EACH subsection include: ' - '(a) subsection title, (b) 2-5 bullet key ' - 'points / questions to answer, ' - '(c) expected output form: narrative synthesis is required; ' - 'optionally add an artifact ' - '(e.g., table/checklist) to support the narrative.' - ), + 'type': + 'string', + 'description': + ('Detailed section-by-section plan for this chapter ' + '(NOT a single-sentence summary). ' + 'Write subsections as a numbered list in markdown. ' + 'For EACH subsection include: ' + '(a) subsection title, (b) 2-5 bullet key ' + 'points / questions to answer, ' + '(c) expected output form: narrative synthesis is required; ' + 'optionally add an artifact ' + '(e.g., table/checklist) to support the narrative.' + ), }, 'candidate_evidence': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'List of note_ids relevant to this chapter.', + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'List of note_ids relevant to this chapter.', }, }, - 'required': ['title', 'goals', 'sections_description', 'candidate_evidence'], + 'required': [ + 'title', 'goals', + 'sections_description', + 'candidate_evidence' + ], }, }, }, @@ -273,11 +320,11 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='prepare_chapter_bundle', server_name=self.SERVER_NAME, - description=( - 'Prepare metadata and evidence content for writing a specific chapter. ' - 'Returns the chapter info with full evidence details for review. ' - 'Call this before commit_chapter to review evidence quality.' - ), + description= + ('Prepare metadata and evidence content for writing a specific chapter. ' + 'Returns the chapter info with full evidence details for review. ' + 'Call this before commit_chapter to review evidence quality.' + ), parameters={ 'type': 'object', 'properties': { @@ -286,12 +333,15 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'description': 'The chapter number (1-based).', }, 'relevant_evidence': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': ( - 'List of note_ids maybe used in this chapter. ' - 'The note_ids in this list will be loaded for review.' - ), + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + ('List of note_ids maybe used in this chapter. ' + 'The note_ids in this list will be loaded for review.' + ), }, # 'need_raw_chunks': { # 'type': 'boolean', @@ -306,57 +356,70 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='commit_chapter', server_name=self.SERVER_NAME, - description=( - 'Write the content of a specific chapter. ' - 'The chapter will be saved as chapter_XX.md and status updated to completed.' - ), + description= + ('Write the content of a specific chapter. ' + 'The chapter will be saved as chapter_XX.md and status updated to completed.' + ), parameters={ - 'type': 'object', + 'type': + 'object', 'properties': { 'chapter_id': { 'type': 'integer', 'description': 'The chapter number (1-based).', }, 'reranked_evidence': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'List of note_ids reranked and chosen for this chapter.', + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'List of note_ids reranked and chosen for this chapter.', }, 'content': { - 'type': 'string', - 'description': ( - 'The markdown content of the chapter. ' - 'The content should include citations to the resources used in this chapter.' - 'Make sure the content is based on the reranked evidence.' - ), + 'type': + 'string', + 'description': + ('The markdown content of the chapter. ' + 'The content should include citations to the resources used in this chapter.' + 'Make sure the content is based on the reranked evidence.' + ), }, 'cited_urls': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': ( - 'List of resource urls actually cited in this chapter.' - 'Keep the same order as cited in content.' - ), + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + ('List of resource urls actually cited in this chapter.' + 'Keep the same order as cited in content.'), }, }, # Keep schema consistent with Python signature (reranked_evidence has no default) - 'required': ['chapter_id', 'reranked_evidence', 'content', 'cited_urls'], - 'additionalProperties': False, + 'required': [ + 'chapter_id', 'reranked_evidence', 'content', + 'cited_urls' + ], + 'additionalProperties': + False, }, ), Tool( tool_name='load_chunk', server_name=self.SERVER_NAME, - description=( - 'Load raw chunk content when evidence summaries are insufficient. ' - 'Reserved for future implementation.' - ), + description= + ('Load raw chunk content when evidence summaries are insufficient. ' + 'Reserved for future implementation.'), parameters={ 'type': 'object', 'properties': { 'chunk_ids': { 'type': 'array', - 'items': {'type': 'string'}, + 'items': { + 'type': 'string' + }, 'description': 'List of chunk IDs to load.', }, }, @@ -367,7 +430,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='commit_conflict', server_name=self.SERVER_NAME, - description='Record a conflict or contradiction between evidence.', + description= + 'Record a conflict or contradiction between evidence.', parameters={ 'type': 'object', 'properties': { @@ -376,17 +440,24 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'description': 'Description of the conflict.', }, 'evidence_ids': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'Note IDs involved in the conflict.', + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Note IDs involved in the conflict.', }, 'chapter_id': { 'type': 'integer', - 'description': 'Optional: Related chapter number.', + 'description': + 'Optional: Related chapter number.', }, 'resolution': { - 'type': 'string', - 'description': 'Optional: How the conflict was resolved.', + 'type': + 'string', + 'description': + 'Optional: How the conflict was resolved.', }, }, 'required': ['description', 'evidence_ids'], @@ -396,7 +467,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='update_outline', server_name=self.SERVER_NAME, - description='Update a specific chapter in the outline (title, goals, or evidence bindings).', + description= + 'Update a specific chapter in the outline (title, goals, or evidence bindings).', parameters={ 'type': 'object', 'properties': { @@ -406,35 +478,45 @@ async def _get_tools_inner(self) -> Dict[str, Any]: }, 'updates': { 'type': 'object', - 'description': 'Fields to update (title, goals, sections_description, candidate_evidence).', + 'description': + 'Fields to update (title, goals, sections_description, candidate_evidence).', 'properties': { 'title': { 'type': 'string', 'description': 'Title of the chapter.', }, 'goals': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'Main objectives of this chapter.', + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Main objectives of this chapter.', }, 'sections_description': { - 'type': 'string', - 'description': ( - 'Detailed section-by-section plan for ' - 'this chapter (NOT a single-sentence summary). ' - 'Write subsections as a numbered list in markdown. ' - 'For EACH subsection include: ' - '(a) subsection title, (b) 2-5 bullet key ' - 'points / questions to answer, ' - '(c) expected output form: narrative synthesis ' - 'is required; optionally add an artifact ' - '(e.g., table/checklist) to support the narrative.' - ), + 'type': + 'string', + 'description': + ('Detailed section-by-section plan for ' + 'this chapter (NOT a single-sentence summary). ' + 'Write subsections as a numbered list in markdown. ' + 'For EACH subsection include: ' + '(a) subsection title, (b) 2-5 bullet key ' + 'points / questions to answer, ' + '(c) expected output form: narrative synthesis ' + 'is required; optionally add an artifact ' + '(e.g., table/checklist) to support the narrative.' + ), }, 'candidate_evidence': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': 'List of note_ids relevant to this chapter.', + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'List of note_ids relevant to this chapter.', }, }, }, @@ -446,22 +528,24 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='assemble_draft', server_name=self.SERVER_NAME, - description=( - 'Assemble all chapters into a draft (draft.md) with TOC and references. ' - 'Returns the draft path along with a summary of recorded conflicts. ' - 'The model should then review the draft and conflicts to produce the final report.' - ), + description= + ('Assemble all chapters into a draft (draft.md) with TOC and references. ' + 'Returns the draft path along with a summary of recorded conflicts. ' + 'The model should then review the draft and conflicts to produce the final report.' + ), parameters={ 'type': 'object', 'properties': { 'include_toc': { 'type': 'boolean', - 'description': 'Whether to include table of contents.', + 'description': + 'Whether to include table of contents.', 'default': True, }, 'include_references': { 'type': 'boolean', - 'description': 'Whether to include references section.', + 'description': + 'Whether to include references section.', 'default': True, }, }, @@ -499,22 +583,30 @@ async def _get_tools_inner(self) -> Dict[str, Any]: } return tools - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: return await getattr(self, tool_name)(**(tool_args or {})) def _load_outline(self, paths: Dict[str, str]) -> Optional[Dict[str, Any]]: """Load outline.json.""" return _safe_read_json(paths['outline_json']) - def _save_outline(self, paths: Dict[str, str], outline: Dict[str, Any], render: bool = True) -> None: + def _save_outline(self, + paths: Dict[str, str], + outline: Dict[str, Any], + render: bool = True) -> None: """Save outline.json and render outline.md.""" outline['updated_at'] = _now_iso() _write_text(paths['outline_json'], _json_dumps(outline)) _write_text(paths['outline_md'], _render_outline_md(outline)) - _write_text(paths['outline_progress_md'], _render_outline_progress_md(outline)) + _write_text(paths['outline_progress_md'], + _render_outline_progress_md(outline)) if render: - render_markdown_todo(paths['outline_progress_md'], title='CURRENT REPORT OUTLINE', use_pager=False) + render_markdown_todo( + paths['outline_progress_md'], + title='CURRENT REPORT OUTLINE', + use_pager=False) def _load_evidence_index(self, paths: Dict[str, str]) -> Dict[str, Any]: """Load evidence index.""" @@ -542,9 +634,11 @@ def _load_full_evidence_index(self, paths: Dict[str, str]) -> Dict[str, Any]: data['analyses'] = legacy return data - def _load_note_content(self, paths: Dict[str, str], note_id: str) -> Optional[Dict[str, Any]]: + def _load_note_content(self, paths: Dict[str, str], + note_id: str) -> Optional[Dict[str, Any]]: """Load a single note's full content from markdown file.""" - note_path = os.path.join(paths['evidence_notes_dir'], f'note_{note_id}.md') + note_path = os.path.join(paths['evidence_notes_dir'], + f'note_{note_id}.md') if not os.path.exists(note_path): return None @@ -563,7 +657,8 @@ def _load_conflict(self, paths: Dict[str, str]) -> Dict[str, Any]: return {'updated_at': _now_iso(), 'conflicts': []} return data - def _save_conflict(self, paths: Dict[str, str], conflict: Dict[str, Any]) -> None: + def _save_conflict(self, paths: Dict[str, str], + conflict: Dict[str, Any]) -> None: """Save conflict.json.""" conflict['updated_at'] = _now_iso() _write_text(paths['conflict_json'], _json_dumps(conflict)) @@ -576,16 +671,14 @@ async def load_index(self) -> str: index = self._load_full_evidence_index(paths) notes = index.get('notes', {}) analyses = index.get('analyses', {}) - return _json_dumps( - { - 'status': 'ok', - 'updated_at': index.get('updated_at', ''), - 'total_notes': len(notes), - 'total_analyses': len(analyses), - 'notes': notes, - 'analyses': analyses, - } - ) + return _json_dumps({ + 'status': 'ok', + 'updated_at': index.get('updated_at', ''), + 'total_notes': len(notes), + 'total_analyses': len(analyses), + 'notes': notes, + 'analyses': analyses, + }) async def commit_outline( self, @@ -612,27 +705,34 @@ async def commit_outline( for idx, ch in enumerate(chapters_list, start=1): candidate_raw = ch.get('candidate_evidence', []) - kept, dropped = self._filter_candidate_evidence(paths, candidate_raw) + kept, dropped = self._filter_candidate_evidence( + paths, candidate_raw) if dropped: invalid_candidate_by_chapter[str(idx)] = dropped covered_evidence.update(kept) - outline_chapters.append( - { - 'chapter_id': idx, - 'title': ch.get('title', f'Chapter {idx}'), - 'goals': ch.get('goals', []), - 'sections_description': ch.get('sections_description', ''), - 'candidate_evidence': kept, - 'status': 'pending', - } - ) + outline_chapters.append({ + 'chapter_id': + idx, + 'title': + ch.get('title', f'Chapter {idx}'), + 'goals': + ch.get('goals', []), + 'sections_description': + ch.get('sections_description', ''), + 'candidate_evidence': + kept, + 'status': + 'pending', + }) # Check coverage uncovered = all_note_ids - covered_evidence coverage_warning = None if uncovered: - coverage_warning = f'Warning: the following evidence is not covered by any chapter: {list(uncovered)}' + coverage_warning = ( + f'Warning: the following evidence is not covered by any chapter: {list(uncovered)}' + ) outline = { 'title': title, @@ -646,7 +746,8 @@ async def commit_outline( result = { 'status': 'ok', - 'outline_path': os.path.relpath(paths['outline_json'], self.output_dir), + 'outline_path': os.path.relpath(paths['outline_json'], + self.output_dir), 'chapters_count': len(outline_chapters), 'total_evidence': len(all_note_ids), 'covered_evidence': len(covered_evidence), @@ -678,9 +779,12 @@ async def prepare_chapter_bundle( # Load outline outline = self._load_outline(paths) if outline is None: - return _json_dumps( - {'status': 'error', 'message': 'Outline not created yet. Please call commit_outline first.'} - ) + return _json_dumps({ + 'status': + 'error', + 'message': + 'Outline not created yet. Please call commit_outline first.' + }) # Find chapter chapter = None @@ -690,17 +794,25 @@ async def prepare_chapter_bundle( break if chapter is None: - return _json_dumps({'status': 'error', 'message': f'Chapter {chapter_id} not found.'}) + return _json_dumps({ + 'status': 'error', + 'message': f'Chapter {chapter_id} not found.' + }) - cand_kept, cand_dropped = self._filter_candidate_evidence(paths, chapter.get('candidate_evidence', [])) - rel_kept, rel_dropped = self._filter_candidate_evidence(paths, relevant_evidence or []) + cand_kept, cand_dropped = self._filter_candidate_evidence( + paths, chapter.get('candidate_evidence', [])) + rel_kept, rel_dropped = self._filter_candidate_evidence( + paths, relevant_evidence or []) # Load evidence content evidence_index = self._load_evidence_index(paths) notes_meta = evidence_index.get('notes', {}) _known_sorted = sorted(notes_meta.keys()) _sample = _known_sorted[:48] - _note_id_hint = 'Known note ids in evidence index (sample): ' + (', '.join(_sample) if _sample else '(none)') + _note_id_hint = ( + 'Known note ids in evidence index (sample): ' + + (', '.join(_sample) if _sample else '(none)') + ) if len(_known_sorted) > len(_sample): _note_id_hint += f' … (+{len(_known_sorted) - len(_sample)} more)' @@ -724,18 +836,24 @@ def _missing_note_entry(note_id: str, meta: Dict[str, Any]) -> Dict[str, Any]: note_data = self._load_note_content(paths, note_id) if note_data: - notes_content.append( - { - 'note_id': note_id, - 'title': meta.get('title', note_data.get('title', '')), - 'content': note_data.get('content', ''), - 'contradicts': note_data.get('contradicts', ''), - 'summary': meta.get('summary', note_data.get('summary', '')), - 'sources': meta.get('sources', note_data.get('sources', [])), - 'quality_score': meta.get('quality_score', note_data.get('quality_score')), - 'tags': meta.get('tags', note_data.get('tags', [])), - } - ) + notes_content.append({ + 'note_id': + note_id, + 'title': + meta.get('title', note_data.get('title', '')), + 'content': + note_data.get('content', ''), + 'contradicts': + note_data.get('contradicts', ''), + 'summary': + meta.get('summary', note_data.get('summary', '')), + 'sources': + meta.get('sources', note_data.get('sources', [])), + 'quality_score': + meta.get('quality_score', note_data.get('quality_score')), + 'tags': + meta.get('tags', note_data.get('tags', [])), + }) else: notes_content.append(_missing_note_entry(note_id, meta)) @@ -747,23 +865,31 @@ def _missing_note_entry(note_id: str, meta: Dict[str, Any]) -> Dict[str, Any]: note_data = self._load_note_content(paths, note_id) if note_data: - notes_content.append( - { - 'note_id': note_id, - 'title': meta.get('title', note_data.get('title', '')), - 'content': note_data.get('content', ''), - 'contradicts': note_data.get('contradicts', ''), - 'summary': meta.get('summary', note_data.get('summary', '')), - 'sources': meta.get('sources', note_data.get('sources', [])), - 'quality_score': meta.get('quality_score', note_data.get('quality_score')), - 'tags': meta.get('tags', note_data.get('tags', [])), - } - ) + notes_content.append({ + 'note_id': + note_id, + 'title': + meta.get('title', note_data.get('title', '')), + 'content': + note_data.get('content', ''), + 'contradicts': + note_data.get('contradicts', ''), + 'summary': + meta.get('summary', note_data.get('summary', '')), + 'sources': + meta.get('sources', note_data.get('sources', [])), + 'quality_score': + meta.get('quality_score', + note_data.get('quality_score')), + 'tags': + meta.get('tags', note_data.get('tags', [])), + }) else: notes_content.append(_missing_note_entry(note_id, meta)) # Build meta (only ids that resolved to on-disk notes for this bundle) - candidate_evidence = list(dict.fromkeys(list(cand_kept) + list(rel_kept))) + candidate_evidence = list( + dict.fromkeys(list(cand_kept) + list(rel_kept))) meta = { 'chapter_id': chapter_id, 'chapter_title': chapter['title'], @@ -776,7 +902,8 @@ def _missing_note_entry(note_id: str, meta: Dict[str, Any]) -> Dict[str, Any]: } # Save meta.json - meta_path = os.path.join(paths['chapters_dir'], f'chapter_{chapter_id:02d}_meta.json') + meta_path = os.path.join(paths['chapters_dir'], + f'chapter_{chapter_id:02d}_meta.json') with file_lock(paths['lock_dir'], f'chapter_{chapter_id}_meta'): _write_text(meta_path, _json_dumps(meta)) @@ -786,13 +913,20 @@ def _missing_note_entry(note_id: str, meta: Dict[str, Any]) -> Dict[str, Any]: self._save_outline(paths, outline) out_bundle: Dict[str, Any] = { - 'status': 'ok', - 'chapter_id': chapter_id, - 'chapter_title': chapter['title'], - 'chapter_goals': chapter.get('goals', []), - 'evidence_count': len(notes_content), - 'meta_path': os.path.relpath(meta_path, self.output_dir), - 'notes_content': notes_content, + 'status': + 'ok', + 'chapter_id': + chapter_id, + 'chapter_title': + chapter['title'], + 'chapter_goals': + chapter.get('goals', []), + 'evidence_count': + len(notes_content), + 'meta_path': + os.path.relpath(meta_path, self.output_dir), + 'notes_content': + notes_content, } skipped: Dict[str, List[str]] = {} if cand_dropped: @@ -821,7 +955,10 @@ async def commit_chapter( # Validate outline exists outline = self._load_outline(paths) if outline is None: - return _json_dumps({'status': 'error', 'message': 'Outline not created yet.'}) + return _json_dumps({ + 'status': 'error', + 'message': 'Outline not created yet.' + }) # Find and update chapter chapter_found = False @@ -836,10 +973,14 @@ async def commit_chapter( break if not chapter_found: - return _json_dumps({'status': 'error', 'message': f'Chapter {chapter_id} not found.'}) + return _json_dumps({ + 'status': 'error', + 'message': f'Chapter {chapter_id} not found.' + }) # Write chapter file - chapter_path = os.path.join(paths['chapters_dir'], f'chapter_{chapter_id:02d}.md') + chapter_path = os.path.join(paths['chapters_dir'], + f'chapter_{chapter_id:02d}.md') with file_lock(paths['lock_dir'], f'chapter_{chapter_id}'): _write_text(chapter_path, content) @@ -847,7 +988,8 @@ async def commit_chapter( with file_lock(paths['lock_dir'], 'report_outline'): self._save_outline(paths, outline) - meta_path = os.path.join(paths['chapters_dir'], f'chapter_{chapter_id:02d}_meta.json') + meta_path = os.path.join(paths['chapters_dir'], + f'chapter_{chapter_id:02d}_meta.json') meta = _safe_read_json(meta_path) meta = meta if isinstance(meta, dict) else {} meta['reranked_evidence'] = list(reranked_evidence or []) @@ -855,27 +997,31 @@ async def commit_chapter( with file_lock(paths['lock_dir'], f'chapter_{chapter_id}_meta'): _write_text(meta_path, _json_dumps(meta)) - return _json_dumps( - { - 'status': 'ok', - 'chapter_id': chapter_id, - 'chapter_title': chapter_title, - 'path': os.path.relpath(chapter_path, self.output_dir), - 'content_length': len(content), - 'reranked_evidence': reranked_evidence or [], - 'cited_urls': cited_urls or [], - } - ) + return _json_dumps({ + 'status': + 'ok', + 'chapter_id': + chapter_id, + 'chapter_title': + chapter_title, + 'path': + os.path.relpath(chapter_path, self.output_dir), + 'content_length': + len(content), + 'reranked_evidence': + reranked_evidence or [], + 'cited_urls': + cited_urls or [], + }) async def load_chunk(self, chunk_ids: List[str]) -> str: """Load raw chunk content. Reserved for future implementation.""" - return _json_dumps( - { - 'status': 'not_implemented', - 'message': 'Chunk storage not enabled in this version. Use evidence notes directly.', - 'chunk_ids': chunk_ids, - } - ) + return _json_dumps({ + 'status': 'not_implemented', + 'message': + 'Chunk storage not enabled in this version. Use evidence notes directly.', + 'chunk_ids': chunk_ids, + }) async def commit_conflict( self, @@ -904,14 +1050,16 @@ async def commit_conflict( conflicts['conflicts'].append(conflict_entry) self._save_conflict(paths, conflicts) - return _json_dumps( - { - 'status': 'ok', - 'conflict_id': conflict_id, - 'total_conflicts': len(conflicts['conflicts']), - 'conflict_path': os.path.relpath(paths['conflict_json'], self.output_dir), - } - ) + return _json_dumps({ + 'status': + 'ok', + 'conflict_id': + conflict_id, + 'total_conflicts': + len(conflicts['conflicts']), + 'conflict_path': + os.path.relpath(paths['conflict_json'], self.output_dir), + }) async def update_outline( self, @@ -925,7 +1073,10 @@ async def update_outline( with file_lock(paths['lock_dir'], 'report_outline'): outline = self._load_outline(paths) if outline is None: - return _json_dumps({'status': 'error', 'message': 'Outline not created yet.'}) + return _json_dumps({ + 'status': 'error', + 'message': 'Outline not created yet.' + }) chapter_found = False invalid_candidate_removed: List[str] = [] @@ -936,16 +1087,23 @@ async def update_outline( if 'goals' in updates: ch['goals'] = updates['goals'] if 'sections_description' in updates: - ch['sections_description'] = updates['sections_description'] + ch['sections_description'] = updates[ + 'sections_description'] if 'candidate_evidence' in updates: - kept, dropped = self._filter_candidate_evidence(paths, updates['candidate_evidence']) + kept, dropped = self._filter_candidate_evidence( + paths, updates['candidate_evidence']) ch['candidate_evidence'] = kept invalid_candidate_removed = dropped chapter_found = True break if not chapter_found: - return _json_dumps({'status': 'error', 'message': f'Chapter {chapter_id} not found.'}) + return _json_dumps({ + 'status': + 'error', + 'message': + f'Chapter {chapter_id} not found.' + }) self._save_outline(paths, outline) @@ -957,7 +1115,8 @@ async def update_outline( if invalid_candidate_removed: out['invalid_candidate_evidence_removed'] = invalid_candidate_removed out['invalid_candidate_evidence_note'] = ( - 'These ids were removed from candidate_evidence because no matching evidence/notes/note_.md exists.' + 'These ids were removed from candidate_evidence because no ' + 'matching evidence/notes/note_.md exists.' ) return _json_dumps(out) @@ -971,38 +1130,47 @@ async def assemble_draft( outline = self._load_outline(paths) if outline is None: - return _json_dumps({'status': 'error', 'message': 'Outline not created yet.'}) + return _json_dumps({ + 'status': 'error', + 'message': 'Outline not created yet.' + }) # Collect chapter contents chapters_content = [] missing_chapters = [] for ch in outline.get('chapters', []): - chapter_path = os.path.join(paths['chapters_dir'], f"chapter_{ch['chapter_id']:02d}.md") + chapter_path = os.path.join(paths['chapters_dir'], + f"chapter_{ch['chapter_id']:02d}.md") if os.path.exists(chapter_path): with open(chapter_path, 'r', encoding='utf-8') as f: - chapters_content.append( - { - 'id': ch['chapter_id'], - 'title': ch['title'], - 'content': f.read(), - 'reranked_evidence': ch.get('reranked_evidence', []), - 'cited_urls': ch.get('cited_urls', []), - } - ) + chapters_content.append({ + 'id': + ch['chapter_id'], + 'title': + ch['title'], + 'content': + f.read(), + 'reranked_evidence': + ch.get('reranked_evidence', []), + 'cited_urls': + ch.get('cited_urls', []), + }) else: missing_chapters.append(ch['chapter_id']) if missing_chapters: - return _json_dumps( - { - 'status': 'error', - 'message': f'The following chapters are not completed yet: {missing_chapters}', - } - ) + return _json_dumps({ + 'status': + 'error', + 'message': + f'The following chapters are not completed yet: {missing_chapters}', + }) # Build draft - draft_lines = [f"# {outline.get('title', 'Research Report')} (Draft)", ''] + draft_lines = [ + f"# {outline.get('title', 'Research Report')} (Draft)", '' + ] # Table of contents if include_toc: @@ -1010,7 +1178,8 @@ async def assemble_draft( draft_lines.append('') for ch in chapters_content: anchor = ch['title'].replace(' ', '-').lower() - draft_lines.append(f"- [Chapter {ch['id']} {ch['title']}](#{anchor})") + draft_lines.append( + f"- [Chapter {ch['id']} {ch['title']}](#{anchor})") draft_lines.append('') # Chapters @@ -1025,7 +1194,7 @@ async def assemble_draft( cited_urls = set() for ch in chapters_content: - for url in ch.get('cited_urls') or []: + for url in (ch.get('cited_urls') or []): cited_urls.add(url) all_cited = set() @@ -1057,31 +1226,33 @@ async def assemble_draft( conflicts_list = conflicts_data.get('conflicts', []) conflicts_summary = [] for c in conflicts_list: - conflicts_summary.append( - { - 'id': c.get('id'), - 'description': c.get('description'), - 'chapter_id': c.get('chapter_id'), - 'resolution': c.get('resolution'), - } - ) - - return _json_dumps( - { - 'status': 'ok', - 'draft_path': os.path.relpath(paths['draft_md'], self.output_dir), - 'chapters_count': len(chapters_content), - 'content_length': len(draft_content), - 'conflicts_count': len(conflicts_list), - 'conflicts_summary': conflicts_summary, - 'next_step_reminder': ( - 'Review the draft and conflicts, then generate the final report. ' - 'Note: the draft cannot be used as the final report; ' - 'do not replace report content with references or pointers to other content or files ' - '(e.g., "details are in chapter_2.md", "see draft.md for more details").' - ), - } - ) + conflicts_summary.append({ + 'id': c.get('id'), + 'description': c.get('description'), + 'chapter_id': c.get('chapter_id'), + 'resolution': c.get('resolution'), + }) + + return _json_dumps({ + 'status': + 'ok', + 'draft_path': + os.path.relpath(paths['draft_md'], self.output_dir), + 'chapters_count': + len(chapters_content), + 'content_length': + len(draft_content), + 'conflicts_count': + len(conflicts_list), + 'conflicts_summary': + conflicts_summary, + 'next_step_reminder': + ('Review the draft and conflicts, then generate the final report. ' + 'Note: the draft cannot be used as the final report; ' + 'do not replace report content with references or pointers to other content or files ' + '(e.g., "details are in chapter_2.md", "see draft.md for more details").' + ), + }) async def get_status(self) -> str: """Get current report generation progress.""" @@ -1091,40 +1262,52 @@ async def get_status(self) -> str: conflicts = self._load_conflict(paths) if outline is None: - return _json_dumps( - { - 'status': 'not_started', - 'outline_exists': False, - 'chapters': [], - 'conflicts_count': len(conflicts.get('conflicts', [])), - } - ) + return _json_dumps({ + 'status': + 'not_started', + 'outline_exists': + False, + 'chapters': [], + 'conflicts_count': + len(conflicts.get('conflicts', [])), + }) chapters_status = [] for ch in outline.get('chapters', []): - chapter_path = os.path.join(paths['chapters_dir'], f"chapter_{ch['chapter_id']:02d}.md") - chapters_status.append( - { - 'chapter_id': ch['chapter_id'], - 'title': ch['title'], - 'status': ch.get('status', 'pending'), - 'file_exists': os.path.exists(chapter_path), - 'candidate_evidence_count': len(ch.get('candidate_evidence', [])), - } - ) - - completed = sum(1 for ch in chapters_status if ch['status'] == 'completed') + chapter_path = os.path.join(paths['chapters_dir'], + f"chapter_{ch['chapter_id']:02d}.md") + chapters_status.append({ + 'chapter_id': + ch['chapter_id'], + 'title': + ch['title'], + 'status': + ch.get('status', 'pending'), + 'file_exists': + os.path.exists(chapter_path), + 'candidate_evidence_count': + len(ch.get('candidate_evidence', [])), + }) + + completed = sum(1 for ch in chapters_status + if ch['status'] == 'completed') total = len(chapters_status) - return _json_dumps( - { - 'status': 'in_progress' if completed < total else 'completed', - 'outline_exists': True, - 'report_title': outline.get('title', ''), - 'progress': f'{completed}/{total}', - 'chapters': chapters_status, - 'conflicts_count': len(conflicts.get('conflicts', [])), - 'draft_exists': os.path.exists(paths['draft_md']), - 'report_exists': os.path.exists(paths['report_md']), - } - ) + return _json_dumps({ + 'status': + 'in_progress' if completed < total else 'completed', + 'outline_exists': + True, + 'report_title': + outline.get('title', ''), + 'progress': + f'{completed}/{total}', + 'chapters': + chapters_status, + 'conflicts_count': + len(conflicts.get('conflicts', [])), + 'draft_exists': + os.path.exists(paths['draft_md']), + 'report_exists': + os.path.exists(paths['report_md']), + }) diff --git a/projects/fin_research/aggregator.py b/projects/fin_research/aggregator.py index 10ac55f40..2380e58e3 100644 --- a/projects/fin_research/aggregator.py +++ b/projects/fin_research/aggregator.py @@ -1,15 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json import os from typing import Any, AsyncGenerator, List, Union +import json from callbacks.file_parser import extract_code_blocks -from omegaconf import DictConfig - from ms_agent.agent.llm_agent import LLMAgent from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG +from omegaconf import DictConfig logger = get_logger() @@ -19,25 +18,32 @@ class AggregatorAgent(LLMAgent): Aggregator Agent that aggregates the reports from SearchAgent and CollectorAgent. """ - def __init__( - self, config: DictConfig = DictConfig({}), tag: str = DEFAULT_TAG, trust_remote_code: bool = False, **kwargs - ): + def __init__(self, + config: DictConfig = DictConfig({}), + tag: str = DEFAULT_TAG, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) async def run( - self, inputs: Union[str, List[str], List[Message], List[List[Message]]], **kwargs + self, inputs: Union[str, List[str], List[Message], + List[List[Message]]], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: reports = {} # Dict of reports if isinstance(inputs, list): if isinstance(inputs[0], str): - refractory_inputs = [[Message(role='user', content=item)] for item in inputs] # multiple parent nodes + refractory_inputs = [[Message(role='user', content=item)] + for item in inputs + ] # multiple parent nodes elif isinstance(inputs[0], Message): refractory_inputs = [inputs] # single parent node elif len(inputs) > 1 and isinstance(inputs[0], list): refractory_inputs = inputs # multiple parent nodes else: - raise ValueError(f"Invalid input type: List[{type(inputs[0]) if inputs else 'empty list'}]") + raise ValueError( + f"Invalid input type: List[{type(inputs[0]) if inputs else 'empty list'}]" + ) elif isinstance(inputs, str): refractory_inputs = [[Message(role='user', content=inputs)]] else: @@ -68,11 +74,9 @@ async def run( with open(report_path, 'r', encoding='utf-8') as f: report = f.read() - report_type = ( - '**Financial Data Analysis Report**' - if 'analysis' in report_path - else '**Online Sentiment Analysis Report**' - ) + report_type = ('**Financial Data Analysis Report**' + if 'analysis' in report_path else + '**Online Sentiment Analysis Report**') reports[report_type] = report if report else message.content plan = {} @@ -84,15 +88,14 @@ async def run( if content: # Only load if file is not empty plan.update(json.loads(content)) except Exception as e: - logger.warning(f'Failed to load plan.json: {e}. Using empty plan.') + logger.warning( + f'Failed to load plan.json: {e}. Using empty plan.') return await super().run( - messages=( - f'The reports from the SearchAgent and AnalystAgent are as follows:\n' - f'{json.dumps(reports, ensure_ascii=False, indent=2)}\n' - f'Please integrate the reports into a comprehensive financial analysis report.\n' - f'Please review the original plan for the financial analysis task:\n' - f'{json.dumps(plan, ensure_ascii=False, indent=2)}\n' - ), - kwargs=kwargs, - ) + messages= + (f'The reports from the SearchAgent and AnalystAgent are as follows:\n' + f'{json.dumps(reports, ensure_ascii=False, indent=2)}\n' + f'Please integrate the reports into a comprehensive financial analysis report.\n' + f'Please review the original plan for the financial analysis task:\n' + f'{json.dumps(plan, ensure_ascii=False, indent=2)}\n'), + kwargs=kwargs) diff --git a/projects/fin_research/callbacks/aggregator_callback.py b/projects/fin_research/callbacks/aggregator_callback.py index 913a0252c..b827f18a6 100644 --- a/projects/fin_research/callbacks/aggregator_callback.py +++ b/projects/fin_research/callbacks/aggregator_callback.py @@ -3,19 +3,19 @@ import re from typing import List -from omegaconf import DictConfig - from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.tools.filesystem_tool import FileSystemTool from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() class AggregatorCallback(Callback): - """Save output plan to local disk.""" + """Save output plan to local disk. + """ def __init__(self, config: DictConfig): super().__init__(config) @@ -37,8 +37,7 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): r'\s*\[ACT=(?:outline|partial_report|final_report)\]\s*:?\s*.*?(?:\n|\.)', '', message.content, - flags=re.MULTILINE, - ).strip() + flags=re.MULTILINE).strip() f.write(filtered_content) break logger.info(f'Aggregator report saved to {self.report_path}') diff --git a/projects/fin_research/callbacks/analyst_callback.py b/projects/fin_research/callbacks/analyst_callback.py index b0952216d..31af7306b 100644 --- a/projects/fin_research/callbacks/analyst_callback.py +++ b/projects/fin_research/callbacks/analyst_callback.py @@ -1,30 +1,34 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json import os import re from pathlib import Path from typing import List -from omegaconf import DictConfig - +import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() class AnalystCallback(Callback): - """Save output plan to local disk.""" + """Save output plan to local disk. + """ def __init__(self, config: DictConfig): super().__init__(config) - self.report_path = self.config.get('report_path', os.path.join(self.config.output_dir, 'analysis_report.md')) + self.report_path = self.config.get( + 'report_path', + os.path.join(self.config.output_dir, 'analysis_report.md')) def _resolve_data_root(self) -> str: - code_exec_cfg = getattr(getattr(self.config, 'tools', {}), 'code_executor', None) - impl = getattr(code_exec_cfg, 'implementation', 'sandbox') if code_exec_cfg else 'sandbox' + code_exec_cfg = getattr( + getattr(self.config, 'tools', {}), 'code_executor', None) + impl = getattr(code_exec_cfg, 'implementation', + 'sandbox') if code_exec_cfg else 'sandbox' if isinstance(impl, str) and impl.lower() == 'sandbox': return '/data' @@ -36,7 +40,8 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): for message in messages: if message.role == 'system': message.content = message.content.replace('\\\n', '') - message.content = message.content.replace('', self._resolve_data_root()) + message.content = message.content.replace( + '', self._resolve_data_root()) elif message.role == 'assistant': if '[ACT=summary]' in message.content: summary_messages['collector_summary'] = message.content @@ -44,30 +49,32 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): summary_messages['collector_plan'] = message.content if os.path.exists(os.path.join(self.config.output_dir, 'plan.json')): - with open(os.path.join(self.config.output_dir, 'plan.json'), 'r') as f: + with open(os.path.join(self.config.output_dir, 'plan.json'), + 'r') as f: plan = json.load(f) if not plan: - logger.error('The plan.json file is empty, please check the file.') + logger.error( + 'The plan.json file is empty, please check the file.') user_message = Message( role='user', - content=( - f'The complete plan for the current overall financial analysis task is as follows:\n{plan}\n' - f'Please follow the plan to complete the data analysis task.\n' - f'IMPORTANT: Review the input analysis specification provided under "financial_data_dimension"' - ), - ) + content= + (f'The complete plan for the current overall financial analysis task is as follows:\n{plan}\n' + f'Please follow the plan to complete the data analysis task.\n' + f'IMPORTANT: Review the input analysis specification provided under "financial_data_dimension"' + )) else: user_message = Message( role='user', - content=( - 'Please conduct data analysis in accordance with the research plan followed during the data ' - 'collection phase and the results obtained from data collection.' - ), - ) + content= + ('Please conduct data analysis in accordance with the research plan followed during the data ' + 'collection phase and the results obtained from data collection.' + )) # Add the summary of the data collection phase to the user message (add plan if exists) if summary_messages['collector_summary']: - messages[:] = [message for message in messages if message.role == 'system'] + messages[:] = [ + message for message in messages if message.role == 'system' + ] summary_messages = ( f'The summary of the data collection phase is as follows:\n' f'{json.dumps(summary_messages, ensure_ascii=False, indent=2)}' @@ -79,11 +86,15 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): for message in messages[::-1]: if message.role == 'assistant' and not message.tool_calls: with open(self.report_path, 'w') as f: - filtered_content = re.sub(r'\s*\[ACT=(?:code|collect|report|fix)\]\s*', '', message.content).strip() + filtered_content = re.sub( + r'\s*\[ACT=(?:code|collect|report|fix)\]\s*', '', + message.content).strip() f.write(filtered_content) break user_message = Message( - role='user', content=json.dumps({'report_path': self.report_path}, ensure_ascii=False, indent=2) - ) + role='user', + content=json.dumps({'report_path': self.report_path}, + ensure_ascii=False, + indent=2)) messages.append(user_message) diff --git a/projects/fin_research/callbacks/collector_callback.py b/projects/fin_research/callbacks/collector_callback.py index 6aeb11bec..caffbcb71 100644 --- a/projects/fin_research/callbacks/collector_callback.py +++ b/projects/fin_research/callbacks/collector_callback.py @@ -1,28 +1,30 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json import os from pathlib import Path from typing import List -from omegaconf import DictConfig - +import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() class CollectorCallback(Callback): - """Save output plan to local disk.""" + """Save output plan to local disk. + """ def __init__(self, config: DictConfig): super().__init__(config) def _resolve_data_root(self) -> str: - code_exec_cfg = getattr(getattr(self.config, 'tools', {}), 'code_executor', None) - impl = getattr(code_exec_cfg, 'implementation', 'sandbox') if code_exec_cfg else 'sandbox' + code_exec_cfg = getattr( + getattr(self.config, 'tools', {}), 'code_executor', None) + impl = getattr(code_exec_cfg, 'implementation', + 'sandbox') if code_exec_cfg else 'sandbox' if isinstance(impl, str) and impl.lower() == 'sandbox': return '/data' @@ -33,13 +35,16 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): for message in messages: if message.role == 'system': message.content = message.content.replace('\\\n', '') - message.content = message.content.replace('', self._resolve_data_root()) + message.content = message.content.replace( + '', self._resolve_data_root()) if os.path.exists(os.path.join(self.config.output_dir, 'plan.json')): - with open(os.path.join(self.config.output_dir, 'plan.json'), 'r') as f: + with open(os.path.join(self.config.output_dir, 'plan.json'), + 'r') as f: plan = json.load(f) if not plan: - logger.error('The plan.json file is empty, please check the file.') + logger.error( + 'The plan.json file is empty, please check the file.') if messages[-1].role == 'user': messages[-1].content = ( f'The complete plan for the current overall financial analysis task is as follows:\n{plan}\n' @@ -48,23 +53,20 @@ async def on_task_begin(self, runtime: Runtime, messages: List[Message]): elif messages[-1].role in ('assistant', 'tool', 'system'): user_message = Message( role='user', - content=( - f'The complete plan for the current global financial analysis task is as follows:\n{plan}\n' - f'Please follow the plan to complete the data collection task.\n' - ), - ) + content= + (f'The complete plan for the current global financial analysis task is as follows:\n{plan}\n' + f'Please follow the plan to complete the data collection task.\n' + )) messages.append(user_message) messages[:] = [ - messages[i] - for i in range(len(messages)) - if (messages[i].role == 'system') or (i == (len(messages) - 1) and messages[i].role == 'user') + messages[i] for i in range(len(messages)) + if (messages[i].role == 'system') or ( + i == (len(messages) - 1) and messages[i].role == 'user') ] else: user_message = Message( role='user', - content=( - 'Please conduct data collection in accordance with the research plan ' - 'provided in orchestrator\'s output.' - ), - ) + content= + ('Please conduct data collection in accordance with the research plan ' + 'provided in orchestrator\'s output.')) messages.append(user_message) diff --git a/projects/fin_research/callbacks/file_parser.py b/projects/fin_research/callbacks/file_parser.py index 83465be55..67f1bdd69 100644 --- a/projects/fin_research/callbacks/file_parser.py +++ b/projects/fin_research/callbacks/file_parser.py @@ -2,7 +2,9 @@ from typing import List, Optional, Tuple -def extract_code_blocks(text: str, target_filename: Optional[str] = None) -> Tuple[List, str]: +def extract_code_blocks(text: str, + target_filename: Optional[str] = None + ) -> Tuple[List, str]: """Extract code blocks from the given text. ```py:a.py diff --git a/projects/fin_research/callbacks/orchestrator_callback.py b/projects/fin_research/callbacks/orchestrator_callback.py index d3456fba2..d9b507dce 100644 --- a/projects/fin_research/callbacks/orchestrator_callback.py +++ b/projects/fin_research/callbacks/orchestrator_callback.py @@ -3,19 +3,19 @@ from typing import List from file_parser import extract_code_blocks -from omegaconf import DictConfig - from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.tools.filesystem_tool import FileSystemTool from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() class OrchestratorCallback(Callback): - """Save output plan to local disk.""" + """Save output plan to local disk. + """ def __init__(self, config: DictConfig): super().__init__(config) @@ -37,7 +37,8 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): all_files, _ = extract_code_blocks(content) results = [] for f in all_files: - result = await self.file_system.write_file(f['filename'], f['code']) + result = await self.file_system.write_file(f['filename'], + f['code']) results.append(result) r = '\n'.join(results) diff --git a/projects/fin_research/searcher.py b/projects/fin_research/searcher.py index 4966e3873..23703fa64 100644 --- a/projects/fin_research/searcher.py +++ b/projects/fin_research/searcher.py @@ -1,16 +1,16 @@ -import json import os from typing import List, Union +import json from callbacks.file_parser import extract_code_blocks -from omegaconf import DictConfig - from ms_agent.agent.code_agent import CodeAgent from ms_agent.llm import Message from ms_agent.llm.openai import OpenAIChat from ms_agent.tools.search_engine import get_web_search_tool from ms_agent.utils import get_logger -from ms_agent.workflow.deep_research.research_workflow_beta import ResearchWorkflowBeta +from ms_agent.workflow.deep_research.research_workflow_beta import \ + ResearchWorkflowBeta +from omegaconf import DictConfig logger = get_logger() @@ -20,39 +20,51 @@ class SearchAgent(CodeAgent): """Agent wrapper that delegates work to ResearchWorkflowBeta.""" - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) if isinstance(self.config, DictConfig): if hasattr(self.config, 'llm'): llm_config = self.config.llm - api_key = getattr(llm_config, 'openai_api_key', '') or os.getenv('OPENAI_API_KEY') - base_url = getattr(llm_config, 'openai_base_url', '') or os.getenv('OPENAI_BASE_URL') - model = getattr(llm_config, 'model', '') or 'Qwen/Qwen3-235B-A22B-Instruct-2507' - self.chat_client = OpenAIChat(api_key=api_key, base_url=base_url, model=model) + api_key = getattr(llm_config, 'openai_api_key', + '') or os.getenv('OPENAI_API_KEY') + base_url = getattr(llm_config, 'openai_base_url', + '') or os.getenv('OPENAI_BASE_URL') + model = getattr(llm_config, 'model', + '') or 'Qwen/Qwen3-235B-A22B-Instruct-2507' + self.chat_client = OpenAIChat( + api_key=api_key, base_url=base_url, model=model) else: - raise ValueError('LLM configuration not found, SearchAgent requires OpenAI compatible API.') + raise ValueError( + 'LLM configuration not found, SearchAgent requires OpenAI compatible API.' + ) - if hasattr(self.config, 'tools') and hasattr(self.config.tools, 'search_engine'): + if hasattr(self.config, 'tools') and hasattr( + self.config.tools, 'search_engine'): self.search_engine = get_web_search_tool( - config_file=getattr(self.config.tools.search_engine, 'config_file', '') - ) + config_file=getattr(self.config.tools.search_engine, + 'config_file', '')) else: raise ValueError('Search engine configuration not found.') self.workdir = getattr(self.config, 'output_dir', './output') self.use_ray = getattr(self.config, 'use_ray', False) - self.report_prefix = getattr(self.config, 'report_prefix', 'sentiment_') + self.report_prefix = getattr(self.config, 'report_prefix', + 'sentiment_') - async def run(self, inputs: Union[str, List[Message]], **kwargs) -> List[Message]: + async def run(self, inputs: Union[str, List[Message]], + **kwargs) -> List[Message]: workflow = ResearchWorkflowBeta( client=self.chat_client, search_engine=self.search_engine, workdir=self.workdir, use_ray=self.use_ray, enable_multimodal=False, - report_prefix=self.report_prefix, - ) + report_prefix=self.report_prefix) if inputs is None: return [Message(role='assistant', content='')] @@ -62,24 +74,29 @@ async def run(self, inputs: Union[str, List[Message]], **kwargs) -> List[Message instruction = {} for message in inputs[::-1]: if message.role == 'assistant': - instruction = json.loads(extract_code_blocks(message.content)[0][0].get('code', {})) + instruction = json.loads( + extract_code_blocks(message.content)[0][0].get( + 'code', {})) break - if not instruction and os.path.exists(os.path.join(self.workdir, 'plan.json')): + if not instruction and os.path.exists( + os.path.join(self.workdir, 'plan.json')): with open(os.path.join(self.workdir, 'plan.json'), 'r') as f: instruction = json.load(f) user_prompt = json.dumps( { - 'public_sentiment_dimension': instruction.get('public_sentiment_dimension', {}), + 'public_sentiment_dimension': + instruction.get('public_sentiment_dimension', {}), }, ensure_ascii=False, - indent=2, - ) + indent=2) elif isinstance(inputs, str): user_prompt = inputs else: - raise ValueError('Invalid input type, SearchAgent requires a string or list of messages.') + raise ValueError( + 'Invalid input type, SearchAgent requires a string or list of messages.' + ) report_path = await workflow.run( user_prompt=user_prompt, @@ -90,5 +107,7 @@ async def run(self, inputs: Union[str, List[Message]], **kwargs) -> List[Message ) result_content = report_path if report_path else 'No report generated.' - result_content = json.dumps({'report_path': report_path}, ensure_ascii=False, indent=2) + result_content = json.dumps({'report_path': report_path}, + ensure_ascii=False, + indent=2) return [Message(role='user', content=result_content)] diff --git a/projects/fin_research/time_handler.py b/projects/fin_research/time_handler.py index 93c8b92b7..dd01bc8ae 100644 --- a/projects/fin_research/time_handler.py +++ b/projects/fin_research/time_handler.py @@ -2,9 +2,8 @@ from datetime import datetime from typing import Any -from omegaconf import DictConfig - from ms_agent.config.config import ConfigLifecycleHandler +from omegaconf import DictConfig class TimeHandler(ConfigLifecycleHandler): @@ -25,7 +24,8 @@ def task_begin(self, config: DictConfig, tag: str) -> DictConfig: def traverse_and_replace(_config: Any): if isinstance(_config, DictConfig): for name, value in _config.items(): - if isinstance(value, DictConfig) or isinstance(value, list): + if isinstance(value, DictConfig) or isinstance( + value, list): traverse_and_replace(value) elif isinstance(value, str): new_value = value @@ -33,7 +33,8 @@ def traverse_and_replace(_config: Any): for var_name, var_value in time_vars.items(): placeholder = f'<{var_name}>' if placeholder in new_value: - new_value = new_value.replace(placeholder, var_value) + new_value = new_value.replace( + placeholder, var_value) setattr(_config, name, new_value) elif isinstance(_config, list): @@ -46,7 +47,8 @@ def traverse_and_replace(_config: Any): for var_name, var_value in time_vars.items(): placeholder = f'<{var_name}>' if placeholder in new_value: - new_value = new_value.replace(placeholder, var_value) + new_value = new_value.replace( + placeholder, var_value) _config[i] = new_value traverse_and_replace(config) diff --git a/projects/fin_research/tools/principle_skill.py b/projects/fin_research/tools/principle_skill.py index 99e77ac9e..19882f497 100644 --- a/projects/fin_research/tools/principle_skill.py +++ b/projects/fin_research/tools/principle_skill.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # flake8: noqa -import json import os from typing import Any, Dict, List, Optional, Tuple +import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger @@ -81,7 +81,8 @@ class PrincipleSkill(ToolBase): def __init__(self, config): super().__init__(config) - tools_cfg = getattr(config, 'tools', None) if config is not None else None + tools_cfg = getattr(config, 'tools', + None) if config is not None else None self.exclude_func(getattr(tools_cfg, 'principle_skill', None)) configured_dir = None @@ -95,7 +96,9 @@ def __init__(self, config): self.principle_dir = configured_dir or default_dir # Build a mapping from normalized user inputs to on-disk filenames and display names - self._name_to_file: Dict[str, Tuple[str, str]] = self._build_principle_index() + self._name_to_file: Dict[str, + Tuple[str, + str]] = self._build_principle_index() async def connect(self): # Warn once if the directory cannot be found; still operate to allow deferred config @@ -111,61 +114,66 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='load_principles', server_name='principle_skill', - description=( - f'Load one or more analysis principles (concept + how to apply to ' - f'financial analysis) and return their curated Markdown content.\n\n' - f'This is a single-aggregator tool designed to fetch multiple principles ' - f'in one call. Provide a list of requested principles via the "principles" ' - f'parameter. The tool supports common synonyms and is case-insensitive.\n\n' - f'Examples of valid principle identifiers: "MECE", "Pyramid", "Minto", ' - f'"SWOT", "Value Chain", "Pareto", "80-20", "80/20", "Boston Matrix", "BCG".\n\n' - f'When format is "markdown" (default), the tool returns a single combined ' - f'Markdown string (optionally including section titles). When format is ' - f'"json", the tool returns a JSON object mapping principle to content.\n' - f'{PRINCIPLE_GUIDE}\n' - f'{ROUTING_GUIDE}\n' - ), + description= + (f'Load one or more analysis principles (concept + how to apply to ' + f'financial analysis) and return their curated Markdown content.\n\n' + f'This is a single-aggregator tool designed to fetch multiple principles ' + f'in one call. Provide a list of requested principles via the "principles" ' + f'parameter. The tool supports common synonyms and is case-insensitive.\n\n' + f'Examples of valid principle identifiers: "MECE", "Pyramid", "Minto", ' + f'"SWOT", "Value Chain", "Pareto", "80-20", "80/20", "Boston Matrix", "BCG".\n\n' + f'When format is "markdown" (default), the tool returns a single combined ' + f'Markdown string (optionally including section titles). When format is ' + f'"json", the tool returns a JSON object mapping principle to content.\n' + f'{PRINCIPLE_GUIDE}\n' + f'{ROUTING_GUIDE}\n'), parameters={ 'type': 'object', 'properties': { 'principles': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': ( - 'List of principles to load. Case-insensitive; supports synonyms.\n' - 'Allowed identifiers include (non-exhaustive):\n' - '- MECE\n- Pyramid\n- Minto\n- SWOT\n- Value Chain\n' - '- Pareto\n- 80-20\n- 80/20\n- Boston Matrix\n- BCG\n' - ), + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + ('List of principles to load. Case-insensitive; supports synonyms.\n' + 'Allowed identifiers include (non-exhaustive):\n' + '- MECE\n- Pyramid\n- Minto\n- SWOT\n- Value Chain\n' + '- Pareto\n- 80-20\n- 80/20\n- Boston Matrix\n- BCG\n' + ), }, 'format': { - 'type': 'string', + 'type': + 'string', 'enum': ['markdown', 'json'], - 'description': ( - 'Output format: "markdown" (combined Markdown string) or "json" ' - '(JSON object mapping principle to content). Default: "markdown".' - ), + 'description': + ('Output format: "markdown" (combined Markdown string) or "json" ' + '(JSON object mapping principle to content). Default: "markdown".' + ), }, 'include_titles': { - 'type': 'boolean', - 'description': ( - 'When format="markdown", if true, each section is prefixed with a ' - 'Markdown heading of the canonical principle title. Default: true.' - ), + 'type': + 'boolean', + 'description': + ('When format="markdown", if true, each section is prefixed with a ' + 'Markdown heading of the canonical principle title. Default: true.' + ), }, 'join_with': { - 'type': 'string', - 'description': ( - 'When format="markdown", the delimiter used to join multiple ' - 'sections. Default: "\n\n---\n\n".' - ), + 'type': + 'string', + 'description': + ('When format="markdown", the delimiter used to join multiple ' + 'sections. Default: "\n\n---\n\n".'), }, 'strict': { - 'type': 'boolean', - 'description': ( - 'If true, unknown principles cause an error. If false, unknown ' - 'items are ignored with a note in the output. Default: false.' - ), + 'type': + 'boolean', + 'description': + ('If true, unknown principles cause an error. If false, unknown ' + 'items are ignored with a note in the output. Default: false.' + ), }, }, 'required': ['principles'], @@ -176,7 +184,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: } return tools - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: return await getattr(self, tool_name)(**tool_args) async def load_principles( @@ -195,7 +204,10 @@ async def load_principles( if not principles: return json.dumps( - {'success': False, 'error': 'No principles provided.'}, + { + 'success': False, + 'error': 'No principles provided.' + }, ensure_ascii=False, indent=2, ) @@ -211,7 +223,12 @@ async def load_principles( if unknown and strict: return json.dumps( - {'success': False, 'error': 'Unknown principles (strict mode): ' + ', '.join(unknown)}, + { + 'success': + False, + 'error': + 'Unknown principles (strict mode): ' + ', '.join(unknown) + }, ensure_ascii=False, indent=2, ) @@ -224,11 +241,15 @@ async def load_principles( content = f.read().strip() loaded[canonical_title] = content except Exception as e: # noqa - loaded[canonical_title] = f'Failed to load {filename}: {str(e)}' + loaded[ + canonical_title] = f'Failed to load {filename}: {str(e)}' if not loaded: return json.dumps( - {'success': False, 'error': 'Failed to load any principles.'}, + { + 'success': False, + 'error': 'Failed to load any principles.' + }, ensure_ascii=False, indent=2, ) @@ -251,10 +272,14 @@ async def load_principles( sections.append(content) if unknown and not strict: - sections.append(f'> Note: Unknown principles ignored: {", ".join(unknown)}') + sections.append( + f'> Note: Unknown principles ignored: {", ".join(unknown)}') return json.dumps( - {'success': True, 'sections': sections}, + { + 'success': True, + 'sections': sections + }, ensure_ascii=False, indent=2, ) @@ -263,24 +288,23 @@ def _build_principle_index(self) -> Dict[str, Tuple[str, str]]: """Return mapping from normalized query → (filename, canonical title).""" entries: List[Tuple[List[str], str, str]] = [ # synonyms, filename, canonical title - (['mece', 'mutually exclusive and collectively exhaustive'], 'MECE.md', 'MECE'), - ( - ['pyramid', 'minto', 'minto pyramid', 'pyramid principle', 'minto_pyramid'], - 'Minto_Pyramid.md', - 'Pyramid (Minto Pyramid)', - ), + (['mece', 'mutually exclusive and collectively exhaustive'], + 'MECE.md', 'MECE'), + ([ + 'pyramid', 'minto', 'minto pyramid', 'pyramid principle', + 'minto_pyramid' + ], 'Minto_Pyramid.md', 'Pyramid (Minto Pyramid)'), (['swot', 'swot analysis'], 'SWOT.md', 'SWOT'), - (['value chain', 'value-chain', 'value_chain'], 'Value_Chain.md', 'Value Chain'), - ( - ['pareto', '80-20', '80/20', 'pareto 80-20', 'pareto_80-20', '8020'], - 'Pareto_80-20.md', - 'Pareto (80/20 Rule)', - ), - ( - ['boston matrix', 'bcg', 'boston consulting group', 'boston_matrix', 'boston'], - 'Boston_Matrix.md', - 'Boston Matrix (BCG)', - ), + (['value chain', 'value-chain', + 'value_chain'], 'Value_Chain.md', 'Value Chain'), + ([ + 'pareto', '80-20', '80/20', 'pareto 80-20', 'pareto_80-20', + '8020' + ], 'Pareto_80-20.md', 'Pareto (80/20 Rule)'), + ([ + 'boston matrix', 'bcg', 'boston consulting group', + 'boston_matrix', 'boston' + ], 'Boston_Matrix.md', 'Boston Matrix (BCG)'), ] index: Dict[str, Tuple[str, str]] = {} diff --git a/projects/fin_research/tools/spec_loader.py b/projects/fin_research/tools/spec_loader.py index a2c86bdcf..6ed8e1e21 100644 --- a/projects/fin_research/tools/spec_loader.py +++ b/projects/fin_research/tools/spec_loader.py @@ -1,13 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # flake8: noqa -import json import os -from spec_constant import PRINCIPLE_ROUTING_GUIDE, PRINCIPLE_SPEC_GUIDE, WRITING_ROUTING_GUIDE, WRITING_SPEC_GUIDE from typing import Any, Dict, List, Tuple +import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger +from spec_constant import (PRINCIPLE_ROUTING_GUIDE, PRINCIPLE_SPEC_GUIDE, + WRITING_ROUTING_GUIDE, WRITING_SPEC_GUIDE) logger = get_logger() @@ -35,11 +36,14 @@ class SpecLoader(ToolBase): def __init__(self, config): super().__init__(config) - tools_cfg = getattr(config, 'tools', None) if config is not None else None - spec_cfg = getattr(tools_cfg, 'spec_loader', None) if tools_cfg is not None else None + tools_cfg = getattr(config, 'tools', + None) if config is not None else None + spec_cfg = getattr(tools_cfg, 'spec_loader', + None) if tools_cfg is not None else None self.exclude_func(spec_cfg) - configured_dir = getattr(spec_cfg, 'spec_dir', None) if spec_cfg is not None else None + configured_dir = getattr(spec_cfg, 'spec_dir', + None) if spec_cfg is not None else None default_dir = os.path.join(os.getcwd(), self.SPEC_DIR) self.spec_dir = configured_dir or default_dir @@ -57,61 +61,66 @@ async def get_tools(self) -> Dict[str, Any]: Tool( tool_name='load_writing_specs', server_name='spec_loader', - description=( - 'Load one or more writing-style specs (rules + examples) and return ' - 'their curated Markdown content. Use this when you are unsure about ' - 'how to structure or phrase a financial report in an analyst-like style.\n\n' - 'Supported spec identifiers (case-insensitive, synonyms allowed):\n' - '- structure → section depth / headings\n' - '- methods → how much to expose MECE/SWOT/etc.\n' - '- bullets → bullets vs paragraphs\n' - '- focus → task focus and relevance\n' - '- tone → analyst-style voice\n' - '- density → length and information density control\n\n' - 'Provide a list of requested writing specs via the "writing_specs" parameter.\n\n' - f'{WRITING_SPEC_GUIDE}\n' - f'{WRITING_ROUTING_GUIDE}\n' - ), + description= + ('Load one or more writing-style specs (rules + examples) and return ' + 'their curated Markdown content. Use this when you are unsure about ' + 'how to structure or phrase a financial report in an analyst-like style.\n\n' + 'Supported spec identifiers (case-insensitive, synonyms allowed):\n' + '- structure → section depth / headings\n' + '- methods → how much to expose MECE/SWOT/etc.\n' + '- bullets → bullets vs paragraphs\n' + '- focus → task focus and relevance\n' + '- tone → analyst-style voice\n' + '- density → length and information density control\n\n' + 'Provide a list of requested writing specs via the "writing_specs" parameter.\n\n' + f'{WRITING_SPEC_GUIDE}\n' + f'{WRITING_ROUTING_GUIDE}\n'), parameters={ 'type': 'object', 'properties': { 'writing_specs': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': ( - 'List of writing specs to load. Case-insensitive; supports synonyms.\n' - 'Allowed identifiers include (non-exhaustive):\n' - '- structure\n- methods\n- bullets\n- focus\n- tone\n- density\n' - ), + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + ('List of writing specs to load. Case-insensitive; supports synonyms.\n' + 'Allowed identifiers include (non-exhaustive):\n' + '- structure\n- methods\n- bullets\n- focus\n- tone\n- density\n' + ), }, 'format': { - 'type': 'string', + 'type': + 'string', 'enum': ['markdown', 'json'], - 'description': ( - 'Output format: "markdown" (combined Markdown string) or "json" ' - '(JSON object mapping spec to content). Default: "markdown".' - ), + 'description': + ('Output format: "markdown" (combined Markdown string) or "json" ' + '(JSON object mapping spec to content). Default: "markdown".' + ), }, 'include_titles': { - 'type': 'boolean', - 'description': ( - 'When format="markdown", if true, each section is prefixed with a ' - 'Markdown heading of the canonical spec title. Default: false.' - ), + 'type': + 'boolean', + 'description': + ('When format="markdown", if true, each section is prefixed with a ' + 'Markdown heading of the canonical spec title. Default: false.' + ), }, 'join_with': { - 'type': 'string', - 'description': ( - 'When format="markdown", the delimiter used to join multiple ' - 'sections. Default: "\n\n---\n\n".' - ), + 'type': + 'string', + 'description': + ('When format="markdown", the delimiter used to join multiple ' + 'sections. Default: "\n\n---\n\n".'), }, 'strict': { - 'type': 'boolean', - 'description': ( - 'If true, unknown specs cause an error. If false, unknown items are ' - 'ignored with a note in the output. Default: false.' - ), + 'type': + 'boolean', + 'description': + ('If true, unknown specs cause an error. If false, unknown items are ' + 'ignored with a note in the output. Default: false.' + ), }, }, 'required': ['writing_specs'], @@ -121,83 +130,94 @@ async def get_tools(self) -> Dict[str, Any]: Tool( tool_name='load_principle_specs', server_name='spec_loader', - description=( - f'Load one or more analysis principles (concept + how to apply to ' - f'financial analysis) and return their curated Markdown content.\n\n' - f'This is a single-aggregator tool designed to fetch multiple principles ' - f'in one call. Provide a list of requested principles via the "principles" ' - f'parameter. The tool supports common synonyms and is case-insensitive.\n\n' - f'Examples of valid principle identifiers: "MECE", "Pyramid", "Minto", ' - f'"SWOT", "Value Chain", "Pareto", "80-20", "80/20", "Boston Matrix", "BCG".\n\n' - f'When format is "markdown" (default), the tool returns a single combined ' - f'Markdown string (optionally including section titles). When format is ' - f'"json", the tool returns a JSON object mapping principle to content.\n' - f'{PRINCIPLE_SPEC_GUIDE}\n' - f'{PRINCIPLE_ROUTING_GUIDE}\n' - ), + description= + (f'Load one or more analysis principles (concept + how to apply to ' + f'financial analysis) and return their curated Markdown content.\n\n' + f'This is a single-aggregator tool designed to fetch multiple principles ' + f'in one call. Provide a list of requested principles via the "principles" ' + f'parameter. The tool supports common synonyms and is case-insensitive.\n\n' + f'Examples of valid principle identifiers: "MECE", "Pyramid", "Minto", ' + f'"SWOT", "Value Chain", "Pareto", "80-20", "80/20", "Boston Matrix", "BCG".\n\n' + f'When format is "markdown" (default), the tool returns a single combined ' + f'Markdown string (optionally including section titles). When format is ' + f'"json", the tool returns a JSON object mapping principle to content.\n' + f'{PRINCIPLE_SPEC_GUIDE}\n' + f'{PRINCIPLE_ROUTING_GUIDE}\n'), parameters={ 'type': 'object', 'properties': { 'principles': { - 'type': 'array', - 'items': {'type': 'string'}, - 'description': ( - 'List of principles to load. Case-insensitive; supports synonyms.\n' - 'Allowed identifiers include (non-exhaustive):\n' - '- MECE\n- Pyramid\n- Minto\n- SWOT\n- Value Chain\n' - '- Pareto\n- 80-20\n- 80/20\n- Boston Matrix\n- BCG\n' - ), + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + ('List of principles to load. Case-insensitive; supports synonyms.\n' + 'Allowed identifiers include (non-exhaustive):\n' + '- MECE\n- Pyramid\n- Minto\n- SWOT\n- Value Chain\n' + '- Pareto\n- 80-20\n- 80/20\n- Boston Matrix\n- BCG\n' + ), }, 'format': { - 'type': 'string', + 'type': + 'string', 'enum': ['markdown', 'json'], - 'description': ( - 'Output format: "markdown" (combined Markdown string) or "json" ' - '(JSON object mapping principle to content). Default: "markdown".' - ), + 'description': + ('Output format: "markdown" (combined Markdown string) or "json" ' + '(JSON object mapping principle to content). Default: "markdown".' + ), }, 'include_titles': { - 'type': 'boolean', - 'description': ( - 'When format="markdown", if true, each section is prefixed with a ' - 'Markdown heading of the canonical principle title. Default: false.' - ), + 'type': + 'boolean', + 'description': + ('When format="markdown", if true, each section is prefixed with a ' + 'Markdown heading of the canonical principle title. Default: false.' + ), }, 'join_with': { - 'type': 'string', - 'description': ( - 'When format="markdown", the delimiter used to join multiple ' - 'sections. Default: "\n\n---\n\n".' - ), + 'type': + 'string', + 'description': + ('When format="markdown", the delimiter used to join multiple ' + 'sections. Default: "\n\n---\n\n".'), }, 'strict': { - 'type': 'boolean', - 'description': ( - 'If true, unknown principles cause an error. If false, unknown ' - 'items are ignored with a note in the output. Default: false.' - ), + 'type': + 'boolean', + 'description': + ('If true, unknown principles cause an error. If false, unknown ' + 'items are ignored with a note in the output. Default: false.' + ), }, }, 'required': ['principles'], 'additionalProperties': False, }, - ), + ) ] } if hasattr(self, 'exclude_functions') and self.exclude_functions: - tools['spec_loader'] = [t for t in tools['spec_loader'] if t['tool_name'] not in self.exclude_functions] + tools['spec_loader'] = [ + t for t in tools['spec_loader'] + if t['tool_name'] not in self.exclude_functions + ] return tools - async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: return await getattr(self, tool_name)(**tool_args) - async def load_writing_specs(self, writing_specs: List[str], **kwargs) -> str: + async def load_writing_specs(self, writing_specs: List[str], + **kwargs) -> str: writing_spec_map = self._build_writing_spec_index() return await self.load_specs(writing_spec_map, writing_specs, **kwargs) - async def load_principle_specs(self, principles: List[str], **kwargs) -> str: + async def load_principle_specs(self, principles: List[str], + **kwargs) -> str: principle_map = self._build_principle_spec_index() return await self.load_specs(principle_map, principles, **kwargs) @@ -218,7 +238,10 @@ async def load_specs( if not specs: return json.dumps( - {'success': False, 'error': 'No specs provided.'}, + { + 'success': False, + 'error': 'No specs provided.' + }, ensure_ascii=False, indent=2, ) @@ -236,7 +259,8 @@ async def load_specs( return json.dumps( { 'success': False, - 'error': 'Unknown specs (strict mode): ' + ', '.join(unknown), + 'error': + 'Unknown specs (strict mode): ' + ', '.join(unknown), }, ensure_ascii=False, indent=2, @@ -250,11 +274,15 @@ async def load_specs( content = f.read().strip() loaded[canonical_title] = content except Exception as e: # noqa - loaded[canonical_title] = f'Failed to load {filename}: {str(e)}' + loaded[ + canonical_title] = f'Failed to load {filename}: {str(e)}' if not loaded: return json.dumps( - {'success': False, 'error': 'Failed to load any specs.'}, + { + 'success': False, + 'error': 'Failed to load any specs.' + }, ensure_ascii=False, indent=2, ) @@ -277,9 +305,16 @@ async def load_specs( sections.append(content) if unknown and not strict: - sections.append(f'> Note: Unknown specs ignored: {", ".join(unknown)}') + sections.append( + f'> Note: Unknown specs ignored: {", ".join(unknown)}') - return json.dumps({'success': True, 'sections': join_with.join(sections)}, ensure_ascii=False, indent=2) + return json.dumps( + { + 'success': True, + 'sections': join_with.join(sections) + }, + ensure_ascii=False, + indent=2) def _build_writing_spec_index(self) -> Dict[str, Tuple[str, str]]: """Return writing spec mapping from normalized query → (filename, canonical title).""" @@ -291,12 +326,18 @@ def _build_writing_spec_index(self) -> Dict[str, Tuple[str, str]]: 'Structure & Layering', ), ( - ['methods', 'methodology', 'framework exposure', 'methodology exposure'], + [ + 'methods', 'methodology', 'framework exposure', + 'methodology exposure' + ], 'writing_specs/Methodology_Exposure.md', 'Methodology Exposure', ), ( - ['bullets', 'bullet', 'bullets & paragraphs', 'paragraph rhythm'], + [ + 'bullets', 'bullet', 'bullets & paragraphs', + 'paragraph rhythm' + ], 'writing_specs/Bullets_Paragraph_Rhythm.md', 'Bullets & Paragraph Rhythm', ), @@ -327,24 +368,23 @@ def _build_principle_spec_index(self) -> Dict[str, Tuple[str, str]]: """Return principle spec mapping from normalized query → (filename, canonical title).""" entries: List[Tuple[List[str], str, str]] = [ # synonyms, filename, canonical title - (['mece', 'mutually exclusive and collectively exhaustive'], 'principle_specs/MECE.md', 'MECE'), - ( - ['pyramid', 'minto', 'minto pyramid', 'pyramid principle', 'minto_pyramid'], - 'principle_specs/Minto_Pyramid.md', - 'Pyramid (Minto Pyramid)', - ), + (['mece', 'mutually exclusive and collectively exhaustive'], + 'principle_specs/MECE.md', 'MECE'), + ([ + 'pyramid', 'minto', 'minto pyramid', 'pyramid principle', + 'minto_pyramid' + ], 'principle_specs/Minto_Pyramid.md', 'Pyramid (Minto Pyramid)'), (['swot', 'swot analysis'], 'principle_specs/SWOT.md', 'SWOT'), - (['value chain', 'value-chain', 'value_chain'], 'principle_specs/Value_Chain.md', 'Value Chain'), - ( - ['pareto', '80-20', '80/20', 'pareto 80-20', 'pareto_80-20', '8020'], - 'principle_specs/Pareto_80-20.md', - 'Pareto (80/20 Rule)', - ), - ( - ['boston matrix', 'bcg', 'boston consulting group', 'boston_matrix', 'boston'], - 'principle_specs/Boston_Matrix.md', - 'Boston Matrix (BCG)', - ), + (['value chain', 'value-chain', + 'value_chain'], 'principle_specs/Value_Chain.md', 'Value Chain'), + ([ + 'pareto', '80-20', '80/20', 'pareto 80-20', 'pareto_80-20', + '8020' + ], 'principle_specs/Pareto_80-20.md', 'Pareto (80/20 Rule)'), + ([ + 'boston matrix', 'bcg', 'boston consulting group', + 'boston_matrix', 'boston' + ], 'principle_specs/Boston_Matrix.md', 'Boston Matrix (BCG)'), ] index: Dict[str, Tuple[str, str]] = {} diff --git a/projects/singularity_cinema/compose_video/agent.py b/projects/singularity_cinema/compose_video/agent.py index c2d5eae4f..09995c629 100644 --- a/projects/singularity_cinema/compose_video/agent.py +++ b/projects/singularity_cinema/compose_video/agent.py @@ -1,28 +1,33 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 -import json import math import os import shutil from copy import deepcopy +import json import moviepy as mp from moviepy import AudioClip -from omegaconf import DictConfig -from PIL import Image - from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger +from omegaconf import DictConfig +from PIL import Image logger = get_logger() class ComposeVideo(CodeAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') - self.background_effect = getattr(self.config, 'background_effect', None) + self.background_effect = getattr(self.config, 'background_effect', + None) self.bg_path = os.path.join(self.work_dir, 'background.png') # Determine render directory based on engine @@ -40,17 +45,9 @@ def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False self.preset = getattr(self.config.video, 'preset', 'ultrafast') self.fps = getattr(self.config.video, 'fps', 24) - def compose_final_video( - self, - background_path, - foreground_paths, - audio_paths, - subtitle_paths, - illustration_paths, - video_paths, - segments, - output_path, - ): + def compose_final_video(self, background_path, foreground_paths, + audio_paths, subtitle_paths, illustration_paths, + video_paths, segments, output_path): segment_durations = [] logger.info('Composing the final video.') @@ -62,7 +59,8 @@ def compose_final_video( segment = segments[i] is_video_frame = 'video' in segment use_video_soundtrack = self.config.use_video_soundtrack and is_video_frame - if audio_path and os.path.exists(audio_path) and not use_video_soundtrack: + if audio_path and os.path.exists( + audio_path) and not use_video_soundtrack: try: audio_clip = mp.AudioFileClip(audio_path) # Use actual audio duration + small pause, no minimum enforcement @@ -73,10 +71,15 @@ def compose_final_video( else: actual_duration = None if not use_video_soundtrack: - raise ValueError(f'File {audio_path} does not exist, run again to generate it.') + raise ValueError( + f'File {audio_path} does not exist, run again to generate it.' + ) - if i < len(foreground_paths) and foreground_paths[i] and os.path.exists(foreground_paths[i]): - animation_clip = mp.VideoFileClip(foreground_paths[i], has_mask=True) + if i < len(foreground_paths + ) and foreground_paths[i] and os.path.exists( + foreground_paths[i]): + animation_clip = mp.VideoFileClip( + foreground_paths[i], has_mask=True) animation_duration = animation_clip.duration animation_clip.close() @@ -88,16 +91,20 @@ def compose_final_video( logger.info('Step1: Compose video for each segment.') segment_videos = [] - for i, (duration, segment) in enumerate(zip(segment_durations, segments)): + for i, (duration, + segment) in enumerate(zip(segment_durations, segments)): if duration is not None: - logger.info(f'Processing {i + 1} segment - {duration:.1f} seconds.') + logger.info( + f'Processing {i + 1} segment - {duration:.1f} seconds.') else: - logger.info(f'Processing {i + 1} segment - use video soundtrack.') + logger.info( + f'Processing {i + 1} segment - use video soundtrack.') current_video_clips = [] # Check if this segment uses generated video instead of illustration - use_generated_video = 'video' in segment and video_paths[i] and os.path.exists(video_paths[i]) + use_generated_video = 'video' in segment and video_paths[ + i] and os.path.exists(video_paths[i]) if use_generated_video: # Use generated video as base layer @@ -118,7 +125,8 @@ def compose_final_video( video_available_w, video_available_h = 1920, 1080 video_scale_w = video_available_w / video_original_w video_scale_h = video_available_h / video_original_h - video_scale = max(video_scale_w, video_scale_h) # Cover mode + video_scale = max(video_scale_w, + video_scale_h) # Cover mode video_new_w = int(video_original_w * video_scale) video_new_h = int(video_original_h * video_scale) @@ -128,13 +136,16 @@ def compose_final_video( video_new_h += 1 if video_new_w > 0 and video_new_h > 0: - video_clip = video_clip.resized((video_new_w, video_new_h)) + video_clip = video_clip.resized( + (video_new_w, video_new_h)) video_clip = video_clip.with_position('center') # Extract and preserve video audio before adjusting duration video_audio = None if video_clip.audio is not None: - logger.info(f'Extracting audio from generated video {i + 1}') + logger.info( + f'Extracting audio from generated video {i + 1}' + ) video_audio = video_clip.audio segment_video_audios.append(video_audio) @@ -144,32 +155,40 @@ def compose_final_video( assert duration is not None and duration > 0 # Adjust video duration to match segment duration if video_clip.duration < duration: - logger.info(f'Video {i + 1} is shorter than segment, extending to {duration:.1f}s') - video_clip = video_clip.with_duration(duration) + logger.info( + f'Video {i + 1} is shorter than segment, extending to {duration:.1f}s' + ) + video_clip = video_clip.with_duration( + duration) elif video_clip.duration > duration: - logger.info(f'Video {i + 1} is longer than segment, trimming to {duration:.1f}s') - video_clip = video_clip.subclipped(0, duration) + logger.info( + f'Video {i + 1} is longer than segment, trimming to {duration:.1f}s' + ) + video_clip = video_clip.subclipped( + 0, duration) current_video_clips.append(video_clip) else: - logger.error(f'Invalid scaled video dimensions: {video_new_w}x{video_new_h}') + logger.error( + f'Invalid scaled video dimensions: {video_new_w}x{video_new_h}' + ) video_clip.close() use_generated_video = False except Exception as e: - logger.error(f'Failed to process video for segment {i + 1}: {e}') + logger.error( + f'Failed to process video for segment {i + 1}: {e}') use_generated_video = False segment_video_audios.append(None) else: segment_video_audios.append(None) # Add illustration as base layer (if not using generated video) - if ( - not use_generated_video - and i < len(illustration_paths) - and illustration_paths[i] - and os.path.exists(illustration_paths[i]) - ): - illustration_clip = mp.ImageClip(illustration_paths[i], duration=duration) + if not use_generated_video and i < len( + illustration_paths + ) and illustration_paths[i] and os.path.exists( + illustration_paths[i]): + illustration_clip = mp.ImageClip( + illustration_paths[i], duration=duration) bg_original_w, bg_original_h = illustration_clip.size # Validate image dimensions @@ -195,10 +214,12 @@ def compose_final_video( # Ensure dimensions are positive if bg_new_w <= 0 or bg_new_h <= 0: - logger.error(f'Invalid scaled dimensions: {bg_new_w}x{bg_new_h}') + logger.error( + f'Invalid scaled dimensions: {bg_new_w}x{bg_new_h}') continue - illustration_clip = illustration_clip.resized((bg_new_w, bg_new_h)) + illustration_clip = illustration_clip.resized( + (bg_new_w, bg_new_h)) exit_duration = 1.0 start_animation_time = max(duration - exit_duration, 0) @@ -219,11 +240,13 @@ def make_ken_burns(t): progress = t / kb_duration if kb_duration > 0 else 0 progress = min(1.0, progress) # Cubic easing for smooth acceleration/deceleration - eased_progress = progress * progress * (3.0 - 2.0 * progress) + eased_progress = progress * progress * ( + 3.0 - 2.0 * progress) if eased_progress > 1.0: eased_progress = 1.0 # Calculate current zoom level - current_zoom = zoom_start + (zoom_end - zoom_start) * eased_progress + current_zoom = zoom_start + ( + zoom_end - zoom_start) * eased_progress # Calculate new dimensions with validation zoom_w = int(kb_base_w * current_zoom) zoom_h = int(kb_base_h * current_zoom) @@ -234,20 +257,26 @@ def make_ken_burns(t): return zoom_w, zoom_h # Apply the zoom effect with resizing over time - illustration_clip = illustration_clip.resized(make_ken_burns) + illustration_clip = illustration_clip.resized( + make_ken_burns) # Keep image centered and stable throughout the animation - illustration_clip = illustration_clip.with_position('center') + illustration_clip = illustration_clip.with_position( + 'center') elif self.background_effect == 'slide': # TODO legacy code, untested # Default slide left animation - def illustration_pos_factory(idx, start_x, end_x, bg_h, start_animation_time, exit_duration): + def illustration_pos_factory(idx, start_x, end_x, bg_h, + start_animation_time, + exit_duration): + def illustration_pos(t): y = (1080 - bg_h) // 2 if t < start_animation_time: x = start_x elif t < start_animation_time + exit_duration: - progress = (t - start_animation_time) / exit_duration + progress = ( + t - start_animation_time) / exit_duration progress = min(max(progress, 0), 1) x = start_x + (end_x - start_x) * progress else: @@ -257,19 +286,22 @@ def illustration_pos(t): return illustration_pos illustration_clip = illustration_clip.with_position( - illustration_pos_factory( - i, (1920 - bg_new_w) // 2, -bg_new_w, bg_new_h, start_animation_time, exit_duration - ) - ) + illustration_pos_factory(i, (1920 - bg_new_w) // 2, + -bg_new_w, bg_new_h, + start_animation_time, + exit_duration)) current_video_clips.append(illustration_clip) # Add foreground animation layer - if i < len(foreground_paths) and foreground_paths[i] and os.path.exists(foreground_paths[i]): + if i < len(foreground_paths + ) and foreground_paths[i] and os.path.exists( + foreground_paths[i]): fg_clip = mp.VideoFileClip(foreground_paths[i], has_mask=True) original_w, original_h = fg_clip.size - available_w, available_h = (1250, 700) if self.config.use_subtitle else (1450, 800) + available_w, available_h = ( + 1250, 700) if self.config.use_subtitle else (1450, 800) scale_w = available_w / original_w scale_h = available_h / original_h scale = min(scale_w, scale_h, 1.0) @@ -281,7 +313,9 @@ def illustration_pos(t): if new_w > 0 and new_h > 0: fg_clip = fg_clip.resized((new_w, new_h)) else: - logger.error(f'Invalid scaled foreground dimensions: {new_w}x{new_h}') + logger.error( + f'Invalid scaled foreground dimensions: {new_w}x{new_h}' + ) fg_clip.close() continue @@ -293,7 +327,8 @@ def illustration_pos(t): fg_clip = fg_clip.with_duration(duration) current_video_clips.append(fg_clip) if self.config.use_subtitle: - if duration is not None and i < len(subtitle_paths) and subtitle_paths[i]: + if duration is not None and i < len( + subtitle_paths) and subtitle_paths[i]: segment_subs = subtitle_paths[i] num_subs = len(segment_subs) sub_duration = duration / num_subs @@ -310,13 +345,17 @@ def illustration_pos(t): ) continue - subtitle_clip = mp.ImageClip(sub_path, duration=sub_duration) + subtitle_clip = mp.ImageClip( + sub_path, duration=sub_duration) subtitle_y = 900 - subtitle_clip = subtitle_clip.with_position(('center', subtitle_y)) - subtitle_clip = subtitle_clip.with_start(k * sub_duration) + subtitle_clip = subtitle_clip.with_position( + ('center', subtitle_y)) + subtitle_clip = subtitle_clip.with_start( + k * sub_duration) current_video_clips.append(subtitle_clip) except Exception as e: - logger.error(f'Failed to load subtitle {sub_path}: {e}') + logger.error( + f'Failed to load subtitle {sub_path}: {e}') # Add background as top layer (transparent PNG with decorative elements) if background_path and os.path.exists(background_path): @@ -325,24 +364,26 @@ def illustration_pos(t): current_video_clips.append(bg_clip) if current_video_clips: - segment_video = mp.CompositeVideoClip(current_video_clips, size=(1920, 1080)) + segment_video = mp.CompositeVideoClip( + current_video_clips, size=(1920, 1080)) segment_videos.append(segment_video) logger.info('Step2: Combine all video segments.') - final_video = mp.concatenate_videoclips(segment_videos, method='compose') + final_video = mp.concatenate_videoclips( + segment_videos, method='compose') logger.info('Step3: Compose audios.') if audio_paths: valid_audio_clips = [] - for i, (audio_path, duration, segment) in enumerate(zip(audio_paths, segment_durations, segments)): + for i, (audio_path, duration, segment) in enumerate( + zip(audio_paths, segment_durations, segments)): segment_audio = None # Check if this segment has generated video audio - if ( - i < len(segment_video_audios) - and segment_video_audios[i] is not None - and self.config.use_video_soundtrack - ): - logger.info(f'Using audio from generated video for segment {i + 1}') + if i < len(segment_video_audios) and segment_video_audios[ + i] is not None and self.config.use_video_soundtrack: + logger.info( + f'Using audio from generated video for segment {i + 1}' + ) segment_audio = segment_video_audios[i] elif audio_path and os.path.exists(audio_path): # Use TTS audio if no video audio available @@ -352,9 +393,14 @@ def illustration_pos(t): if audio_clip.duration > duration: audio_clip = audio_clip.subclipped(0, duration) elif audio_clip.duration < duration: - silence = AudioClip(lambda t: [0, 0], duration=duration - audio_clip.duration).with_fps(44100) + + silence = AudioClip( + lambda t: [0, 0], + duration=duration + - audio_clip.duration).with_fps(44100) # silence = silence.set_channels(2) - audio_clip = mp.concatenate_audioclips([audio_clip, silence]) + audio_clip = mp.concatenate_audioclips( + [audio_clip, silence]) segment_audio = audio_clip if segment_audio is not None: @@ -362,12 +408,18 @@ def illustration_pos(t): if valid_audio_clips: final_audio = mp.concatenate_audioclips(valid_audio_clips) - logger.info(f'Audio composing done: {final_audio.duration:.1f} seconds.') + logger.info( + f'Audio composing done: {final_audio.duration:.1f} seconds.' + ) if final_audio.duration > final_video.duration: - final_audio = final_audio.subclipped(0, final_video.duration) + final_audio = final_audio.subclipped( + 0, final_video.duration) elif final_audio.duration < final_video.duration: - silence = AudioClip(lambda t: [0, 0], duration=final_video.duration - final_audio.duration) - final_audio = mp.concatenate_audioclips([final_audio, silence]) + silence = AudioClip( + lambda t: [0, 0], + duration=final_video.duration - final_audio.duration) + final_audio = mp.concatenate_audioclips( + [final_audio, silence]) final_video = final_video.with_audio(final_audio) @@ -375,31 +427,43 @@ def illustration_pos(t): if os.path.exists(self.config.bg_audio_path): bg_music_path = self.config.bg_audio_path else: - bg_music_path = os.path.join(self.config.local_dir, self.config.bg_audio_path) + bg_music_path = os.path.join(self.config.local_dir, + self.config.bg_audio_path) else: bg_music_path = '' - if os.path.exists(bg_music_path) and not self.config.use_video_soundtrack: + if os.path.exists( + bg_music_path) and not self.config.use_video_soundtrack: bg_music = mp.AudioFileClip(bg_music_path) if bg_music.duration < final_video.duration: - repeat_times = int(final_video.duration / bg_music.duration) + 1 - bg_music = mp.concatenate_audioclips([bg_music] * repeat_times) + repeat_times = int( + final_video.duration / bg_music.duration) + 1 + bg_music = mp.concatenate_audioclips([bg_music] + * repeat_times) bg_music = bg_music.subclipped(0, final_video.duration) elif bg_music.duration > final_video.duration: bg_music = bg_music.subclipped(0, final_video.duration) - bg_music = bg_music.with_volume_scaled(self.config.bg_audio_volume) + bg_music = bg_music.with_volume_scaled( + self.config.bg_audio_volume) if final_video.audio: - tts_audio = final_video.audio.with_duration(final_video.duration).with_volume_scaled(1.0) + tts_audio = final_video.audio.with_duration( + final_video.duration).with_volume_scaled(1.0) bg_audio = bg_music.with_duration(final_video.duration) - mixed_audio = mp.CompositeAudioClip([tts_audio, bg_audio]).with_duration(final_video.duration) + mixed_audio = mp.CompositeAudioClip( + [tts_audio, + bg_audio]).with_duration(final_video.duration) else: - mixed_audio = bg_music.with_duration(final_video.duration).with_volume_scaled(0.3) + mixed_audio = bg_music.with_duration( + final_video.duration).with_volume_scaled(0.3) final_video = final_video.with_audio(mixed_audio) assert final_video is not None logger.info('Rendering final video...') - logger.info(f'Total video duration: {final_video.duration:.1f} seconds') + logger.info( + f'Total video duration: {final_video.duration:.1f} seconds') logger.info(f'Video resolution: {final_video.size}') - logger.info(f"Audio status: {'Has audio' if final_video.audio else 'No audio'}") + logger.info( + f"Audio status: {'Has audio' if final_video.audio else 'No audio'}" + ) logger.info(f'final_video type: {type(final_video)}') logger.info(f'final_video attributes: {dir(final_video)}') @@ -416,8 +480,7 @@ def illustration_pos(t): audio_bitrate='192k', audio_fps=44100, preset=self.preset, - write_logfile=False, - ) + write_logfile=False) logger.info(f'file saved: {output_path}') @@ -443,14 +506,19 @@ async def execute_code(self, messages, **kwargs): illustration_paths = [] video_paths = [] for i, segment in enumerate(segments): - illustration_paths.append(os.path.join(self.images_dir, f'illustration_{i + 1}.png')) - foreground_paths.append(os.path.join(self.render_dir, f'scene_{i + 1}', f'Scene{i + 1}.mov')) - audio_paths.append(os.path.join(self.tts_dir, f'segment_{i + 1}.mp3')) + illustration_paths.append( + os.path.join(self.images_dir, f'illustration_{i + 1}.png')) + foreground_paths.append( + os.path.join(self.render_dir, f'scene_{i + 1}', + f'Scene{i+1}.mov')) + audio_paths.append( + os.path.join(self.tts_dir, f'segment_{i + 1}.mp3')) segment_subtitles = [] j = 0 while True: - sub_path = os.path.join(self.subtitle_dir, f'bilingual_subtitle_{i + 1}_{j}.png') + sub_path = os.path.join(self.subtitle_dir, + f'bilingual_subtitle_{i + 1}_{j}.png') if os.path.exists(sub_path): segment_subtitles.append(sub_path) j += 1 @@ -458,7 +526,8 @@ async def execute_code(self, messages, **kwargs): break subtitle_paths.append(segment_subtitles) - video_paths.append(os.path.join(self.videos_dir, f'video_{i + 1}.mp4')) + video_paths.append( + os.path.join(self.videos_dir, f'video_{i + 1}.mp4')) self.compose_final_video( background_path=self.bg_path, @@ -468,6 +537,5 @@ async def execute_code(self, messages, **kwargs): illustration_paths=illustration_paths, video_paths=video_paths, segments=segments, - output_path=final_video_path, - ) + output_path=final_video_path) return messages diff --git a/projects/singularity_cinema/create_background/agent.py b/projects/singularity_cinema/create_background/agent.py index 3cc233ad7..44e510b13 100644 --- a/projects/singularity_cinema/create_background/agent.py +++ b/projects/singularity_cinema/create_background/agent.py @@ -3,19 +3,23 @@ import textwrap import matplotlib.font_manager as fm -from omegaconf import DictConfig -from PIL import Image, ImageDraw, ImageFont - from ms_agent.agent import CodeAgent from ms_agent.llm import LLM from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils import get_logger +from omegaconf import DictConfig +from PIL import Image, ImageDraw, ImageFont logger = get_logger() class CreateBackground(CodeAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.bg_path = os.path.join(self.work_dir, 'background.png') @@ -26,10 +30,10 @@ def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False def get_font(self, size): candidates = list(self.fonts) import subprocess - for name in candidates: try: - font_path = subprocess.check_output(['fc-match', '-f', '%{file}\n', name], text=True).strip() + font_path = subprocess.check_output( + ['fc-match', '-f', '%{file}\n', name], text=True).strip() font = ImageFont.truetype(font_path, size) font.getmask('中') @@ -60,7 +64,7 @@ async def execute_code(self, messages, **kwargs): 'padding': 50, 'line_width': 8, 'subtitle_offset': 40, - 'line_position_offset': 140, + 'line_position_offset': 140 } # Create image with transparent background (RGBA mode) @@ -74,18 +78,27 @@ async def execute_code(self, messages, **kwargs): y_position = config['padding'] for line in title_lines: bbox = draw.textbbox((0, 0), line, font=title_font) - draw.text((config['padding'], y_position), line, font=title_font, fill=slogan_subtitle_color) + draw.text((config['padding'], y_position), + line, + font=title_font, + fill=slogan_subtitle_color) y_position += (bbox[3] - bbox[1]) + config['line_spacing'] subtitle_lines = self.slogan y_position = config['padding'] for i, line in enumerate(subtitle_lines): bbox = draw.textbbox((0, 0), line, font=subtitle_font) - x_offset = width - bbox[2] - (config['padding'] + 30) + (i * config['subtitle_offset']) - draw.text((x_offset, y_position), line, font=subtitle_font, fill=slogan_subtitle_color) + x_offset = width - bbox[2] - (config['padding'] + 30) + ( + i * config['subtitle_offset']) + draw.text((x_offset, y_position), + line, + font=subtitle_font, + fill=slogan_subtitle_color) y_position += bbox[3] - bbox[1] + 5 line_y = height - config['padding'] - config['line_position_offset'] if self.config.use_subtitle: - draw.line([(0, line_y), (width, line_y)], fill=slogan_subtitle_color, width=config['line_width']) + draw.line([(0, line_y), (width, line_y)], + fill=slogan_subtitle_color, + width=config['line_width']) image.save(self.bg_path) return messages diff --git a/projects/singularity_cinema/generate_animation/agent.py b/projects/singularity_cinema/generate_animation/agent.py index e273b2e84..26837050b 100644 --- a/projects/singularity_cinema/generate_animation/agent.py +++ b/projects/singularity_cinema/generate_animation/agent.py @@ -3,13 +3,17 @@ import os import sys -from omegaconf import DictConfig - from ms_agent.agent import CodeAgent +from omegaconf import DictConfig class GenerateAnimation(CodeAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) async def execute_code(self, messages, **kwargs): @@ -17,15 +21,15 @@ async def execute_code(self, messages, **kwargs): sys.path.insert(0, os.path.dirname(__file__)) if engine == 'manim': from generate_manim_code import GenerateManimCode - sys.path.pop(0) - agent = GenerateManimCode(self.config, self.tag, self.trust_remote_code, **kwargs) + agent = GenerateManimCode(self.config, self.tag, + self.trust_remote_code, **kwargs) return await agent.execute_code(messages, **kwargs) elif engine == 'remotion': from generate_remotion_code import GenerateRemotionCode - sys.path.pop(0) - agent = GenerateRemotionCode(self.config, self.tag, self.trust_remote_code, **kwargs) + agent = GenerateRemotionCode(self.config, self.tag, + self.trust_remote_code, **kwargs) return await agent.execute_code(messages, **kwargs) else: raise ValueError(f'Unknown animation engine: {engine}') diff --git a/projects/singularity_cinema/generate_animation/generate_manim_code.py b/projects/singularity_cinema/generate_animation/generate_manim_code.py index 43a6f768b..990bab72d 100644 --- a/projects/singularity_cinema/generate_animation/generate_manim_code.py +++ b/projects/singularity_cinema/generate_animation/generate_manim_code.py @@ -1,21 +1,25 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json import os from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Union -from omegaconf import DictConfig -from PIL import Image - +import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger +from omegaconf import DictConfig +from PIL import Image logger = get_logger() class GenerateManimCode(CodeAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 'llm_num_parallel', 10) @@ -23,7 +27,8 @@ def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False self.manim_code_dir = os.path.join(self.work_dir, 'manim_code') os.makedirs(self.manim_code_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], + **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) with open(os.path.join(self.work_dir, 'audio_info.txt'), 'r') as f: @@ -40,7 +45,8 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: futures = { - executor.submit(self._generate_manim_code_static, seg, dur, idx, self.config, self.images_dir): idx + executor.submit(self._generate_manim_code_static, seg, dur, + idx, self.config, self.images_dir): idx for seg, dur, idx in tasks } for future in as_completed(futures): @@ -48,16 +54,20 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L manim_code[idx] = future.result() for i, code in enumerate(manim_code): - manim_file = os.path.join(self.manim_code_dir, f'segment_{i + 1}.py') + manim_file = os.path.join(self.manim_code_dir, + f'segment_{i + 1}.py') with open(manim_file, 'w') as f: f.write(code) return messages @staticmethod - def _generate_manim_code_static(segment, audio_duration, i, config, image_dir): + def _generate_manim_code_static(segment, audio_duration, i, config, + image_dir): """Static method for multiprocessing""" llm = LLM.from_config(config) - return GenerateManimCode._generate_manim_impl(llm, segment, audio_duration, i, image_dir, config) + return GenerateManimCode._generate_manim_impl(llm, segment, + audio_duration, i, + image_dir, config) @staticmethod def get_image_size(filename): @@ -72,7 +82,8 @@ def get_all_images_info(segment, i, image_dir): # Now check for files corresponding to these descriptions for idx, desc in enumerate(descriptions): - foreground_image = os.path.join(image_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') + foreground_image = os.path.join( + image_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') if os.path.exists(foreground_image): size = GenerateManimCode.get_image_size(foreground_image) @@ -83,7 +94,8 @@ def get_all_images_info(segment, i, image_dir): } all_images_info.append(image_info) - image_info_file = os.path.join(os.path.dirname(image_dir), 'image_info.txt') + image_info_file = os.path.join( + os.path.dirname(image_dir), 'image_info.txt') if os.path.exists(image_info_file): with open(image_info_file, 'r') as f: for line in f.readlines(): @@ -95,11 +107,13 @@ def get_all_images_info(segment, i, image_dir): return all_images_info @staticmethod - def _generate_manim_impl(llm, segment, audio_duration, i, image_dir, config): + def _generate_manim_impl(llm, segment, audio_duration, i, image_dir, + config): class_name = f'Scene{i + 1}' content = segment['content'] manim_requirement = segment['manim'] - images_info = GenerateManimCode.get_all_images_info(segment, i, image_dir) + images_info = GenerateManimCode.get_all_images_info( + segment, i, image_dir) if images_info: images_info = json.dumps(images_info, indent=4, ensure_ascii=False) else: @@ -163,7 +177,8 @@ def _generate_manim_impl(llm, segment, audio_duration, i, image_dir, config): """ logger.info(f'正在生成 manim 代码:{content}') - _response_message = llm.generate([Message(role='user', content=prompt)], temperature=0.3) + _response_message = llm.generate( + [Message(role='user', content=prompt)], temperature=0.3) response = _response_message.content if '```python' in response: manim_code = response.split('```python')[1].split('```')[0] diff --git a/projects/singularity_cinema/generate_animation/generate_remotion_code.py b/projects/singularity_cinema/generate_animation/generate_remotion_code.py index 11ea68acd..1007a60b1 100644 --- a/projects/singularity_cinema/generate_animation/generate_remotion_code.py +++ b/projects/singularity_cinema/generate_animation/generate_remotion_code.py @@ -1,23 +1,27 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import glob -import json import os import re from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Union -from omegaconf import DictConfig -from PIL import Image - +import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger +from omegaconf import DictConfig +from PIL import Image logger = get_logger() class GenerateRemotionCode(CodeAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 'llm_num_parallel', 10) @@ -25,7 +29,8 @@ def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False self.remotion_code_dir = os.path.join(self.work_dir, 'remotion_code') os.makedirs(self.remotion_code_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], + **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) with open(os.path.join(self.work_dir, 'audio_info.txt'), 'r') as f: @@ -38,7 +43,8 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L animation_requirement = segment.get('remotion') if animation_requirement is not None: # Check if file already exists - remotion_file = os.path.join(self.remotion_code_dir, f'Segment{i + 1}.tsx') + remotion_file = os.path.join(self.remotion_code_dir, + f'Segment{i + 1}.tsx') if os.path.exists(remotion_file): continue tasks.append((segment, audio_info['audio_duration'], i)) @@ -47,14 +53,16 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L # Load existing files for skipped segments for i in range(len(segments)): - remotion_file = os.path.join(self.remotion_code_dir, f'Segment{i + 1}.tsx') + remotion_file = os.path.join(self.remotion_code_dir, + f'Segment{i + 1}.tsx') if os.path.exists(remotion_file): with open(remotion_file, 'r', encoding='utf-8') as f: remotion_code[i] = f.read() with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: futures = { - executor.submit(self._generate_remotion_code_static, seg, dur, idx, self.config, self.images_dir): idx + executor.submit(self._generate_remotion_code_static, seg, dur, + idx, self.config, self.images_dir): idx for seg, dur, idx in tasks } for future in as_completed(futures): @@ -62,16 +70,19 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L remotion_code[idx] = future.result() for i, code in enumerate(remotion_code): - remotion_file = os.path.join(self.remotion_code_dir, f'Segment{i + 1}.tsx') + remotion_file = os.path.join(self.remotion_code_dir, + f'Segment{i + 1}.tsx') with open(remotion_file, 'w', encoding='utf-8') as f: f.write(code) return messages @staticmethod - def _generate_remotion_code_static(segment, audio_duration, i, config, image_dir): + def _generate_remotion_code_static(segment, audio_duration, i, config, + image_dir): """Static method for multiprocessing""" llm = LLM.from_config(config) - return GenerateRemotionCode._generate_remotion_impl(llm, segment, audio_duration, i, image_dir, config) + return GenerateRemotionCode._generate_remotion_impl( + llm, segment, audio_duration, i, image_dir, config) @staticmethod def get_image_size(filename): @@ -84,17 +95,23 @@ def get_all_images_info(segment, i, image_dir): foreground = segment.get('foreground', []) for idx, _req in enumerate(foreground): - foreground_image = os.path.join(image_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') + foreground_image = os.path.join( + image_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') if os.path.exists(foreground_image): size = GenerateRemotionCode.get_image_size(foreground_image) image_info = { - 'filename': os.path.join('images', os.path.basename(foreground_image)), # Use basename for Remotion - 'size': size, - 'description': _req, + 'filename': + os.path.join('images', os.path.basename( + foreground_image)), # Use basename for Remotion + 'size': + size, + 'description': + _req, } all_images_info.append(image_info) - image_info_file = os.path.join(os.path.dirname(image_dir), 'image_info.txt') + image_info_file = os.path.join( + os.path.dirname(image_dir), 'image_info.txt') if os.path.exists(image_info_file): with open(image_info_file, 'r') as f: for line in f.readlines(): @@ -106,11 +123,13 @@ def get_all_images_info(segment, i, image_dir): return all_images_info @staticmethod - def _generate_remotion_impl(llm, segment, audio_duration, i, image_dir, config): + def _generate_remotion_impl(llm, segment, audio_duration, i, image_dir, + config): component_name = f'Segment{i + 1}' content = segment['content'] animation_requirement = segment['remotion'] - images_info = GenerateRemotionCode.get_all_images_info(segment, i, image_dir) + images_info = GenerateRemotionCode.get_all_images_info( + segment, i, image_dir) # Inject image info with code snippets. images_info_str = '' @@ -191,11 +210,14 @@ def _generate_remotion_impl(llm, segment, audio_duration, i, image_dir, config): """ logger.info(f'正在生成 remotion 代码:{content}') - _response_message = llm.generate([Message(role='user', content=prompt)], temperature=0.3) + _response_message = llm.generate( + [Message(role='user', content=prompt)], temperature=0.3) response = _response_message.content # Robust code extraction using regex - code_match = re.search(r'```(?:typescript|tsx|js|javascript)?\s*(.*?)```', response, re.DOTALL) + code_match = re.search( + r'```(?:typescript|tsx|js|javascript)?\s*(.*?)```', response, + re.DOTALL) if code_match: code = code_match.group(1) else: diff --git a/projects/singularity_cinema/generate_audio/agent.py b/projects/singularity_cinema/generate_audio/agent.py index 670454a60..499779c73 100644 --- a/projects/singularity_cinema/generate_audio/agent.py +++ b/projects/singularity_cinema/generate_audio/agent.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio -import json import os import shutil from copy import deepcopy @@ -8,28 +7,34 @@ from typing import List import edge_tts +import json import numpy as np from moviepy import AudioClip, AudioFileClip -from omegaconf import DictConfig - from ms_agent.agent import CodeAgent from ms_agent.llm import LLM from ms_agent.llm.openai_llm import OpenAI from ms_agent.tools.audio_generator import AudioGenerator from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() @dataclass class Pattern: + name: str pattern: str tags: List[str] = field(default_factory=list) class GenerateAudio(CodeAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.llm: OpenAI = LLM.from_config(self.config) @@ -53,18 +58,17 @@ async def execute_code(self, messages, **kwargs): assert len(audio_durations) == len(audio_paths) audio_info = [] for audio_path, audio_duration in zip(audio_paths, audio_durations): - audio_info.append( - { - 'audio_path': audio_path, - 'audio_duration': audio_duration, - } - ) + audio_info.append({ + 'audio_path': audio_path, + 'audio_duration': audio_duration, + }) with open(os.path.join(self.work_dir, 'audio_info.txt'), 'w') as f: f.write(json.dumps(audio_info, indent=4, ensure_ascii=False)) return messages @staticmethod async def create_silent_audio(output_path, duration=5.0): + def make_frame(t): return np.array([0.0, 0.0]) @@ -81,7 +85,8 @@ async def audio_generate(self, text, output_file, speaker='male'): os.makedirs(output_dir, exist_ok=True) _config = deepcopy(self.config) _config.tools.audio_generator = _config.audio_generator - _temp_file = await AudioGenerator(self.config).generate_audio(text, speaker=voice, rate=rate, pitch=pitch) + _temp_file = await AudioGenerator(self.config).generate_audio( + text, speaker=voice, rate=rate, pitch=pitch) shutil.move(_temp_file, output_file) @staticmethod diff --git a/projects/singularity_cinema/generate_illustration_prompts/agent.py b/projects/singularity_cinema/generate_illustration_prompts/agent.py index 13bf1c78d..5529dc67a 100644 --- a/projects/singularity_cinema/generate_illustration_prompts/agent.py +++ b/projects/singularity_cinema/generate_illustration_prompts/agent.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json import os import re import time @@ -7,23 +6,25 @@ from dataclasses import dataclass, field from typing import List, Optional, Union -from omegaconf import DictConfig - +import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() @dataclass class Pattern: + name: str pattern: str tags: List[str] = field(default_factory=list) class GenerateIllustrationPrompts(CodeAgent): + # Background prompt generator (t2i) system = """你是一名提示词工程师,负责为短视频生成一张背景图。 @@ -43,14 +44,20 @@ class GenerateIllustrationPrompts(CodeAgent): - 不要留白:使用适当的背景填充图像,尽量不要使用白色背景 """ - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 'llm_num_parallel', 10) - self.illustration_prompts_dir = os.path.join(self.work_dir, 'illustration_prompts') + self.illustration_prompts_dir = os.path.join(self.work_dir, + 'illustration_prompts') os.makedirs(self.illustration_prompts_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], + **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) logger.info('Generating illustration prompts.') @@ -59,9 +66,9 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: futures = { - executor.submit( - self._generate_illustration_prompts_static, i, segment, self.config, self.illustration_prompts_dir - ): i + executor.submit(self._generate_illustration_prompts_static, i, + segment, self.config, + self.illustration_prompts_dir): i for i, segment in tasks } for future in as_completed(futures): @@ -69,14 +76,16 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L return messages @staticmethod - def _generate_illustration_prompts_static(i, segment, config, illustration_prompts_dir): + def _generate_illustration_prompts_static(i, segment, config, + illustration_prompts_dir): """Static method for multiprocessing""" llm = LLM.from_config(config) max_retries = 10 if config.background == 'image': for attempt in range(max_retries): try: - GenerateIllustrationPrompts._generate_illustration_impl(llm, i, segment, illustration_prompts_dir) + GenerateIllustrationPrompts._generate_illustration_impl( + llm, i, segment, illustration_prompts_dir) break except Exception: time.sleep(2) @@ -84,20 +93,26 @@ def _generate_illustration_prompts_static(i, segment, config, illustration_promp if config.foreground == 'image': for attempt in range(max_retries): try: - GenerateIllustrationPrompts._generate_foreground_impl(llm, i, segment, illustration_prompts_dir) + GenerateIllustrationPrompts._generate_foreground_impl( + llm, i, segment, illustration_prompts_dir) break except Exception: time.sleep(2) @staticmethod def _generate_illustration_impl(llm, i, segment, illustration_prompts_dir): - if os.path.exists(os.path.join(illustration_prompts_dir, f'segment_{i + 1}.txt')): + if os.path.exists( + os.path.join(illustration_prompts_dir, f'segment_{i+1}.txt')): return background_concept = segment.get('background') - logger.info(f'Generating background prompt from plan: {background_concept}') + logger.info( + f'Generating background prompt from plan: {background_concept}') - with open(os.path.join(os.path.dirname(illustration_prompts_dir), 'topic.txt'), 'r') as f: + with open( + os.path.join( + os.path.dirname(illustration_prompts_dir), 'topic.txt'), + 'r') as f: topic = f.read() query = ( f'User original topic: {topic}\n' @@ -111,35 +126,47 @@ def _generate_illustration_impl(llm, i, segment, illustration_prompts_dir): response = llm.generate(inputs).content.strip() # Strip thinking tags - response = re.sub(r'.*?', '', response, flags=re.DOTALL).strip() + response = re.sub( + r'.*?', '', response, flags=re.DOTALL).strip() - with open(os.path.join(illustration_prompts_dir, f'segment_{i + 1}.txt'), 'w') as f: + with open( + os.path.join(illustration_prompts_dir, f'segment_{i + 1}.txt'), + 'w') as f: f.write(response) @staticmethod def _generate_foreground_impl(llm, i, segment, illustration_prompts_dir): foreground_assets = segment.get('foreground') for idx, asset_desc in enumerate(foreground_assets): - file_path = os.path.join(illustration_prompts_dir, f'segment_{i + 1}_foreground_{idx + 1}.txt') + file_path = os.path.join(illustration_prompts_dir, + f'segment_{i+1}_foreground_{idx+1}.txt') if os.path.exists(file_path): continue - logger.info(f'Generating foreground_{idx} prompt from plan: {asset_desc}') + logger.info( + f'Generating foreground_{idx} prompt from plan: {asset_desc}') - with open(os.path.join(os.path.dirname(illustration_prompts_dir), 'topic.txt'), 'r') as f: + with open( + os.path.join( + os.path.dirname(illustration_prompts_dir), + 'topic.txt'), 'r') as f: topic = f.read() - query = f'User original topic: {topic}\nDesign a single foreground asset: {asset_desc}\n' + query = (f'User original topic: {topic}\n' + f'Design a single foreground asset: {asset_desc}\n') inputs = [ - Message(role='system', content=GenerateIllustrationPrompts.system_foreground), + Message( + role='system', + content=GenerateIllustrationPrompts.system_foreground), Message(role='user', content=query), ] response = llm.generate(inputs).content.strip() # Strip thinking tags - response = re.sub(r'.*?', '', response, flags=re.DOTALL).strip() + response = re.sub( + r'.*?', '', response, flags=re.DOTALL).strip() with open(file_path, 'w') as f: f.write(response) diff --git a/projects/singularity_cinema/generate_images/agent.py b/projects/singularity_cinema/generate_images/agent.py index 42a3b8d03..e21eb325d 100644 --- a/projects/singularity_cinema/generate_images/agent.py +++ b/projects/singularity_cinema/generate_images/agent.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio -import json import os import re import shutil @@ -10,42 +9,55 @@ from typing import List, Union import aiohttp +import json import numpy as np -from omegaconf import DictConfig -from PIL import Image - from ms_agent.agent import CodeAgent from ms_agent.llm import Message from ms_agent.tools.image_generator import ImageGenerator from ms_agent.utils import get_logger +from omegaconf import DictConfig +from PIL import Image logger = get_logger() class GenerateImages(CodeAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 't2i_num_parallel', 1) self.fusion = self.fade - self.illustration_prompts_dir = os.path.join(self.work_dir, 'illustration_prompts') + self.illustration_prompts_dir = os.path.join(self.work_dir, + 'illustration_prompts') self.images_dir = os.path.join(self.work_dir, 'images') os.makedirs(self.images_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], + **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) illustration_prompts = [] for i in range(len(segments)): - illustration_path = os.path.join(self.illustration_prompts_dir, f'segment_{i + 1}.txt') - if self.config.background == 'image' and os.path.exists(illustration_path): + illustration_path = os.path.join(self.illustration_prompts_dir, + f'segment_{i+1}.txt') + if self.config.background == 'image' and os.path.exists( + illustration_path): with open(illustration_path, 'r') as f: illustration_prompts.append(f.read()) else: illustration_prompts.append(None) logger.info('Generating images.') - tasks = [(i, segment, prompt) for i, (segment, prompt) in enumerate(zip(segments, illustration_prompts))] + tasks = [ + (i, segment, prompt) + for i, (segment, + prompt) in enumerate(zip(segments, illustration_prompts)) + ] # Use ThreadPoolExecutor for parallel execution with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: @@ -55,13 +67,16 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L final_prompt = prompt_text if final_prompt: # Remove thinking tags if present - final_prompt = re.sub(r'.*?', '', final_prompt, flags=re.DOTALL).strip() + final_prompt = re.sub( + r'.*?', + '', + final_prompt, + flags=re.DOTALL).strip() futures.append( - executor.submit( - self._process_single_illustration_static, i, segment, final_prompt, self.config, self.images_dir - ) - ) + executor.submit(self._process_single_illustration_static, + i, segment, final_prompt, self.config, + self.images_dir)) # Wait for all tasks to complete for future in futures: future.result() @@ -69,27 +84,30 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L return messages @staticmethod - def _process_single_illustration_static(i, segment, prompt, config, images_dir): + def _process_single_illustration_static(i, segment, prompt, config, + images_dir): """Static method for thread pool execution""" # Create new event loop for this thread loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete( - GenerateImages._process_single_illustration_impl(i, segment, prompt, config, images_dir) - ) + GenerateImages._process_single_illustration_impl( + i, segment, prompt, config, images_dir)) loop.run_until_complete( - GenerateImages._process_foreground_illustration_impl(i, segment, config, images_dir) - ) + GenerateImages._process_foreground_illustration_impl( + i, segment, config, images_dir)) finally: loop.close() @staticmethod - async def _process_single_illustration_impl(i, segment, prompt, config, images_dir): + async def _process_single_illustration_impl(i, segment, prompt, config, + images_dir): """Implementation of single illustration processing""" if config.background != 'image': # Generate a 2000x2000 solid color image - logger.info(f'Generating solid color background for segment {i + 1}.') + logger.info( + f'Generating solid color background for segment {i + 1}.') output_path = os.path.join(images_dir, f'illustration_{i + 1}.png') if not os.path.exists(output_path): # Create a 2000x2000 image with the color defined in config.background @@ -97,7 +115,8 @@ async def _process_single_illustration_impl(i, segment, prompt, config, images_d img.save(output_path) else: logger.info(f'Generating image for: {prompt}.') - img_path = os.path.join(images_dir, f'illustration_{i + 1}_origin.png') + img_path = os.path.join(images_dir, + f'illustration_{i + 1}_origin.png') output_path = os.path.join(images_dir, f'illustration_{i + 1}.png') if os.path.exists(output_path): return @@ -114,13 +133,16 @@ async def _process_single_illustration_impl(i, segment, prompt, config, images_d elif hasattr(_config.image_generator, 'size'): kwargs['size'] = _config.image_generator.size - logger.info(f'Generating image. Prompt: {prompt[:50]}... kwargs: {kwargs}') + logger.info( + f'Generating image. Prompt: {prompt[:50]}... kwargs: {kwargs}') _temp_file = await image_generator.generate_image(prompt, **kwargs) # Check directly if the return is a valid file path if not _temp_file or not os.path.exists(_temp_file): - logger.error(f'Background image generation failed for segment {i + 1}. Result: {_temp_file}') + logger.error( + f'Background image generation failed for segment {i + 1}. Result: {_temp_file}' + ) return shutil.move(_temp_file, img_path) @@ -132,22 +154,27 @@ async def _process_single_illustration_impl(i, segment, prompt, config, images_d pass @staticmethod - async def _process_foreground_illustration_impl(i, segment, config, images_dir): + async def _process_foreground_illustration_impl(i, segment, config, + images_dir): """Implementation of foreground illustration processing""" if config.foreground != 'image': return logger.info(f'Generating foreground image for: segment {i}.') work_dir = getattr(config, 'output_dir', 'output') - illustration_prompts_dir = os.path.join(work_dir, 'illustration_prompts') + illustration_prompts_dir = os.path.join(work_dir, + 'illustration_prompts') foreground_assets = segment.get('foreground', []) for idx, _req in enumerate(foreground_assets): - foreground_image = os.path.join(images_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') + foreground_image = os.path.join( + images_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') if os.path.exists(foreground_image): continue - foreground_prompt_path = os.path.join(illustration_prompts_dir, f'segment_{i + 1}_foreground_{idx + 1}.txt') + foreground_prompt_path = os.path.join( + illustration_prompts_dir, + f'segment_{i+1}_foreground_{idx+1}.txt') assert os.path.exists(foreground_prompt_path) @@ -155,7 +182,9 @@ async def _process_foreground_illustration_impl(i, segment, config, images_dir): prompt_text = f.read() # Clean Prompt from Thinking process - prompt = re.sub(r'.*?', '', prompt_text, flags=re.DOTALL).strip() + prompt = re.sub( + r'.*?', '', prompt_text, + flags=re.DOTALL).strip() _config = deepcopy(config) _config.tools.image_generator = _config.image_generator @@ -176,12 +205,18 @@ async def _process_foreground_illustration_impl(i, segment, config, images_dir): os.remove(_temp_file) @staticmethod - def fade(input_image, output_image, segment, fade_factor=0.3, brightness_boost=80, opacity=1.0): + def fade(input_image, + output_image, + segment, + fade_factor=0.3, + brightness_boost=80, + opacity=1.0): # Support both 'manim' and 'remotion' keys for animation detection has_animation = segment.get('manim') or segment.get('remotion') img = Image.open(input_image).convert('RGBA') if has_animation: - logger.info('Applying fade effect to background image (Animation present)') + logger.info( + 'Applying fade effect to background image (Animation present)') arr = np.array(img, dtype=np.float32) arr[..., :3] = arr[..., :3] * fade_factor + brightness_boost arr[..., :3] = np.clip(arr[..., :3], 0, 255) diff --git a/projects/singularity_cinema/generate_script/agent.py b/projects/singularity_cinema/generate_script/agent.py index 694f626e2..fc0a726cc 100644 --- a/projects/singularity_cinema/generate_script/agent.py +++ b/projects/singularity_cinema/generate_script/agent.py @@ -3,17 +3,21 @@ from copy import deepcopy from typing import List -from omegaconf import DictConfig - from ms_agent import LLMAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() class GenerateScript(LLMAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') extra = getattr(self.config, 'extra_requirement', '') diff --git a/projects/singularity_cinema/generate_subtitle/agent.py b/projects/singularity_cinema/generate_subtitle/agent.py index 64904af99..efa873851 100644 --- a/projects/singularity_cinema/generate_subtitle/agent.py +++ b/projects/singularity_cinema/generate_subtitle/agent.py @@ -1,17 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json import os import re from typing import List +import json import matplotlib.font_manager as fm -from omegaconf import DictConfig -from PIL import Image, ImageDraw, ImageFont - from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils import get_logger +from omegaconf import DictConfig +from PIL import Image, ImageDraw, ImageFont logger = get_logger() @@ -39,7 +38,7 @@ def _chunk_tokens(tokens: List[str], max_len: int) -> List[str]: if len(t) > max_len: # If a single token exceeds max_len, split it for i in range(0, len(t), max_len): - sub = t[i : i + max_len] + sub = t[i:i + max_len] if cur: chunks.append(cur.strip()) cur = '' @@ -52,7 +51,8 @@ def _chunk_tokens(tokens: List[str], max_len: int) -> List[str]: continue # If t is punctuation and can be merged with previous chunk (allowing slight overflow) - if _is_punct(t) and cur and len(cur) + len(t) <= max_len + PUNCTUATION_OVERFLOW_ALLOWANCE: + if _is_punct(t) and cur and len(cur) + len( + t) <= max_len + PUNCTUATION_OVERFLOW_ALLOWANCE: cur = (cur + t).strip() + ' ' continue @@ -81,11 +81,17 @@ def _clean_chunks(chunks: List[str], max_len: int) -> List[str]: class GenerateSubtitle(CodeAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.llm: OpenAI = LLM.from_config(self.config) - self.subtitle_translate = getattr(self.config, 'subtitle_translate', None) + self.subtitle_translate = getattr(self.config, 'subtitle_translate', + None) self.subtitle_dir = os.path.join(self.work_dir, 'subtitles') os.makedirs(self.subtitle_dir, exist_ok=True) self.fonts = self.config.fonts @@ -102,15 +108,20 @@ async def execute_code(self, messages, **kwargs): for j, chunk_text in enumerate(text_chunks): subtitle = None if self.subtitle_translate: - subtitle = await self.translate_text(chunk_text, self.subtitle_translate) + subtitle = await self.translate_text( + chunk_text, self.subtitle_translate) - output_file = os.path.join(self.subtitle_dir, f'bilingual_subtitle_{i + 1}_{j}.png') + output_file = os.path.join( + self.subtitle_dir, f'bilingual_subtitle_{i + 1}_{j}.png') if os.path.exists(output_file): continue self.create_bilingual_subtitle_image( - source=chunk_text, target=subtitle, output_file=output_file, width=1720, height=180 - ) + source=chunk_text, + target=subtitle, + output_file=output_file, + width=1720, + height=180) return messages def split_text_to_chunks(self, text, max_len: int = 30): @@ -122,6 +133,7 @@ def split_text_to_chunks(self, text, max_len: int = 30): return _clean_chunks(chunks, max_len) async def translate_text(self, text, to_lang): + prompt = f"""You are a professional translation expert specializing in accurately and fluently translating text into {to_lang}. ## Skills @@ -135,7 +147,7 @@ async def translate_text(self, text, to_lang): - Output only the translation result without any explanations. Now translate: -""" # noqa +""" # noqa messages = [ Message(role='system', content=prompt), Message(role='user', content=text), @@ -189,16 +201,14 @@ def smart_wrap_text(self, text, max_lines=2, chars_per_line=50): return lines if lines else [text] - def create_subtitle_image( - self, - text, - width=1720, - height=120, - font_size=28, - text_color='black', - bg_color='rgba(0,0,0,0)', - chars_per_line=50, - ): + def create_subtitle_image(self, + text, + width=1720, + height=120, + font_size=28, + text_color='black', + bg_color='rgba(0,0,0,0)', + chars_per_line=50): font = self.get_font(font_size) min_font_size = 18 max_height = 500 @@ -207,13 +217,15 @@ def create_subtitle_image( while font_size >= min_font_size: if font_size != original_font_size: font = self.get_font(font_size) - lines = self.smart_wrap_text(text, max_lines=2, chars_per_line=chars_per_line) + lines = self.smart_wrap_text( + text, max_lines=2, chars_per_line=chars_per_line) line_height = font_size + 8 total_text_height = len(lines) * line_height all_lines_fit = True for line in lines: - bbox = ImageDraw.Draw(Image.new('RGB', (1, 1))).textbbox((0, 0), line, font=font) + bbox = ImageDraw.Draw(Image.new('RGB', (1, 1))).textbbox( + (0, 0), line, font=font) line_width = bbox[2] - bbox[0] if line_width > width * 0.95: all_lines_fit = False @@ -245,18 +257,28 @@ def create_subtitle_image( draw.text((x, y), line, fill=text_color, font=font) return img, actual_height - def create_bilingual_subtitle_image(self, source, output_file, target='', width=1720, height=180): + def create_bilingual_subtitle_image(self, + source, + output_file, + target='', + width=1720, + height=180): main_font_size = 32 target_font_size = 22 main_target_gap = 6 pattern = r'^[a-zA-Z0-9\s.,!?;:\'"()-]+$' chars_per_line = 50 if not bool(re.match(pattern, source)) else 100 if target: - target_chars_per_line = 50 if not bool(re.match(pattern, target)) else 100 + target_chars_per_line = 50 if not bool(re.match(pattern, + target)) else 100 main_img, main_height = self.create_subtitle_image( - source, width, height, main_font_size, 'black', chars_per_line=chars_per_line - ) + source, + width, + height, + main_font_size, + 'black', + chars_per_line=chars_per_line) if target and target.strip(): target_chars_per_line = 100 @@ -266,12 +288,13 @@ def create_bilingual_subtitle_image(self, source, output_file, target='', width= height, target_font_size, '#404040', # Darker gray for better visibility - chars_per_line=target_chars_per_line, - ) + chars_per_line=target_chars_per_line) total_height = main_height + target_height + main_target_gap - combined_img = Image.new('RGBA', (width, total_height), (0, 0, 0, 0)) + combined_img = Image.new('RGBA', (width, total_height), + (0, 0, 0, 0)) combined_img.paste(main_img, (0, 0), main_img) - combined_img.paste(target_img, (0, main_height + main_target_gap), target_img) + combined_img.paste(target_img, (0, main_height + main_target_gap), + target_img) final_img = combined_img final_height = total_height else: diff --git a/projects/singularity_cinema/generate_video/agent.py b/projects/singularity_cinema/generate_video/agent.py index f58e34c3c..aaecd7d5f 100644 --- a/projects/singularity_cinema/generate_video/agent.py +++ b/projects/singularity_cinema/generate_video/agent.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio -import json import os import shutil from concurrent.futures import ThreadPoolExecutor @@ -8,18 +7,23 @@ from typing import List, Union import aiohttp -from omegaconf import DictConfig - +import json from ms_agent.agent import CodeAgent from ms_agent.llm import Message from ms_agent.tools.video_generator import VideoGenerator from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() class GenerateVideo(CodeAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 't2v_num_parallel', 1) @@ -27,24 +31,30 @@ def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False self.videos_dir = os.path.join(self.work_dir, 'videos') os.makedirs(self.videos_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], + **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) video_prompts = [] for i in range(len(segments)): if 'video' in segments[i]: - with open(os.path.join(self.video_prompts_dir, f'segment_{i + 1}.txt'), 'r') as f: + with open( + os.path.join(self.video_prompts_dir, + f'segment_{i + 1}.txt'), 'r') as f: video_prompts.append(f.read()) else: video_prompts.append(None) logger.info('Generating videos.') - tasks = [(i, segment, prompt) for i, (segment, prompt) in enumerate(zip(segments, video_prompts))] + tasks = [(i, segment, prompt) + for i, (segment, + prompt) in enumerate(zip(segments, video_prompts))] # Use ThreadPoolExecutor for parallel execution with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: futures = [ - executor.submit(self._process_single_video_static, i, segment, prompt, self.config, self.videos_dir) + executor.submit(self._process_single_video_static, i, segment, + prompt, self.config, self.videos_dir) for i, segment, prompt in tasks ] # Wait for all tasks to complete @@ -60,19 +70,25 @@ def _process_single_video_static(i, segment, prompt, config, videos_dir): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - loop.run_until_complete(GenerateVideo._process_single_video_impl(i, segment, prompt, config, videos_dir)) + loop.run_until_complete( + GenerateVideo._process_single_video_impl( + i, segment, prompt, config, videos_dir)) finally: loop.close() @staticmethod - async def _process_single_video_impl(i, segment, prompt, config, videos_dir): + async def _process_single_video_impl(i, segment, prompt, config, + videos_dir): if prompt is None: - logger.info(f'Skipping video generation for segment {i + 1} (no video prompt).') + logger.info( + f'Skipping video generation for segment {i + 1} (no video prompt).' + ) return output_path = os.path.join(videos_dir, f'video_{i + 1}.mp4') if os.path.exists(output_path): - logger.info(f'Video already exists for segment {i + 1}: {output_path}') + logger.info( + f'Video already exists for segment {i + 1}: {output_path}') return logger.info(f'Generating video for segment {i + 1}: {prompt}') @@ -91,5 +107,6 @@ async def _process_single_video_impl(i, segment, prompt, config, videos_dir): _config.tools.video_generator = _config.video_generator video_generator = VideoGenerator(_config) - _temp_file = await video_generator.generate_video(prompt, seconds=fit_duration) + _temp_file = await video_generator.generate_video( + prompt, seconds=fit_duration) shutil.move(_temp_file, output_path) diff --git a/projects/singularity_cinema/generate_video_prompts/agent.py b/projects/singularity_cinema/generate_video_prompts/agent.py index 6a66e562d..2cd091f3b 100644 --- a/projects/singularity_cinema/generate_video_prompts/agent.py +++ b/projects/singularity_cinema/generate_video_prompts/agent.py @@ -1,20 +1,20 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json import os from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Union -from omegaconf import DictConfig - +import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() class GenerateVideoPrompts(CodeAgent): - system = """ + + system = (""" You are an expert in creating scene descriptions for video generation. Based on given knowledge points or storyboard scripts, generate detailed English descriptions for creating text-to-video content that align with specified themes and styles. @@ -33,16 +33,21 @@ class GenerateVideoPrompts(CodeAgent): - Output approximately 200 words in English. - Return ONLY the prompt description. Do not include style keywords unless requested, and do not add explanations or markers. - """ + """) - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 'llm_num_parallel', 10) self.video_prompts_dir = os.path.join(self.work_dir, 'video_prompts') os.makedirs(self.video_prompts_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], + **kwargs) -> List[Message]: if not self.config.use_text2video: return messages with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: @@ -55,30 +60,27 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: futures = { - executor.submit( - self._generate_video_prompts_static, - i, - segment, - self.config, - topic, - self.system, - self.video_prompts_dir, - ): i - for i, segment in tasks - if 'video' in segment + executor.submit(self._generate_video_prompts_static, i, + segment, self.config, topic, self.system, + self.video_prompts_dir): i + for i, segment in tasks if 'video' in segment } for future in as_completed(futures): future.result() return messages @staticmethod - def _generate_video_prompts_static(i, segment, config, topic, system, video_prompts_dir): + def _generate_video_prompts_static(i, segment, config, topic, system, + video_prompts_dir): llm = LLM.from_config(config) - GenerateVideoPrompts._generate_video_prompt_impl(llm, i, segment, topic, system, video_prompts_dir, config) + GenerateVideoPrompts._generate_video_prompt_impl( + llm, i, segment, topic, system, video_prompts_dir, config) @staticmethod - def _generate_video_prompt_impl(llm, i, segment, topic, system, video_prompts_dir, config): - if os.path.exists(os.path.join(video_prompts_dir, f'segment_{i + 1}.txt')): + def _generate_video_prompt_impl(llm, i, segment, topic, system, + video_prompts_dir, config): + if os.path.exists( + os.path.join(video_prompts_dir, f'segment_{i+1}.txt')): return work_dir = os.path.dirname(video_prompts_dir) @@ -93,12 +95,10 @@ def _generate_video_prompt_impl(llm, i, segment, topic, system, video_prompts_di break video = segment['video'] - query = ( - f'The user original request is: {topic}, ' - f'illustration based on: {segment["content"]}, ' - f'Video duration: {fit_duration}, ' - f'Requirements from the storyboard designer: {video}' - ) + query = (f'The user original request is: {topic}, ' + f'illustration based on: {segment["content"]}, ' + f'Video duration: {fit_duration}, ' + f'Requirements from the storyboard designer: {video}') logger.info(f'Generating video prompt for : {segment["content"]}.') inputs = [ Message(role='system', content=system), @@ -107,5 +107,7 @@ def _generate_video_prompt_impl(llm, i, segment, topic, system, video_prompts_di _response_message = llm.generate(inputs) response = _response_message.content prompt = response.strip() - with open(os.path.join(video_prompts_dir, f'segment_{i + 1}.txt'), 'w') as f: + with open( + os.path.join(video_prompts_dir, f'segment_{i + 1}.txt'), + 'w') as f: f.write(prompt) diff --git a/projects/singularity_cinema/parse_images/agent.py b/projects/singularity_cinema/parse_images/agent.py index 8a5e324c7..138803dc0 100644 --- a/projects/singularity_cinema/parse_images/agent.py +++ b/projects/singularity_cinema/parse_images/agent.py @@ -1,39 +1,45 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 import hashlib -import json import os import re from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from urllib.request import urlretrieve -from omegaconf import DictConfig -from PIL import Image - +import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils import get_logger +from omegaconf import DictConfig +from PIL import Image logger = get_logger() class ParseImages(CodeAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') _config = deepcopy(config) delattr(_config, 'llm') _config.llm = DictConfig({}) for key, value in _config.mllm.items(): - key = key[len('mllm_') :] + key = key[len('mllm_'):] setattr(_config.llm, key, value) _config.generation_config = DictConfig({'temperature': 0.3}) if 'extra_body' in config.generation_config: _config.generation_config.extra_body = config.generation_config.extra_body self.mllm: OpenAI = LLM.from_config(_config) - logger.info(f"Using MLLM for image parsing: {getattr(self.mllm, 'model', None)}") + logger.info( + f"Using MLLM for image parsing: {getattr(self.mllm, 'model', None)}" + ) self.image_dir = os.path.join(self.work_dir, 'images') os.makedirs(self.image_dir, exist_ok=True) @@ -124,18 +130,22 @@ def get_image_description(self, filename): image_data = image_file.read() base64_image = base64.b64encode(image_data).decode('utf-8') - _content = [ - { - 'type': 'text', - 'text': ( - 'Describe this image in under 50 words. Be objective and accurate. For charts/graphs, ' - 'analyze axis labels and data to explain what the chart shows and its purpose, ' - 'not just the chart type. Provide enough detail to distinguish it from other images.' - 'Return only the requested image description. Do not add any other content.' - ), - }, - {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}', 'detail': 'high'}}, - ] + _content = [{ + 'type': + 'text', + 'text': + ('Describe this image in under 50 words. Be objective and accurate. For charts/graphs, ' + 'analyze axis labels and data to explain what the chart shows and its purpose, ' + 'not just the chart type. Provide enough detail to distinguish it from other images.' + 'Return only the requested image description. Do not add any other content.' + ) + }, { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}', + 'detail': 'high' + } + }] messages = [ Message(role='user', content=_content), diff --git a/projects/singularity_cinema/render_animation/agent.py b/projects/singularity_cinema/render_animation/agent.py index 1340608a6..1b3007c66 100644 --- a/projects/singularity_cinema/render_animation/agent.py +++ b/projects/singularity_cinema/render_animation/agent.py @@ -3,13 +3,17 @@ import os import sys -from omegaconf import DictConfig - from ms_agent.agent import CodeAgent +from omegaconf import DictConfig class RenderAnimation(CodeAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) async def execute_code(self, messages, **kwargs): @@ -17,15 +21,15 @@ async def execute_code(self, messages, **kwargs): sys.path.insert(0, os.path.dirname(__file__)) if engine == 'manim': from render_manim import RenderManim - sys.path.pop(0) - agent = RenderManim(self.config, self.tag, self.trust_remote_code, **kwargs) + agent = RenderManim(self.config, self.tag, self.trust_remote_code, + **kwargs) return await agent.execute_code(messages, **kwargs) elif engine == 'remotion': from render_remotion import RenderRemotion - sys.path.pop(0) - agent = RenderRemotion(self.config, self.tag, self.trust_remote_code, **kwargs) + agent = RenderRemotion(self.config, self.tag, + self.trust_remote_code, **kwargs) return await agent.execute_code(messages, **kwargs) else: raise ValueError(f'Unknown animation engine: {engine}') diff --git a/projects/singularity_cinema/render_animation/render_manim.py b/projects/singularity_cinema/render_animation/render_manim.py index 59686e226..19b9abc27 100644 --- a/projects/singularity_cinema/render_animation/render_manim.py +++ b/projects/singularity_cinema/render_animation/render_manim.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 -import json import os import re import shutil @@ -10,41 +9,48 @@ from os import getcwd from typing import List, Union +import json from moviepy import VideoFileClip -from omegaconf import DictConfig -from PIL import Image - from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger +from omegaconf import DictConfig +from PIL import Image logger = get_logger() class RenderManim(CodeAgent): + window_size = (1250, 700) - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) if not self.config.use_subtitle: self.window_size = (1450, 800) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 'llm_num_parallel', 10) self.manim_render_timeout = getattr( - self.config, 'animation_render_timeout', getattr(self.config, 'manim_render_timeout', 300) - ) + self.config, 'animation_render_timeout', + getattr(self.config, 'manim_render_timeout', 300)) self.render_dir = os.path.join(self.work_dir, 'manim_render') self.code_fix_round = getattr(self.config, 'code_fix_round', 5) self.mllm_check_round = getattr(self.config, 'mllm_fix_round', 1) os.makedirs(self.render_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], + **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) manim_code_dir = os.path.join(self.work_dir, 'manim_code') manim_code = [] for i in range(len(segments)): - with open(os.path.join(manim_code_dir, f'segment_{i + 1}.py'), 'r') as f: + with open(os.path.join(manim_code_dir, f'segment_{i+1}.py'), + 'r') as f: manim_code.append(f.read()) with open(os.path.join(self.work_dir, 'audio_info.txt'), 'r') as f: audio_infos = json.load(f) @@ -53,25 +59,17 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L tasks = [ (i, segment, code, audio_info['audio_duration']) - for i, (segment, code, audio_info) in enumerate(zip(segments, manim_code, audio_infos)) + for i, (segment, code, audio_info + ) in enumerate(zip(segments, manim_code, audio_infos)) ] with ThreadPoolExecutor(max_workers=self.num_parallel) as executor: futures = { - executor.submit( - self._render_manim_scene_static, - i, - segment, - code, - duration, - self.config, - self.work_dir, - self.render_dir, - self.window_size, - self.manim_render_timeout, - self.code_fix_round, - self.mllm_check_round, - ): i + executor.submit(self._render_manim_scene_static, i, segment, + code, duration, self.config, self.work_dir, + self.render_dir, self.window_size, + self.manim_render_timeout, self.code_fix_round, + self.mllm_check_round): i for i, segment, code, duration in tasks } for future in as_completed(futures): @@ -80,54 +78,23 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L return messages @staticmethod - def _render_manim_scene_static( - i, - segment, - code, - audio_duration, - config, - work_dir, - render_dir, - window_size, - manim_render_timeout, - code_fix_round, - mllm_check_round, - ): + def _render_manim_scene_static(i, segment, code, audio_duration, config, + work_dir, render_dir, window_size, + manim_render_timeout, code_fix_round, + mllm_check_round): """Static method for multiprocessing""" llm = LLM.from_config(config) - return RenderManim._render_manim_impl( - llm, - i, - segment, - code, - audio_duration, - work_dir, - render_dir, - window_size, - manim_render_timeout, - config, - code_fix_round, - mllm_check_round, - ) + return RenderManim._render_manim_impl(llm, i, segment, code, + audio_duration, work_dir, + render_dir, window_size, + manim_render_timeout, config, + code_fix_round, mllm_check_round) @staticmethod - def _render_manim_impl( - llm, - i, - segment, - code, - audio_duration, - work_dir, - render_dir, - window_size, - manim_render_timeout, - config, - code_fix_round, - mllm_check_round, - ): - scene_name = ( - f'Scene{i + 1}' # sometimes actual_scene_name cannot find matched class, so do not change this name - ) + def _render_manim_impl(llm, i, segment, code, audio_duration, work_dir, + render_dir, window_size, manim_render_timeout, + config, code_fix_round, mllm_check_round): + scene_name = f'Scene{i+1}' # sometimes actual_scene_name cannot find matched class, so do not change this name logger.info(f'Rendering manim code for: scene_{i + 1}') output_dir = os.path.join(render_dir, f'scene_{i + 1}') os.makedirs(output_dir, exist_ok=True) @@ -159,16 +126,10 @@ def _render_manim_impl( env['LC_ALL'] = 'zh_CN.UTF-8' window_size_str = ','.join([str(x) for x in window_size]) cmd = [ - 'manim', - 'render', - '-ql', - '--transparent', - '--format=mov', - f'--resolution={window_size_str}', - '--disable_caching', - f'--media_dir={os.path.dirname(code_file)}', - code_file, - actual_scene_name, + 'manim', 'render', '-ql', '--transparent', '--format=mov', + f'--resolution={window_size_str}', '--disable_caching', + f'--media_dir={os.path.dirname(code_file)}', code_file, + actual_scene_name ] try: @@ -180,14 +141,15 @@ def _render_manim_impl( text=True, encoding='utf-8', errors='ignore', - env=env, - ) + env=env) # Wait for process to complete with timeout - stdout, stderr = process.communicate(timeout=manim_render_timeout) + stdout, stderr = process.communicate( + timeout=manim_render_timeout) # Create result object compatible with original logic class Result: + def __init__(self, returncode, stdout, stderr): self.returncode = returncode self.stdout = stdout @@ -196,64 +158,41 @@ def __init__(self, returncode, stdout, stderr): result = Result(process.returncode, stdout, stderr) output_text = (result.stdout or '') + (result.stderr or '') except subprocess.TimeoutExpired as e: - output_text = (e.stdout.decode('utf-8', errors='ignore') if e.stdout else '') + ( - e.stderr.decode('utf-8', errors='ignore') if e.stderr else '' - ) # noqa + output_text = (e.stdout.decode('utf-8', errors='ignore') + if e.stdout else '') + ( + e.stderr.decode('utf-8', errors='ignore') + if e.stderr else '') # noqa logger.error( f'Manim rendering timed out after {manim_render_timeout} ' - f'seconds for {actual_scene_name}, output: {output_text}' - ) + f'seconds for {actual_scene_name}, output: {output_text}') logger.info('Trying to fix manim code.') code, fix_history = RenderManim._fix_manim_code_impl( - llm, - output_text, - fix_history, - code, - manim_requirement, - class_name, - content, - audio_duration, - segment, - i, - work_dir, - ) + llm, output_text, fix_history, code, manim_requirement, + class_name, content, audio_duration, segment, i, work_dir) continue if result.returncode != 0: - logger.warning(f'Manim command exited with code {result.returncode}') + logger.warning( + f'Manim command exited with code {result.returncode}') logger.warning(f'Output: {output_text}') real_error_indicators = [ - 'SyntaxError', - 'NameError', - 'ImportError', - 'AttributeError', - 'TypeError', - 'ValueError', - 'ModuleNotFoundError', - 'Traceback', - 'Error:', - 'Failed to render', - 'unexpected keyword argument', - 'got an unexpected', - 'invalid syntax', + 'SyntaxError', 'NameError', 'ImportError', + 'AttributeError', 'TypeError', 'ValueError', + 'ModuleNotFoundError', 'Traceback', 'Error:', + 'Failed to render', 'unexpected keyword argument', + 'got an unexpected', 'invalid syntax' ] - if any([error_indicator in output_text for error_indicator in real_error_indicators]): + if any([ + error_indicator in output_text + for error_indicator in real_error_indicators + ]): logger.info('Trying to fix manim code.') code, fix_history = RenderManim._fix_manim_code_impl( - llm, - output_text, - fix_history, - code, - manim_requirement, - class_name, - content, - audio_duration, - segment, - i, - work_dir, - ) + llm, output_text, fix_history, code, manim_requirement, + class_name, content, audio_duration, segment, i, + work_dir) continue for root, dirs, files in os.walk(output_dir): @@ -261,37 +200,31 @@ def __init__(self, returncode, stdout, stderr): if file == f'{actual_scene_name}.mov': found_file = os.path.join(root, file) if not RenderManim.verify_and_fix_mov_file(found_file): - fixed_path = RenderManim.convert_mov_to_compatible(found_file) + fixed_path = RenderManim.convert_mov_to_compatible( + found_file) if fixed_path: found_file = fixed_path shutil.copy2(found_file, output_path) - scaled_path = RenderManim.scale_video_to_fit(output_path, target_size=window_size) + scaled_path = RenderManim.scale_video_to_fit( + output_path, target_size=window_size) if scaled_path and scaled_path != output_path: shutil.rmtree(output_path, ignore_errors=True) shutil.copy2(scaled_path, output_path) final_file_path = output_path if not final_file_path: - logger.error(f'Manim file: {class_name} not found, trying to fix manim code.') - code, fix_history = RenderManim._fix_manim_code_impl( - llm, - output_text, - fix_history, - code, - manim_requirement, - class_name, - content, - audio_duration, - segment, - i, - work_dir, + logger.error( + f'Manim file: {class_name} not found, trying to fix manim code.' ) + code, fix_history = RenderManim._fix_manim_code_impl( + llm, output_text, fix_history, code, manim_requirement, + class_name, content, audio_duration, segment, i, work_dir) else: if cur_check_round >= mllm_max_check_round: break output_text = RenderManim.check_manim_quality( - final_file_path, work_dir, i, config, segment, cur_check_round - ) + final_file_path, work_dir, i, config, segment, + cur_check_round) cur_check_round += 1 if output_text: try: @@ -300,26 +233,18 @@ def __init__(self, returncode, stdout, stderr): except OSError: pass logger.info( - f'Trying to fix manim code of segment {i + 1}, because model checking not passed: \n{output_text}' + f'Trying to fix manim code of segment {i+1}, because model checking not passed: \n{output_text}' ) code, fix_history = RenderManim._fix_manim_code_impl( - llm, - output_text, - fix_history, - code, - manim_requirement, - class_name, - content, - audio_duration, - segment, - i, - work_dir, - ) + llm, output_text, fix_history, code, manim_requirement, + class_name, content, audio_duration, segment, i, + work_dir) continue else: break if final_file_path: - RenderManim._extract_preview_frames_static(final_file_path, i, work_dir, 'final') + RenderManim._extract_preview_frames_static(final_file_path, i, + work_dir, 'final') manim_code_dir = os.path.join(work_dir, 'manim_code') manim_file = os.path.join(manim_code_dir, f'segment_{i + 1}.py') with open(manim_file, 'w') as f: @@ -328,13 +253,14 @@ def __init__(self, returncode, stdout, stderr): raise FileNotFoundError(final_file_path) @staticmethod - def check_manim_quality(final_file_path, work_dir, i, config, segment, cur_check_round): + def check_manim_quality(final_file_path, work_dir, i, config, segment, + cur_check_round): _mm_config = deepcopy(config) delattr(_mm_config, 'llm') _mm_config.llm = DictConfig({}) _mm_config.generation_config = DictConfig({'temperature': 0.3}) for key, value in _mm_config.mllm.items(): - key = key[len('mllm_') :] + key = key[len('mllm_'):] setattr(_mm_config.llm, key, value) test_system = """**Role Definition** @@ -388,34 +314,44 @@ def check_manim_quality(final_file_path, work_dir, i, config, segment, cur_check The right component is squeezed to the edge. Fix suggestion: Reduce the width of the four left components, move the right component further right... ``` -""" # noqa +"""# noqa - test_images = RenderManim._extract_preview_frames_static(final_file_path, i, work_dir, cur_check_round) + test_images = RenderManim._extract_preview_frames_static( + final_file_path, i, work_dir, cur_check_round) llm = LLM.from_config(_mm_config) - logger.info(f"Using mllm model for manim quality check: {getattr(llm, 'model', None)}") + logger.info( + f"Using mllm model for manim quality check: {getattr(llm, 'model', None)}" + ) - frame_names = ['the middle frame of the animation', 'the last frame of the animation'] + frame_names = [ + 'the middle frame of the animation', + 'the last frame of the animation' + ] content = segment['content'] manim_requirement = segment['manim'] all_issues = [] - for idx, (image_path, frame_name) in enumerate(zip(test_images, frame_names)): + for idx, (image_path, + frame_name) in enumerate(zip(test_images, frame_names)): with open(image_path, 'rb') as image_file: image_data = image_file.read() base64_image = base64.b64encode(image_data).decode('utf-8') - _content = [ - { - 'type': 'text', - 'text': ( - f'The checked frame is: {frame_name} of this animation\n' - f'The content of this animation: {content}\n' - f'The manim animation requirement: {manim_requirement}\n' - f'You must carefully check the animation layout issues.' - ), - }, - {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}', 'detail': 'high'}}, - ] + _content = [{ + 'type': + 'text', + 'text': + (f'The checked frame is: {frame_name} of this animation\n' + f'The content of this animation: {content}\n' + f'The manim animation requirement: {manim_requirement}\n' + f'You must carefully check the animation layout issues.') + }, { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/png;base64,{base64_image}', + 'detail': 'high' + } + }] messages = [ Message(role='system', content=test_system), @@ -430,7 +366,8 @@ def check_manim_quality(final_file_path, work_dir, i, config, segment, cur_check issues.append(issue) issues = '\n'.join(issues).strip() if issues: - issues = f'The checked frame is: {frame_name}\nProblems found: {issues}\n' + issues = (f'The checked frame is: {frame_name}\n' + f'Problems found: {issues}\n') pattern = r'(.*?)' desc = [] @@ -438,14 +375,17 @@ def check_manim_quality(final_file_path, work_dir, i, config, segment, cur_check desc.append(_desc) desc = '\n'.join(desc).strip() if issues and desc: - issues = f'{issues}The detail description of this frame: {desc}\n' + issues = (f'{issues}' + f'The detail description of this frame: {desc}\n') all_issues.append(issues) all_issues = '\n\n'.join(all_issues).strip() return all_issues @staticmethod - def _extract_preview_frames_static(video_path, segment_id, work_dir, cur_check_round): + def _extract_preview_frames_static(video_path, segment_id, work_dir, + cur_check_round): + test_dir = os.path.join(work_dir, 'manim_test') os.makedirs(test_dir, exist_ok=True) video = VideoFileClip(video_path) @@ -455,7 +395,10 @@ def _extract_preview_frames_static(video_path, segment_id, work_dir, cur_check_r preview_paths = [] for frame_idx, timestamp in timestamps.items(): - output_path = os.path.join(test_dir, f'segment_{segment_id + 1}_round{cur_check_round}_{frame_idx}.png') + output_path = os.path.join( + test_dir, + f'segment_{segment_id + 1}_round{cur_check_round}_{frame_idx}.png' + ) video.save_frame(output_path, t=timestamp) preview_paths.append(output_path) video.close() @@ -471,7 +414,8 @@ def get_all_images_info(segment, i, image_dir): all_images_info = [] foreground = segment.get('foreground', []) for idx, _req in enumerate(foreground): - foreground_image = os.path.join(image_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') + foreground_image = os.path.join( + image_dir, f'illustration_{i + 1}_foreground_{idx + 1}.png') size = RenderManim.get_image_size(foreground_image) image_info = { 'filename': foreground_image, @@ -480,7 +424,8 @@ def get_all_images_info(segment, i, image_dir): } all_images_info.append(image_info) - image_info_file = os.path.join(os.path.dirname(image_dir), 'image_info.txt') + image_info_file = os.path.join( + os.path.dirname(image_dir), 'image_info.txt') if os.path.exists(image_info_file): with open(image_info_file, 'r') as f: for line in f.readlines(): @@ -492,19 +437,9 @@ def get_all_images_info(segment, i, image_dir): return all_images_info @staticmethod - def _fix_manim_code_impl( - llm, - error_log, - fix_history, - manim_code, - manim_requirement, - class_name, - content, - audio_duration, - segment, - i, - work_dir, - ): + def _fix_manim_code_impl(llm, error_log, fix_history, manim_code, + manim_requirement, class_name, content, + audio_duration, segment, i, work_dir): image_dir = os.path.join(work_dir, 'images') images_info = RenderManim.get_all_images_info(segment, i, image_dir) @@ -520,7 +455,7 @@ def _fix_manim_code_impl( * Scale the images. Do not use the original size, carefully rescale the images to match the requirements below: * The image size on the canvas depend on its importance, important image occupies more spaces * Use 1/4 space of the canvas for each image -""" # noqa +""" # noqa else: image_prompt = '' @@ -586,7 +521,7 @@ def _fix_manim_code_impl( - **don't remove any image or its effects when making modifications** Please precisely fix the detected issues while maintaining the richness and creativity of the animation. -""" # noqa +""" # noqa inputs = [Message(role='user', content=fix_request)] _response_message = llm.generate(inputs) response = _response_message.content @@ -599,8 +534,7 @@ def _fix_manim_code_impl( fix_history = ( f'You have a fix history which generates the code which is given to you:\n\n{fix_request}\n\n' f'If last error is same with latest error, **You probably find a wrong root cause**, ' - f'Check carefully and fix it again.**' - ) + f'Check carefully and fix it again.**') return manim_code, fix_history @staticmethod @@ -622,8 +556,7 @@ def convert_mov_to_compatible(mov_path): fps=24, verbose=False, logger=None, - ffmpeg_params=['-pix_fmt', 'yuva420p'], - ) + ffmpeg_params=['-pix_fmt', 'yuva420p']) clip.close() if RenderManim.verify_and_fix_mov_file(fixed_path): @@ -659,8 +592,7 @@ def scale_video_to_fit(video_path, target_size=None): audio_codec='aac' if scaled_clip.audio else None, fps=24, verbose=False, - logger=None, - ) + logger=None) clip.close() scaled_clip.close() diff --git a/projects/singularity_cinema/render_animation/render_remotion.py b/projects/singularity_cinema/render_animation/render_remotion.py index cb3c75f7a..e58e2927e 100644 --- a/projects/singularity_cinema/render_animation/render_remotion.py +++ b/projects/singularity_cinema/render_animation/render_remotion.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json import os import re import shutil @@ -9,34 +8,42 @@ from collections import defaultdict from typing import List, Optional, Tuple, Union +import json from moviepy import VideoFileClip -from omegaconf import DictConfig - from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() class RenderRemotion(CodeAgent): - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): super().__init__(config, tag, trust_remote_code, **kwargs) self.work_dir = getattr(self.config, 'output_dir', 'output') self.num_parallel = getattr(self.config, 'llm_num_parallel', 5) # When enabled, render compositions one-by-one and attempt a fix immediately on failure. # This reduces wasted work when one broken Segment TSX causes global bundler failure. - self.render_immediate_fix = getattr(self.config, 'render_immediate_fix', True) + self.render_immediate_fix = getattr(self.config, + 'render_immediate_fix', True) self.render_dir = os.path.join(self.work_dir, 'remotion_render') - self.remotion_project_dir = os.path.join(self.work_dir, 'remotion_project') + self.remotion_project_dir = os.path.join(self.work_dir, + 'remotion_project') self.remotion_code_dir = os.path.join(self.work_dir, 'remotion_code') self.images_dir = os.path.join(self.work_dir, 'images') self.code_fix_round = getattr(self.config, 'code_fix_round', 3) # Default to 1 to ensure visual quality check runs at least once unless explicitly disabled (-1) self.mllm_check_round = getattr(self.config, 'mllm_fix_round', 1) # Maximum times to attempt automatic visual fixes per segment - self.max_visual_fix_rounds = getattr(self.config, 'max_visual_fix_rounds', 2) + self.max_visual_fix_rounds = getattr(self.config, + 'max_visual_fix_rounds', 2) # Track per-segment visual failure counts self.visual_fail_counts = defaultdict(int) # Track scale per segment for edge clipping retry @@ -44,7 +51,8 @@ def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False os.makedirs(self.render_dir, exist_ok=True) - async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> List[Message]: + async def execute_code(self, messages: Union[str, List[Message]], + **kwargs) -> List[Message]: with open(os.path.join(self.work_dir, 'segments.txt'), 'r') as f: segments = json.load(f) with open(os.path.join(self.work_dir, 'audio_info.txt'), 'r') as f: @@ -57,37 +65,51 @@ async def execute_code(self, messages: Union[str, List[Message]], **kwargs) -> L self._ensure_browser(self.remotion_project_dir) logger.info('Installing dependencies...') - subprocess.run('npm install', cwd=self.remotion_project_dir, shell=True, check=True) + subprocess.run( + 'npm install', + cwd=self.remotion_project_dir, + shell=True, + check=True) segment_status = { - i: os.path.exists(os.path.join(self.render_dir, f'scene_{i + 1}', f'Scene{i + 1}.mov')) + i: os.path.exists( + os.path.join(self.render_dir, f'scene_{i+1}', + f'Scene{i+1}.mov')) for i in range(len(segments)) } for round_idx in range(self.code_fix_round + 1): # Identify segments needing render (all initially, then only failed ones) - segments_to_render = [i for i, status in segment_status.items() if status is not True] + segments_to_render = [ + i for i, status in segment_status.items() if status is not True + ] if not segments_to_render: logger.info('All segments rendered successfully.') break - logger.info(f'Round {round_idx + 1}: Rendering {len(segments_to_render)} segments...') + logger.info( + f'Round {round_idx + 1}: Rendering {len(segments_to_render)} segments...' + ) results = {} def _read_current_code(seg_i: int) -> str: - code_path = os.path.join(self.remotion_code_dir, f'Segment{seg_i + 1}.tsx') + code_path = os.path.join(self.remotion_code_dir, + f'Segment{seg_i+1}.tsx') if os.path.exists(code_path): with open(code_path, 'r', encoding='utf-8') as f: return f.read() - project_code_path = os.path.join(self.remotion_project_dir, 'src', f'Segment{seg_i + 1}.tsx') + project_code_path = os.path.join(self.remotion_project_dir, + 'src', + f'Segment{seg_i+1}.tsx') if os.path.exists(project_code_path): with open(project_code_path, 'r', encoding='utf-8') as f: return f.read() return '' - def _extract_error_segment_indices(log_text: Optional[str]) -> List[int]: + def _extract_error_segment_indices( + log_text: Optional[str]) -> List[int]: if not log_text: return [] # esbuild/webpack error lines usually include: ...\src\Segment15.tsx:... @@ -122,7 +144,9 @@ def _extract_error_segment_indices(log_text: Optional[str]) -> List[int]: if not success and error_log and 'EDGE_CLIPPING' in error_log: new_scale = 0.8 self.segment_scales[i] = new_scale - logger.info(f'Edge clipping detected for segment {i + 1}, reducing scale to {new_scale}') + logger.info( + f'Edge clipping detected for segment {i+1}, reducing scale to {new_scale}' + ) # Update Root.tsx with new scale self._update_root_tsx_for_segment(i) segment_status[i] = False # Force retry @@ -134,12 +158,14 @@ def _extract_error_segment_indices(log_text: Optional[str]) -> List[int]: # If bundler fails globally, error_log points to the culprit file. culprit_indices = _extract_error_segment_indices(error_log) to_fix = culprit_indices if culprit_indices else [i] - to_fix = sorted({idx for idx in to_fix if 0 <= idx < len(segments)}) + to_fix = sorted( + {idx + for idx in to_fix if 0 <= idx < len(segments)}) # If the error points to OTHER segments, it means the current segment failed due to global breakage. # Pause and fix the root cause first. logger.info( - f'Immediate fix triggered by failure on segment {i + 1}. Fix targets: {[x + 1 for x in to_fix]}' + f'Immediate fix triggered by failure on segment {i+1}. Fix targets: {[x+1 for x in to_fix]}' ) # Apply fixes @@ -147,40 +173,49 @@ def _extract_error_segment_indices(log_text: Optional[str]) -> List[int]: err_text = error_log or 'Unknown error' current_code = _read_current_code(fix_i) _, fixed_code = self._fix_code_static( - fix_i, err_text, current_code, self.config, self.remotion_project_dir - ) + fix_i, err_text, current_code, self.config, + self.remotion_project_dir) if fixed_code: self._update_segment_code(fix_i, fixed_code) # If we fixed a different segment, we should probably reset its status too if fix_i != i: - segment_status[fix_i] = False # Force re-render of the culprit later if it was skipped + segment_status[ + fix_i] = False # Force re-render of the culprit later if it was skipped return messages def _update_segment_code(self, i, code): # Update in remotion_code_dir (source of truth) - src_file = os.path.join(self.remotion_code_dir, f'Segment{i + 1}.tsx') + src_file = os.path.join(self.remotion_code_dir, f'Segment{i+1}.tsx') with open(src_file, 'w', encoding='utf-8') as f: f.write(code) # Update in remotion_project_dir (execution env) - dst_file = os.path.join(self.remotion_project_dir, 'src', f'Segment{i + 1}.tsx') + dst_file = os.path.join(self.remotion_project_dir, 'src', + f'Segment{i+1}.tsx') with open(dst_file, 'w', encoding='utf-8') as f: f.write(code) def _setup_remotion_project(self, segments, audio_infos): # 1. Create project structure - os.makedirs(os.path.join(self.remotion_project_dir, 'src'), exist_ok=True) - os.makedirs(os.path.join(self.remotion_project_dir, 'public', 'images'), exist_ok=True) + os.makedirs( + os.path.join(self.remotion_project_dir, 'src'), exist_ok=True) + os.makedirs( + os.path.join(self.remotion_project_dir, 'public', 'images'), + exist_ok=True) # Some generated TSX may import assets via relative paths like `./images/foo.png`. # Keep a mirrored copy under `src/images` to avoid bundler module resolution failures. - os.makedirs(os.path.join(self.remotion_project_dir, 'src', 'images'), exist_ok=True) + os.makedirs( + os.path.join(self.remotion_project_dir, 'src', 'images'), + exist_ok=True) if os.path.exists(self.images_dir): for file in os.listdir(self.images_dir): src = os.path.join(self.images_dir, file) - dst_public = os.path.join(self.remotion_project_dir, 'public', 'images', file) - dst_src = os.path.join(self.remotion_project_dir, 'src', 'images', file) + dst_public = os.path.join(self.remotion_project_dir, 'public', + 'images', file) + dst_src = os.path.join(self.remotion_project_dir, 'src', + 'images', file) for dst in (dst_public, dst_src): shutil.copy(src, dst) @@ -196,18 +231,24 @@ def _setup_remotion_project(self, segments, audio_infos): # Extract filename from absolute path filename = os.path.basename(original_path) # Copy to public/images and src/images - dst_public = os.path.join(self.remotion_project_dir, 'public', 'images', filename) - dst_src = os.path.join(self.remotion_project_dir, 'src', 'images', filename) + dst_public = os.path.join(self.remotion_project_dir, + 'public', 'images', filename) + dst_src = os.path.join(self.remotion_project_dir, 'src', + 'images', filename) shutil.copy(original_path, dst_public) shutil.copy(original_path, dst_src) # Store mapping for path replacement user_image_mapping[original_path] = f'images/{filename}' - logger.info(f'Copied user image: {original_path} -> images/{filename}') + logger.info( + f'Copied user image: {original_path} -> images/{filename}' + ) # 3. Copy generated code and replace absolute paths for i in range(len(segments)): - src_file = os.path.join(self.remotion_code_dir, f'Segment{i + 1}.tsx') - dst_file = os.path.join(self.remotion_project_dir, 'src', f'Segment{i + 1}.tsx') + src_file = os.path.join(self.remotion_code_dir, + f'Segment{i+1}.tsx') + dst_file = os.path.join(self.remotion_project_dir, 'src', + f'Segment{i+1}.tsx') if os.path.exists(src_file): # Read file content with open(src_file, 'r', encoding='utf-8') as f: @@ -223,13 +264,13 @@ def _setup_remotion_project(self, segments, audio_infos): else: with open(dst_file, 'w') as f: f.write( - f"import React from 'react';\nexport const Segment{i + 1} = () =>
Missing Segment
;" + f"import React from 'react';\nexport const Segment{i+1} = () =>
Missing Segment
;" ) else: # Create a dummy file if missing to prevent build failure with open(dst_file, 'w') as f: f.write( - f"import React from 'react';\nexport const Segment{i + 1} = () =>
Missing Segment
;" + f"import React from 'react';\nexport const Segment{i+1} = () =>
Missing Segment
;" ) # 4. Create package.json with locked versions @@ -244,14 +285,18 @@ def _setup_remotion_project(self, segments, audio_infos): '@remotion/bundler': '^4.0.0', '@remotion/renderer': '^4.0.0', '@remotion/shapes': '^4.0.0', - '@remotion/media-utils': '^4.0.0', - }, + '@remotion/media-utils': '^4.0.0' + } } - with open(os.path.join(self.remotion_project_dir, 'package.json'), 'w') as f: + with open( + os.path.join(self.remotion_project_dir, 'package.json'), + 'w') as f: json.dump(package_json, f, indent=2) # 5. Create src/index.ts - with open(os.path.join(self.remotion_project_dir, 'src', 'index.ts'), 'w') as f: + with open( + os.path.join(self.remotion_project_dir, 'src', 'index.ts'), + 'w') as f: f.write("import { registerRoot } from 'remotion';\n") f.write("import { RemotionRoot } from './Root';\n") f.write('registerRoot(RemotionRoot);\n') @@ -260,7 +305,6 @@ def _setup_remotion_project(self, segments, audio_infos): self._generate_root_tsx(segments, audio_infos) # 7. Create tsconfig.json - def _generate_root_tsx(self, segments, audio_infos): """Generate Root.tsx with dynamic scale support""" fps = self.config.video.fps @@ -274,14 +318,13 @@ def _generate_root_tsx(self, segments, audio_infos): root_content = "import React from 'react';\n" root_content += "import { Composition } from 'remotion';\n" for i in range(len(segments)): - root_content += f"import * as Segment{i + 1}_NS from './Segment{i + 1}';\n" + root_content += f"import * as Segment{i+1}_NS from './Segment{i+1}';\n" root_content += '\nexport const RemotionRoot: React.FC = () => {\n' for i in range(len(segments)): root_content += ( - f' const Segment{i + 1} = Segment{i + 1}_NS.default || ' - f'Segment{i + 1}_NS.Segment{i + 1} || (() => null);\n' - ) + f' const Segment{i+1} = Segment{i+1}_NS.default || ' + f'Segment{i+1}_NS.Segment{i+1} || (() => null);\n') root_content += ' return (\n' root_content += ' <>\n' @@ -291,8 +334,8 @@ def _generate_root_tsx(self, segments, audio_infos): # Get scale from tracking dict or use default scale = self.segment_scales.get(i, 0.9) root_content += ' Tuple[int, bool, Optional[str]]: + i, + segment, + duration, + config, + work_dir, + render_dir, + remotion_project_dir, + mllm_check_round=0, + scale=0.9) -> Tuple[int, bool, Optional[str]]: """Static method for multiprocessing""" - composition_id = f'Segment{i + 1}' - output_dir_scene = os.path.join(render_dir, f'scene_{i + 1}') + composition_id = f'Segment{i+1}' + output_dir_scene = os.path.join(render_dir, f'scene_{i+1}') os.makedirs(output_dir_scene, exist_ok=True) - output_path = os.path.abspath(os.path.join(output_dir_scene, f'Scene{i + 1}.mov')) + output_path = os.path.abspath( + os.path.join(output_dir_scene, f'Scene{i+1}.mov')) logger.info(f'Rendering {composition_id} to {output_path}') # Determine remotion command if os.name == 'nt': - remotion_cmd = os.path.abspath(os.path.join(remotion_project_dir, 'node_modules', '.bin', 'remotion.cmd')) + remotion_cmd = os.path.abspath( + os.path.join(remotion_project_dir, 'node_modules', '.bin', + 'remotion.cmd')) else: - remotion_cmd = os.path.abspath(os.path.join(remotion_project_dir, 'node_modules', '.bin', 'remotion')) + remotion_cmd = os.path.abspath( + os.path.join(remotion_project_dir, 'node_modules', '.bin', + 'remotion')) if not os.path.exists(remotion_cmd): remotion_cmd = 'npx remotion' @@ -515,21 +594,24 @@ def _render_remotion_scene_static( '--prores-profile=4444', '--pixel-format=yuva444p10le', '--image-format=png', - '--every-nth-frame=1', # Render every frame for smooth animation + '--every-nth-frame=1' # Render every frame for smooth animation ] # Try to find browser executable (Local > System) browser_executable = None - remotion_cache_dir = os.path.join(remotion_project_dir, 'node_modules', '.remotion') + remotion_cache_dir = os.path.join(remotion_project_dir, 'node_modules', + '.remotion') # 1. Check Local Cache if os.path.exists(remotion_cache_dir): for root, _, files in os.walk(remotion_cache_dir): if 'chrome-headless-shell.exe' in files: - browser_executable = os.path.abspath(os.path.join(root, 'chrome-headless-shell.exe')) + browser_executable = os.path.abspath( + os.path.join(root, 'chrome-headless-shell.exe')) break elif 'chrome-headless-shell' in files: - browser_executable = os.path.abspath(os.path.join(root, 'chrome-headless-shell')) + browser_executable = os.path.abspath( + os.path.join(root, 'chrome-headless-shell')) break # 2. Check System Chrome if not found locally @@ -539,20 +621,22 @@ def _render_remotion_scene_static( if not browser_executable: # shutil is imported at module level browser_executable = ( - shutil.which('chrome') - or shutil.which('google-chrome') + shutil.which('chrome') or shutil.which('google-chrome') or shutil.which('chromium') - or shutil.which('chromium-browser') - ) + or shutil.which('chromium-browser')) if not browser_executable and os.name == 'nt': possible_paths = [ r'C:\Program Files\Google\Chrome\Application\chrome.exe', r'C:\Program Files (x86)\Google\Chrome\Application\chrome.exe', - os.path.expandvars(r'%LOCALAPPDATA%\Google\Chrome\Application\chrome.exe'), + os.path.expandvars( + r'%LOCALAPPDATA%\Google\Chrome\Application\chrome.exe' + ), r'C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe', r'C:\Program Files\Microsoft\Edge\Application\msedge.exe', - os.path.expandvars(r'%ProgramFiles(x86)%\Microsoft\Edge\Application\msedge.exe'), + os.path.expandvars( + r'%ProgramFiles(x86)%\Microsoft\Edge\Application\msedge.exe' + ), ] for p in possible_paths: if os.path.exists(p): @@ -563,9 +647,10 @@ def _render_remotion_scene_static( logger.info(f'Using browser executable: {browser_executable}') base_cmd.extend(['--browser-executable', browser_executable]) # Add stability flags - base_cmd.extend( - ['--chromium-options', 'no-sandbox,disable-setuid-sandbox,disable-gpu,disable-dev-shm-usage'] - ) + base_cmd.extend([ + '--chromium-options', + 'no-sandbox,disable-setuid-sandbox,disable-gpu,disable-dev-shm-usage' + ]) if os.name == 'nt' and 'remotion.cmd' in remotion_cmd: cmd = [remotion_cmd] + base_cmd @@ -581,7 +666,8 @@ def _render_remotion_scene_static( # But subprocess.run with shell=True and a list is tricky. # It's safer to join the command into a string for shell=True on Windows. if os.name == 'nt': - cmd_str = ' '.join([f'"{arg}"' if ' ' in arg else arg for arg in cmd]) + cmd_str = ' '.join( + [f'"{arg}"' if ' ' in arg else arg for arg in cmd]) result = subprocess.run( cmd_str, cwd=remotion_project_dir, @@ -589,8 +675,7 @@ def _render_remotion_scene_static( capture_output=True, text=True, encoding='utf-8', - errors='ignore', - ) + errors='ignore') else: result = subprocess.run( cmd, @@ -599,8 +684,7 @@ def _render_remotion_scene_static( capture_output=True, text=True, encoding='utf-8', - errors='ignore', - ) + errors='ignore') else: result = subprocess.run( cmd, @@ -609,13 +693,15 @@ def _render_remotion_scene_static( capture_output=True, text=True, encoding='utf-8', - errors='ignore', - ) + errors='ignore') if result.returncode != 0: # Capture output was set to True to allow smart error detection. - log_content = (result.stderr or '') + '\n' + (result.stdout or '') - logger.warning(f'Rendering failed for {composition_id}. Log (except): {log_content[:500]}...') + log_content = (result.stderr or '') + '\n' + ( + result.stdout or '') + logger.warning( + f'Rendering failed for {composition_id}. Log (except): {log_content[:500]}...' + ) return i, False, log_content else: logger.info(f'Rendered {composition_id} successfully.') @@ -632,8 +718,8 @@ def _check_edge_clipping(frame_path, threshold=10): Returns True if clipping detected (colored pixels at edges). """ try: - import numpy as np from PIL import Image + import numpy as np img = Image.open(frame_path).convert('RGB') pixels = np.array(img) @@ -645,7 +731,8 @@ def _check_edge_clipping(frame_path, threshold=10): left_edge = pixels[:, 0, :] right_edge = pixels[:, width - 1, :] - edges = np.concatenate([top_edge, bottom_edge, left_edge, right_edge]) + edges = np.concatenate( + [top_edge, bottom_edge, left_edge, right_edge]) # Check if pixels are near black (0,0,0) or white (255,255,255) near_black = np.all(edges < threshold, axis=1) @@ -662,6 +749,7 @@ def _check_edge_clipping(frame_path, threshold=10): @staticmethod def _extract_preview_frames_static(video_path, segment_id, work_dir): + test_dir = os.path.join(work_dir, 'remotion_test') os.makedirs(test_dir, exist_ok=True) video = VideoFileClip(video_path) @@ -671,31 +759,39 @@ def _extract_preview_frames_static(video_path, segment_id, work_dir): preview_paths = [] for frame_idx, timestamp in timestamps.items(): - output_path = os.path.join(test_dir, f'segment_{segment_id + 1}_{frame_idx}.png') + output_path = os.path.join( + test_dir, f'segment_{segment_id + 1}_{frame_idx}.png') video.save_frame(output_path, t=timestamp) preview_paths.append(output_path) video.close() return preview_paths @staticmethod - def _fix_code_static(i, error_log, code, config, remotion_project_dir=None): + def _fix_code_static(i, + error_log, + code, + config, + remotion_project_dir=None): """Static method for multiprocessing fix""" if not code: return i, '' # 3. Use LLM to fix remaining issues. llm = LLM.from_config(config) - logger.info(f'Fixing code for segment {i + 1} with LLM...') - return i, RenderRemotion._fix_code_impl(llm, error_log, code, remotion_project_dir) + logger.info(f'Fixing code for segment {i+1} with LLM...') + return i, RenderRemotion._fix_code_impl(llm, error_log, code, + remotion_project_dir) @staticmethod def _fix_code_impl(llm, error_log, code, remotion_project_dir=None): available_images_info = '' if remotion_project_dir: - images_path = os.path.join(remotion_project_dir, 'public', 'images') + images_path = os.path.join(remotion_project_dir, 'public', + 'images') if os.path.exists(images_path): files = sorted(os.listdir(images_path)) - available_images_info = '\nAvailable images in public/images/:\n' + '\n'.join([f'- {f}' for f in files]) + available_images_info = '\nAvailable images in public/images/:\n' + '\n'.join( + [f'- {f}' for f in files]) if 'VISUAL CHECK FAILED' in error_log: fix_prompt = f""" @@ -762,7 +858,9 @@ def _fix_code_impl(llm, error_log, code, remotion_project_dir=None): response = _response_message.content # Robust code extraction using regex - code_match = re.search(r'```(?:typescript|tsx|js|javascript)?\s*(.*?)```', response, re.DOTALL) + code_match = re.search( + r'```(?:typescript|tsx|js|javascript)?\s*(.*?)```', response, + re.DOTALL) if code_match: code = code_match.group(1) else: diff --git a/projects/singularity_cinema/segment/agent.py b/projects/singularity_cinema/segment/agent.py index a72804a25..35df82260 100644 --- a/projects/singularity_cinema/segment/agent.py +++ b/projects/singularity_cinema/segment/agent.py @@ -1,18 +1,18 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import json import os from copy import deepcopy -from omegaconf import DictConfig - +import json from ms_agent.agent import LLMAgent from ms_agent.llm import Message from ms_agent.utils import get_logger +from omegaconf import DictConfig logger = get_logger() class Segment(LLMAgent): + system = """你是一名动画分镜设计师。现在有一个短视频场景需要进行分镜设计。分镜需要满足以下条件: - 每个分镜包含: @@ -60,7 +60,7 @@ class Segment(LLMAgent): ... ] ``` -""" # noqa +""" # noqa video_prompt = """- 你可以使用文生视频功能来渲染某些镜头,这可以增强短视频的整体趣味性和可读性 * 当使用文生视频渲染某些镜头时,返回的结构应该只包含三个字段:index、content和video。不要包含其他字段如{animation_engine}、background等。换句话说,文生视频镜头不应该包含动画引擎或背景图片 @@ -69,9 +69,13 @@ class Segment(LLMAgent): * **生成具有强动态效果的视频,而不是只有镜头移动的静态场景。你需要在视频中讲好你的故事** * video字段包含你对文生视频生成的要求。注意生成的视频如何与前后镜头协调 * 如果你使用多个文生视频镜头,注意保持角色、建筑、动物等的ID一致性 - * 需要叙述摄像机和镜头信息,集中于讲述故事、推进情节和深化主题""" # noqa + * 需要叙述摄像机和镜头信息,集中于讲述故事、推进情节和深化主题""" # noqa - def __init__(self, config: DictConfig, tag: str, trust_remote_code: bool = False, **kwargs): + def __init__(self, + config: DictConfig, + tag: str, + trust_remote_code: bool = False, + **kwargs): _config = deepcopy(config) _config.tools = DictConfig({}) super().__init__(_config, tag, trust_remote_code, **kwargs) @@ -85,7 +89,8 @@ async def create_messages(self, messages): video_prompt = self.video_prompt if self.config.use_text2video else '' video_prompt = video_prompt.format(animation_engine=self.engine) - system = system.format(video_prompt=video_prompt, animation_engine=self.engine) + system = system.format( + video_prompt=video_prompt, animation_engine=self.engine) return [ Message(role='system', content=system), @@ -105,7 +110,10 @@ async def run(self, messages, **kwargs): if self.config.background != 'image': image_prompt = f'\n\n背景图片无需生成,是纯色:{self.config.background}\n\n' - query = f'原始主题:\n\n{topic}\n\n原始脚本:\n\n{script}\n\n{image_prompt}请完成你的动画分镜设计:\n' + query = (f'原始主题:\n\n{topic}\n\n' + f'原始脚本:\n\n{script}\n\n' + f'{image_prompt}' + f'请完成你的动画分镜设计:\n') messages = await super().run(query, **kwargs) response = messages[-1].content if '```json' in response: @@ -139,10 +147,9 @@ async def run(self, messages, **kwargs): return messages async def add_images(self, segments, topic, script, **kwargs): - video_prompt = ( - '注意:不需要修改包含video字段的镜头。这些镜头是文生视频镜头,它不需要背景、动画或前景图片。' - '只需在返回值中保留并返回这些镜头的index即可。' - ) + + video_prompt = ('注意:不需要修改包含video字段的镜头。这些镜头是文生视频镜头,它不需要背景、动画或前景图片。' + '只需在返回值中保留并返回这些镜头的index即可。') if not self.config.use_text2video: video_prompt = '' @@ -201,13 +208,14 @@ async def add_images(self, segments, topic, script, **kwargs): ] 现在开始: -""" # noqa +""" # noqa # Format the system prompt with the actual engine name animation_engine = self.engine animation_engine_cap = animation_engine.capitalize() system = system.format( - video_prompt=video_prompt, animation_engine=animation_engine, animation_engine_cap=animation_engine_cap - ) + video_prompt=video_prompt, + animation_engine=animation_engine, + animation_engine_cap=animation_engine_cap) new_image_info = '未提供图片。' name_mapping = {} @@ -215,7 +223,9 @@ async def add_images(self, segments, topic, script, **kwargs): with open(os.path.join(self.work_dir, 'image_info.txt'), 'r') as f: image_info = f.readlines() - image_info = [image.strip() for image in image_info if image.strip()] + image_info = [ + image.strip() for image in image_info if image.strip() + ] image_list = [] for i, info in enumerate(image_info): info = json.loads(info) @@ -232,8 +242,7 @@ async def add_images(self, segments, topic, script, **kwargs): f'原始脚本:\n\n{script}\n\n' f'原始分镜:\n\n{json.dumps(segments, ensure_ascii=False, indent=4)}\n\n' f'用户提供的图片:\n\n{new_image_info}\n\n' - f'请完成你的图片设计:\n' - ) + f'请完成你的图片设计:\n') messages = [ Message(role='system', content=system), Message(role='user', content=query), diff --git a/setup.py b/setup.py index a79918b67..1bded7e04 100644 --- a/setup.py +++ b/setup.py @@ -2,10 +2,9 @@ # !/usr/bin/env python import os import shutil -from typing import List - from setuptools import find_packages, setup from setuptools.command.build_py import build_py as _build_py +from typing import List def readme(): @@ -42,7 +41,6 @@ def parse_requirements(fname='requirements.txt', with_version=True): import re import sys from os.path import exists - require_fpath = fname def parse_line(line): @@ -72,7 +70,8 @@ def parse_line(line): if ';' in rest: # Handle platform specific dependencies # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies - version, platform_deps = map(str.strip, rest.split(';')) + version, platform_deps = map(str.strip, + rest.split(';')) info['platform_deps'] = platform_deps else: version = rest # NOQA @@ -86,7 +85,8 @@ def parse_require_file(fpath): if line.startswith('http'): print('skip http requirements %s' % line) continue - if line and not line.startswith('#') and not line.startswith('--'): + if line and not line.startswith('#') and not line.startswith( + '--'): for info in parse_line(line): yield info elif line and line.startswith('--find-links'): @@ -121,6 +121,7 @@ def gen_packages_items(): class build_py(_build_py): + def run(self): super().run() @@ -146,7 +147,8 @@ def _build_and_copy_webui(self): webui_src = os.path.join(repo_root, 'webui') if not os.path.isdir(webui_src): - print('Warning: webui directory not found, skipping webui packaging') + print( + 'Warning: webui directory not found, skipping webui packaging') return frontend_src = os.path.join(webui_src, 'frontend') @@ -154,11 +156,17 @@ def _build_and_copy_webui(self): # Check if npm is available try: - subprocess.run(['npm', '--version'], capture_output=True, check=True, timeout=5) + subprocess.run(['npm', '--version'], + capture_output=True, + check=True, + timeout=5) npm_available = True - except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired): + except (subprocess.CalledProcessError, FileNotFoundError, + subprocess.TimeoutExpired): npm_available = False - print('Warning: npm not found, cannot build frontend. WebUI may not work properly.') + print( + 'Warning: npm not found, cannot build frontend. WebUI may not work properly.' + ) # Build frontend if npm is available if npm_available and os.path.isdir(frontend_src): @@ -169,16 +177,24 @@ def _build_and_copy_webui(self): if not os.path.exists(node_modules): print('Installing frontend dependencies...') try: - subprocess.run(['npm', 'install'], cwd=frontend_src, check=True, timeout=300) - except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: + subprocess.run(['npm', 'install'], + cwd=frontend_src, + check=True, + timeout=300) + except (subprocess.CalledProcessError, + subprocess.TimeoutExpired) as e: print(f'Warning: npm install failed: {e}') return # Build frontend try: - subprocess.run(['npm', 'run', 'build'], cwd=frontend_src, check=True, timeout=300) + subprocess.run(['npm', 'run', 'build'], + cwd=frontend_src, + check=True, + timeout=300) print('Frontend built successfully') - except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: + except (subprocess.CalledProcessError, + subprocess.TimeoutExpired) as e: print(f'Warning: npm build failed: {e}') return @@ -203,17 +219,23 @@ def _build_and_copy_webui(self): shutil.copytree(frontend_dist_src, frontend_dst) print(f'Copied frontend dist to {frontend_dst}') else: - print('Warning: frontend dist not found, WebUI may not work in production mode') + print( + 'Warning: frontend dist not found, WebUI may not work in production mode' + ) if __name__ == '__main__': - print('Usage: `python setup.py sdist bdist_wheel` or `pip install .[framework]` from source code') + print( + 'Usage: `python setup.py sdist bdist_wheel` or `pip install .[framework]` from source code' + ) - install_requires, deps_link = parse_requirements('requirements/framework.txt') + install_requires, deps_link = parse_requirements( + 'requirements/framework.txt') extra_requires = {} all_requires = [] - extra_requires['research'], _ = parse_requirements('requirements/research.txt') + extra_requires['research'], _ = parse_requirements( + 'requirements/research.txt') extra_requires['code'], _ = parse_requirements('requirements/code.txt') extra_requires['webui'], _ = parse_requirements('requirements/webui.txt') all_requires.extend(install_requires) @@ -225,7 +247,8 @@ def _build_and_copy_webui(self): setup( name='ms-agent', version=get_version(), - description='MS-Agent: Lightweight Framework for Empowering Agents with Autonomous Exploration', + description= + 'MS-Agent: Lightweight Framework for Empowering Agents with Autonomous Exploration', long_description=readme(), long_description_content_type='text/markdown', author='The ModelScope teams', @@ -257,7 +280,8 @@ def _build_and_copy_webui(self): license='Apache License 2.0', install_requires=install_requires, extras_require=extra_requires, - entry_points={'console_scripts': ['ms-agent=ms_agent.cli.cli:run_cmd']}, + entry_points={ + 'console_scripts': ['ms-agent=ms_agent.cli.cli:run_cmd'] + }, dependency_links=deps_link, - zip_safe=False, - ) + zip_safe=False) diff --git a/shell-grep-glob-workspace-policy.md b/shell-grep-glob-workspace-policy.md new file mode 100644 index 000000000..ac4e3f912 --- /dev/null +++ b/shell-grep-glob-workspace-policy.md @@ -0,0 +1,225 @@ +# Shell / Grep / Glob 与策略内核架构方案 + +本文档描述在 modelscope-agent 中为 **Shell**、**Grep**、**Glob** 提供统一的安全、权限、沙箱与产物管理的设计,以及与 **`feat/agent-tool-overhaul`** 分支中 **TaskManager**(后台 Agent、预留 Shell)的兼容方式。 + +--- + +## 1. 目标与边界 + +### 目标 + +- 在「同一工作区、同一沙箱视图」下,为 **Shell / Grep / Glob** 提供统一的: + - **安全**(命令与路径约束) + - **权限**(只读 / 写工作区 / 网络等分级) + - **沙箱**(本地子进程 vs Docker enclave 等与现有 `CodeExecutionTool` 对齐) + - **产物管理**(大 stdout/stderr 落盘、预览、配额) +- **默认 `allow_list`(允许根路径)包含 `output_dir`**(及其规范化的绝对路径),可配置追加其它根。 + +### 边界 + +- **不替代** `FileSystemTool` 的精确编辑与读缓存等语义;Shell 面向构建、包管理、复杂管道。 +- **Grep / Glob** 作为**只读发现面**的独立工具,减少对裸 shell 的依赖;复杂 `find -exec` 等仍可由受控 Shell 在更高权限模式下完成(若产品允许)。 + +--- + +## 2. 分层架构 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Tool Facade 层 │ +│ ShellTool │ GrepTool │ GlobTool (独立 JSON Schema) │ +└────────────┬───────────────────────────────┬────────────────┘ + │ │ +┌────────────▼───────────────────────────────▼────────────────┐ +│ WorkspacePolicyKernel(策略内核,纯逻辑、可单测) │ +│ - roots: 默认含 canonical(output_dir),可配置追加 │ +│ - allow_list / deny_list 合并与优先级 │ +│ - resolve_path(rel|abs) → 必须在 allow_roots 下 │ +│ - classify(op): read | search | mutate | exec | network_hint │ +└────────────┬────────────────────────────────────────────────┘ + │ +┌────────────▼────────────────────────────────────────────────┐ +│ SandboxRuntime(执行面,可替换实现) │ +│ - LocalProcessRuntime(asyncio subprocess,cwd=workspace) │ +│ - EnclaveRuntime(现有 ms_enclave / CodeExecutionTool 路径) │ +│ - 会话级 sandbox_id / working_dir 与挂载点一致 │ +└────────────┬────────────────────────────────────────────────┘ + │ +┌────────────▼────────────────────────────────────────────────┐ +│ ArtifactManager(产物管理) │ +│ - 超阈值 stdout/stderr → 落盘 + preview + 相对路径引用 │ +│ - 按 task_id / tool_call_id 分目录 │ +│ - TTL / 总配额(建议:output_dir/.ms_agent_artifacts/) │ +└─────────────────────────────────────────────────────────────┘ +``` + +**原则**:Grep/Glob 的**主路径**不是「拼一条 shell 给模型」;内部可调用 `rg` 或文件系统 walk,但必须经过 **PolicyKernel** 与 **SandboxRuntime**,输出经 **ArtifactManager**。 + +--- + +## 3. WorkspacePolicyKernel(共享策略内核) + +### 3.1 默认 allow_list(允许根集合) + +- 初始化:`allow_roots = { canonical_abs(output_dir) }`。 +- 配置可追加,例如:`tools.code_executor.extra_allow_roots` 或 `tools.workspace_policy.allow`(列表),合并去重。 +- Shell / Grep / Glob 涉及的 **`path`、`cwd`、搜索根目录** 均先执行 `resolve_under_allow_roots()`;失败则**拒绝**并返回结构化错误(不静默改路径到其它目录)。 + +### 3.2 权限与操作分类(建议) + +| 类别 | 示例 | Shell | Grep | Glob | +|------|------|-------|------|------| +| read | 读取工作区内文件 | 受模式 + 策略约束 | ✓ | ✓ | +| search | 内容/文件名发现 | 可引导至 Grep/Glob | ✓ | ✓ | +| mutate | rm、chmod、git 写入等 | 需 `workspace_write` | — | — | +| network | curl、pip 等 | 需显式 **network** 能力位 | — | — | + +Shell 在 **`read_only`** 模式下:仅允许白名单类命令(如 `git status`/`diff`/`log`、只读参数的 `rg` 等),并对重定向、写入工作区外等行为做拒绝或降级(可用前缀表 + 危险模式黑名单,必要时辅以轻量解析)。 + +### 3.3 Shell 安全补充 + +- **固定 cwd**:默认 `workspace_root`(与 `output_dir` 或沙箱内挂载点一致)。 +- **环境变量**:最小集或白名单继承;避免将宿主敏感变量原样传入。 +- **命令预处理**:与现有 `CodeExecutionTool.shell_executor` 思路一致——含 `| && ; > <` 等时使用 `sh -lc` 与安全 quoting;另加**命令长度上限**、**可配置的危险构造限制**(如嵌套命令替换,按产品分级)。 +- **(暂时不做)** 与 `FileSystemTool` 的「写前必读 / staleness」策略对齐:对会修改工作区文件的 Shell 子类共享元数据(若产品需要强一致)。 + +--- + +## 4. SandboxRuntime(共享沙箱) + +- **会话级**:每个 Agent 运行周期内一个 `SandboxSession`(或复用现有 `sandbox_id`)。 +- **Shell / Grep / Glob** 共用同一 **`working_dir` / 挂载视图** 与同一 **`SandboxRuntime` 实现**(本地 `asyncio` 子进程 vs Docker enclave,由 `implementation: sandbox | python_env` 等与现有一致)。 +- **Grep**:在 enclave 内调用 `rg` 或使用宿主 `ripgrep` 库(由部署二选一);**Glob**:在策略解析后的根上做目录遍历或 `pathspec`,避免默认可执行任意 `find -exec`。 + +--- + +## 5. ArtifactManager(产物管理) + +- **阈值**:例如 stdout+stderr 合计超过 N KB 则 spill 至 + `{output_dir}/.ms_agent_artifacts/{tool_name}/{task_or_call_id}.txt`(路径可配置)。 +- **返回**:JSON 中包含 `preview`(首尾若干字符/行)、`artifact_path`(相对 `output_dir`)、`truncated: true`。 +- **与 TaskManager 配合**:后台任务完成时,`TaskManager.complete(task_id, result)` 的 `result` 宜为「短摘要 + artifact 路径」,避免通知与下一轮上下文被撑爆。 + +--- + +## 6. GrepTool / GlobTool(独立工具、共享内核) + +- **输入**:结构化字段(如 pattern、path、glob、head_limit、offset、output_mode),不把「整条 shell」作为唯一 API。 +- **实现**:内部调用 `SandboxRuntime.exec_rg(...)` 或在策略内核限定根上的 glob 遍历;**禁止**由用户可控字符串直接拼接未校验的 shell。 +- **共享**:同一 `WorkspacePolicyKernel` + `SandboxRuntime` + `ArtifactManager`(由 `ToolManager` 或执行类工具在初始化时注入)。 +- **注册**:在 `ToolManager` 中作为独立 `ToolBase`(可一个 server 多个 tool,或两个 server);与 `file_system` 解耦,保持 `file_system` 精简。 + +--- + +## 7. 与 `feat/agent-tool-overhaul` 的 Task 体系兼容 + +### 7.1 分支中的现状(摘要) + +- **`TaskManager`**(`ms_agent/utils/task_manager.py`):进程级后台任务注册表;`BackgroundTask` 中 **`task_type` 注释已包含 `'agent' | 'shell'`**。 +- **`AgentTool`**:`run_in_background` 时 `register(task_type='agent', proc=mp.Process, ...)`,watcher 在子进程结束后调用 `complete` / `fail`;`LLMAgent` 通过 `set_task_manager` 注入同一 `TaskManager`,每轮 `drain_notifications()` 将完成事件注入对话。 + +### 7.2 Shell 后台(与 Agent 对称) + +**建议接口** + +- **同步**:`shell_executor(command, timeout)` → 行为与现网接近,但走 PolicyKernel + ArtifactManager。 +- **后台**:增加 `run_in_background: bool`(或等价命名), **`__call_id`**(与 `AgentTool` 注入一致,便于对账与「推后台」扩展)。 + +**后台行为** + +1. `task_id = task_manager.register(task_type='shell', tool_name='shell_executor', description=command[:200], proc=...)` +2. `proc` 可为 **`asyncio.create_subprocess_*` 返回的 `Process`**(与 Agent 的 `mp.Process` 不同,需在 **`TaskManager.kill` / `kill_all` 中扩展**:对 `asyncio.subprocess.Process` 调用 `kill()` / `terminate()`,并处理已结束进程)。 +3. `asyncio.create_task(watcher)`:等待结束 → `ArtifactManager.maybe_spill` → `await task_manager.complete(task_id, result_str)`(失败则 `fail`)。 + +**立即返回 JSON**(与 Agent 后台对齐,便于统一文档与客户端): + +```json +{ + "status": "async_launched", + "task_id": "", + "tool_name": "shell_executor" +} +``` + +### 7.3 LLMAgent 接线 + +- 与 overhaul 一致:构造 `TaskManager()`,遍历 `extra_tools`,若实现 **`set_task_manager(self.task_manager)`** 则注入。 +- **`LocalCodeExecutionTool` / 未来的 `SecureShellTool`** 实现 `set_task_manager`,与 `AgentTool` 共享同一 `TaskManager` 实例。 + +### 7.4 长同步 Shell → Escape 到后台 + +- 与 `AgentTool._run_sync_escapable` 类似:同步 Shell 带 `sync_timeout_s`,超时或显式信号后取消当前子进程并改为 `register(task_type='shell', ...)` 后台重跑或仅保留已产出部分(产品二选一)。 +- 若存在 **TaskControlTool** 类机制,可复用「`__call_id` + escape 事件」模式,Shell 侧维护 `call_id → Process` 映射以支持 **kill / escape**。 + +### 7.5 兼容对照表 + +| 能力 | overhaul 行为 | 本方案落点 | +|------|----------------|------------| +| 后台 Agent | `register(task_type='agent', proc=Process)` | 不变 | +| 预留 Shell | `task_type` 含 `'shell'` | `shell_executor(run_in_background=true)` 走同一 register / complete / fail | +| 回合内通知 | `drain_notifications()` | Shell 完成同样入队 | +| Kill / 清理 | `kill` / `kill_all` | 扩展支持 asyncio 子进程;watcher `finally` 释放资源 | + +--- + +## 8. 配置示例(OmegaConf / YAML 意向) + +```yaml +tools: + workspace_policy: + allow_roots: [] # 追加;默认已含 output_dir + deny_globs: ['**/.git/**'] + code_executor: + implementation: python_env # or sandbox + shell: + default_mode: workspace_write # read_only | workspace_write + max_output_kb: 256 + wall_time_s: 900 + grep: + default_head_limit: 250 + glob: + max_files: 100 +``` + +--- + +## 9. 实施顺序建议 + +1. 抽出 **`WorkspacePolicyKernel`** + 单元测试(路径解析、默认 `output_dir`、追加 allow)。 +2. 实现 **`ArtifactManager`**,接到现有 `shell_executor` 返回(先本地工具、后接沙箱)。 +3. 将 **`TaskManager`**(overhaul)合入主线并 **扩展 `kill` 支持 `asyncio.subprocess.Process`**。 +4. **`LocalCodeExecutionTool.set_task_manager` + `run_in_background` 的 `shell_executor`**。 +5. 新增 **GrepTool / GlobTool** façade,共享上述内核与运行时。 +6. 更新文档与系统提示:默认 **发现用 Grep/Glob,构建用 Shell,改文件用 file_system**。 + +--- + +## 10. 设计取舍小结 + +- **Shell**:强约束的通用执行面 + 后台,与 **TaskManager** 统一生命周期与通知。 +- **Grep / Glob**:独立 Schema、只读、易截断,与 Shell **共享策略与沙箱**,避免把一切搜索都绑在一条 shell 字符串上。 +- **默认 allow_roots 含 `output_dir`**:与现有 Agent 工作区模型一致,减少越权访问宿主路径的风险。 + +--- + +## 修订记录 + +| 日期 | 说明 | +|------|------| +| 2026-04-13 | 初版:根据设计与 `feat/agent-tool-overhaul` 中 TaskManager / AgentTool 后台模型整理成文。 | +| 2026-04-13 | 实现落地:见下文「实现映射」。 | + +## 11. 实现映射(代码位置) + +| 组件 | 路径 | +|------|------| +| WorkspacePolicyKernel | `ms_agent/utils/workspace_policy.py` | +| ArtifactManager | `ms_agent/utils/artifact_manager.py` | +| TaskManager | `ms_agent/utils/task_manager.py` | +| Shell 策略 / 产物 / 后台 | `ms_agent/tools/code/local_code_executor.py`(`set_task_manager`、`shell_executor`) | +| Grep / Glob | `ms_agent/tools/filesystem_tool.py` 中 `grep` / `glob` 工具(与 `read_file` / `edit_file` / `write_file` 同属 `file_system` server;用 `tools.file_system.include` / `exclude` 控制)。可选键:`grep_timeout_s`、`grep_head_limit`、`glob_max_files`;`include` 短名 `read` / `edit` / `write` 分别等价 `read_file` / `edit_file` / `write_file`。 | +| `__call_id` 注入 shell | `ms_agent/tools/tool_manager.py` | +| TaskManager 与通知 | `ms_agent/agent/llm_agent.py`(`prepare_tools` / `cleanup_tools` / `_append_task_notifications`) | +| 单测 | `tests/utils/test_workspace_policy.py` | + +**未在本阶段实现**:文档 §7.4 长同步 Shell escape 到后台;Docker `CodeExecutionTool` 侧 shell 与策略对齐(仍沿用原沙箱实现)。 diff --git a/webui/backend/agent_runner.py b/webui/backend/agent_runner.py index 37d3ec3b7..1c122ff27 100644 --- a/webui/backend/agent_runner.py +++ b/webui/backend/agent_runner.py @@ -3,7 +3,6 @@ Agent runner for MS-Agent Web UI Manages the execution of ms-agent through subprocess with log streaming. """ - import asyncio import os import re @@ -19,18 +18,16 @@ class AgentRunner: """Runs ms-agent as a subprocess with output streaming""" - def __init__( - self, - session_id: str, - project: Dict[str, Any], - config_manager, - on_output: Callable[[Dict[str, Any]], None] = None, - on_log: Callable[[Dict[str, Any]], None] = None, - on_progress: Callable[[Dict[str, Any]], None] = None, - on_complete: Callable[[Dict[str, Any]], None] = None, - on_error: Callable[[Dict[str, Any]], None] = None, - workflow_type: str = 'standard', - ): + def __init__(self, + session_id: str, + project: Dict[str, Any], + config_manager, + on_output: Callable[[Dict[str, Any]], None] = None, + on_log: Callable[[Dict[str, Any]], None] = None, + on_progress: Callable[[Dict[str, Any]], None] = None, + on_complete: Callable[[Dict[str, Any]], None] = None, + on_error: Callable[[Dict[str, Any]], None] = None, + workflow_type: str = 'standard'): self.session_id = session_id self.project = project self.config_manager = config_manager @@ -56,7 +53,8 @@ def __init__( self._current_tool_args = None # Current tool arguments self._current_tool_result = None # Current tool result self._tool_call_json_buffer = '' # Buffer for collecting multi-line JSON tool call info - self._is_chat_mode = project.get('id') == '__chat__' # Simple chat mode flag + self._is_chat_mode = project.get( + 'id') == '__chat__' # Simple chat mode flag self._chat_response_buffer = '' # Buffer for chat mode responses async def start(self, query: str): @@ -75,13 +73,11 @@ async def start(self, query: str): # Log the command if self.on_log: - self.on_log( - { - 'level': 'info', - 'message': f'Starting agent: {" ".join(cmd[:5])}...', - 'timestamp': datetime.now().isoformat(), - } - ) + self.on_log({ + 'level': 'info', + 'message': f'Starting agent: {" ".join(cmd[:5])}...', + 'timestamp': datetime.now().isoformat() + }) # Start subprocess self.process = await asyncio.create_subprocess_exec( @@ -91,8 +87,7 @@ async def start(self, query: str): stdin=asyncio.subprocess.PIPE, env=env, cwd=self.project['path'], - start_new_session=True, - ) + start_new_session=True) print(f'[Runner] Process started with PID: {self.process.pid}') @@ -102,7 +97,6 @@ async def start(self, query: str): except Exception as e: print(f'[Runner] ERROR: {e}') import traceback - traceback.print_exc() if self.on_error: self.on_error({'message': str(e), 'type': 'startup_error'}) @@ -148,25 +142,35 @@ async def send_input(self, text: str): if not self.process: print('[Runner] ERROR: Process is None, cannot send input') if self.on_error: - self.on_error( - {'message': 'Agent process is not running. Please start a new conversation.', 'type': 'input_error'} - ) + self.on_error({ + 'message': + 'Agent process is not running. Please start a new conversation.', + 'type': 'input_error' + }) return # Check if process has exited if self.process.returncode is not None: - print(f'[Runner] ERROR: Process has exited with code {self.process.returncode}, cannot send input') + print( + f'[Runner] ERROR: Process has exited with code {self.process.returncode}, cannot send input' + ) if self.on_error: - self.on_error( - {'message': 'Agent process has terminated. Please start a new conversation.', 'type': 'input_error'} - ) + self.on_error({ + 'message': + 'Agent process has terminated. Please start a new conversation.', + 'type': 'input_error' + }) return # Check if stdin is available if not self.process.stdin: print('[Runner] ERROR: Process stdin is None, cannot send input') if self.on_error: - self.on_error({'message': 'Cannot send input: process stdin is not available.', 'type': 'input_error'}) + self.on_error({ + 'message': + 'Cannot send input: process stdin is not available.', + 'type': 'input_error' + }) return print(f'[Runner] Sending input to agent: {text[:100]}...') @@ -184,12 +188,11 @@ async def send_input(self, text: str): except (BrokenPipeError, RuntimeError, OSError) as e: print(f'[Runner] ERROR: Failed to send input: {e}') if self.on_error: - self.on_error( - { - 'message': f'Failed to send input: Process may have terminated. Error: {str(e)}', - 'type': 'input_error', - } - ) + self.on_error({ + 'message': + f'Failed to send input: Process may have terminated. Error: {str(e)}', + 'type': 'input_error' + }) # Mark process as not running self.is_running = False self._waiting_for_input = False @@ -205,7 +208,8 @@ def _build_command(self, query: str) -> list: workflow_type = getattr(self, '_workflow_type', 'standard') if workflow_type == 'simple' and project_type == 'workflow': # For code_genesis with simple workflow, use simple_workflow.yaml - simple_config_file = os.path.join(project_path, 'simple_workflow.yaml') + simple_config_file = os.path.join(project_path, + 'simple_workflow.yaml') if os.path.exists(simple_config_file): config_file = simple_config_file @@ -217,7 +221,10 @@ def _build_command(self, query: str) -> list: if project_type == 'workflow' or project_type == 'agent': # Use ms-agent CLI command (installed via entry point) - cmd = ['ms-agent', 'run', '--config', config_file, '--trust_remote_code', 'true'] + cmd = [ + 'ms-agent', 'run', '--config', config_file, + '--trust_remote_code', 'true' + ] if query: cmd.extend(['--query', query]) @@ -227,85 +234,126 @@ def _build_command(self, query: str) -> list: # Add LLM config from user settings llm_config = self.config_manager.get_llm_config() - temperature_enabled = bool(llm_config.get('temperature_enabled', False)) + temperature_enabled = bool( + llm_config.get('temperature_enabled', False)) if llm_config.get('api_key'): provider = llm_config.get('provider', 'modelscope') if provider == 'modelscope': - cmd.extend(['--llm.modelscope_api_key', llm_config['api_key']]) + cmd.extend( + ['--llm.modelscope_api_key', llm_config['api_key']]) # Set llm.service to modelscope to ensure the correct service is used cmd.extend(['--llm.service', 'modelscope']) # Pass base_url if set by user if llm_config.get('base_url'): - cmd.extend(['--llm.modelscope_base_url', llm_config['base_url']]) + cmd.extend([ + '--llm.modelscope_base_url', llm_config['base_url'] + ]) # Pass model if set by user if llm_config.get('model'): cmd.extend(['--llm.model', llm_config['model']]) # Pass temperature if set by user (in generation_config) - if temperature_enabled and llm_config.get('temperature') is not None: - cmd.extend(['--generation_config.temperature', str(llm_config['temperature'])]) + if temperature_enabled and llm_config.get( + 'temperature') is not None: + cmd.extend([ + '--generation_config.temperature', + str(llm_config['temperature']) + ]) # Pass max_tokens if set by user (in generation_config) if llm_config.get('max_tokens'): - cmd.extend(['--generation_config.max_tokens', str(llm_config['max_tokens'])]) + cmd.extend([ + '--generation_config.max_tokens', + str(llm_config['max_tokens']) + ]) elif provider == 'openai': cmd.extend(['--llm.openai_api_key', llm_config['api_key']]) # Set llm.service to openai to ensure the correct service is used cmd.extend(['--llm.service', 'openai']) # Pass base_url if set by user if llm_config.get('base_url'): - cmd.extend(['--llm.openai_base_url', llm_config['base_url']]) + cmd.extend( + ['--llm.openai_base_url', llm_config['base_url']]) # Pass model if set by user if llm_config.get('model'): cmd.extend(['--llm.model', llm_config['model']]) # Pass temperature if set by user (in generation_config) - if temperature_enabled and llm_config.get('temperature') is not None: - cmd.extend(['--generation_config.temperature', str(llm_config['temperature'])]) + if temperature_enabled and llm_config.get( + 'temperature') is not None: + cmd.extend([ + '--generation_config.temperature', + str(llm_config['temperature']) + ]) # Pass max_tokens if set by user (in generation_config) if llm_config.get('max_tokens'): - cmd.extend(['--generation_config.max_tokens', str(llm_config['max_tokens'])]) + cmd.extend([ + '--generation_config.max_tokens', + str(llm_config['max_tokens']) + ]) # Add edit_file_config from user settings (skip for chat mode) if self.project.get('id') != '__chat__': edit_file_config = self.config_manager.get_edit_file_config() if edit_file_config.get('api_key'): # If API key is provided, pass edit_file_config - cmd.extend(['--tools.file_system.edit_file_config.api_key', edit_file_config['api_key']]) + cmd.extend([ + '--tools.file_system.edit_file_config.api_key', + edit_file_config['api_key'] + ]) if edit_file_config.get('base_url'): - cmd.extend(['--tools.file_system.edit_file_config.base_url', edit_file_config['base_url']]) + cmd.extend([ + '--tools.file_system.edit_file_config.base_url', + edit_file_config['base_url'] + ]) if edit_file_config.get('diff_model'): - cmd.extend(['--tools.file_system.edit_file_config.diff_model', edit_file_config['diff_model']]) + cmd.extend([ + '--tools.file_system.edit_file_config.diff_model', + edit_file_config['diff_model'] + ]) else: # If no API key, exclude edit_file from tools # Read the current include list from config file and remove edit_file try: with open(config_file, 'r', encoding='utf-8') as f: config_data = yaml.safe_load(f) - if config_data and 'tools' in config_data and 'file_system' in config_data['tools']: - include_list = config_data['tools']['file_system'].get('include', []) - if isinstance(include_list, list) and 'edit_file' in include_list: + if config_data and 'tools' in config_data and 'file_system' in config_data[ + 'tools']: + include_list = config_data['tools'][ + 'file_system'].get('include', []) + if isinstance( + include_list, + list) and 'edit_file' in include_list: # Remove edit_file from the list - filtered_include = [tool for tool in include_list if tool != 'edit_file'] + filtered_include = [ + tool for tool in include_list + if tool != 'edit_file' + ] # Pass the filtered list as comma-separated string - cmd.extend(['--tools.file_system.include', ','.join(filtered_include)]) + cmd.extend([ + '--tools.file_system.include', + ','.join(filtered_include) + ]) except Exception as e: - print(f'[Runner] Warning: Could not read config file to exclude edit_file: {e}') + print( + f'[Runner] Warning: Could not read config file to exclude edit_file: {e}' + ) # Fallback: explicitly exclude edit_file - cmd.extend(['--tools.file_system.exclude', 'edit_file']) + cmd.extend( + ['--tools.file_system.exclude', 'edit_file']) # Add EdgeOne Pages API token and project name from user settings - edgeone_pages_config = self.config_manager.get_edgeone_pages_config() + edgeone_pages_config = self.config_manager.get_edgeone_pages_config( + ) if edgeone_pages_config.get('api_token'): # If API token is provided, pass it to the MCP server config - cmd.extend( - ['--tools.edgeone-pages-mcp.env.EDGEONE_PAGES_API_TOKEN', edgeone_pages_config['api_token']] - ) + cmd.extend([ + '--tools.edgeone-pages-mcp.env.EDGEONE_PAGES_API_TOKEN', + edgeone_pages_config['api_token'] + ]) if edgeone_pages_config.get('project_name'): # If project name is provided, pass it to the MCP server config - cmd.extend( - [ - '--tools.edgeone-pages-mcp.env.EDGEONE_PAGES_PROJECT_NAME', - edgeone_pages_config['project_name'], - ] - ) + cmd.extend([ + '--tools.edgeone-pages-mcp.env.EDGEONE_PAGES_PROJECT_NAME', + edgeone_pages_config['project_name'] + ]) elif project_type == 'script': # Run the script directly @@ -338,7 +386,9 @@ async def _read_output(self): # Check if process has exited if self.process.returncode is not None and not process_exited: process_exited = True - print(f'[Runner] Process exited with code: {self.process.returncode}') + print( + f'[Runner] Process exited with code: {self.process.returncode}' + ) # Continue reading remaining output even after process exits # This ensures we don't miss any URLs or important messages if not self.process.stdout: @@ -355,7 +405,8 @@ async def _read_output(self): try: # Use shorter timeout after process exits to read remaining data faster timeout = 0.1 if process_exited else 1.0 - line = await asyncio.wait_for(self.process.stdout.readline(), timeout=timeout) + line = await asyncio.wait_for( + self.process.stdout.readline(), timeout=timeout) except asyncio.TimeoutError: # Timeout - check if we're waiting for input if self._waiting_for_input: @@ -366,14 +417,14 @@ async def _read_output(self): self._flush_chat_response() # Send waiting_input message to enable frontend input if self.on_output and not self._waiting_input_sent: - self.on_output( - { - 'type': 'waiting_input', - 'content': '', - 'role': 'system', - 'metadata': {'waiting': True}, + self.on_output({ + 'type': 'waiting_input', + 'content': '', + 'role': 'system', + 'metadata': { + 'waiting': True } - ) + }) self._waiting_input_sent = True # Process is still alive, continue waiting continue @@ -412,21 +463,26 @@ async def _read_output(self): self._flush_chat_response() # Send waiting_input message to enable frontend input if self.on_output and not self._waiting_input_sent: - self.on_output( - { - 'type': 'waiting_input', - 'content': '', - 'role': 'system', - 'metadata': {'waiting': True}, + self.on_output({ + 'type': 'waiting_input', + 'content': '', + 'role': 'system', + 'metadata': { + 'waiting': True } - ) + }) self._waiting_input_sent = True - print('[Runner] Agent is waiting for user input, keeping process alive...') + print( + '[Runner] Agent is waiting for user input, keeping process alive...' + ) # Keep process alive and wait for input - await asyncio.sleep(0.5) # Small delay to avoid busy waiting + await asyncio.sleep( + 0.5) # Small delay to avoid busy waiting continue else: - print('[Runner] Process exited while waiting for input') + print( + '[Runner] Process exited while waiting for input' + ) # Process exited, but continue reading any remaining output # Don't break yet - there might be more data in stdout buffer process_exited = True @@ -437,13 +493,13 @@ async def _read_output(self): # Reset empty line count when we get actual data empty_line_count = 0 text = line.decode('utf-8', errors='replace').rstrip() - print(f'[Runner] Output: {text[:200]}' if len(text) > 200 else f'[Runner] Output: {text}') + print(f'[Runner] Output: {text[:200]}' + if len(text) > 200 else f'[Runner] Output: {text}') try: await self._process_line(text) except Exception as e: print(f'[Runner] ERROR processing line: {e}') import traceback - traceback.print_exc() # Wait for process to complete and handle completion @@ -460,46 +516,50 @@ async def _read_output(self): self._flush_chat_response() # Flush any accumulated assistant output before handling completion - if self._collecting_assistant_output and self._accumulated_output.strip(): - cleaned = re.sub(r'\[INFO:ms_agent\]\s*', '', self._accumulated_output.strip()) + if self._collecting_assistant_output and self._accumulated_output.strip( + ): + cleaned = re.sub(r'\[INFO:ms_agent\]\s*', '', + self._accumulated_output.strip()) cleaned = re.sub(r'\[([^\]]+)\]\s*', '', cleaned, count=1) - print(f'[Runner] Flushing accumulated output on process exit: {cleaned[:200]}...') + print( + f'[Runner] Flushing accumulated output on process exit: {cleaned[:200]}...' + ) if cleaned and self.on_output: - self.on_output( - { - 'type': 'agent_output', - 'content': cleaned, - 'role': 'assistant', - 'metadata': {'agent': self._current_step or 'agent'}, + self.on_output({ + 'type': 'agent_output', + 'content': cleaned, + 'role': 'assistant', + 'metadata': { + 'agent': self._current_step or 'agent' } - ) + }) self._accumulated_output = '' self._collecting_assistant_output = False # If stop was requested, do not report as completion/error if self._stop_requested: if self.on_log: - self.on_log( - { - 'level': 'info', - 'message': 'Agent stopped by user', - 'timestamp': datetime.now().isoformat(), - } - ) + self.on_log({ + 'level': 'info', + 'message': 'Agent stopped by user', + 'timestamp': datetime.now().isoformat() + }) return # Complete current step if any before handling exit if self._current_step and self.on_output: - self.on_output( - { - 'type': 'step_complete', - 'content': self._current_step, - 'role': 'assistant', - 'metadata': {'step': self._current_step, 'status': 'completed'}, + self.on_output({ + 'type': 'step_complete', + 'content': self._current_step, + 'role': 'assistant', + 'metadata': { + 'step': self._current_step, + 'status': 'completed' } - ) + }) # If Refine step completes successfully, it should be waiting for input - if return_code == 0 and self._current_step.lower() == 'refine': + if return_code == 0 and self._current_step.lower( + ) == 'refine': self._waiting_for_input = True self._current_step = None @@ -510,48 +570,57 @@ async def _read_output(self): if return_code == 0: # Send waiting_input message if not already sent if self.on_output and not self._waiting_input_sent: - self.on_output( - { - 'type': 'waiting_input', - 'content': ( - '✅ Initial refinement completed. ' - 'You can now provide additional feedback or modifications.' - ), - 'role': 'system', - 'metadata': {'waiting': True}, + self.on_output({ + 'type': + 'waiting_input', + 'content': + ('✅ Initial refinement completed. ' + 'You can now provide additional feedback or modifications.' + ), + 'role': + 'system', + 'metadata': { + 'waiting': True } - ) + }) self._waiting_input_sent = True if self.on_complete: - self.on_complete({'status': 'success', 'message': 'Agent completed successfully'}) + self.on_complete({ + 'status': + 'success', + 'message': + 'Agent completed successfully' + }) else: if self.on_error: - self.on_error( - { - 'message': ( - f'Agent process terminated while waiting for input. Exit code: {return_code}' - ), - 'type': 'process_exit_error', - 'code': return_code, - } - ) + self.on_error({ + 'message': + ('Agent process terminated while waiting for input. ' + f'Exit code: {return_code}'), + 'type': + 'process_exit_error', + 'code': + return_code + }) elif return_code == 0: if self.on_complete: - self.on_complete({'status': 'success', 'message': 'Agent completed successfully'}) + self.on_complete({ + 'status': + 'success', + 'message': + 'Agent completed successfully' + }) else: if self.on_error: - self.on_error( - { - 'message': f'Agent exited with code {return_code}', - 'type': 'exit_error', - 'code': return_code, - } - ) + self.on_error({ + 'message': f'Agent exited with code {return_code}', + 'type': 'exit_error', + 'code': return_code + }) except Exception as e: print(f'[Runner] Read error: {e}') import traceback - traceback.print_exc() if not self._stop_requested and self.on_error: self.on_error({'message': str(e), 'type': 'read_error'}) @@ -590,7 +659,8 @@ async def _process_chat_line(self, line: str): else: self._tool_call_json_buffer = cleaned # Check if we have a complete JSON object - if cleaned == '}' and self._tool_call_json_buffer.strip().startswith('{'): + if cleaned == '}' and self._tool_call_json_buffer.strip( + ).startswith('{'): self._flush_tool_call() return @@ -598,9 +668,14 @@ async def _process_chat_line(self, line: str): if 'execute tool call' in line: if self.on_output: is_error = 'error' in line.lower() - self.on_output( - {'type': 'tool_result', 'content': cleaned, 'role': 'assistant', 'metadata': {'is_error': is_error}} - ) + self.on_output({ + 'type': 'tool_result', + 'content': cleaned, + 'role': 'assistant', + 'metadata': { + 'is_error': is_error + } + }) return # Detect [assistant]: marker - start collecting @@ -629,25 +704,23 @@ async def _process_chat_line(self, line: str): def _flush_tool_call(self): """Send tool call information to frontend""" - if self._is_chat_mode and self._tool_call_json_buffer.strip() and self.on_output: + if self._is_chat_mode and self._tool_call_json_buffer.strip( + ) and self.on_output: try: import json - tool_data = json.loads(self._tool_call_json_buffer) tool_name = tool_data.get('tool_name', 'unknown') print(f'[Runner] Tool call: {tool_name}') - self.on_output( - { - 'type': 'tool_call', - 'content': '', - 'role': 'assistant', - 'metadata': { - 'tool_name': tool_name, - 'arguments': tool_data.get('arguments', {}), - 'id': tool_data.get('id', ''), - }, + self.on_output({ + 'type': 'tool_call', + 'content': '', + 'role': 'assistant', + 'metadata': { + 'tool_name': tool_name, + 'arguments': tool_data.get('arguments', {}), + 'id': tool_data.get('id', '') } - ) + }) except json.JSONDecodeError: print('[Runner] Failed to parse tool call JSON') self._tool_call_json_buffer = '' @@ -655,11 +728,17 @@ def _flush_tool_call(self): def _flush_chat_response(self): """Send final chat response with done=True""" - if self._is_chat_mode and self._chat_response_buffer.strip() and self.on_output: - print(f'[Runner] Chat complete: {len(self._chat_response_buffer)} chars') - self.on_output( - {'type': 'stream', 'content': self._chat_response_buffer.strip(), 'role': 'assistant', 'done': True} + if self._is_chat_mode and self._chat_response_buffer.strip( + ) and self.on_output: + print( + f'[Runner] Chat complete: {len(self._chat_response_buffer)} chars' ) + self.on_output({ + 'type': 'stream', + 'content': self._chat_response_buffer.strip(), + 'role': 'assistant', + 'done': True + }) self._chat_response_buffer = '' # Don't reset _collecting_assistant_output here - more content may come # It will be reset when we see [tool_calling]: or [user]: or process exits @@ -680,7 +759,6 @@ async def _process_line(self, line: str): if '[INFO:ms_agent]' in line: # Check if there's an agent name tag [xxx] after [INFO:ms_agent] import re - if not re.search(r'\[INFO:ms_agent\]\s*\[([^\]]+)\]', line): return @@ -688,13 +766,14 @@ async def _process_line(self, line: str): if self.on_log: log_level = self._detect_log_level(line) cleaned_message = self._clean_log_prefix(line) - await self.on_log( - { - 'level': log_level, - 'message': cleaned_message if cleaned_message else line, - 'timestamp': datetime.now().isoformat(), - } - ) + await self.on_log({ + 'level': + log_level, + 'message': + cleaned_message if cleaned_message else line, + 'timestamp': + datetime.now().isoformat() + }) # Parse for special patterns (use original line for pattern matching) await self._detect_patterns(line) @@ -743,23 +822,25 @@ def _scan_and_send_output_files(self, programmer_step=None): file_path = line.split(':')[0].strip() generated_files.append(file_path) - print(f'[Runner] Found {len(generated_files)} files in tasks.txt: {generated_files}') + print( + f'[Runner] Found {len(generated_files)} files in tasks.txt: {generated_files}' + ) # Send all files in one batch if generated_files and self.on_output: - self.on_output( - { - 'type': 'file_output', - 'content': generated_files, # Send as array - 'role': 'assistant', - 'metadata': {'files': generated_files, 'source': 'tasks.txt'}, + self.on_output({ + 'type': 'file_output', + 'content': generated_files, # Send as array + 'role': 'assistant', + 'metadata': { + 'files': generated_files, + 'source': 'tasks.txt' } - ) + }) except Exception as e: print(f'[Runner] Error reading tasks.txt: {e}') import traceback - traceback.print_exc() async def _detect_patterns(self, line: str): @@ -771,34 +852,40 @@ async def _detect_patterns(self, line: str): url_match = re.search(r'"url":\s*"(https?://[^"]+)"', line) # Pattern 2: Direct URL like "https://mcp.edgeone.site/share/..." if not url_match: - url_match = re.search(r'(https?://mcp\.edgeone\.site/[^\s]+)', line) + url_match = re.search(r'(https?://mcp\.edgeone\.site/[^\s]+)', + line) # Pattern 3: EdgeOne Pages URL like "https://...edgeone.cool?..." # BUT skip if this is a curl command line (testing command, not actual deployment URL) if not url_match and 'curl -s' not in line and 'curl ' not in line: - url_match = re.search(r'(https?://[^\s]*edgeone\.cool[^\s]*)', line) + url_match = re.search(r'(https?://[^\s]*edgeone\.cool[^\s]*)', + line) # Pattern 4: Also check for edgeone.site URLs in any format (fallback) # BUT skip if this is a curl command line if not url_match and 'curl -s' not in line and 'curl ' not in line: - url_match = re.search(r'(https?://[^\s]*edgeone\.site[^\s]*)', line) + url_match = re.search(r'(https?://[^\s]*edgeone\.site[^\s]*)', + line) if url_match: deployment_url = url_match.group(1) # Clean up escaped characters in URL (e.g., \& -> &) deployment_url = deployment_url.replace('\\&', '&') - print(f'[Runner] Detected deployment URL (early): {deployment_url} from line: {line[:100]}') + print( + f'[Runner] Detected deployment URL (early): {deployment_url} from line: {line[:100]}' + ) if self.on_output: - self.on_output( - { - 'type': 'deployment_url', - 'content': deployment_url, - 'role': 'assistant', - 'metadata': {'url': deployment_url}, + self.on_output({ + 'type': 'deployment_url', + 'content': deployment_url, + 'role': 'assistant', + 'metadata': { + 'url': deployment_url } - ) + }) # Continue processing - don't return yet, other patterns might also match # Detect OpenAI API errors and other API errors # Check for OpenAI error patterns - if 'openai.' in line.lower() and ('error' in line.lower() or 'Error' in line): + if 'openai.' in line.lower() and ('error' in line.lower() + or 'Error' in line): error_message = line.strip() # Try to extract error details from the line # Pattern: openai.NotFoundError: Error code: 404 - {'error': {'message': '...', ...}} @@ -806,11 +893,12 @@ async def _detect_patterns(self, line: str): if json_match: try: import json - error_data = json.loads(json_match.group(0)) - if 'error' in error_data and 'message' in error_data['error']: + if 'error' in error_data and 'message' in error_data[ + 'error']: error_msg = error_data['error']['message'] - error_type = error_data['error'].get('type', 'API Error') + error_type = error_data['error'].get( + 'type', 'API Error') error_message = f'**{error_type}**: {error_msg}' except Exception: pass @@ -820,14 +908,14 @@ async def _detect_patterns(self, line: str): self.on_error({'message': error_message, 'type': 'api_error'}) # Also send as output message so it appears in the conversation if self.on_output: - self.on_output( - { - 'type': 'error', - 'content': error_message, - 'role': 'system', - 'metadata': {'error_type': 'api_error'}, + self.on_output({ + 'type': 'error', + 'content': error_message, + 'role': 'system', + 'metadata': { + 'error_type': 'api_error' } - ) + }) return # Detect other error patterns @@ -844,72 +932,78 @@ async def _detect_patterns(self, line: str): if json_match: try: import json - error_data = json.loads(json_match.group(0)) - if 'error' in error_data and 'message' in error_data['error']: + if 'error' in error_data and 'message' in error_data[ + 'error']: error_msg = error_data['error']['message'] - error_type = error_data['error'].get('type', 'API Error') + error_type = error_data['error'].get( + 'type', 'API Error') error_message = f'**{error_type}**: {error_msg}' except Exception: pass print(f'[Runner] Detected API error: {error_message}') if self.on_error: - self.on_error( - { - 'message': error_message, - 'type': 'api_error', - 'code': error_match.group(1) if error_match.groups() else None, - } - ) + self.on_error({ + 'message': + error_message, + 'type': + 'api_error', + 'code': + error_match.group(1) if error_match.groups() else None + }) # Also send as output message so it appears in the conversation if self.on_output: - self.on_output( - { - 'type': 'error', - 'content': error_message, - 'role': 'system', - 'metadata': {'error_type': 'api_error'}, + self.on_output({ + 'type': 'error', + 'content': error_message, + 'role': 'system', + 'metadata': { + 'error_type': 'api_error' } - ) + }) return # Detect workflow step beginning: "[tag] Agent tag task beginning." - begin_match = re.search(r'\[([^\]]+)\]\s*Agent\s+\S+\s+task\s+beginning', line) + begin_match = re.search( + r'\[([^\]]+)\]\s*Agent\s+\S+\s+task\s+beginning', line) if begin_match: step_name = begin_match.group(1) # Skip sub-steps and programmer agents (handled separately) - if ('-r' in step_name and '-' in step_name.split('-r')[-1]) or step_name.startswith('programmer-'): + if (('-r' in step_name and '-' in step_name.split('-r')[-1]) + or step_name.startswith('programmer-')): return print(f'[Runner] Step beginning: {step_name}') # Flush previous step if exists if self._current_step and self._accumulated_output.strip(): - cleaned = re.sub(r'\[INFO:ms_agent\]\s*', '', self._accumulated_output.strip()) + cleaned = re.sub(r'\[INFO:ms_agent\]\s*', '', + self._accumulated_output.strip()) cleaned = re.sub(r'\[([^\]]+)\]\s*', '', cleaned, count=1) if cleaned and self.on_output: - self.on_output( - { - 'type': 'agent_output', - 'content': cleaned, - 'role': 'assistant', - 'metadata': {'agent': self._current_step}, + self.on_output({ + 'type': 'agent_output', + 'content': cleaned, + 'role': 'assistant', + 'metadata': { + 'agent': self._current_step } - ) + }) self._accumulated_output = '' self._collecting_assistant_output = False if self._current_step and self.on_output: - self.on_output( - { - 'type': 'step_complete', - 'content': self._current_step, - 'role': 'assistant', - 'metadata': {'step': self._current_step, 'status': 'completed'}, + self.on_output({ + 'type': 'step_complete', + 'content': self._current_step, + 'role': 'assistant', + 'metadata': { + 'step': self._current_step, + 'status': 'completed' } - ) + }) # Start new step self._current_step = step_name @@ -917,35 +1011,29 @@ async def _detect_patterns(self, line: str): self._workflow_steps.append(step_name) step_status = { - s: ( - 'completed' - if i < self._workflow_steps.index(step_name) - else 'running' - if s == step_name - else 'pending' - ) + s: ('completed' if i < self._workflow_steps.index(step_name) + else 'running' if s == step_name else 'pending') for i, s in enumerate(self._workflow_steps) } if self.on_progress: - self.on_progress( - { - 'type': 'workflow', - 'current_step': step_name, - 'steps': self._workflow_steps.copy(), - 'step_status': step_status, - } - ) + self.on_progress({ + 'type': 'workflow', + 'current_step': step_name, + 'steps': self._workflow_steps.copy(), + 'step_status': step_status + }) if self.on_output: - self.on_output( - { - 'type': 'step_start', - 'content': step_name, - 'role': 'assistant', - 'metadata': {'step': step_name, 'status': 'running'}, + self.on_output({ + 'type': 'step_start', + 'content': step_name, + 'role': 'assistant', + 'metadata': { + 'step': step_name, + 'status': 'running' } - ) + }) # If Refine step is starting, scan tasks.txt for all generated files # This ensures files are detected after Coding phase completes @@ -960,35 +1048,40 @@ async def _detect_patterns(self, line: str): programmer_agent = f'programmer-{programmer_match.group(1)}' # If this is FIRST programmer agent, trigger coding step start - if not self._current_step or not self._current_step.startswith('programmer-'): - print(f'[Runner] First programmer agent detected: {programmer_agent} - starting coding step') + if not self._current_step or not self._current_step.startswith( + 'programmer-'): + print( + f'[Runner] First programmer agent detected: {programmer_agent} - starting coding step' + ) # Flush previous step's output if self._current_step and self._accumulated_output.strip(): - cleaned = re.sub(r'\[INFO:ms_agent\]\s*', '', self._accumulated_output.strip()) + cleaned = re.sub(r'\[INFO:ms_agent\]\s*', '', + self._accumulated_output.strip()) cleaned = re.sub(r'\[([^\]]+)\]\s*', '', cleaned, count=1) if cleaned and self.on_output: - self.on_output( - { - 'type': 'agent_output', - 'content': cleaned, - 'role': 'assistant', - 'metadata': {'agent': self._current_step}, + self.on_output({ + 'type': 'agent_output', + 'content': cleaned, + 'role': 'assistant', + 'metadata': { + 'agent': self._current_step } - ) + }) self._accumulated_output = '' self._collecting_assistant_output = False # Mark previous step complete if self._current_step and self.on_output: - self.on_output( - { - 'type': 'step_complete', - 'content': self._current_step, - 'role': 'assistant', - 'metadata': {'step': self._current_step, 'status': 'completed'}, + self.on_output({ + 'type': 'step_complete', + 'content': self._current_step, + 'role': 'assistant', + 'metadata': { + 'step': self._current_step, + 'status': 'completed' } - ) + }) # Start coding step self._current_step = programmer_agent @@ -996,35 +1089,29 @@ async def _detect_patterns(self, line: str): self._workflow_steps.append('coding') step_status = { - s: ( - 'completed' - if i < self._workflow_steps.index('coding') - else 'running' - if s == 'coding' - else 'pending' - ) + s: ('completed' if i < self._workflow_steps.index('coding') + else 'running' if s == 'coding' else 'pending') for i, s in enumerate(self._workflow_steps) } if self.on_progress: - self.on_progress( - { - 'type': 'workflow', - 'current_step': 'coding', - 'steps': self._workflow_steps.copy(), - 'step_status': step_status, - } - ) + self.on_progress({ + 'type': 'workflow', + 'current_step': 'coding', + 'steps': self._workflow_steps.copy(), + 'step_status': step_status + }) if self.on_output: - self.on_output( - { - 'type': 'step_start', - 'content': 'coding', - 'role': 'assistant', - 'metadata': {'step': 'coding', 'status': 'running'}, + self.on_output({ + 'type': 'step_start', + 'content': 'coding', + 'role': 'assistant', + 'metadata': { + 'step': 'coding', + 'status': 'running' } - ) + }) # Update current programmer agent elif programmer_agent != self._current_step: @@ -1032,21 +1119,23 @@ async def _detect_patterns(self, line: str): # Helper to flush accumulated assistant output def flush_accumulated_output(): - print( - f'[Runner] flush_accumulated_output called: ' - f'collecting={self._collecting_assistant_output}, ' - f'buffer_len={len(self._accumulated_output)}' - ) + print(f'[Runner] flush_accumulated_output called: ' + f'collecting={self._collecting_assistant_output}, ' + f'buffer_len={len(self._accumulated_output)}') print( f'[Runner] Buffer content: {self._accumulated_output[:200]}...' - if len(self._accumulated_output) > 200 - else f'[Runner] Buffer content: {self._accumulated_output}' - ) - if self._collecting_assistant_output and self._accumulated_output.strip(): + if len(self._accumulated_output) > 200 else + f'[Runner] Buffer content: {self._accumulated_output}') + if self._collecting_assistant_output and self._accumulated_output.strip( + ): # Clean log prefixes - cleaned_content = re.sub(r'\[INFO:ms_agent\]\s*', '', self._accumulated_output.strip()) - cleaned_content = re.sub(r'\[([^\]]+)\]\s*', '', cleaned_content, count=1) - print(f'[Runner] Flushing assistant output: {cleaned_content[:100]}...') + cleaned_content = re.sub(r'\[INFO:ms_agent\]\s*', '', + self._accumulated_output.strip()) + cleaned_content = re.sub( + r'\[([^\]]+)\]\s*', '', cleaned_content, count=1) + print( + f'[Runner] Flushing assistant output: {cleaned_content[:100]}...' + ) # Map agent name for display agent_name = self._current_step or 'agent' @@ -1055,30 +1144,30 @@ def flush_accumulated_output(): display_agent = 'coding' if cleaned_content and self.on_output: - self.on_output( - { - 'type': 'agent_output', - 'content': cleaned_content, - 'role': 'assistant', - 'metadata': {'agent': display_agent}, + self.on_output({ + 'type': 'agent_output', + 'content': cleaned_content, + 'role': 'assistant', + 'metadata': { + 'agent': display_agent } - ) + }) self._accumulated_output = '' self._collecting_assistant_output = False else: - print( - f'[Runner] flush_accumulated_output skipped: ' - f'collecting={self._collecting_assistant_output}, ' - f'has_content={bool(self._accumulated_output.strip())}' - ) + print(f'[Runner] flush_accumulated_output skipped: ' + f'collecting={self._collecting_assistant_output}, ' + f'has_content={bool(self._accumulated_output.strip())}') # Detect workflow step finished: "[tag] Agent tag task finished." - end_match = re.search(r'\[([^\]]+)\]\s*Agent\s+\S+\s+task\s+finished', line) + end_match = re.search(r'\[([^\]]+)\]\s*Agent\s+\S+\s+task\s+finished', + line) if end_match: step_name = end_match.group(1) # Skip install (handled by programmer detection) and sub-steps - if step_name == 'install' or ('-r' in step_name and '-' in step_name.split('-r')[-1]): + if step_name == 'install' or ('-r' in step_name and '-' + in step_name.split('-r')[-1]): return # Skip flush for refine (already flushed during collection) @@ -1111,30 +1200,28 @@ def flush_accumulated_output(): # Build step status dict - all steps up to current are completed step_status = {} for s in self._workflow_steps: - step_status[s] = ( - 'completed' if self._workflow_steps.index(s) <= self._workflow_steps.index(step_name) else 'pending' - ) + step_status[s] = 'completed' if self._workflow_steps.index( + s) <= self._workflow_steps.index(step_name) else 'pending' if self.on_progress: - self.on_progress( - { - 'type': 'workflow', - 'current_step': step_name, - 'steps': self._workflow_steps.copy(), - 'step_status': step_status, - } - ) + self.on_progress({ + 'type': 'workflow', + 'current_step': step_name, + 'steps': self._workflow_steps.copy(), + 'step_status': step_status + }) # Send step complete message if self.on_output: - self.on_output( - { - 'type': 'step_complete', - 'content': step_name, - 'role': 'assistant', - 'metadata': {'step': step_name, 'status': 'completed'}, + self.on_output({ + 'type': 'step_complete', + 'content': step_name, + 'role': 'assistant', + 'metadata': { + 'step': step_name, + 'status': 'completed' } - ) + }) # Clear current step since it's completed self._current_step = None @@ -1143,7 +1230,8 @@ def flush_accumulated_output(): # Clean log prefixes from line # Detect assistant output: "[tag] [assistant]:" if '[assistant]:' in line: - in_coding = self._current_step and self._current_step.startswith('programmer-') + in_coding = self._current_step and self._current_step.startswith( + 'programmer-') if not in_coding: # Start collecting (don't send first line immediately) @@ -1164,24 +1252,27 @@ def flush_accumulated_output(): # Continue collecting assistant output elif self._collecting_assistant_output: # Skip if in coding phase - if self._current_step and self._current_step.startswith('programmer-'): + if self._current_step and self._current_step.startswith( + 'programmer-'): self._collecting_assistant_output = False self._accumulated_output = '' # Don't return - continue processing else: # Check if new pattern starts - if '[tool_calling]:' in line or ('[assistant]:' in line and 'Agent' not in line): + if '[tool_calling]:' in line or ('[assistant]:' in line + and 'Agent' not in line): if self._accumulated_output.strip(): - cleaned = self._clean_log_prefix(self._accumulated_output.strip()) + cleaned = self._clean_log_prefix( + self._accumulated_output.strip()) if cleaned and self.on_output: - self.on_output( - { - 'type': 'agent_output', - 'content': cleaned, - 'role': 'assistant', - 'metadata': {'agent': self._current_step or 'agent'}, + self.on_output({ + 'type': 'agent_output', + 'content': cleaned, + 'role': 'assistant', + 'metadata': { + 'agent': self._current_step or 'agent' } - ) + }) self._accumulated_output = '' self._collecting_assistant_output = False else: @@ -1191,37 +1282,46 @@ def flush_accumulated_output(): if cleaned_line: self._accumulated_output += cleaned_line + '\n' # Check for EdgeOne deployment URL in this line - url_match = re.search(r'(https?://[^\s]*edgeone\.cool[^\s]*)', cleaned_line) + url_match = re.search( + r'(https?://[^\s]*edgeone\.cool[^\s]*)', + cleaned_line) if url_match: deployment_url = url_match.group(1) # Clean up escaped characters in URL (e.g., \& -> &) - deployment_url = deployment_url.replace('\\&', '&') - print(f'[Runner] Detected deployment URL in assistant: {deployment_url}') + deployment_url = deployment_url.replace( + '\\&', '&') + print( + f'[Runner] Detected deployment URL in assistant: {deployment_url}' + ) if self.on_output: - self.on_output( - { - 'type': 'deployment_url', - 'content': deployment_url, - 'role': 'assistant', - 'metadata': {'url': deployment_url}, + self.on_output({ + 'type': 'deployment_url', + 'content': deployment_url, + 'role': 'assistant', + 'metadata': { + 'url': deployment_url } - ) + }) # Check for waiting for input pattern - if 'Waiting for user feedback' in line or 'Waiting for user input from stdin' in line: + if ('Waiting for user feedback' in line + or 'Waiting for user input from stdin' + in line): print('[Runner] Agent waiting for user input') self._waiting_for_input = True if self.on_output and not self._waiting_input_sent: - self.on_output( - { - 'type': 'waiting_input', - 'content': ( - '✅ Initial refinement completed. ' - 'You can now provide additional feedback or modifications.' - ), - 'role': 'system', - 'metadata': {'waiting': True}, + self.on_output({ + 'type': + 'waiting_input', + 'content': + ('✅ Initial refinement completed. ' + 'You can now provide additional feedback or modifications.' + ), + 'role': + 'system', + 'metadata': { + 'waiting': True } - ) + }) self._waiting_input_sent = True return @@ -1240,7 +1340,8 @@ def flush_accumulated_output(): self._tool_call_json_buffer = json_part elif json_part: # Try to extract tool name directly if it's not JSON format - tool_match = re.search(r'([\w\-]+(?:---[\w\-]+)?)', json_part) + tool_match = re.search(r'([\w\-]+(?:---[\w\-]+)?)', + json_part) if tool_match: self._current_tool_name = tool_match.group(1) return @@ -1250,7 +1351,8 @@ def flush_accumulated_output(): # Extract agent name from line if available (for better matching) agent_name_from_line = None if '[INFO:ms_agent]' in line: - agent_match = re.search(r'\[INFO:ms_agent\]\s*\[([^\]]+)\]', line) + agent_match = re.search(r'\[INFO:ms_agent\]\s*\[([^\]]+)\]', + line) if agent_match: agent_name_from_line = agent_match.group(1) @@ -1278,43 +1380,49 @@ def flush_accumulated_output(): self._tool_call_json_buffer += cleaned_line # Only try to parse when buffer contains tool_name and ends with } - if ( - self._tool_call_json_buffer - and '"tool_name"' in self._tool_call_json_buffer - and self._tool_call_json_buffer.strip().endswith('}') - ): + if (self._tool_call_json_buffer + and '"tool_name"' in self._tool_call_json_buffer + and self._tool_call_json_buffer.strip().endswith('}')): try: import json - tool_info = json.loads(self._tool_call_json_buffer) print('[Runner] Parsed tool JSON successfully') - tool_name = tool_info.get('tool_name') or tool_info.get('name', 'unknown') + tool_name = tool_info.get('tool_name') or tool_info.get( + 'name', 'unknown') tool_args = tool_info.get('arguments', {}) print(f'[Runner] Extracted tool_name: {tool_name}') if tool_name and tool_name != 'unknown': self._current_tool_name = tool_name self._current_tool_args = tool_args agent_name = agent_name_from_line or self._current_step or 'agent' - print(f'[Runner] Sending tool call: {tool_name}, agent: {agent_name}') + print( + f'[Runner] Sending tool call: {tool_name}, agent: {agent_name}' + ) if self.on_output: - self.on_output( - { - 'type': 'tool_call', - 'content': f'调用工具: {tool_name}', - 'role': 'assistant', - 'metadata': {'tool_name': tool_name, 'tool_args': tool_args, 'agent': agent_name}, + self.on_output({ + 'type': 'tool_call', + 'content': f'调用工具: {tool_name}', + 'role': 'assistant', + 'metadata': { + 'tool_name': tool_name, + 'tool_args': tool_args, + 'agent': agent_name } - ) + }) # Clear buffer but KEEP collecting - there may be more tool calls self._tool_call_json_buffer = '' # Don't return or stop collecting - next line might be another tool call JSON else: - print(f'[Runner] WARNING: Invalid tool_name: {tool_name}') + print( + f'[Runner] WARNING: Invalid tool_name: {tool_name}' + ) except json.JSONDecodeError as e: # JSON not complete yet, keep collecting # Only log if we have tool_name - helps debug parsing issues if '"tool_name"' in self._tool_call_json_buffer: - print(f'[Runner] JSON incomplete, continuing... (error: {str(e)[:50]})') + print( + f'[Runner] JSON incomplete, continuing... (error: {str(e)[:50]})' + ) except Exception as e: print(f'[Runner] Error parsing tool JSON: {e}') @@ -1322,18 +1430,23 @@ def flush_accumulated_output(): if '[assistant]:' in line or 'Agent' in line and 'task' in line or '[tool_result]:' in line: # If we have partial data, try to send it if self._tool_call_json_buffer: - tool_name_match = re.search(r'"tool_name"\s*:\s*"([^"]+)"', self._tool_call_json_buffer) + tool_name_match = re.search(r'"tool_name"\s*:\s*"([^"]+)"', + self._tool_call_json_buffer) if tool_name_match: tool_name = tool_name_match.group(1) # Try to extract arguments - handle nested JSON objects - args_start = self._tool_call_json_buffer.find('"arguments"') + args_start = self._tool_call_json_buffer.find( + '"arguments"') tool_args = {} if args_start != -1: - brace_start = self._tool_call_json_buffer.find('{', args_start) + brace_start = self._tool_call_json_buffer.find( + '{', args_start) if brace_start != -1: brace_count = 0 brace_end = brace_start - for i in range(brace_start, len(self._tool_call_json_buffer)): + for i in range( + brace_start, + len(self._tool_call_json_buffer)): if self._tool_call_json_buffer[i] == '{': brace_count += 1 elif self._tool_call_json_buffer[i] == '}': @@ -1343,7 +1456,8 @@ def flush_accumulated_output(): break if brace_end > brace_start: - args_str = self._tool_call_json_buffer[brace_start:brace_end] + args_str = self._tool_call_json_buffer[ + brace_start:brace_end] try: tool_args = json.loads(args_str) except Exception: @@ -1356,14 +1470,16 @@ def flush_accumulated_output(): f'{tool_name}, agent: {agent_name}, args: {tool_args}' ) if self.on_output: - self.on_output( - { - 'type': 'tool_call', - 'content': f'调用工具: {tool_name}', - 'role': 'assistant', - 'metadata': {'tool_name': tool_name, 'tool_args': tool_args, 'agent': agent_name}, + self.on_output({ + 'type': 'tool_call', + 'content': f'调用工具: {tool_name}', + 'role': 'assistant', + 'metadata': { + 'tool_name': tool_name, + 'tool_args': tool_args, + 'agent': agent_name } - ) + }) self._collecting_tool_call = False self._tool_call_json_buffer = '' return @@ -1379,18 +1495,16 @@ def flush_accumulated_output(): self._current_tool_result = result_content # Send tool result immediately if we have tool name if self._current_tool_name and self.on_output: - self.on_output( - { - 'type': 'tool_result', - 'content': f'工具 {self._current_tool_name} 执行完成', - 'role': 'assistant', - 'metadata': { - 'tool_name': self._current_tool_name, - 'tool_result': result_content, - 'agent': self._current_step or 'agent', - }, + self.on_output({ + 'type': 'tool_result', + 'content': f'工具 {self._current_tool_name} 执行完成', + 'role': 'assistant', + 'metadata': { + 'tool_name': self._current_tool_name, + 'tool_result': result_content, + 'agent': self._current_step or 'agent' } - ) + }) # Reset tool info self._current_tool_name = None self._current_tool_result = None @@ -1408,51 +1522,57 @@ def flush_accumulated_output(): # Check for EdgeOne deployment URL in tool result # Pattern 1: JSON format with edgeone.cool or edgeone.site - url_match = re.search(r'"url":\s*"(https?://[^"]+edgeone\.(cool|site)[^"]+)"', line) + url_match = re.search( + r'"url":\s*"(https?://[^"]+edgeone\.(cool|site)[^"]+)"', + line) # Pattern 2: Direct URL with edgeone.cool or edgeone.site if not url_match: - url_match = re.search(r'(https?://[^\s]*edgeone\.(cool|site)[^\s]*)', line) + url_match = re.search( + r'(https?://[^\s]*edgeone\.(cool|site)[^\s]*)', line) if url_match: deployment_url = url_match.group(1) # Clean up escaped characters in URL (e.g., \& -> &) deployment_url = deployment_url.replace('\\&', '&') - print(f'[Runner] Detected deployment URL in tool result: {deployment_url}') + print( + f'[Runner] Detected deployment URL in tool result: {deployment_url}' + ) if self.on_output: - self.on_output( - { - 'type': 'deployment_url', - 'content': deployment_url, - 'role': 'assistant', - 'metadata': {'url': deployment_url}, + self.on_output({ + 'type': 'deployment_url', + 'content': deployment_url, + 'role': 'assistant', + 'metadata': { + 'url': deployment_url } - ) + }) # After deployment success, prompt user for further input self._waiting_for_input = True if not self._waiting_input_sent: - self.on_output( - { - 'type': 'waiting_input', - 'content': 'You can now provide additional feedback or visit the deployed site.', - 'role': 'system', - 'metadata': {'waiting': True, 'deployment_complete': True}, + self.on_output({ + 'type': 'waiting_input', + 'content': + 'You can now provide additional feedback or visit the deployed site.', + 'role': 'system', + 'metadata': { + 'waiting': True, + 'deployment_complete': True } - ) + }) self._waiting_input_sent = True # Send result if we have tool name and accumulated enough content - if self._current_tool_name and len(self._current_tool_result) > 100 and self.on_output: - self.on_output( - { - 'type': 'tool_result', - 'content': f'工具 {self._current_tool_name} 执行完成', - 'role': 'assistant', - 'metadata': { - 'tool_name': self._current_tool_name, - 'tool_result': self._current_tool_result, - 'agent': self._current_step or 'agent', - }, + if self._current_tool_name and len( + self._current_tool_result) > 100 and self.on_output: + self.on_output({ + 'type': 'tool_result', + 'content': f'工具 {self._current_tool_name} 执行完成', + 'role': 'assistant', + 'metadata': { + 'tool_name': self._current_tool_name, + 'tool_result': self._current_tool_result, + 'agent': self._current_step or 'agent' } - ) + }) # Reset self._current_tool_name = None self._current_tool_result = None @@ -1460,52 +1580,68 @@ def flush_accumulated_output(): elif '[assistant]:' in line or '[tool_calling]:' in line or 'Agent' in line and 'task' in line: # Hit a new pattern, send accumulated result if self._current_tool_name and self._current_tool_result and self.on_output: - self.on_output( - { - 'type': 'tool_result', - 'content': f'工具 {self._current_tool_name} 执行完成', - 'role': 'assistant', - 'metadata': { - 'tool_name': self._current_tool_name, - 'tool_result': self._current_tool_result, - 'agent': self._current_step or 'agent', - }, + self.on_output({ + 'type': 'tool_result', + 'content': f'工具 {self._current_tool_name} 执行完成', + 'role': 'assistant', + 'metadata': { + 'tool_name': self._current_tool_name, + 'tool_result': self._current_tool_result, + 'agent': self._current_step or 'agent' } - ) + }) self._current_tool_name = None self._current_tool_result = None self._collecting_tool_result = False return # Detect file writing - file_match = re.search(r'writing file:?\s*["\']?([^\s"\']+)["\']?', line.lower()) + file_match = re.search(r'writing file:?\s*["\']?([^\s"\']+)["\']?', + line.lower()) if not file_match: - file_match = re.search(r'creating file:?\s*["\']?([^\s"\']+)["\']?', line.lower()) + file_match = re.search( + r'creating file:?\s*["\']?([^\s"\']+)["\']?', line.lower()) if file_match and self.on_progress: filename = file_match.group(1) - self.on_progress({'type': 'file', 'file': filename, 'status': 'writing'}) + self.on_progress({ + 'type': 'file', + 'file': filename, + 'status': 'writing' + }) return # Detect file written/created/saved - multiple patterns - file_keywords = ['file created', 'file written', 'file saved', 'saved to:', 'wrote to', 'generated:', 'output:'] + file_keywords = [ + 'file created', 'file written', 'file saved', 'saved to:', + 'wrote to', 'generated:', 'output:' + ] if any(keyword in line.lower() for keyword in file_keywords): # Try to extract filename with extension # More strict pattern: must have a proper filename with extension, not just numbers - file_match = re.search(r'["\']?([a-zA-Z0-9_\-][^\s"\'\/\[\]]*\.[a-zA-Z0-9]+)["\']?', line) + file_match = re.search( + r'["\']?([a-zA-Z0-9_\-][^\s"\'\/\[\]]*\.[a-zA-Z0-9]+)["\']?', + line) if file_match and self.on_progress: filename = file_match.group(1) # Validate filename: must not be just numbers or version numbers like "0.0" - if filename and not re.match(r'^\d+\.\d+$', filename) and len(filename) > 2: + if filename and not re.match(r'^\d+\.\d+$', + filename) and len(filename) > 2: # Strip 'programmer-' prefix from filename if filename.startswith('programmer-'): - filename = filename[len('programmer-') :] + filename = filename[len('programmer-'):] print(f'[Runner] Detected file output: {filename}') # Only send progress update (file_output will be sent from tasks.txt) - self.on_progress({'type': 'file', 'file': filename, 'status': 'completed'}) + self.on_progress({ + 'type': 'file', + 'file': filename, + 'status': 'completed' + }) return # Detect output file paths (e.g., "output/user_story.txt" standalone) - output_path_match = re.search(r'(?:^|\s)((?:output|projects)/[^\s]+\.[a-zA-Z0-9]+)(?:\s|$)', line) + output_path_match = re.search( + r'(?:^|\s)((?:output|projects)/[^\s]+\.[a-zA-Z0-9]+)(?:\s|$)', + line) if output_path_match and self.on_progress: filename = output_path_match.group(1) # Strip 'programmer-' prefix from basename only (not from path) @@ -1513,13 +1649,17 @@ def flush_accumulated_output(): if '/' in filename: parts = filename.rsplit('/', 1) if len(parts) == 2 and parts[1].startswith('programmer-'): - parts[1] = parts[1][len('programmer-') :] + parts[1] = parts[1][len('programmer-'):] filename = '/'.join(parts) elif filename.startswith('programmer-'): - filename = filename[len('programmer-') :] + filename = filename[len('programmer-'):] print(f'[Runner] Detected output path: {filename}') # Only send progress update (file_output will be sent from tasks.txt) - self.on_progress({'type': 'file', 'file': filename, 'status': 'completed'}) + self.on_progress({ + 'type': 'file', + 'file': filename, + 'status': 'completed' + }) return # Deployment URL detection moved to the beginning of _detect_patterns @@ -1529,23 +1669,22 @@ def flush_accumulated_output(): # Pattern: "✅ Initial refinement completed. You can now provide..." # Also detect: "Agent completed initial refinement. Waiting for user feedback." # Also detect: "Waiting for user input from stdin..." - if ( - 'Initial refinement completed' in line - or 'provide additional feedback' in line - or 'Waiting for user feedback' in line - or 'Agent completed initial refinement' in line - or 'Waiting for user input from stdin' in line - ): + if ('Initial refinement completed' in line + or 'provide additional feedback' in line + or 'Waiting for user feedback' in line + or 'Agent completed initial refinement' in line + or 'Waiting for user input from stdin' in line): print('[Runner] Agent waiting for user input') self._waiting_for_input = True # Mark that agent is waiting for input if self.on_output and not self._waiting_input_sent: - self.on_output( - { - 'type': 'waiting_input', - 'content': '✅ Initial refinement completed. You can now provide additional feedback or modifications.', - 'role': 'system', - 'metadata': {'waiting': True}, + self.on_output({ + 'type': 'waiting_input', + 'content': + '✅ Initial refinement completed. You can now provide additional feedback or modifications.', + 'role': 'system', + 'metadata': { + 'waiting': True } - ) + }) self._waiting_input_sent = True return diff --git a/webui/backend/api.py b/webui/backend/api.py index d0664ad9b..126ac4c02 100644 --- a/webui/backend/api.py +++ b/webui/backend/api.py @@ -2,7 +2,6 @@ """ API endpoints for the MS-Agent Web UI """ - import mimetypes import os from pathlib import Path @@ -11,7 +10,6 @@ from fastapi import APIRouter, HTTPException, Query from fastapi.responses import FileResponse from pydantic import BaseModel, Field - # Import shared instances from shared import config_manager, project_discovery, session_manager @@ -19,7 +17,8 @@ def get_backend_root() -> Path: - return Path(__file__).resolve().parents[1] # equal to dirname(dirname(__file__)) + return Path(__file__).resolve().parents[ + 1] # equal to dirname(dirname(__file__)) def get_session_root(session_id: str) -> Path: @@ -47,7 +46,8 @@ class ProjectInfo(BaseModel): class SessionCreate(BaseModel): project_id: Optional[str] = None # Optional for chat mode query: Optional[str] = None - workflow_type: Optional[str] = 'standard' # 'standard' or 'simple' for code_genesis + workflow_type: Optional[ + str] = 'standard' # 'standard' or 'simple' for code_genesis session_type: Optional[str] = 'project' # 'project' or 'chat' @@ -99,10 +99,14 @@ class DeepResearchSearchConfig(BaseModel): class DeepResearchConfig(BaseModel): - researcher: DeepResearchAgentConfig = Field(default_factory=DeepResearchAgentConfig) - searcher: DeepResearchAgentConfig = Field(default_factory=DeepResearchAgentConfig) - reporter: DeepResearchAgentConfig = Field(default_factory=DeepResearchAgentConfig) - search: DeepResearchSearchConfig = Field(default_factory=DeepResearchSearchConfig) + researcher: DeepResearchAgentConfig = Field( + default_factory=DeepResearchAgentConfig) + searcher: DeepResearchAgentConfig = Field( + default_factory=DeepResearchAgentConfig) + reporter: DeepResearchAgentConfig = Field( + default_factory=DeepResearchAgentConfig) + search: DeepResearchSearchConfig = Field( + default_factory=DeepResearchSearchConfig) class MCPServer(BaseModel): @@ -125,7 +129,9 @@ class GlobalConfig(BaseModel): @router.get('/projects', response_model=List[ProjectInfo]) async def list_projects(): """List all available projects""" - print(f'project_discovery.discover_projects(): {project_discovery.discover_projects()}') + print( + f'project_discovery.discover_projects(): {project_discovery.discover_projects()}' + ) return project_discovery.discover_projects() @@ -148,7 +154,8 @@ async def get_project_readme(project_id: str): @router.get('/projects/{project_id}/workflow') -async def get_project_workflow(project_id: str, session_id: Optional[str] = None): +async def get_project_workflow(project_id: str, + session_id: Optional[str] = None): """Get the workflow configuration for a project If session_id is provided, returns the workflow based on the session's workflow_type. @@ -181,12 +188,12 @@ async def get_project_workflow(project_id: str, session_id: Optional[str] = None try: import yaml - with open(workflow_file, 'r', encoding='utf-8') as f: workflow_data = yaml.safe_load(f) return {'workflow': workflow_data, 'workflow_type': workflow_type} except Exception as e: - raise HTTPException(status_code=500, detail=f'Error reading workflow file: {str(e)}') + raise HTTPException( + status_code=500, detail=f'Error reading workflow file: {str(e)}') # Session Endpoints @@ -197,8 +204,10 @@ async def create_session(session_data: SessionCreate): if session_data.session_type == 'chat': # Create chat session without requiring a project session = session_manager.create_session( - project_id='__chat__', project_name='Chat Assistant', workflow_type='standard', session_type='chat' - ) + project_id='__chat__', + project_name='Chat Assistant', + workflow_type='standard', + session_type='chat') return session # For project mode, validate project exists @@ -210,14 +219,15 @@ async def create_session(session_data: SessionCreate): workflow_type = session_data.workflow_type or 'standard' if project.get('supports_workflow_switch'): if workflow_type not in ['standard', 'simple']: - raise HTTPException(status_code=400, detail="workflow_type must be 'standard' or 'simple'") + raise HTTPException( + status_code=400, + detail="workflow_type must be 'standard' or 'simple'") session = session_manager.create_session( project_id=session_data.project_id, project_name=project['name'], workflow_type=workflow_type, - session_type='project', - ) + session_type='project') return session @@ -255,7 +265,8 @@ async def get_session_messages(session_id: str): @router.get('/sessions/{session_id}/dr_events') -async def get_session_dr_events(session_id: str, after_id: Optional[int] = Query(None, ge=0)): +async def get_session_dr_events(session_id: str, + after_id: Optional[int] = Query(None, ge=0)): """Get deep research event history for a session.""" events = session_manager.list_dr_events(session_id, after_id) if events is None: @@ -358,7 +369,8 @@ async def update_deep_research_config(config: DeepResearchConfig): @router.post('/config/mcp/servers') async def add_mcp_server(server: MCPServer): """Add a new MCP server""" - config_manager.add_mcp_server(server.name, server.model_dump(exclude={'name'})) + config_manager.add_mcp_server(server.name, + server.model_dump(exclude={'name'})) return {'status': 'added'} @@ -380,14 +392,38 @@ async def list_available_models(): { 'provider': 'modelscope', 'model': 'Qwen/Qwen3-235B-A22B-Instruct-2507', - 'display_name': 'Qwen3-235B (Recommended)', + 'display_name': 'Qwen3-235B (Recommended)' + }, + { + 'provider': 'modelscope', + 'model': 'Qwen/Qwen2.5-72B-Instruct', + 'display_name': 'Qwen2.5-72B' + }, + { + 'provider': 'modelscope', + 'model': 'Qwen/Qwen2.5-32B-Instruct', + 'display_name': 'Qwen2.5-32B' + }, + { + 'provider': 'modelscope', + 'model': 'deepseek-ai/DeepSeek-V3', + 'display_name': 'DeepSeek-V3' + }, + { + 'provider': 'openai', + 'model': 'gpt-4o', + 'display_name': 'GPT-4o' + }, + { + 'provider': 'openai', + 'model': 'gpt-4o-mini', + 'display_name': 'GPT-4o Mini' + }, + { + 'provider': 'anthropic', + 'model': 'claude-3-5-sonnet-20241022', + 'display_name': 'Claude 3.5 Sonnet' }, - {'provider': 'modelscope', 'model': 'Qwen/Qwen2.5-72B-Instruct', 'display_name': 'Qwen2.5-72B'}, - {'provider': 'modelscope', 'model': 'Qwen/Qwen2.5-32B-Instruct', 'display_name': 'Qwen2.5-32B'}, - {'provider': 'modelscope', 'model': 'deepseek-ai/DeepSeek-V3', 'display_name': 'DeepSeek-V3'}, - {'provider': 'openai', 'model': 'gpt-4o', 'display_name': 'GPT-4o'}, - {'provider': 'openai', 'model': 'gpt-4o-mini', 'display_name': 'GPT-4o Mini'}, - {'provider': 'anthropic', 'model': 'claude-3-5-sonnet-20241022', 'display_name': 'Claude 3.5 Sonnet'}, ] } @@ -401,22 +437,26 @@ class FileReadRequest(BaseModel): @router.get('/files/list') async def list_output_files( - output_dir: Optional[str] = Query(default='output'), - session_id: Optional[str] = Query(default=None), - root_dir: Optional[str] = Query(default=None), + output_dir: Optional[str] = Query(default='output'), + session_id: Optional[str] = Query(default=None), + root_dir: Optional[str] = Query(default=None), ): """List all files under root_dir as a tree structure. root_dir: optional. If not provided, defaults to ms-agent/output. Also supports 'projects' or 'projects/xxx' etc. """ # Excluded folders - exclude_dirs = {'node_modules', '__pycache__', '.git', '.venv', 'venv', 'dist', 'build'} + exclude_dirs = { + 'node_modules', '__pycache__', '.git', '.venv', 'venv', 'dist', 'build' + } # Base directories (same way as read_file_content) - base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + base_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) projects_dir = os.path.join(base_dir, 'projects') if session_id: + session_root = get_session_root(session_id) resolved_root = (session_root / '').resolve() @@ -482,15 +522,14 @@ def build_tree(dir_path: str) -> dict: # Return RELATIVE path to resolved_root (better for frontend + read API) rel_path = os.path.relpath(full_path, resolved_root) - result['files'].append( - { - 'name': item, - 'path': rel_path, # <-- relative path - 'abs_path': full_path, # optional: if you still want absolute for debugging - 'size': os.path.getsize(full_path), - 'modified': os.path.getmtime(full_path), - } - ) + result['files'].append({ + 'name': item, + 'path': rel_path, # <-- relative path + 'abs_path': + full_path, # optional: if you still want absolute for debugging + 'size': os.path.getsize(full_path), + 'modified': os.path.getmtime(full_path) + }) result['files'].sort(key=lambda x: x['modified'], reverse=True) return result @@ -501,10 +540,12 @@ def build_tree(dir_path: str) -> dict: def get_allowed_roots(): - base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + base_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) output_dir = os.path.join(base_dir, 'output') projects_dir = os.path.join(base_dir, 'projects') - return base_dir, os.path.normpath(output_dir), os.path.normpath(projects_dir) + return base_dir, os.path.normpath(output_dir), os.path.normpath( + projects_dir) def resolve_root_dir(root_dir: Optional[str]) -> str: @@ -535,7 +576,8 @@ def resolve_root_dir(root_dir: Optional[str]) -> str: cand1 = os.path.join(output_dir, rd) cand2 = os.path.join(projects_dir, rd) # choose existing one if possible, otherwise default to cand1 - resolved = cand1 if os.path.exists(cand1) else (cand2 if os.path.exists(cand2) else cand1) + resolved = cand1 if os.path.exists(cand1) else ( + cand2 if os.path.exists(cand2) else cand1) resolved = os.path.normpath(os.path.abspath(resolved)) @@ -561,11 +603,14 @@ def resolve_file_path(root_dir_abs: str, file_path: str) -> str: elif file_path.startswith('projects/'): # Special case: if path starts with 'projects/', resolve from base_dir # This handles: projects/code_genesis/output/config.js - base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - full_path = os.path.normpath(os.path.abspath(os.path.join(base_dir, file_path))) + base_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + full_path = os.path.normpath( + os.path.abspath(os.path.join(base_dir, file_path))) else: # Try multiple locations - base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + base_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) candidates = [ # First try with root_dir_abs (for session-based access) @@ -580,7 +625,8 @@ def resolve_file_path(root_dir_abs: str, file_path: str) -> str: for project_name in os.listdir(projects_dir): project_path = os.path.join(projects_dir, project_name) if os.path.isdir(project_path): - candidates.append(os.path.join(project_path, 'output', file_path)) + candidates.append( + os.path.join(project_path, 'output', file_path)) except (OSError, PermissionError): pass @@ -614,10 +660,12 @@ async def read_file_content(request: FileReadRequest): full_path = resolve_file_path(root_abs, request.path) if not os.path.exists(full_path): - raise HTTPException(status_code=404, detail=f'File not found: {full_path}') + raise HTTPException( + status_code=404, detail=f'File not found: {full_path}') if not os.path.isfile(full_path): - raise HTTPException(status_code=400, detail=f'Path {full_path} is not a file') + raise HTTPException( + status_code=400, detail=f'Path {full_path} is not a file') # limit 1MB file_size = os.path.getsize(full_path) if file_size > 1024 * 1024: @@ -658,17 +706,19 @@ async def read_file_content(request: FileReadRequest): 'root_dir': root_abs, 'filename': os.path.basename(full_path), 'language': language, - 'size': file_size, + 'size': file_size } except UnicodeDecodeError: raise HTTPException(status_code=400, detail='File is not a text file') except Exception as e: - raise HTTPException(status_code=500, detail=f'Error reading file: {str(e)}') + raise HTTPException( + status_code=500, detail=f'Error reading file: {str(e)}') def resolve_and_check_path(file_path: str) -> str: """Resolve file path, trying multiple locations""" - base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + base_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) if os.path.isabs(file_path): full_path = file_path @@ -698,7 +748,9 @@ def resolve_and_check_path(file_path: str) -> str: project_path = os.path.join(projects_dir, project_name) if os.path.isdir(project_path): # Try project/output/filename - candidates.append(os.path.join(project_path, 'output', file_path)) + candidates.append( + os.path.join(project_path, 'output', + file_path)) except (OSError, PermissionError): pass @@ -712,7 +764,8 @@ def resolve_and_check_path(file_path: str) -> str: if not full_path: # If not found, use the first candidate for error message - full_path = os.path.normpath(candidates[0] if candidates else file_path) + full_path = os.path.normpath( + candidates[0] if candidates else file_path) full_path = os.path.normpath(full_path) @@ -722,15 +775,18 @@ def resolve_and_check_path(file_path: str) -> str: # TODO: Security check: ensure `full_path` is within configured allowed roots. if not os.path.exists(full_path): - raise HTTPException(status_code=404, detail=f'File not found: {full_path}') + raise HTTPException( + status_code=404, detail=f'File not found: {full_path}') if not os.path.isfile(full_path): - raise HTTPException(status_code=400, detail=f'Path {full_path} is not a file') + raise HTTPException( + status_code=400, detail=f'Path {full_path} is not a file') return full_path @router.get('/files/stream') -async def stream_file(path: str, session_id: Optional[str] = Query(default=None)): +async def stream_file(path: str, + session_id: Optional[str] = Query(default=None)): if session_id: session_root = get_session_root(session_id) root_abs = str(session_root.resolve()) @@ -744,5 +800,8 @@ async def stream_file(path: str, session_id: Optional[str] = Query(default=None) full_path, media_type=media_type, filename=os.path.basename(full_path), - headers={'Content-Disposition': f'inline; filename="{os.path.basename(full_path)}"'}, + headers={ + 'Content-Disposition': + f'inline; filename="{os.path.basename(full_path)}"' + }, ) diff --git a/webui/backend/config_manager.py b/webui/backend/config_manager.py index 9d1511af8..eddf94915 100644 --- a/webui/backend/config_manager.py +++ b/webui/backend/config_manager.py @@ -3,12 +3,12 @@ Configuration management for MS-Agent Web UI Handles global settings, LLM configuration, and MCP server configuration. """ - -import json import os from threading import Lock from typing import Any, Dict, Optional +import json + class ConfigManager: """Manages global configuration for the Web UI""" @@ -21,23 +21,46 @@ class ConfigManager: 'base_url': 'https://api-inference.modelscope.cn/v1/', 'temperature': None, 'temperature_enabled': False, - 'max_tokens': None, + 'max_tokens': None }, 'deep_research': { - 'researcher': {'model': '', 'api_key': '', 'base_url': ''}, - 'searcher': {'model': '', 'api_key': '', 'base_url': ''}, - 'reporter': {'model': '', 'api_key': '', 'base_url': ''}, - 'search': {'summarizer_model': '', 'summarizer_api_key': '', 'summarizer_base_url': ''}, + 'researcher': { + 'model': '', + 'api_key': '', + 'base_url': '' + }, + 'searcher': { + 'model': '', + 'api_key': '', + 'base_url': '' + }, + 'reporter': { + 'model': '', + 'api_key': '', + 'base_url': '' + }, + 'search': { + 'summarizer_model': '', + 'summarizer_api_key': '', + 'summarizer_base_url': '' + } + }, + 'edit_file_config': { + 'api_key': '', + 'base_url': 'https://api.morphllm.com/v1', + 'diff_model': 'morph-v3-fast' + }, + 'edgeone_pages': { + 'api_token': '', + 'project_name': '' }, - 'edit_file_config': {'api_key': '', 'base_url': 'https://api.morphllm.com/v1', 'diff_model': 'morph-v3-fast'}, - 'edgeone_pages': {'api_token': '', 'project_name': ''}, 'search_keys': { 'exa_api_key': '', 'serpapi_api_key': '', }, 'mcp_servers': {}, 'theme': 'dark', - 'output_dir': './output', + 'output_dir': './output' } def __init__(self, config_dir: str): @@ -85,7 +108,10 @@ def _save_config(self): """Save configuration to file""" with self._lock: # Save main config (without mcp_servers) - config_to_save = {k: v for k, v in self._config.items() if k != 'mcp_servers'} + config_to_save = { + k: v + for k, v in self._config.items() if k != 'mcp_servers' + } with open(self.config_file, 'w', encoding='utf-8') as f: json.dump(config_to_save, f, indent=2) @@ -132,7 +158,8 @@ def update_mcp_config(self, mcp_config: Dict[str, Any]): def get_edit_file_config(self) -> Dict[str, Any]: """Get edit_file_config configuration""" config = self._load_config() - return config.get('edit_file_config', self.DEFAULT_CONFIG['edit_file_config']) + return config.get('edit_file_config', + self.DEFAULT_CONFIG['edit_file_config']) def update_edit_file_config(self, edit_file_config: Dict[str, Any]): """Update edit_file_config configuration""" @@ -143,9 +170,11 @@ def update_edit_file_config(self, edit_file_config: Dict[str, Any]): def get_edgeone_pages_config(self) -> Dict[str, Any]: """Get EdgeOne Pages configuration""" config = self._load_config() - return config.get('edgeone_pages', self.DEFAULT_CONFIG['edgeone_pages']) + return config.get('edgeone_pages', + self.DEFAULT_CONFIG['edgeone_pages']) - def update_edgeone_pages_config(self, edgeone_pages_config: Dict[str, Any]): + def update_edgeone_pages_config(self, edgeone_pages_config: Dict[str, + Any]): """Update EdgeOne Pages configuration""" self._load_config() self._config['edgeone_pages'] = edgeone_pages_config @@ -165,9 +194,11 @@ def update_search_keys(self, search_keys: Dict[str, Any]): def get_deep_research_config(self) -> Dict[str, Any]: """Get deep research configuration""" config = self._load_config() - return config.get('deep_research', self.DEFAULT_CONFIG['deep_research']) + return config.get('deep_research', + self.DEFAULT_CONFIG['deep_research']) - def update_deep_research_config(self, deep_research_config: Dict[str, Any]): + def update_deep_research_config(self, deep_research_config: Dict[str, + Any]): """Update deep research configuration""" self._load_config() self._config['deep_research'] = deep_research_config diff --git a/webui/backend/deep_research_eventizer.py b/webui/backend/deep_research_eventizer.py index 261f3b513..84949674a 100644 --- a/webui/backend/deep_research_eventizer.py +++ b/webui/backend/deep_research_eventizer.py @@ -1,6 +1,6 @@ -import json from typing import Any, Callable, Dict, List, Optional +import json from ms_agent.llm.utils import Message, ToolCall @@ -16,6 +16,7 @@ def _stringify_content(content: Any) -> str: class HistoryEventizer: + def __init__( self, emit: Callable[[Dict[str, Any]], None], @@ -51,7 +52,8 @@ def reset(self) -> None: self._tool_call_args = {} self._tool_call_names = {} - def _wrap_event(self, event_type: str, payload: Dict[str, Any]) -> Dict[str, Any]: + def _wrap_event(self, event_type: str, + payload: Dict[str, Any]) -> Dict[str, Any]: event: Dict[str, Any] = {'type': event_type, 'payload': payload} if self._session_id: event['session_id'] = self._session_id @@ -65,7 +67,7 @@ def _emit_event(self, event_type: str, payload: Dict[str, Any]) -> None: def _should_reset(self, messages: List[Message]) -> bool: if len(messages) < len(self._prev_messages): return True - for idx, msg in enumerate(messages[: len(self._prev_messages)]): + for idx, msg in enumerate(messages[:len(self._prev_messages)]): if msg.role != self._prev_messages[idx].role: return True return False @@ -85,11 +87,8 @@ def _ensure_message_id(self, idx: int, message: Message) -> str: def _is_subagent_tool(self, tool_name: str) -> bool: if not tool_name: return False - return ( - tool_name.startswith('agent_tools---') - or tool_name.endswith('searcher_tool') - or tool_name.endswith('reporter_tool') - ) + return tool_name.startswith('agent_tools---') or tool_name.endswith( + 'searcher_tool') or tool_name.endswith('reporter_tool') def _extract_tool_name(self, call: ToolCall) -> str: if not isinstance(call, dict): @@ -126,7 +125,8 @@ def _parse_tool_args(self, raw: Any) -> Dict[str, Any]: return {'request': raw} return {} - def _build_subagent_title(self, tool_name: str, tool_args: Dict[str, Any]) -> str: + def _build_subagent_title(self, tool_name: str, + tool_args: Dict[str, Any]) -> str: if 'searcher' in tool_name: base = 'Searcher' elif 'reporter' in tool_name: @@ -140,18 +140,21 @@ def _build_subagent_title(self, tool_name: str, tool_args: Dict[str, Any]) -> st try: parsed = json.loads(request) if isinstance(parsed, dict): - summary = parsed.get('task_id') or parsed.get('调研目标') or parsed.get('目标') + summary = parsed.get('task_id') or parsed.get( + '调研目标') or parsed.get('目标') except Exception: summary = None if summary is None: summary = request.strip().splitlines()[0][:80] return f'{base}: {summary}' if summary else base - def _record_tool_call(self, call_id: str, tool_name: str, tool_args: Dict[str, Any]) -> tuple[bool, bool]: + def _record_tool_call(self, call_id: str, tool_name: str, + tool_args: Dict[str, Any]) -> tuple[bool, bool]: is_new = call_id not in self._seen_tool_calls prev_args = self._tool_call_args.get(call_id) prev_name = self._tool_call_names.get(call_id) - if not is_new and prev_args == tool_args and (not tool_name or tool_name == prev_name): + if (not is_new and prev_args == tool_args + and (not tool_name or tool_name == prev_name)): return False, False self._seen_tool_calls.add(call_id) self._tool_call_args[call_id] = tool_args @@ -159,10 +162,12 @@ def _record_tool_call(self, call_id: str, tool_name: str, tool_args: Dict[str, A self._tool_call_names[call_id] = tool_name return True, is_new - def _maybe_emit_todos(self, tool_name: str, result_text: str, call_id: Optional[str]) -> None: + def _maybe_emit_todos(self, tool_name: str, result_text: str, + call_id: Optional[str]) -> None: if not tool_name: return - if not ('todo_list---todo_write' in tool_name or 'todo_list---todo_read' in tool_name): + if not ('todo_list---todo_write' in tool_name + or 'todo_list---todo_read' in tool_name): return try: parsed = json.loads(result_text) @@ -179,7 +184,8 @@ def _maybe_emit_todos(self, tool_name: str, result_text: str, call_id: Optional[ payload['call_id'] = call_id self._emit_event('dr.state', payload) - def _emit_assistant_delta(self, message_id: str, delta: str, full: str) -> None: + def _emit_assistant_delta(self, message_id: str, delta: str, + full: str) -> None: payload = { 'message_id': message_id, 'delta': delta, @@ -187,7 +193,8 @@ def _emit_assistant_delta(self, message_id: str, delta: str, full: str) -> None: } self._emit_event('dr.chat.message.delta', payload) - def _emit_subagent_delta(self, message_id: str, delta: str, full: str) -> None: + def _emit_subagent_delta(self, message_id: str, delta: str, + full: str) -> None: payload = { 'card_id': self._card_id, 'message_id': message_id, @@ -196,7 +203,8 @@ def _emit_subagent_delta(self, message_id: str, delta: str, full: str) -> None: } self._emit_event('dr.subagent.message.delta', payload) - def _emit_subagent_message(self, message_id: str, role: str, content: str) -> None: + def _emit_subagent_message(self, message_id: str, role: str, + content: str) -> None: payload = { 'card_id': self._card_id, 'message_id': message_id, @@ -213,7 +221,8 @@ def _emit_assistant_completed(self, message_id: str, content: str) -> None: } self._emit_event('dr.chat.message.completed', payload) - def _emit_chat_message(self, message_id: str, role: str, content: str, name: Optional[str]) -> None: + def _emit_chat_message(self, message_id: str, role: str, content: str, + name: Optional[str]) -> None: payload = { 'message_id': message_id, 'role': role, @@ -223,15 +232,19 @@ def _emit_chat_message(self, message_id: str, role: str, content: str, name: Opt payload['name'] = name self._emit_event('dr.chat.message', payload) - def _process_tool_calls(self, message_id: str, tool_calls: List[ToolCall]) -> None: + def _process_tool_calls(self, message_id: str, + tool_calls: List[ToolCall]) -> None: for idx, call in enumerate(tool_calls or []): call_id = call.get('id') or f'{message_id}-call-{idx}' tool_name = self._extract_tool_name(call) - tool_args = self._parse_tool_args(self._extract_tool_args_raw(call)) - should_emit, is_new = self._record_tool_call(call_id, tool_name, tool_args) + tool_args = self._parse_tool_args( + self._extract_tool_args_raw(call)) + should_emit, is_new = self._record_tool_call( + call_id, tool_name, tool_args) if not should_emit: continue - category = 'subagent' if self._is_subagent_tool(tool_name) else 'normal' + category = 'subagent' if self._is_subagent_tool( + tool_name) else 'normal' payload = { 'call_id': call_id, 'source_message_id': message_id, @@ -276,13 +289,10 @@ def _process_tool_result(self, message: Message) -> None: self._emit_event('dr.tool.result', payload) if call_id in self._subagent_call_ids: summary = result_text.strip().splitlines()[0][:160] - self._emit_event( - 'dr.subagent.card.completed', - { - 'card_id': call_id, - 'summary': summary, - }, - ) + self._emit_event('dr.subagent.card.completed', { + 'card_id': call_id, + 'summary': summary, + }) self._maybe_emit_todos(tool_name, result_text, call_id) def process(self, messages: List[Message]) -> None: @@ -302,14 +312,16 @@ def process(self, messages: List[Message]) -> None: prev_content = self._assistant_contents.get(message_id, '') if content and content != prev_content: if content.startswith(prev_content): - delta = content[len(prev_content) :] + delta = content[len(prev_content):] else: delta = content if delta: - self._emit_assistant_delta(message_id, delta, content) + self._emit_assistant_delta(message_id, delta, + content) self._assistant_contents[message_id] = content if message.tool_calls: - self._process_tool_calls(message_id, message.tool_calls) + self._process_tool_calls(message_id, + message.tool_calls) elif role == 'tool': self._process_tool_result(message) else: @@ -318,25 +330,31 @@ def process(self, messages: List[Message]) -> None: if idx >= prev_len: content = _stringify_content(message.content) if content: - self._emit_chat_message(message_id, role, content, getattr(message, 'name', None)) + self._emit_chat_message( + message_id, role, content, + getattr(message, 'name', None)) else: if role == 'assistant': content = _stringify_content(message.content) prev_content = self._assistant_contents.get(message_id, '') if content and content != prev_content: if content.startswith(prev_content): - delta = content[len(prev_content) :] + delta = content[len(prev_content):] else: delta = content if delta: - self._emit_subagent_delta(message_id, delta, content) + self._emit_subagent_delta(message_id, delta, + content) self._assistant_contents[message_id] = content if message.tool_calls: for idx, call in enumerate(message.tool_calls or []): - call_id = call.get('id') or f'{message_id}-call-{idx}' + call_id = call.get( + 'id') or f'{message_id}-call-{idx}' tool_name = self._extract_tool_name(call) - tool_args = self._parse_tool_args(self._extract_tool_args_raw(call)) - should_emit, is_new = self._record_tool_call(call_id, tool_name, tool_args) + tool_args = self._parse_tool_args( + self._extract_tool_args_raw(call)) + should_emit, is_new = self._record_tool_call( + call_id, tool_name, tool_args) if not should_emit: continue payload = { @@ -357,7 +375,8 @@ def process(self, messages: List[Message]) -> None: if role != 'tool' and idx >= prev_len: content = _stringify_content(message.content) if content: - self._emit_subagent_message(message_id, role, content) + self._emit_subagent_message( + message_id, role, content) if role == 'tool' and message.tool_call_id: call_id = message.tool_call_id if call_id in self._seen_tool_results: @@ -373,8 +392,11 @@ def process(self, messages: List[Message]) -> None: tool_args = self._tool_call_args.get(call_id) if tool_args is not None: payload['tool'] = { - 'name': (tool_name or self._tool_call_names.get(call_id, '')), - 'arguments': tool_args, + 'name': + (tool_name + or self._tool_call_names.get(call_id, '')), + 'arguments': + tool_args, } self._emit_event('dr.subagent.tool.result', payload) diff --git a/webui/backend/deep_research_worker.py b/webui/backend/deep_research_worker.py index 5ca62420b..e8a3480a3 100644 --- a/webui/backend/deep_research_worker.py +++ b/webui/backend/deep_research_worker.py @@ -1,6 +1,5 @@ import argparse import asyncio -import json import os import signal import sys @@ -8,11 +7,11 @@ from pathlib import Path from typing import Any, Dict, Optional +import json from deep_research_eventizer import HistoryEventizer # noqa: E402 -from omegaconf import OmegaConf - from ms_agent.agent.loader import AgentLoader from ms_agent.tools.agent_tool import AgentTool +from omegaconf import OmegaConf BACKEND_DIR = Path(__file__).resolve().parent if str(BACKEND_DIR) not in sys.path: @@ -22,6 +21,7 @@ class NullWriter: + def write(self, _: str) -> int: return 0 @@ -30,6 +30,7 @@ def flush(self) -> None: class NDJSONEmitter: + def __init__(self, stream) -> None: self._stream = stream @@ -70,16 +71,21 @@ def _normalize_agent_override(raw: Optional[Dict[str, Any]]) -> Dict[str, str]: } -def _resolve_agent_llm_config(role: str, llm_config: Dict[str, Any], dr_config: Dict[str, Any]) -> Dict[str, str]: +def _resolve_agent_llm_config(role: str, llm_config: Dict[str, Any], + dr_config: Dict[str, Any]) -> Dict[str, str]: overrides = _normalize_agent_override((dr_config or {}).get(role)) return { - 'model': overrides.get('model') or str(llm_config.get('model') or ''), - 'api_key': overrides.get('api_key') or str(llm_config.get('api_key') or ''), - 'base_url': overrides.get('base_url') or str(llm_config.get('base_url') or ''), + 'model': + overrides.get('model') or str(llm_config.get('model') or ''), + 'api_key': + overrides.get('api_key') or str(llm_config.get('api_key') or ''), + 'base_url': + overrides.get('base_url') or str(llm_config.get('base_url') or ''), } -def _normalize_search_override(raw: Optional[Dict[str, Any]]) -> Dict[str, str]: +def _normalize_search_override( + raw: Optional[Dict[str, Any]]) -> Dict[str, str]: raw = raw or {} return { 'summarizer_model': str(raw.get('summarizer_model') or ''), @@ -89,8 +95,8 @@ def _normalize_search_override(raw: Optional[Dict[str, Any]]) -> Dict[str, str]: def _build_config_override( - llm_config: Dict[str, Any], output_dir: str, dr_config: Dict[str, Any] -) -> Optional[Dict[str, Any]]: + llm_config: Dict[str, Any], output_dir: str, + dr_config: Dict[str, Any]) -> Optional[Dict[str, Any]]: override: Dict[str, Any] = {} if output_dir: override['output_dir'] = output_dir @@ -126,7 +132,8 @@ def _build_config_override( return override or None -async def _watch_artifacts(output_dir: str, emitter: NDJSONEmitter, session_id: str) -> None: +async def _watch_artifacts(output_dir: str, emitter: NDJSONEmitter, + session_id: str) -> None: last_snapshot: Dict[str, tuple[int, float]] = {} output_path = Path(output_dir) ignore_dirs = {'.locks', '__pycache__'} @@ -148,23 +155,25 @@ async def _watch_artifacts(output_dir: str, emitter: NDJSONEmitter, session_id: except OSError: continue snapshot[rel_path] = (stat.st_size, stat.st_mtime) - files.append( - { - 'path': rel_path, - 'relative_path': rel_path, - 'size': stat.st_size, - 'modified': stat.st_mtime, - } - ) + files.append({ + 'path': rel_path, + 'relative_path': rel_path, + 'size': stat.st_size, + 'modified': stat.st_mtime, + }) if snapshot != last_snapshot: - emitter.emit( - { - 'type': 'dr.artifact.updated', - 'payload': {'files': sorted(files, key=lambda x: x.get('modified', 0), reverse=True)}, - 'session_id': session_id, - } - ) + emitter.emit({ + 'type': 'dr.artifact.updated', + 'payload': { + 'files': + sorted( + files, + key=lambda x: x.get('modified', 0), + reverse=True) + }, + 'session_id': session_id, + }) last_snapshot = snapshot await asyncio.sleep(1.0) @@ -172,14 +181,16 @@ async def _watch_artifacts(output_dir: str, emitter: NDJSONEmitter, session_id: async def run_worker(args: argparse.Namespace) -> None: emitter = NDJSONEmitter(sys.__stdout__) - main_eventizer = HistoryEventizer(emitter.emit, channel='main', session_id=args.session_id) + main_eventizer = HistoryEventizer( + emitter.emit, channel='main', session_id=args.session_id) subagent_eventizers: Dict[str, HistoryEventizer] = {} loop = asyncio.get_running_loop() subagent_queue: asyncio.Queue = asyncio.Queue() def chunk_callback(*, event_type: str, data: Dict[str, Any]) -> None: - loop.call_soon_threadsafe(subagent_queue.put_nowait, (event_type, data)) + loop.call_soon_threadsafe(subagent_queue.put_nowait, + (event_type, data)) async def consume_subagent_events(): while True: @@ -204,8 +215,10 @@ async def consume_subagent_events(): llm_config = _load_llm_config() dr_config = _load_deep_research_config() - config_override = _build_config_override(llm_config, args.output_dir, dr_config) - config_override = OmegaConf.create(config_override) if config_override else None + config_override = _build_config_override(llm_config, args.output_dir, + dr_config) + config_override = OmegaConf.create( + config_override) if config_override else None agent = AgentLoader.build( config_dir_or_id=args.config, @@ -233,10 +246,13 @@ async def prepare_tools_with_callback(): tool_name = str(spec.tool_name or '') if 'searcher' in tool_name: - resolved = _resolve_agent_llm_config('searcher', llm_config, dr_config) - search_override = _normalize_search_override((dr_config or {}).get('search')) + resolved = _resolve_agent_llm_config( + 'searcher', llm_config, dr_config) + search_override = _normalize_search_override( + (dr_config or {}).get('search')) elif 'reporter' in tool_name: - resolved = _resolve_agent_llm_config('reporter', llm_config, dr_config) + resolved = _resolve_agent_llm_config( + 'reporter', llm_config, dr_config) search_override = {} else: resolved = {} @@ -257,11 +273,16 @@ async def prepare_tools_with_callback(): tools_cfg = dict(updated.get('tools') or {}) web_cfg = dict(tools_cfg.get('web_search') or {}) if search_override.get('summarizer_model'): - web_cfg['summarizer_model'] = search_override['summarizer_model'] + web_cfg['summarizer_model'] = search_override[ + 'summarizer_model'] if search_override.get('summarizer_api_key'): - web_cfg['summarizer_api_key'] = search_override['summarizer_api_key'] + web_cfg[ + 'summarizer_api_key'] = search_override[ + 'summarizer_api_key'] if search_override.get('summarizer_base_url'): - web_cfg['summarizer_base_url'] = search_override['summarizer_base_url'] + web_cfg[ + 'summarizer_base_url'] = search_override[ + 'summarizer_base_url'] if web_cfg: tools_cfg['web_search'] = web_cfg updated['tools'] = tools_cfg @@ -276,7 +297,8 @@ async def prepare_tools_with_callback(): agent.prepare_tools = prepare_tools_with_callback - artifact_task = asyncio.create_task(_watch_artifacts(args.output_dir, emitter, args.session_id)) + artifact_task = asyncio.create_task( + _watch_artifacts(args.output_dir, emitter, args.session_id)) subagent_task = asyncio.create_task(consume_subagent_events()) had_error = False @@ -290,48 +312,40 @@ async def prepare_tools_with_callback(): main_eventizer.process(result) except Exception as exc: had_error = True - emitter.emit( - { - 'type': 'dr.worker.error', - 'payload': { - 'error': str(exc), - 'traceback': traceback.format_exc(), - }, - 'session_id': args.session_id, - } - ) - emitter.emit( - { - 'type': 'error', - 'message': str(exc), - } - ) + emitter.emit({ + 'type': 'dr.worker.error', + 'payload': { + 'error': str(exc), + 'traceback': traceback.format_exc(), + }, + 'session_id': args.session_id, + }) + emitter.emit({ + 'type': 'error', + 'message': str(exc), + }) raise finally: main_eventizer.finalize() - emitter.emit( - { - 'type': 'dr.worker.exited', - 'payload': {'status': 'completed'}, - 'session_id': args.session_id, - } - ) + emitter.emit({ + 'type': 'dr.worker.exited', + 'payload': { + 'status': 'completed' + }, + 'session_id': args.session_id, + }) if STOP_REQUESTED: - emitter.emit( - { - 'type': 'status', - 'status': 'stopped', - } - ) + emitter.emit({ + 'type': 'status', + 'status': 'stopped', + }) elif not had_error: - emitter.emit( - { - 'type': 'complete', - 'result': { - 'status': 'success', - }, - } - ) + emitter.emit({ + 'type': 'complete', + 'result': { + 'status': 'success', + }, + }) subagent_queue.put_nowait((None, None)) artifact_task.cancel() subagent_task.cancel() diff --git a/webui/backend/deep_research_worker_manager.py b/webui/backend/deep_research_worker_manager.py index 76c4b9d84..a5eb29d05 100644 --- a/webui/backend/deep_research_worker_manager.py +++ b/webui/backend/deep_research_worker_manager.py @@ -1,5 +1,4 @@ import asyncio -import json import os import signal import sys @@ -7,9 +6,13 @@ from pathlib import Path from typing import Any, Awaitable, Callable, Dict, Optional +import json + class DeepResearchWorkerManager: - def __init__(self, send_event: Callable[[str, Dict[str, Any]], Awaitable[None]]): + + def __init__(self, send_event: Callable[[str, Dict[str, Any]], + Awaitable[None]]): self._send_event = send_event self._processes: Dict[str, asyncio.subprocess.Process] = {} self._stdout_tasks: Dict[str, asyncio.Task] = {} @@ -23,18 +26,18 @@ def _get_worker_path(self) -> Path: return Path(__file__).resolve().parent / 'deep_research_worker.py' def _build_env( - self, - env_vars: Optional[Dict[str, str]], - llm_config: Optional[Dict[str, Any]], - deep_research_config: Optional[Dict[str, Any]], - ) -> Dict[str, str]: + self, env_vars: Optional[Dict[str, str]], + llm_config: Optional[Dict[str, Any]], + deep_research_config: Optional[Dict[str, Any]]) -> Dict[str, str]: env = os.environ.copy() if env_vars: env.update({k: v for k, v in env_vars.items() if v}) if llm_config: - env['MS_AGENT_LLM_CONFIG'] = json.dumps(llm_config, ensure_ascii=False) + env['MS_AGENT_LLM_CONFIG'] = json.dumps( + llm_config, ensure_ascii=False) if deep_research_config: - env['MS_AGENT_DEEP_RESEARCH_CONFIG'] = json.dumps(deep_research_config, ensure_ascii=False) + env['MS_AGENT_DEEP_RESEARCH_CONFIG'] = json.dumps( + deep_research_config, ensure_ascii=False) api_key = (llm_config or {}).get('api_key') base_url = (llm_config or {}).get('base_url') @@ -46,20 +49,20 @@ def _build_env( repo_root = str(self._get_repo_root()) existing_path = env.get('PYTHONPATH', '') if repo_root not in existing_path.split(os.pathsep): - env['PYTHONPATH'] = repo_root + (os.pathsep + existing_path if existing_path else '') + env['PYTHONPATH'] = repo_root + ( + os.pathsep + existing_path if existing_path else '') return env async def start( - self, - session_id: str, - *, - query: str, - config_path: str, - output_dir: str, - env_vars: Optional[Dict[str, str]] = None, - llm_config: Optional[Dict[str, Any]] = None, - deep_research_config: Optional[Dict[str, Any]] = None, - ) -> None: + self, + session_id: str, + *, + query: str, + config_path: str, + output_dir: str, + env_vars: Optional[Dict[str, str]] = None, + llm_config: Optional[Dict[str, Any]] = None, + deep_research_config: Optional[Dict[str, Any]] = None) -> None: if session_id in self._processes: await self.stop(session_id) @@ -92,17 +95,17 @@ async def start( ) self._processes[session_id] = process - self._stdout_tasks[session_id] = asyncio.create_task(self._read_stdout(session_id, process)) - self._stderr_tasks[session_id] = asyncio.create_task(self._read_stderr(session_id, process)) + self._stdout_tasks[session_id] = asyncio.create_task( + self._read_stdout(session_id, process)) + self._stderr_tasks[session_id] = asyncio.create_task( + self._read_stderr(session_id, process)) await self._send_event( - session_id, - { + session_id, { 'type': 'log', 'level': 'info', 'message': f'Deep research worker started (pid={process.pid})', 'timestamp': datetime.now().isoformat(), - }, - ) + }) async def stop(self, session_id: str) -> None: process = self._processes.get(session_id) @@ -132,7 +135,8 @@ async def stop(self, session_id: str) -> None: finally: self._cleanup(session_id) - async def _read_stdout(self, session_id: str, process: asyncio.subprocess.Process) -> None: + async def _read_stdout(self, session_id: str, + process: asyncio.subprocess.Process) -> None: if not process.stdout: return while True: @@ -158,16 +162,20 @@ async def _read_stdout(self, session_id: str, process: asyncio.subprocess.Proces return_code = None if return_code not in (None, 0) and session_id not in self._stopping: await self._send_event( - session_id, - { - 'type': 'error', - 'message': f'Deep research worker exited with code {return_code}', - }, - ) - await self._send_event(session_id, {'type': 'status', 'status': 'error'}) + session_id, { + 'type': + 'error', + 'message': + f'Deep research worker exited with code {return_code}', + }) + await self._send_event(session_id, { + 'type': 'status', + 'status': 'error' + }) self._cleanup(session_id) - async def _read_stderr(self, session_id: str, process: asyncio.subprocess.Process) -> None: + async def _read_stderr(self, session_id: str, + process: asyncio.subprocess.Process) -> None: if not process.stderr: return while True: @@ -180,14 +188,12 @@ async def _read_stderr(self, session_id: str, process: asyncio.subprocess.Proces sys.stderr.write(text) sys.stderr.flush() await self._send_event( - session_id, - { + session_id, { 'type': 'log', 'level': 'error', 'message': f'[deep_research_worker] {text.strip()}', 'timestamp': datetime.now().isoformat(), - }, - ) + }) except Exception: pass diff --git a/webui/backend/main.py b/webui/backend/main.py index 2f2d17291..2afcebbe3 100644 --- a/webui/backend/main.py +++ b/webui/backend/main.py @@ -3,7 +3,6 @@ MS-Agent Web UI Backend Server Provides REST API and WebSocket endpoints for the ms-agent framework. """ - import os import sys @@ -16,11 +15,15 @@ from websocket_handler import router as ws_router # Add ms-agent to path -MS_AGENT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'ms-agent')) +MS_AGENT_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), '..', '..', 'ms-agent')) if MS_AGENT_PATH not in sys.path: sys.path.insert(0, MS_AGENT_PATH) -app = FastAPI(title='MS-Agent Web UI', description='Web interface for the MS-Agent framework', version='1.0.0') +app = FastAPI( + title='MS-Agent Web UI', + description='Web interface for the MS-Agent framework', + version='1.0.0') # CORS configuration app.add_middleware( @@ -38,7 +41,10 @@ # Serve static files in production STATIC_DIR = os.path.join(os.path.dirname(__file__), '..', 'frontend', 'dist') if os.path.exists(STATIC_DIR): - app.mount('/assets', StaticFiles(directory=os.path.join(STATIC_DIR, 'assets')), name='assets') + app.mount( + '/assets', + StaticFiles(directory=os.path.join(STATIC_DIR, 'assets')), + name='assets') @app.get('/{full_path:path}') async def serve_spa(full_path: str): @@ -58,19 +64,19 @@ async def health_check(): def main(): """Start the server""" import argparse - parser = argparse.ArgumentParser(description='MS-Agent Web UI Server') parser.add_argument('--host', default='0.0.0.0', help='Host to bind') parser.add_argument('--port', type=int, default=7860, help='Port to bind') - parser.add_argument('--reload', action='store_true', help='Enable auto-reload') + parser.add_argument( + '--reload', action='store_true', help='Enable auto-reload') args = parser.parse_args() - print(f"\n{'=' * 60}") + print(f"\n{'='*60}") print(' MS-Agent Web UI Server') - print(f"{'=' * 60}") + print(f"{'='*60}") print(f' Server running at: http://{args.host}:{args.port}') print(f' API documentation: http://{args.host}:{args.port}/docs') - print(f"{'=' * 60}\n") + print(f"{'='*60}\n") uvicorn.run('main:app', host=args.host, port=args.port, reload=args.reload) diff --git a/webui/backend/project_discovery.py b/webui/backend/project_discovery.py index 43f9b9e0c..e30bf639d 100644 --- a/webui/backend/project_discovery.py +++ b/webui/backend/project_discovery.py @@ -3,7 +3,6 @@ Project discovery module for MS-Agent Web UI Discovers and manages available projects from the ms-agent/projects directory. """ - import os import re from typing import Any, Dict, List, Optional @@ -19,7 +18,8 @@ def __init__(self, projects_dir: str): self.projects_dir = projects_dir self._projects_cache: Optional[List[Dict[str, Any]]] = None - def discover_projects(self, force_refresh: bool = False) -> List[Dict[str, Any]]: + def discover_projects(self, + force_refresh: bool = False) -> List[Dict[str, Any]]: """Discover all available projects""" if self._projects_cache is not None and not force_refresh: return self._projects_cache @@ -32,7 +32,8 @@ def discover_projects(self, force_refresh: bool = False) -> List[Dict[str, Any]] for item in os.listdir(self.projects_dir): item_path = os.path.join(self.projects_dir, item) # Only show projects in the whitelist - if os.path.isdir(item_path) and not item.startswith('.') and item in self.VISIBLE_PROJECTS: + if os.path.isdir(item_path) and not item.startswith( + '.') and item in self.VISIBLE_PROJECTS: project_info = self._analyze_project(item, item_path) if project_info: projects.append(project_info) @@ -52,24 +53,24 @@ def _build_virtual_projects(self) -> List[Dict[str, Any]]: researcher_yaml = os.path.join(v2_root, 'researcher.yaml') if os.path.exists(researcher_yaml): readme_path = os.path.join(v2_root, 'README.md') - description = self._extract_description(readme_path) if os.path.exists(readme_path) else '' - projects.append( - { - 'id': 'deep_research_v2', - 'name': 'deep_research_v2', - 'display_name': 'Deep Research', - 'description': description, - 'type': 'agent', - 'path': v2_root, - 'has_readme': os.path.exists(readme_path), - 'config_file': researcher_yaml, - 'supports_workflow_switch': False, - } - ) + description = self._extract_description( + readme_path) if os.path.exists(readme_path) else '' + projects.append({ + 'id': 'deep_research_v2', + 'name': 'deep_research_v2', + 'display_name': 'Deep Research', + 'description': description, + 'type': 'agent', + 'path': v2_root, + 'has_readme': os.path.exists(readme_path), + 'config_file': researcher_yaml, + 'supports_workflow_switch': False + }) return projects - def _analyze_project(self, name: str, path: str) -> Optional[Dict[str, Any]]: + def _analyze_project(self, name: str, + path: str) -> Optional[Dict[str, Any]]: """Analyze a project directory and extract its information""" # Check for workflow.yaml or agent.yaml workflow_file = os.path.join(path, 'workflow.yaml') @@ -94,14 +95,16 @@ def _analyze_project(self, name: str, path: str) -> Optional[Dict[str, Any]]: # Check if project supports workflow switching (e.g., code_genesis) supports_workflow_switch = False - if project_type == 'workflow' and name == 'code_genesis' and os.path.exists(simple_workflow_file): + if project_type == 'workflow' and name == 'code_genesis' and os.path.exists( + simple_workflow_file): supports_workflow_switch = True # Generate display name from directory name display_name = self._format_display_name(name) # Extract description from README if available - description = self._extract_description(readme_file) if os.path.exists(readme_file) else '' + description = self._extract_description(readme_file) if os.path.exists( + readme_file) else '' return { 'id': name, @@ -112,7 +115,7 @@ def _analyze_project(self, name: str, path: str) -> Optional[Dict[str, Any]]: 'path': path, 'has_readme': os.path.exists(readme_file), 'config_file': config_file, - 'supports_workflow_switch': supports_workflow_switch, + 'supports_workflow_switch': supports_workflow_switch } def _format_display_name(self, name: str) -> str: @@ -138,7 +141,8 @@ def _extract_description(self, readme_path: str) -> str: stripped = line.strip() # Skip headers and empty lines at the beginning if not in_description: - if stripped and not stripped.startswith('#') and not stripped.startswith('['): + if stripped and not stripped.startswith( + '#') and not stripped.startswith('['): in_description = True description_lines.append(stripped) else: @@ -184,7 +188,6 @@ def get_project_config(self, project_id: str) -> Optional[Dict[str, Any]]: try: import yaml - with open(project['config_file'], 'r', encoding='utf-8') as f: return yaml.safe_load(f) except Exception: diff --git a/webui/backend/session_manager.py b/webui/backend/session_manager.py index 5852c386a..1ee20587b 100644 --- a/webui/backend/session_manager.py +++ b/webui/backend/session_manager.py @@ -3,7 +3,6 @@ Session management for MS-Agent Web UI Handles session lifecycle and message history. """ - import uuid from datetime import datetime from threading import Lock @@ -20,9 +19,11 @@ def __init__(self): self._dr_event_counters: Dict[str, int] = {} self._lock = Lock() - def create_session( - self, project_id: str, project_name: str, workflow_type: str = 'standard', session_type: str = 'project' - ) -> Dict[str, Any]: + def create_session(self, + project_id: str, + project_name: str, + workflow_type: str = 'standard', + session_type: str = 'project') -> Dict[str, Any]: """Create a new session""" session_id = str(uuid.uuid4()) session = { @@ -35,7 +36,7 @@ def create_session( 'file_progress': None, 'current_step': None, 'workflow_type': workflow_type, # 'standard' or 'simple' - 'session_type': session_type, # 'project' or 'chat' + 'session_type': session_type # 'project' or 'chat' } with self._lock: @@ -77,9 +78,12 @@ def list_sessions(self) -> List[Dict[str, Any]]: """List all sessions""" return list(self._sessions.values()) - def add_message( - self, session_id: str, role: str, content: str, message_type: str = 'text', metadata: Dict[str, Any] = None - ) -> bool: + def add_message(self, + session_id: str, + role: str, + content: str, + message_type: str = 'text', + metadata: Dict[str, Any] = None) -> bool: """Add a message to a session""" if session_id not in self._sessions: return False @@ -90,7 +94,7 @@ def add_message( 'content': content, 'type': message_type, # text, tool_call, tool_result, error, log 'timestamp': datetime.now().isoformat(), - 'metadata': metadata or {}, + 'metadata': metadata or {} } with self._lock: @@ -106,7 +110,8 @@ def get_messages(self, session_id: str) -> Optional[List[Dict[str, Any]]]: return None return self._messages.get(session_id, []) - def add_dr_event(self, session_id: str, event: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def add_dr_event(self, session_id: str, + event: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Add a deep research event for replay.""" if session_id not in self._sessions: return None @@ -118,14 +123,19 @@ def add_dr_event(self, session_id: str, event: Dict[str, Any]) -> Optional[Dict[ self._dr_events.setdefault(session_id, []).append(stored) return stored - def list_dr_events(self, session_id: str, after_id: Optional[int] = None) -> Optional[List[Dict[str, Any]]]: + def list_dr_events( + self, + session_id: str, + after_id: Optional[int] = None) -> Optional[List[Dict[str, Any]]]: """List deep research events for a session.""" if session_id not in self._sessions: return None events = self._dr_events.get(session_id, []) if after_id is None: return list(events) - return [event for event in events if event.get('event_id', 0) > after_id] + return [ + event for event in events if event.get('event_id', 0) > after_id + ] def update_last_message(self, session_id: str, content: str) -> bool: """Update the content of the last message (for streaming)""" @@ -136,7 +146,8 @@ def update_last_message(self, session_id: str, content: str) -> bool: self._messages[session_id][-1]['content'] = content return True - def set_workflow_progress(self, session_id: str, progress: Dict[str, Any]) -> bool: + def set_workflow_progress(self, session_id: str, + progress: Dict[str, Any]) -> bool: """Set workflow progress for a session""" if session_id not in self._sessions: return False @@ -145,7 +156,8 @@ def set_workflow_progress(self, session_id: str, progress: Dict[str, Any]) -> bo self._sessions[session_id]['workflow_progress'] = progress return True - def set_file_progress(self, session_id: str, progress: Dict[str, Any]) -> bool: + def set_file_progress(self, session_id: str, progress: Dict[str, + Any]) -> bool: """Set file writing progress for a session""" if session_id not in self._sessions: return False diff --git a/webui/backend/shared.py b/webui/backend/shared.py index c34b09944..1d3ce9754 100644 --- a/webui/backend/shared.py +++ b/webui/backend/shared.py @@ -3,7 +3,6 @@ Shared instances for backend modules. Ensures api.py and websocket_handler.py use the same manager instances. """ - import os from config_manager import ConfigManager @@ -11,7 +10,8 @@ from session_manager import SessionManager # Initialize paths -BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +BASE_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) PROJECTS_DIR = os.path.join(BASE_DIR, 'projects') # Use ~/.ms_agent/ for configuration storage (privacy-sensitive data) CONFIG_DIR = os.path.expanduser('~/.ms_agent') diff --git a/webui/backend/websocket_handler.py b/webui/backend/websocket_handler.py index fd5915d37..7b2830b17 100644 --- a/webui/backend/websocket_handler.py +++ b/webui/backend/websocket_handler.py @@ -3,18 +3,16 @@ WebSocket handler for real-time communication Handles agent execution, log streaming, and progress updates. """ - import asyncio -import json import os from datetime import datetime from pathlib import Path from typing import Any, Dict, Set +import json from agent_runner import AgentRunner from deep_research_worker_manager import DeepResearchWorkerManager from fastapi import APIRouter, WebSocket, WebSocketDisconnect - # Import shared instances from shared import config_manager, project_discovery, session_manager @@ -77,7 +75,8 @@ async def broadcast_log(self, log_entry: Dict[str, Any]): agent_tasks: Dict[str, asyncio.Task] = {} -async def _update_deep_research_status(session_id: str, event: Dict[str, Any]) -> None: +async def _update_deep_research_status(session_id: str, + event: Dict[str, Any]) -> None: event_type = event.get('type') if event_type == 'status': status = event.get('status') @@ -99,7 +98,8 @@ async def _update_deep_research_status(session_id: str, event: Dict[str, Any]) - session_manager.update_session(session_id, {'status': 'error'}) -async def _send_deep_research_event(session_id: str, event: Dict[str, Any]) -> None: +async def _send_deep_research_event(session_id: str, event: Dict[str, + Any]) -> None: event_type = str(event.get('type') or '') stored_event = event if event_type.startswith('dr.'): @@ -127,7 +127,8 @@ async def websocket_session(websocket: WebSocket, session_id: str): print(f'[WS] Client disconnected from session: {session_id}') connection_manager.disconnect(websocket, session_id) session = session_manager.get_session(session_id) - is_deep_research = bool(session and session.get('project_id') == 'deep_research_v2') + is_deep_research = bool( + session and session.get('project_id') == 'deep_research_v2') # Stop agent if running if session_id in agent_runners: await agent_runners[session_id].stop() @@ -152,7 +153,8 @@ async def websocket_logs(websocket: WebSocket): connection_manager.disconnect(websocket) -async def handle_session_message(session_id: str, data: Dict[str, Any], websocket: WebSocket): +async def handle_session_message(session_id: str, data: Dict[str, Any], + websocket: WebSocket): """Handle incoming WebSocket messages""" action = data.get('action') @@ -166,14 +168,18 @@ async def handle_session_message(session_id: str, data: Dict[str, Any], websocke await send_status(session_id, websocket) -async def start_agent(session_id: str, data: Dict[str, Any], websocket: WebSocket): +async def start_agent(session_id: str, data: Dict[str, Any], + websocket: WebSocket): """Start an agent for a session""" print(f'[Agent] Starting agent for session: {session_id}') session = session_manager.get_session(session_id) if not session: print(f'[Agent] ERROR: Session not found: {session_id}') - await websocket.send_json({'type': 'error', 'message': 'Session not found'}) + await websocket.send_json({ + 'type': 'error', + 'message': 'Session not found' + }) return session_type = session.get('session_type', 'project') @@ -182,7 +188,6 @@ async def start_agent(session_id: str, data: Dict[str, Any], websocket: WebSocke if session_type == 'chat': # Create a virtual project for chat mode using the default agent.yaml import ms_agent - # Get ms_agent package installation path # Use __path__ which is always available for packages and gives real filesystem paths if hasattr(ms_agent, '__path__') and ms_agent.__path__: @@ -190,7 +195,8 @@ async def start_agent(session_id: str, data: Dict[str, Any], websocket: WebSocke elif ms_agent.__file__ is not None: ms_agent_package_path = Path(ms_agent.__file__).parent else: - raise RuntimeError('Cannot determine ms_agent package path. Please ensure ms_agent is properly installed.') + raise RuntimeError('Cannot determine ms_agent package path. ' + 'Please ensure ms_agent is properly installed.') chat_config_path = ms_agent_package_path / 'agent' / 'agent.yaml' project = { @@ -202,14 +208,17 @@ async def start_agent(session_id: str, data: Dict[str, Any], websocket: WebSocke 'path': str(ms_agent_package_path / 'agent'), 'config_file': str(chat_config_path), 'has_readme': False, - 'supports_workflow_switch': False, + 'supports_workflow_switch': False } else: # For project mode, get the project project = project_discovery.get_project(session['project_id']) if not project: print(f"[Agent] ERROR: Project not found: {session['project_id']}") - await websocket.send_json({'type': 'error', 'message': 'Project not found'}) + await websocket.send_json({ + 'type': 'error', + 'message': 'Project not found' + }) return # Clean up output directory for code_genesis before starting @@ -218,32 +227,30 @@ async def start_agent(session_id: str, data: Dict[str, Any], websocket: WebSocke if os.path.exists(output_dir): try: import shutil - shutil.rmtree(output_dir) print(f'[Agent] Cleaned up output directory: {output_dir}') await connection_manager.send_to_session( - session_id, - { + session_id, { 'type': 'log', 'level': 'info', 'message': 'Cleaned up previous output directory', - 'timestamp': datetime.now().isoformat(), - }, - ) + 'timestamp': datetime.now().isoformat() + }) except Exception as e: - print(f'[Agent] WARNING: Failed to clean output directory: {e}') + print( + f'[Agent] WARNING: Failed to clean output directory: {e}' + ) # Don't fail if cleanup fails, just log it # Get workflow_type from session (default to 'standard') workflow_type = session.get('workflow_type', 'standard') - print( - f"[Agent] Project: {project['id']}, type: {project['type']}, " - f"config: {project['config_file']}, workflow_type: {workflow_type}" - ) + print(f"[Agent] Project: {project['id']}, type: {project['type']}, " + f"config: {project['config_file']}, workflow_type: {workflow_type}") query = data.get('query', '') - print(f'[Agent] Query: {query[:100]}...' if len(query) > 100 else f'[Agent] Query: {query}') + print(f'[Agent] Query: {query[:100]}...' + if len(query) > 100 else f'[Agent] Query: {query}') # Add user message to session (but don't broadcast - frontend already has it) session_manager.add_message(session_id, 'user', query, 'text') @@ -262,11 +269,16 @@ async def start_agent(session_id: str, data: Dict[str, Any], websocket: WebSocke deep_research_config=config_manager.get_deep_research_config(), ) session_manager.update_session(session_id, {'status': 'running'}) - await connection_manager.send_to_session(session_id, {'type': 'status', 'status': 'running'}) + await connection_manager.send_to_session(session_id, { + 'type': 'status', + 'status': 'running' + }) except Exception as e: await connection_manager.send_to_session( - session_id, {'type': 'error', 'message': f'Worker 启动失败: {str(e)}'} - ) + session_id, { + 'type': 'error', + 'message': f'Worker 启动失败: {str(e)}' + }) session_manager.update_session(session_id, {'status': 'error'}) return @@ -275,19 +287,25 @@ async def start_agent(session_id: str, data: Dict[str, Any], websocket: WebSocke session_id=session_id, project=project, config_manager=config_manager, - on_output=lambda msg: asyncio.create_task(on_agent_output(session_id, msg)), + on_output=lambda msg: asyncio.create_task( + on_agent_output(session_id, msg)), on_log=lambda log: asyncio.create_task(on_agent_log(session_id, log)), - on_progress=lambda prog: asyncio.create_task(on_agent_progress(session_id, prog)), - on_complete=lambda result: asyncio.create_task(on_agent_complete(session_id, result)), - on_error=lambda err: asyncio.create_task(on_agent_error(session_id, err)), - workflow_type=workflow_type, - ) + on_progress=lambda prog: asyncio.create_task( + on_agent_progress(session_id, prog)), + on_complete=lambda result: asyncio.create_task( + on_agent_complete(session_id, result)), + on_error=lambda err: asyncio.create_task( + on_agent_error(session_id, err)), + workflow_type=workflow_type) agent_runners[session_id] = runner session_manager.update_session(session_id, {'status': 'running'}) # Notify session started - await connection_manager.send_to_session(session_id, {'type': 'status', 'status': 'running'}) + await connection_manager.send_to_session(session_id, { + 'type': 'status', + 'status': 'running' + }) # Start agent in background so the WS loop can still receive stop/input messages task = asyncio.create_task(runner.start(query)) @@ -310,7 +328,10 @@ async def stop_agent(session_id: str): del agent_tasks[session_id] session_manager.update_session(session_id, {'status': 'stopped'}) - await connection_manager.send_to_session(session_id, {'type': 'status', 'status': 'stopped'}) + await connection_manager.send_to_session(session_id, { + 'type': 'status', + 'status': 'stopped' + }) async def send_input(session_id: str, data: Dict[str, Any]): @@ -318,15 +339,13 @@ async def send_input(session_id: str, data: Dict[str, Any]): if session_id not in agent_runners: print(f'[WS] ERROR: Agent runner not found for session: {session_id}') await connection_manager.send_to_session( - session_id, - { - 'type': 'error', - 'message': ( - 'Agent is not running. The workflow may have completed. ' - 'Please start a new conversation or restart the agent.' - ), - }, - ) + session_id, { + 'type': + 'error', + 'message': + ('Agent is not running. The workflow may have completed. ' + 'Please start a new conversation or restart the agent.') + }) return input_text = data.get('input', '') @@ -335,21 +354,26 @@ async def send_input(session_id: str, data: Dict[str, Any]): # Check if process is still alive runner = agent_runners[session_id] if runner.process and runner.process.returncode is not None: - print(f'[WS] ERROR: Process has exited with code {runner.process.returncode}') - await connection_manager.send_to_session( - session_id, - { - 'type': 'error', - 'message': 'Agent process has terminated. The workflow completed. Please start a new conversation to continue.', - }, + print( + f'[WS] ERROR: Process has exited with code {runner.process.returncode}' ) + await connection_manager.send_to_session( + session_id, { + 'type': + 'error', + 'message': + 'Agent process has terminated. The workflow completed. Please start a new conversation to continue.' + }) # Clean up the runner del agent_runners[session_id] return # Update session status to running session_manager.update_session(session_id, {'status': 'running'}) - await connection_manager.send_to_session(session_id, {'type': 'status', 'status': 'running'}) + await connection_manager.send_to_session(session_id, { + 'type': 'status', + 'status': 'running' + }) # Add user message to session session_manager.add_message(session_id, 'user', input_text, 'text') @@ -360,18 +384,26 @@ async def send_input(session_id: str, data: Dict[str, Any]): except Exception as e: print(f'[WS] ERROR: Failed to send input: {e}') await connection_manager.send_to_session( - session_id, - {'type': 'error', 'message': f'Failed to send input: {str(e)}. The process may have terminated.'}, - ) + session_id, { + 'type': + 'error', + 'message': + f'Failed to send input: {str(e)}. The process may have terminated.' + }) async def send_status(session_id: str, websocket: WebSocket): """Send current status to a client""" session = session_manager.get_session(session_id) if session: - await websocket.send_json( - {'type': 'status', 'session': session, 'messages': session_manager.get_messages(session_id)} - ) + await websocket.send_json({ + 'type': + 'status', + 'session': + session, + 'messages': + session_manager.get_messages(session_id) + }) async def on_agent_output(session_id: str, message: Dict[str, Any]): @@ -383,27 +415,32 @@ async def on_agent_output(session_id: str, message: Dict[str, Any]): if msg_type == 'stream': # Streaming update await connection_manager.send_to_session( - session_id, {'type': 'stream', 'content': content, 'done': message.get('done', False)} - ) + session_id, { + 'type': 'stream', + 'content': content, + 'done': message.get('done', False) + }) if message.get('done'): session_manager.add_message(session_id, role, content, 'text') else: - session_manager.add_message(session_id, role, content, msg_type, message.get('metadata')) + session_manager.add_message(session_id, role, content, msg_type, + message.get('metadata')) await connection_manager.send_to_session( - session_id, - { + session_id, { 'type': 'message', 'role': role, 'content': content, 'message_type': msg_type, - 'metadata': message.get('metadata'), - }, - ) + 'metadata': message.get('metadata') + }) async def on_agent_log(session_id: str, log: Dict[str, Any]): """Handle agent log""" - await connection_manager.send_to_session(session_id, {'type': 'log', **log}) + await connection_manager.send_to_session(session_id, { + 'type': 'log', + **log + }) await connection_manager.broadcast_log({'session_id': session_id, **log}) @@ -413,11 +450,15 @@ async def on_agent_progress(session_id: str, progress: Dict[str, Any]): if progress_type == 'workflow': session_manager.set_workflow_progress(session_id, progress) - session_manager.set_current_step(session_id, progress.get('current_step')) + session_manager.set_current_step(session_id, + progress.get('current_step')) elif progress_type == 'file': session_manager.set_file_progress(session_id, progress) - await connection_manager.send_to_session(session_id, {'type': 'progress', **progress}) + await connection_manager.send_to_session(session_id, { + 'type': 'progress', + **progress + }) async def on_agent_complete(session_id: str, result: Dict[str, Any]): @@ -430,13 +471,17 @@ async def on_agent_complete(session_id: str, result: Dict[str, Any]): agent_tasks[session_id].cancel() del agent_tasks[session_id] - await connection_manager.send_to_session(session_id, {'type': 'complete', 'result': result}) + await connection_manager.send_to_session(session_id, { + 'type': 'complete', + 'result': result + }) async def on_agent_error(session_id: str, error: Dict[str, Any]): """Handle agent error""" session_manager.update_session(session_id, {'status': 'error'}) - session_manager.add_message(session_id, 'system', error.get('message', 'Unknown error'), 'error') + session_manager.add_message(session_id, 'system', + error.get('message', 'Unknown error'), 'error') if session_id in agent_runners: del agent_runners[session_id] @@ -444,4 +489,7 @@ async def on_agent_error(session_id: str, error: Dict[str, Any]): agent_tasks[session_id].cancel() del agent_tasks[session_id] - await connection_manager.send_to_session(session_id, {'type': 'error', **error}) + await connection_manager.send_to_session(session_id, { + 'type': 'error', + **error + }) From 34d5e92017d3ecef1661928de9bfa44d67fe464a Mon Sep 17 00:00:00 2001 From: suluyan Date: Tue, 28 Apr 2026 16:15:21 +0800 Subject: [PATCH 40/40] fix lint --- .gitignore | 1 + .pre-commit-config.yaml | 13 +- ms-agent-skills/scripts/check_ms_agent.py | 3 +- ms_agent/agent/base.py | 2 +- ms_agent/agent/code_agent.py | 3 +- ms_agent/agent/llm_agent.py | 27 +- ms_agent/agent/loader.py | 5 +- ms_agent/callbacks/base.py | 2 +- ms_agent/callbacks/input_callback.py | 2 +- ms_agent/capabilities/__init__.py | 11 +- ms_agent/capabilities/mcp_server.py | 4 +- ms_agent/cli/app.py | 6 +- ms_agent/cli/run.py | 7 +- ms_agent/config/config.py | 7 +- ms_agent/config/env.py | 3 +- ms_agent/llm/anthropic_llm.py | 6 +- ms_agent/llm/dashscope_llm.py | 2 +- ms_agent/llm/deepseek_llm.py | 2 +- ms_agent/llm/llm.py | 5 +- ms_agent/llm/modelscope_llm.py | 3 +- ms_agent/llm/openai.py | 6 +- ms_agent/llm/openai_llm.py | 10 +- ms_agent/llm/utils.py | 12 +- ms_agent/memory/base.py | 2 +- ms_agent/memory/condenser/code_condenser.py | 2 +- .../memory/condenser/context_compressor.py | 2 +- ms_agent/memory/condenser/refine_condenser.py | 2 +- ms_agent/memory/default_memory.py | 6 +- ms_agent/memory/diversity.py | 3 +- ms_agent/memory/memory_manager.py | 2 +- ms_agent/prompting/file_resolver.py | 3 +- ms_agent/rag/extraction.py | 4 +- ms_agent/rag/llama_index_rag.py | 20 +- ms_agent/retriever/hybrid_retriever.py | 4 +- ms_agent/sandbox/sandbox.py | 2 +- ms_agent/skill/loader.py | 1 - ms_agent/skill/schema.py | 3 +- ms_agent/tools/agent_tool.py | 90 +++-- ms_agent/tools/base.py | 2 +- ms_agent/tools/code/code_executor.py | 7 +- ms_agent/tools/code/local_code_executor.py | 45 ++- ms_agent/tools/code/sandbox_manager.py | 5 +- ms_agent/tools/code_server/lsp_code_server.py | 2 +- ms_agent/tools/docling/chunker.py | 5 +- ms_agent/tools/docling/doc_loader.py | 6 +- ms_agent/tools/docling/doc_postprocess.py | 3 +- ms_agent/tools/docling/patches.py | 4 +- ms_agent/tools/fetch_playwright_fallback.py | 17 +- ms_agent/tools/filesystem_tool.py | 358 +++++++++++------- ms_agent/tools/findata/akshare_source.py | 2 +- ms_agent/tools/findata/baostock_source.py | 2 +- ms_agent/tools/findata/data_source_base.py | 3 +- ms_agent/tools/findata/findata_fetcher.py | 8 +- ms_agent/tools/findata/hybrid_source.py | 2 +- .../tools/image_generator/ds_image_gen.py | 1 - .../tools/image_generator/ms_image_gen.py | 3 +- ms_agent/tools/jina_reader.py | 17 +- ms_agent/tools/mcp_client.py | 11 +- ms_agent/tools/mineru/pdf_parser.py | 1 - ms_agent/tools/search/arxiv/schema.py | 6 +- ms_agent/tools/search/arxiv/search.py | 4 +- ms_agent/tools/search/content_optimizer.py | 4 +- ms_agent/tools/search/exa/schema.py | 5 +- ms_agent/tools/search/exa/search.py | 2 +- ms_agent/tools/search/localsearch_tool.py | 54 +-- ms_agent/tools/search/search_base.py | 3 +- ms_agent/tools/search/sirchmunk_search.py | 22 +- ms_agent/tools/search/tavily/fetcher.py | 7 +- ms_agent/tools/search/tavily/http.py | 3 +- ms_agent/tools/search/tavily/schema.py | 3 +- ms_agent/tools/search/tavily/search.py | 57 ++- ms_agent/tools/search/web_search_spill.py | 30 +- ms_agent/tools/search/websearch_tool.py | 78 ++-- ms_agent/tools/search_engine.py | 2 +- ms_agent/tools/task_control_tool.py | 18 +- ms_agent/tools/todolist_tool.py | 2 +- ms_agent/tools/tool_manager.py | 21 +- ms_agent/utils/artifact_manager.py | 25 +- ms_agent/utils/parser_utils.py | 3 +- ms_agent/utils/push_to_hub.py | 7 +- ms_agent/utils/snapshot.py | 45 ++- ms_agent/utils/stats.py | 3 +- ms_agent/utils/stream_writer.py | 10 +- ms_agent/utils/task_manager.py | 32 +- ms_agent/utils/thread_util.py | 5 +- ms_agent/utils/utils.py | 11 +- ms_agent/utils/workspace_policy.py | 8 +- ms_agent/workflow/base.py | 2 +- ms_agent/workflow/chain_workflow.py | 2 +- ms_agent/workflow/dag_workflow.py | 2 +- .../workflow/deep_research/research_utils.py | 3 +- .../deep_research/research_workflow.py | 8 +- .../deep_research/research_workflow_beta.py | 4 +- ms_agent/workflow/loader.py | 2 +- projects/code_genesis/workflow/api_search.py | 2 +- projects/code_genesis/workflow/coding.py | 4 +- projects/code_genesis/workflow/file_design.py | 2 +- projects/code_genesis/workflow/file_order.py | 2 +- projects/code_genesis/workflow/refine.py | 6 +- .../v2/callbacks/quality_checker.py | 4 +- .../v2/callbacks/reporter_callback.py | 8 +- .../v2/callbacks/researcher_callback.py | 6 +- .../v2/callbacks/searcher_callback.py | 15 +- .../deep_research/v2/eval/dr_bench_runner.py | 4 +- projects/deep_research/v2/reporter.py | 2 +- projects/deep_research/v2/researcher.py | 2 +- projects/deep_research/v2/time_handler.py | 2 +- .../deep_research/v2/tools/evidence_tool.py | 17 +- .../deep_research/v2/tools/report_tool.py | 86 ++--- projects/fin_research/aggregator.py | 6 +- .../callbacks/aggregator_callback.py | 2 +- .../callbacks/analyst_callback.py | 4 +- .../callbacks/collector_callback.py | 4 +- .../callbacks/orchestrator_callback.py | 4 +- projects/fin_research/searcher.py | 6 +- projects/fin_research/time_handler.py | 2 +- .../fin_research/tools/principle_skill.py | 2 +- projects/fin_research/tools/spec_loader.py | 6 +- .../singularity_cinema/compose_video/agent.py | 10 +- .../create_background/agent.py | 6 +- .../generate_animation/agent.py | 2 +- .../generate_animation/generate_manim_code.py | 6 +- .../generate_remotion_code.py | 6 +- .../generate_audio/agent.py | 10 +- .../generate_illustration_prompts/agent.py | 4 +- .../generate_images/agent.py | 10 +- .../generate_script/agent.py | 2 +- .../generate_subtitle/agent.py | 8 +- .../generate_video/agent.py | 6 +- .../generate_video_prompts/agent.py | 4 +- .../singularity_cinema/parse_images/agent.py | 6 +- .../render_animation/agent.py | 2 +- .../render_animation/render_manim.py | 8 +- .../render_animation/render_remotion.py | 8 +- projects/singularity_cinema/segment/agent.py | 4 +- setup.py | 5 +- shell-grep-glob-workspace-policy.md | 225 ----------- webui/backend/agent_runner.py | 3 +- webui/backend/api.py | 5 +- webui/backend/config_manager.py | 3 +- webui/backend/deep_research_eventizer.py | 2 +- webui/backend/deep_research_worker.py | 6 +- webui/backend/deep_research_worker_manager.py | 3 +- webui/backend/main.py | 1 - webui/backend/shared.py | 1 - webui/backend/websocket_handler.py | 10 +- 146 files changed, 878 insertions(+), 971 deletions(-) delete mode 100644 shell-grep-glob-workspace-policy.md diff --git a/.gitignore b/.gitignore index 30dfa8d1f..58fd44f05 100644 --- a/.gitignore +++ b/.gitignore @@ -150,6 +150,7 @@ apps/agentfabric/config/local_user/* *.pt .run/ .run/* +.claude* # ast template ast_index_file.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 00657312b..64a6fc0e1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,24 +1,23 @@ - - - +default_language_version: + python: python3.11 repos: - repo: https://github.com/pycqa/flake8.git - rev: 4.0.0 + rev: 7.3.0 hooks: - id: flake8 exclude: ^(thirdparty/|examples/|tests/|projects/agent_skills/|projects/fin_research/examples/|ms_agent/utils/prompts\.py) - repo: https://github.com/PyCQA/isort.git - rev: 4.3.21 + rev: 5.13.2 hooks: - id: isort exclude: ^(examples/|projects/fin_research/examples/|projects/agent_skills/|tests/) - repo: https://github.com/pre-commit/mirrors-yapf.git - rev: v0.30.0 + rev: v0.32.0 hooks: - id: yapf exclude: ^(thirdparty/|examples/|projects/fin_research/examples/|projects/agent_skills/|tests/) - repo: https://github.com/pre-commit/pre-commit-hooks.git - rev: v3.1.0 + rev: v5.0.0 hooks: - id: trailing-whitespace exclude: ^(thirdparty/|tests/|projects/fin_research/examples/|projects/agent_skills/) diff --git a/ms-agent-skills/scripts/check_ms_agent.py b/ms-agent-skills/scripts/check_ms_agent.py index 64e4668fa..59f56c3f7 100644 --- a/ms-agent-skills/scripts/check_ms_agent.py +++ b/ms-agent-skills/scripts/check_ms_agent.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import subprocess import sys -import json - def check_import() -> dict: """Check that ms_agent is importable.""" diff --git a/ms_agent/agent/base.py b/ms_agent/agent/base.py index cb78d5ce2..ba568d04c 100644 --- a/ms_agent/agent/base.py +++ b/ms_agent/agent/base.py @@ -1,12 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from abc import ABC, abstractmethod +from omegaconf import DictConfig from typing import Any, AsyncGenerator, List, Tuple, Union from ms_agent.llm import Message from ms_agent.utils import read_history, save_history from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_RETRY_COUNT -from omegaconf import DictConfig class Agent(ABC): diff --git a/ms_agent/agent/code_agent.py b/ms_agent/agent/code_agent.py index 33b2b63a4..63d378867 100644 --- a/ms_agent/agent/code_agent.py +++ b/ms_agent/agent/code_agent.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import Any, List, Union from ms_agent.llm import Message -from omegaconf import DictConfig - from .base import Agent diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 14adfa8b3..cdca437e0 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -2,15 +2,16 @@ import asyncio import importlib import inspect +import json import os.path import sys import threading import uuid from contextlib import contextmanager from copy import deepcopy +from omegaconf import DictConfig, OmegaConf from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union -import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping from ms_agent.llm.llm import LLM @@ -21,12 +22,10 @@ from ms_agent.rag.utils import rag_mapping from ms_agent.tools import ToolManager from ms_agent.utils import async_retry, read_history, save_history -from ms_agent.utils.task_manager import TaskManager from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger from ms_agent.utils.snapshot import take_snapshot -from omegaconf import DictConfig, OmegaConf - +from ms_agent.utils.task_manager import TaskManager from ..config.config import Config, ConfigLifecycleHandler from .base import Agent @@ -103,15 +102,16 @@ def resolve_enable_snapshots(config: Any) -> bool: like ``\"false\"`` coerced to boolean). """ if OmegaConf.is_config(config): - raw = OmegaConf.select(config, 'enable_snapshots', - default=_MISSING_ENABLE_SNAPSHOTS) + raw = OmegaConf.select( + config, 'enable_snapshots', default=_MISSING_ENABLE_SNAPSHOTS) if raw is not _MISSING_ENABLE_SNAPSHOTS and raw is not None: return LLMAgent._coerce_enable_snapshots_value(raw) sub = bool( OmegaConf.select(config, 'ms_agent_subagent', default=False)) return not sub if isinstance(config, dict): - if 'enable_snapshots' in config and config['enable_snapshots'] is not None: + if 'enable_snapshots' in config and config[ + 'enable_snapshots'] is not None: return LLMAgent._coerce_enable_snapshots_value( config['enable_snapshots']) return not bool(config.get('ms_agent_subagent')) @@ -195,7 +195,8 @@ def _ensure_auto_skills(self) -> bool: # Check sandbox requirements use_sandbox = getattr(skills_config, 'use_sandbox', True) if use_sandbox: - from ms_agent.utils.docker_utils import is_docker_daemon_running + from ms_agent.utils.docker_utils import \ + is_docker_daemon_running if not is_docker_daemon_running(): logger.warning( @@ -536,7 +537,8 @@ async def on_task_begin(self, messages: List[Message]): ) take_snapshot( self.output_dir, - f'[pre] {_user_content}' if _user_content else '[pre] new task', + f'[pre] {_user_content}' + if _user_content else '[pre] new task', message_count=len(messages), ) await self.loop_callback('on_task_begin', messages) @@ -898,7 +900,8 @@ def handle_new_response(self, messages: List[Message], and response_message.tool_calls): messages[-1].content = 'Let me do a tool calling.' - def _append_task_notifications(self, messages: List[Message]) -> List[Message]: + def _append_task_notifications(self, + messages: List[Message]) -> List[Message]: """Inject drained TaskManager completion notices as a user message.""" if self.task_manager is None: return messages @@ -1237,7 +1240,9 @@ async def run_loop(self, messages: Union[List[Message], str], if self.task_manager is not None: notifications = self.task_manager.drain_notifications() if notifications: - messages.append(Message(role='user', content='\n'.join(notifications))) + messages.append( + Message( + role='user', content='\n'.join(notifications))) async for messages in self.step(messages): yield messages self.runtime.round += 1 diff --git a/ms_agent/agent/loader.py b/ms_agent/agent/loader.py index 21b1687a7..48ed74d91 100644 --- a/ms_agent/agent/loader.py +++ b/ms_agent/agent/loader.py @@ -3,12 +3,11 @@ import inspect import os import sys +from omegaconf import DictConfig, OmegaConf from typing import Dict, Optional from ms_agent.config.config import Config from ms_agent.utils.constants import DEFAULT_AGENT_FILE, DEFAULT_TAG -from omegaconf import DictConfig, OmegaConf - from .base import Agent @@ -44,8 +43,8 @@ def build(cls, None) is None and config_dir_or_id is not None: agent_config.local_dir = config_dir_or_id - from .llm_agent import LLMAgent from .code_agent import CodeAgent + from .llm_agent import LLMAgent agent_type = LLMAgent.AGENT_NAME if 'code_file' in kwargs: code_file = kwargs.pop('code_file') diff --git a/ms_agent/callbacks/base.py b/ms_agent/callbacks/base.py index 849fe7069..e8b5f0740 100644 --- a/ms_agent/callbacks/base.py +++ b/ms_agent/callbacks/base.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import List from ms_agent.agent.runtime import Runtime from ms_agent.llm.utils import Message -from omegaconf import DictConfig class Callback: diff --git a/ms_agent/callbacks/input_callback.py b/ms_agent/callbacks/input_callback.py index e44db1e31..4017681fd 100644 --- a/ms_agent/callbacks/input_callback.py +++ b/ms_agent/callbacks/input_callback.py @@ -1,11 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import List from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/ms_agent/capabilities/__init__.py b/ms_agent/capabilities/__init__.py index 73feb864e..c5717be55 100644 --- a/ms_agent/capabilities/__init__.py +++ b/ms_agent/capabilities/__init__.py @@ -23,6 +23,7 @@ """ from __future__ import annotations + from typing import Any from ms_agent.capabilities.descriptor import CapabilityDescriptor @@ -39,13 +40,9 @@ def create_registry(config: Any = None) -> CapabilityRegistry: """ registry = CapabilityRegistry() - from ms_agent.capabilities.wrappers import ( - agent_delegate, - deep_research, - filesystem, - lsp_code_server, - web_search, - ) + from ms_agent.capabilities.wrappers import (agent_delegate, deep_research, + filesystem, lsp_code_server, + web_search) filesystem.register_all(registry, config) lsp_code_server.register_all(registry, config) diff --git a/ms_agent/capabilities/mcp_server.py b/ms_agent/capabilities/mcp_server.py index 650dc33cf..84e201f1e 100644 --- a/ms_agent/capabilities/mcp_server.py +++ b/ms_agent/capabilities/mcp_server.py @@ -1,11 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import argparse +import json import logging import os import sys - -import json from dotenv import find_dotenv, load_dotenv + from ms_agent.capabilities import create_registry logger = logging.getLogger(__name__) diff --git a/ms_agent/cli/app.py b/ms_agent/cli/app.py index d9a690d7d..e9d94d281 100644 --- a/ms_agent/cli/app.py +++ b/ms_agent/cli/app.py @@ -55,13 +55,15 @@ def define_args(parsers: argparse.ArgumentParser): def execute(self): if self.args.app_type == 'doc_research': - from ms_agent.app.doc_research import launch_server as launch_doc_research + from ms_agent.app.doc_research import \ + launch_server as launch_doc_research launch_doc_research( server_name=self.args.server_name, server_port=self.args.server_port, share=self.args.share) elif self.args.app_type == 'fin_research': - from ms_agent.app.fin_research import launch_server as launch_fin_research + from ms_agent.app.fin_research import \ + launch_server as launch_fin_research launch_fin_research( server_name=self.args.server_name, server_port=self.args.server_port, diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index 2accdb40e..1bb1d67ad 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -3,13 +3,12 @@ import asyncio import os from importlib import resources as importlib_resources +from omegaconf import OmegaConf from ms_agent.config import Config from ms_agent.config.env import Env from ms_agent.utils import get_logger, strtobool from ms_agent.utils.constants import AGENT_CONFIG_FILE, MS_AGENT_ASCII -from omegaconf import OmegaConf - from .base import CLICommand logger = get_logger() @@ -152,9 +151,7 @@ def define_args(parsers: argparse.ArgumentParser): required=False, type=str, default=None, - help= - 'Comma-separated list of paths for knowledge search.' - ) + help='Comma-separated list of paths for knowledge search.') parser.set_defaults(func=subparser_func) def execute(self): diff --git a/ms_agent/config/config.py b/ms_agent/config/config.py index 2f6175524..c1f1ab85c 100644 --- a/ms_agent/config/config.py +++ b/ms_agent/config/config.py @@ -3,14 +3,13 @@ import os.path from abc import abstractmethod from copy import deepcopy -from typing import Any, Dict, Union - -from ms_agent.prompting import apply_prompt_files -from ms_agent.utils import get_logger from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf.basecontainer import BaseContainer +from typing import Any, Dict, Union from modelscope import snapshot_download +from ms_agent.prompting import apply_prompt_files +from ms_agent.utils import get_logger from ..utils.constants import TOOL_PLUGIN_NAME from .env import Env diff --git a/ms_agent/config/env.py b/ms_agent/config/env.py index 83553254c..39a03e9ef 100644 --- a/ms_agent/config/env.py +++ b/ms_agent/config/env.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from copy import copy -from typing import Dict, Optional - from dotenv import load_dotenv +from typing import Dict, Optional class Env: diff --git a/ms_agent/llm/anthropic_llm.py b/ms_agent/llm/anthropic_llm.py index 5b35bfb5d..e229b5140 100644 --- a/ms_agent/llm/anthropic_llm.py +++ b/ms_agent/llm/anthropic_llm.py @@ -1,13 +1,13 @@ +import httpx import inspect +import json +from omegaconf import DictConfig, OmegaConf from typing import Any, Dict, Generator, Iterator, List, Optional, Union -import httpx -import json from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool, ToolCall from ms_agent.utils import assert_package_exist, retry from ms_agent.utils.constants import get_service_config -from omegaconf import DictConfig, OmegaConf class _SSEEventInjector(httpx.SyncByteStream): diff --git a/ms_agent/llm/dashscope_llm.py b/ms_agent/llm/dashscope_llm.py index b4a6ddaa8..00239286e 100644 --- a/ms_agent/llm/dashscope_llm.py +++ b/ms_agent/llm/dashscope_llm.py @@ -1,10 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import List from ms_agent.llm.openai_llm import OpenAI from ms_agent.llm.utils import Message, Tool from ms_agent.utils.constants import get_service_config -from omegaconf import DictConfig class DashScope(OpenAI): diff --git a/ms_agent/llm/deepseek_llm.py b/ms_agent/llm/deepseek_llm.py index e565308bc..3bc6a59d0 100644 --- a/ms_agent/llm/deepseek_llm.py +++ b/ms_agent/llm/deepseek_llm.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import List from ms_agent.llm.openai_llm import OpenAI from ms_agent.llm.utils import Message, Tool -from omegaconf import DictConfig class DeepSeek(OpenAI): diff --git a/ms_agent/llm/llm.py b/ms_agent/llm/llm.py index 72af53467..803434f87 100644 --- a/ms_agent/llm/llm.py +++ b/ms_agent/llm/llm.py @@ -1,11 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from abc import abstractmethod +from omegaconf import DictConfig from typing import Any, Dict, List, Optional from ms_agent.config import Config -from omegaconf import DictConfig - from ..utils.constants import DEFAULT_RETRY_COUNT from .utils import Message, Tool @@ -69,7 +68,7 @@ def from_config(cls, config: DictConfig) -> Any: Returns: The LLM instance. """ - from .model_mapping import all_services_mapping, OpenAI + from .model_mapping import OpenAI, all_services_mapping if config.llm.get('service') in all_services_mapping: return all_services_mapping[config.llm.service](config) else: diff --git a/ms_agent/llm/modelscope_llm.py b/ms_agent/llm/modelscope_llm.py index 7b761c5c0..bd1b54a56 100644 --- a/ms_agent/llm/modelscope_llm.py +++ b/ms_agent/llm/modelscope_llm.py @@ -1,7 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig + from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils.constants import get_service_config -from omegaconf import DictConfig class ModelScope(OpenAI): diff --git a/ms_agent/llm/openai.py b/ms_agent/llm/openai.py index 390f1d4ab..cd902119f 100644 --- a/ms_agent/llm/openai.py +++ b/ms_agent/llm/openai.py @@ -1,11 +1,11 @@ # flake8: noqa +import json import uuid +from openai import OpenAI, Stream +from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall from typing import TYPE_CHECKING, Any, Dict, List, Literal -import json from ms_agent.utils.logger import get_logger -from openai import OpenAI, Stream -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall logger = get_logger() diff --git a/ms_agent/llm/openai_llm.py b/ms_agent/llm/openai_llm.py index fa2df6004..86969210f 100644 --- a/ms_agent/llm/openai_llm.py +++ b/ms_agent/llm/openai_llm.py @@ -1,18 +1,18 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import httpx import inspect +import json from copy import deepcopy +from omegaconf import DictConfig, OmegaConf +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, Function) from typing import Any, Dict, Generator, Iterable, List, Optional -import httpx -import json from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool, ToolCall from ms_agent.utils import (MAX_CONTINUE_RUNS, assert_package_exist, get_logger, retry) from ms_agent.utils.constants import get_service_config -from omegaconf import DictConfig, OmegaConf -from openai.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, Function) logger = get_logger() diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 4ae5833bf..12f66af24 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -1,8 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Optional, Union - -import json from typing_extensions import Literal, Required, TypedDict @@ -91,8 +90,13 @@ def to_dict_clean(self): required = ['content', 'role'] # Never send UI-only fields to model providers. rm = [ - 'completion_tokens', 'prompt_tokens', 'api_calls', 'tool_detail', - 'searching_detail', 'search_result', '_responses_output_items', + 'completion_tokens', + 'prompt_tokens', + 'api_calls', + 'tool_detail', + 'searching_detail', + 'search_result', + '_responses_output_items', ] return { key: value diff --git a/ms_agent/memory/base.py b/ms_agent/memory/base.py index fde42fb56..72baec529 100644 --- a/ms_agent/memory/base.py +++ b/ms_agent/memory/base.py @@ -1,10 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from abc import ABC, abstractmethod +from omegaconf import DictConfig from typing import List from ms_agent.llm.utils import Message from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR -from omegaconf import DictConfig class Memory(ABC): diff --git a/ms_agent/memory/condenser/code_condenser.py b/ms_agent/memory/condenser/code_condenser.py index 12fe626e5..723b5eebe 100644 --- a/ms_agent/memory/condenser/code_condenser.py +++ b/ms_agent/memory/condenser/code_condenser.py @@ -1,7 +1,7 @@ +import json import os from typing import List -import json from ms_agent.llm import LLM, Message from ms_agent.memory import Memory from ms_agent.utils import get_logger diff --git a/ms_agent/memory/condenser/context_compressor.py b/ms_agent/memory/condenser/context_compressor.py index 9bec9bcf1..08ef02755 100644 --- a/ms_agent/memory/condenser/context_compressor.py +++ b/ms_agent/memory/condenser/context_compressor.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json from typing import List, Optional -import json from ms_agent.llm import LLM, Message from ms_agent.memory import Memory from ms_agent.utils.logger import logger diff --git a/ms_agent/memory/condenser/refine_condenser.py b/ms_agent/memory/condenser/refine_condenser.py index 557e2a1ff..6d7ec804e 100644 --- a/ms_agent/memory/condenser/refine_condenser.py +++ b/ms_agent/memory/condenser/refine_condenser.py @@ -1,6 +1,6 @@ +import json from typing import List -import json from ms_agent.llm import LLM, Message from ms_agent.memory import Memory diff --git a/ms_agent/memory/default_memory.py b/ms_agent/memory/default_memory.py index b087a2dd9..d84826de4 100644 --- a/ms_agent/memory/default_memory.py +++ b/ms_agent/memory/default_memory.py @@ -2,23 +2,23 @@ import asyncio import hashlib import importlib +import json +import json5 import os import re import traceback from datetime import datetime from functools import partial, wraps from inspect import signature +from omegaconf import DictConfig, OmegaConf from typing import Any, Dict, List, Optional, Tuple -import json -import json5 from ms_agent.llm.utils import Message from ms_agent.memory import Memory from ms_agent.utils import get_fact_retrieval_prompt from ms_agent.utils.constants import (DEFAULT_OUTPUT_DIR, DEFAULT_SEARCH_LIMIT, DEFAULT_USER, get_service_config) from ms_agent.utils.logger import logger -from omegaconf import DictConfig, OmegaConf class MemoryMapping: diff --git a/ms_agent/memory/diversity.py b/ms_agent/memory/diversity.py index 775e80da1..2e307a0ac 100644 --- a/ms_agent/memory/diversity.py +++ b/ms_agent/memory/diversity.py @@ -1,11 +1,10 @@ import asyncio import re from copy import deepcopy +from omegaconf import DictConfig from typing import List from ms_agent.utils import get_logger -from omegaconf import DictConfig - from ..llm import LLM, Message from .base import Memory diff --git a/ms_agent/memory/memory_manager.py b/ms_agent/memory/memory_manager.py index 5a203505d..5d0a10210 100644 --- a/ms_agent/memory/memory_manager.py +++ b/ms_agent/memory/memory_manager.py @@ -1,10 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import Dict from ms_agent.memory import Memory, memory_mapping from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_USER -from omegaconf import DictConfig logger = get_logger() diff --git a/ms_agent/prompting/file_resolver.py b/ms_agent/prompting/file_resolver.py index c04b8bdde..2bf340200 100644 --- a/ms_agent/prompting/file_resolver.py +++ b/ms_agent/prompting/file_resolver.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from dataclasses import dataclass -from typing import List, Optional, Tuple - from omegaconf import DictConfig +from typing import List, Optional, Tuple @dataclass(frozen=True) diff --git a/ms_agent/rag/extraction.py b/ms_agent/rag/extraction.py index c62b10067..0215d8b7b 100644 --- a/ms_agent/rag/extraction.py +++ b/ms_agent/rag/extraction.py @@ -1,11 +1,11 @@ # flake8: noqa # yapf: disable from abc import ABC, abstractmethod -from typing import Any, Dict, List - from docling_core.transforms.chunker import BaseChunk from docling_core.types import DoclingDocument from docling_core.types.doc import DocItem, DocItemLabel +from typing import Any, Dict, List + from ms_agent.rag.schema import KeyInformation from ms_agent.tools.docling.chunker import HybridDocumentChunker from ms_agent.tools.docling.doc_loader import DocLoader diff --git a/ms_agent/rag/llama_index_rag.py b/ms_agent/rag/llama_index_rag.py index e4535d38d..129fda95f 100644 --- a/ms_agent/rag/llama_index_rag.py +++ b/ms_agent/rag/llama_index_rag.py @@ -1,11 +1,10 @@ import os import shutil -from typing import Any, List, Optional - -from ms_agent.utils import assert_package_exist from omegaconf import DictConfig +from typing import Any, List, Optional from modelscope import snapshot_download +from ms_agent.utils import assert_package_exist from ..llm import LLM, Message from .base import RAG @@ -41,6 +40,7 @@ def __init__(self, config: DictConfig): from llama_index.core import Settings from llama_index.core.node_parser import SentenceSplitter + # Set node parser Settings.node_parser = SentenceSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) @@ -49,10 +49,10 @@ def __init__(self, config: DictConfig): if self.retrieve_only: Settings.llm = None else: + from llama_index.core.base.llms.types import (CompletionResponse, + LLMMetadata) from llama_index.core.llms import CustomLLM - from llama_index.core.base.llms.types import LLMMetadata from llama_index.core.llms.callbacks import llm_completion_callback - from llama_index.core.base.llms.types import CompletionResponse self._llm_instance = LLM.from_config(self.config) class MSCustomLLM(CustomLLM): @@ -108,7 +108,7 @@ def _validate_config(self, config: DictConfig): raise ValueError('chunk_size must be greater than 0') def _setup_embedding_model(self, config: DictConfig): - from llama_index.core import (Settings) + from llama_index.core import Settings from llama_index.embeddings.huggingface import HuggingFaceEmbedding try: use_hf = getattr(config, 'use_huggingface', False) @@ -124,7 +124,7 @@ def _setup_embedding_model(self, config: DictConfig): async def add_documents(self, documents: List[str]): if not documents: raise ValueError('Document list cannot be empty') - from llama_index.core import (Document, VectorStoreIndex) + from llama_index.core import Document, VectorStoreIndex docs = [Document(text=doc) for doc in documents] self.index = VectorStoreIndex.from_documents(docs) if not self.retrieve_only: @@ -159,6 +159,7 @@ async def _setup_query_engine(self): return from llama_index.core import Settings + # Check if LLM is set if Settings.llm is None and not self.retrieve_only: return @@ -239,10 +240,11 @@ async def hybrid_search(self, query: str, top_k: int = 5) -> List[dict]: return [] from llama_index.core.retrievers import VectorIndexRetriever + # Try to import BM25 related modules try: - from llama_index.retrievers.bm25 import BM25Retriever from llama_index.core.retrievers import QueryFusionRetriever + from llama_index.retrievers.bm25 import BM25Retriever bm25_available = True except ImportError: bm25_available = False @@ -316,7 +318,7 @@ async def load_index(self, persist_dir: Optional[str] = None): raise FileNotFoundError( f'Index directory does not exist: {load_dir}') - from llama_index.core import (StorageContext, load_index_from_storage) + from llama_index.core import StorageContext, load_index_from_storage storage_context = StorageContext.from_defaults(persist_dir=load_dir) self.index = load_index_from_storage(storage_context) diff --git a/ms_agent/retriever/hybrid_retriever.py b/ms_agent/retriever/hybrid_retriever.py index e84bc8398..d97c4fcdc 100644 --- a/ms_agent/retriever/hybrid_retriever.py +++ b/ms_agent/retriever/hybrid_retriever.py @@ -1,11 +1,11 @@ import asyncio +import faiss import math +import numpy as np import os import threading from typing import List, Tuple -import faiss -import numpy as np from ms_agent.utils.tokenizer_util import TokenizerUtil os.environ['OMP_NUM_THREADS'] = '1' diff --git a/ms_agent/sandbox/sandbox.py b/ms_agent/sandbox/sandbox.py index 8a5753761..3f63bf962 100644 --- a/ms_agent/sandbox/sandbox.py +++ b/ms_agent/sandbox/sandbox.py @@ -46,7 +46,7 @@ def __init__(self, **kwargs): super().__init__() self._init() - from ms_enclave.sandbox import SandboxConfig, DockerSandboxConfig + from ms_enclave.sandbox import DockerSandboxConfig, SandboxConfig # Mount host directories into the sandbox container if provided _volumes = kwargs.pop('volumes', None) or [] diff --git a/ms_agent/skill/loader.py b/ms_agent/skill/loader.py index 1f5dca2a7..f6d898730 100644 --- a/ms_agent/skill/loader.py +++ b/ms_agent/skill/loader.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Union from ms_agent.utils.logger import logger - from .schema import SkillSchema, SkillSchemaParser diff --git a/ms_agent/skill/schema.py b/ms_agent/skill/schema.py index 722e0acc4..71a50d01b 100644 --- a/ms_agent/skill/schema.py +++ b/ms_agent/skill/schema.py @@ -6,13 +6,12 @@ Each Skill is represented as a self-contained directory with metadata. """ import re +import yaml from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Union -import yaml from ms_agent.utils.logger import logger - from .spec import Spec SUPPORTED_SCRIPT_EXT = ('.py', '.sh', '.js') diff --git a/ms_agent/tools/agent_tool.py b/ms_agent/tools/agent_tool.py index 02a57b8d5..554c6a6d6 100644 --- a/ms_agent/tools/agent_tool.py +++ b/ms_agent/tools/agent_tool.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio +import json import multiprocessing as mp import os import threading @@ -7,11 +8,11 @@ import uuid from collections import defaultdict from dataclasses import dataclass +from omegaconf import DictConfig, ListConfig, OmegaConf from queue import Empty as QueueEmpty from queue import Full as QueueFull from typing import Any, Callable, Dict, List, Optional, Union -import json from ms_agent.agent.loader import AgentLoader from ms_agent.llm.utils import Message, Tool from ms_agent.tools.base import ToolBase @@ -20,7 +21,6 @@ get_stats_path, monotonic, now_iso, summarize_usage) from ms_agent.utils.stream_writer import SubAgentStreamWriter -from omegaconf import DictConfig, ListConfig, OmegaConf logger = get_logger() @@ -232,15 +232,18 @@ def enabled(self) -> bool: 'type': 'object', 'properties': { 'tasks': { - 'type': 'array', - 'description': ( - 'MANDATORY: Each element is a dict, which must contains two fields: ' - '`system`(str) and `query`(str) to start one sub task.'), + 'type': + 'array', + 'description': + ('MANDATORY: Each element is a dict, which must contains two fields: ' + '`system`(str) and `query`(str) to start one sub task.'), }, 'execution_mode': { - 'type': 'string', + 'type': + 'string', 'enum': ['sequential', 'parallel'], - 'description': 'Whether to run sub-tasks sequentially or in parallel.', + 'description': + 'Whether to run sub-tasks sequentially or in parallel.', }, }, 'required': ['tasks'], @@ -255,7 +258,8 @@ def _load_specs(self): split_cfg = tools_cfg.split_task tag_prefix = getattr(split_cfg, 'tag_prefix', 'worker-') run_in_thread = bool(getattr(split_cfg, 'run_in_thread', True)) - run_in_process = bool(getattr(split_cfg, 'run_in_process', run_in_thread)) + run_in_process = bool( + getattr(split_cfg, 'run_in_process', run_in_thread)) builtin_spec = _AgentToolSpec( tool_name='split_to_sub_task', description=self._SPLIT_TASK_DESCRIPTION, @@ -391,7 +395,8 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], env_cfg = _to_container(env_cfg) if env_cfg is not None else None disallowed_raw = getattr(cfg, 'disallowed_tools', None) - disallowed_tools = _to_container(disallowed_raw) if disallowed_raw is not None else None + disallowed_tools = _to_container( + disallowed_raw) if disallowed_raw is not None else None if isinstance(disallowed_tools, list): disallowed_tools = [str(t) for t in disallowed_tools] elif disallowed_tools is not None: @@ -541,24 +546,28 @@ async def call_tool(self, server_name: str, *, tool_name: str, payload = self._build_payload(tool_args, spec) if spec.run_in_background: - return await self._launch_background(payload, spec, effective_call_id) + return await self._launch_background(payload, spec, + effective_call_id) use_subprocess = spec.run_in_thread and spec.run_in_process if use_subprocess: messages = await self._run_agent( None, payload, spec, call_id=effective_call_id) result_str = self._format_output(messages, spec) - return self._maybe_append_stream_path(result_str, effective_call_id) + return self._maybe_append_stream_path(result_str, + effective_call_id) # Pure async/await with optional escape-to-background support. - result = await self._run_sync_escapable(payload, spec, effective_call_id) + result = await self._run_sync_escapable(payload, spec, + effective_call_id) if isinstance(result, str): # Already formatted: escaped to background, returns async_launched JSON. return result result_str = self._format_output(result, spec) return self._maybe_append_stream_path(result_str, effective_call_id) - def _maybe_append_stream_path(self, result_str: str, effective_call_id: str) -> str: + def _maybe_append_stream_path(self, result_str: str, + effective_call_id: str) -> str: """Append a human- and LLM-readable note about the step-by-step execution log to *result_str* if streaming is enabled. @@ -575,15 +584,14 @@ def _maybe_append_stream_path(self, result_str: str, effective_call_id: str) -> result_str += ( f'\n\n[Note: The sub-agent\'s step-by-step execution trace ' f'(messages, tool calls, intermediate reasoning) was streamed ' - f'incrementally to: {path}]' - ) + f'incrementally to: {path}]') return result_str def _build_agent(self, spec: _AgentToolSpec): return _build_sub_agent(spec, self._trust_remote_code) async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, - call_id: Optional[str]) -> Any: + call_id: Optional[str]) -> Any: """Run sub-agent inline (pure async/await). If spec.sync_timeout_s is set, the call auto-escapes to background after @@ -600,7 +608,7 @@ async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, self._run_agent(None, payload, spec, call_id=effective_call_id)) self._active_sync_tasks[effective_call_id] = (run_task, spec, payload, - escape_event) + escape_event) try: if spec.sync_timeout_s and spec.sync_timeout_s > 0: @@ -634,9 +642,8 @@ async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, self._active_sync_tasks.pop(effective_call_id, None) async def _escape_running_task(self, call_id: str, - run_task: 'asyncio.Task[Any]', - spec: _AgentToolSpec, - payload: Any) -> str: + run_task: 'asyncio.Task[Any]', + spec: _AgentToolSpec, payload: Any) -> str: """Cancel the in-progress sync task and re-launch it as a background subprocess.""" if self._task_manager is None: raise RuntimeError( @@ -673,7 +680,7 @@ def escape_to_background(self, call_id: str) -> bool: return True async def _launch_background(self, payload: Any, spec: _AgentToolSpec, - call_id: Optional[str]) -> str: + call_id: Optional[str]) -> str: """Fire-and-forget: start subprocess, register with TaskManager, return immediately.""" if self._task_manager is None: raise RuntimeError( @@ -704,7 +711,9 @@ async def _watcher(): try: result = await self._wait_process_result(proc, result_queue) if result is None or not result.get('ok'): - err = (result or {}).get('error', 'subprocess exited without result') + err = (result + or {}).get('error', + 'subprocess exited without result') tb = (result or {}).get('traceback', '') if tb: logger.warning(tb) @@ -728,13 +737,16 @@ async def _watcher(): self._watcher_tasks.add(t) t.add_done_callback(self._watcher_tasks.discard) - return json.dumps({ - 'status': 'async_launched', - 'task_id': task_id, - 'tool_name': spec.tool_name, - }, ensure_ascii=False) + return json.dumps( + { + 'status': 'async_launched', + 'task_id': task_id, + 'tool_name': spec.tool_name, + }, + ensure_ascii=False) - async def _call_dynamic(self, tool_args: dict, spec: '_AgentToolSpec') -> str: + async def _call_dynamic(self, tool_args: dict, + spec: '_AgentToolSpec') -> str: tasks = tool_args.get('tasks', []) execution_mode = tool_args.get('execution_mode', 'sequential') @@ -743,11 +755,13 @@ async def _call_dynamic(self, tool_args: dict, spec: '_AgentToolSpec') -> str: async def _run_one(i: int, task: dict) -> str: system = task.get('system', '') query = task.get('query', '') - task_config = dict(base_config) if isinstance(base_config, dict) else {} + task_config = dict(base_config) if isinstance(base_config, + dict) else {} # Avoid inheriting the parent agent's snapshot preference into each # split sub-task; sub-agents use ms_agent_subagent defaults instead. task_config.pop('enable_snapshots', None) - if 'prompt' not in task_config or not isinstance(task_config.get('prompt'), dict): + if 'prompt' not in task_config or not isinstance( + task_config.get('prompt'), dict): task_config['prompt'] = {} task_config['prompt']['system'] = system sub_spec = _AgentToolSpec( @@ -925,7 +939,9 @@ async def _run_agent(self, ) logger.info( '[stream] %s (call_id=%s) streaming to %s', - spec.tool_name, _effective_call_id, _writer.stream_path, + spec.tool_name, + _effective_call_id, + _writer.stream_path, ) # ─────────────────────────────────────────────────────────────────── @@ -1174,7 +1190,8 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: f'Failed to write agent tool stats for {spec.tool_name}: {exc}' ) - def _save_transcript(self, messages: Any, agent_tag: Optional[str]) -> None: + def _save_transcript(self, messages: Any, + agent_tag: Optional[str]) -> None: if not isinstance(messages, list) or not agent_tag: return try: @@ -1185,9 +1202,12 @@ def _save_transcript(self, messages: Any, agent_tag: Optional[str]) -> None: with open(path, 'w', encoding='utf-8') as f: for msg in messages: if hasattr(msg, 'to_dict'): - f.write(json.dumps(msg.to_dict(), ensure_ascii=False) + '\n') + f.write( + json.dumps(msg.to_dict(), ensure_ascii=False) + + '\n') except Exception as exc: - logger.warning(f'Failed to save sub-agent transcript for {agent_tag}: {exc}') + logger.warning( + f'Failed to save sub-agent transcript for {agent_tag}: {exc}') def _build_payload(self, tool_args: dict, spec: _AgentToolSpec): if spec.input_mode == 'messages': diff --git a/ms_agent/tools/base.py b/ms_agent/tools/base.py index 12ece9948..18e66d175 100644 --- a/ms_agent/tools/base.py +++ b/ms_agent/tools/base.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from abc import abstractmethod +from omegaconf import DictConfig from typing import Any, Dict from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR -from omegaconf import DictConfig class ToolBase: diff --git a/ms_agent/tools/code/code_executor.py b/ms_agent/tools/code/code_executor.py index 1df89e39c..cfe43c3b5 100644 --- a/ms_agent/tools/code/code_executor.py +++ b/ms_agent/tools/code/code_executor.py @@ -1,17 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio +import json import socket +from omegaconf import DictConfig from pathlib import Path from typing import Any, Dict, Optional, Union -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.tools.code.sandbox_manager import SandboxManagerFactory from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR from ms_agent.utils.utils import install_package -from omegaconf import DictConfig logger = get_logger() @@ -117,7 +117,8 @@ def _build_sandbox_config( self, config) -> Union['DockerNotebookConfig', 'DockerSandboxConfig']: """Build sandbox configuration from agent config""" - from ms_enclave.sandbox.model import DockerNotebookConfig, DockerSandboxConfig, SandboxType + from ms_enclave.sandbox.model import (DockerNotebookConfig, + DockerSandboxConfig, SandboxType) # Get sandbox-specific config or use defaults if isinstance(config, DictConfig) and hasattr( diff --git a/ms_agent/tools/code/local_code_executor.py b/ms_agent/tools/code/local_code_executor.py index d1e2104c5..63cd4d3b9 100644 --- a/ms_agent/tools/code/local_code_executor.py +++ b/ms_agent/tools/code/local_code_executor.py @@ -2,6 +2,7 @@ import asyncio.subprocess as ai_subprocess import inspect import io +import json import os import shlex import shutil @@ -10,14 +11,14 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger from ms_agent.utils.artifact_manager import ArtifactManager from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR from ms_agent.utils.utils import install_package -from ms_agent.utils.workspace_policy import WorkspacePolicyError, WorkspacePolicyKernel +from ms_agent.utils.workspace_policy import (WorkspacePolicyError, + WorkspacePolicyKernel) logger = get_logger() @@ -234,7 +235,8 @@ class LocalCodeExecutionTool(ToolBase): def __init__(self, config): super().__init__(config) self.output_dir = Path( - getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR)).expanduser().resolve() + getattr(config, 'output_dir', + DEFAULT_OUTPUT_DIR)).expanduser().resolve() self.output_dir.mkdir(parents=True, exist_ok=True) self.tool_config = getattr( @@ -264,13 +266,15 @@ def __init__(self, config): dg = getattr(wp, 'deny_globs', None) if dg: deny_globs = list(dg) - shell_cfg = getattr(self.tool_config, 'shell', None) if self.tool_config else None - shell_mode = getattr(shell_cfg, 'default_mode', - 'workspace_write') if shell_cfg else 'workspace_write' - net = bool(getattr(shell_cfg, 'network_enabled', False) - ) if shell_cfg else False - max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192) - ) if shell_cfg else 8192 + shell_cfg = getattr(self.tool_config, 'shell', + None) if self.tool_config else None + shell_mode = getattr( + shell_cfg, 'default_mode', + 'workspace_write') if shell_cfg else 'workspace_write' + net = bool(getattr(shell_cfg, 'network_enabled', + False)) if shell_cfg else False + max_cmd = int(getattr(shell_cfg, 'max_command_chars', + 8192)) if shell_cfg else 8192 self._policy = WorkspacePolicyKernel( self.output_dir, extra_allow_roots=extra_allow, @@ -463,12 +467,12 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='shell_executor', server_name='code_executor', - description=( - 'Execute shell commands locally under the workspace output directory (cwd). ' - 'Subject to policy (read_only vs workspace_write, network toggle). ' - 'Large stdout/stderr may be spilled to .ms_agent_artifacts. ' - 'Use run_in_background=true to return immediately with task_id; poll via task notifications.' - ), + description= + ('Execute shell commands locally under the workspace output directory (cwd). ' + 'Subject to policy (read_only vs workspace_write, network toggle). ' + 'Large stdout/stderr may be spilled to .ms_agent_artifacts. ' + 'Use run_in_background=true to return immediately with task_id; poll via task notifications.' + ), parameters={ 'type': 'object', 'properties': { @@ -488,8 +492,10 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'default': False, }, '__call_id': { - 'type': 'string', - 'description': 'Optional correlation id (injected by host when supported).', + 'type': + 'string', + 'description': + 'Optional correlation id (injected by host when supported).', }, }, 'required': ['command'], @@ -709,7 +715,8 @@ async def shell_executor(self, if self._task_manager is None: return json.dumps( { - 'success': False, + 'success': + False, 'error': 'run_in_background requires TaskManager (host must wire LLMAgent.task_manager).', }, diff --git a/ms_agent/tools/code/sandbox_manager.py b/ms_agent/tools/code/sandbox_manager.py index 9744ca7f8..4d3d843ce 100644 --- a/ms_agent/tools/code/sandbox_manager.py +++ b/ms_agent/tools/code/sandbox_manager.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import Union from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() @@ -58,7 +58,8 @@ async def create_manager( Raises: ValueError: If sandbox mode is unknown """ - from ms_enclave.sandbox.manager import HttpSandboxManager, LocalSandboxManager + from ms_enclave.sandbox.manager import (HttpSandboxManager, + LocalSandboxManager) # Extract sandbox configuration if isinstance(config, DictConfig) and hasattr( diff --git a/ms_agent/tools/code_server/lsp_code_server.py b/ms_agent/tools/code_server/lsp_code_server.py index df1e043e5..ac7f3fbaf 100644 --- a/ms_agent/tools/code_server/lsp_code_server.py +++ b/ms_agent/tools/code_server/lsp_code_server.py @@ -1,11 +1,11 @@ import asyncio +import json import os import shutil import sys from pathlib import Path from typing import Any, Dict, List, Optional -import json from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger from ms_agent.utils.constants import (DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, diff --git a/ms_agent/tools/docling/chunker.py b/ms_agent/tools/docling/chunker.py index 3781cb76b..5a7c894c6 100644 --- a/ms_agent/tools/docling/chunker.py +++ b/ms_agent/tools/docling/chunker.py @@ -1,5 +1,3 @@ -from typing import Iterable, Iterator, List, Union - from docling_core.transforms.chunker import BaseChunk, DocChunk from docling_core.transforms.chunker.hierarchical_chunker import ( ChunkingDocSerializer, ChunkingSerializerProvider) @@ -10,11 +8,12 @@ from docling_core.transforms.serializer.markdown import MarkdownParams from docling_core.types import DoclingDocument from docling_core.types.doc import DocItemLabel -from ms_agent.utils.logger import get_logger from rich.console import Console from rich.panel import Panel +from typing import Iterable, Iterator, List, Union from modelscope import AutoTokenizer +from ms_agent.utils.logger import get_logger logger = get_logger() diff --git a/ms_agent/tools/docling/doc_loader.py b/ms_agent/tools/docling/doc_loader.py index 5daf1652b..daac553e5 100644 --- a/ms_agent/tools/docling/doc_loader.py +++ b/ms_agent/tools/docling/doc_loader.py @@ -2,9 +2,6 @@ # yapf: disable import ast import os -from typing import Dict, Iterator, List, Optional, Tuple, Union -from unittest.mock import patch as mock_patch - from docling.backend.html_backend import HTMLDocumentBackend from docling.datamodel.accelerator_options import AcceleratorOptions from docling.datamodel.base_models import InputFormat @@ -18,6 +15,9 @@ from docling.models.table_structure_model import TableStructureModel from docling_core.types import DoclingDocument from docling_core.types.doc import DocItem +from typing import Dict, Iterator, List, Optional, Tuple, Union +from unittest.mock import patch as mock_patch + from ms_agent.tools.docling.doc_postprocess import PostProcess from ms_agent.tools.docling.patches import (download_models_ms, download_models_pic_classifier_ms, diff --git a/ms_agent/tools/docling/doc_postprocess.py b/ms_agent/tools/docling/doc_postprocess.py index 2e21051d9..2b66dd95e 100644 --- a/ms_agent/tools/docling/doc_postprocess.py +++ b/ms_agent/tools/docling/doc_postprocess.py @@ -1,6 +1,5 @@ -from typing import Union - from docling_core.types import DoclingDocument +from typing import Union class PostProcess: diff --git a/ms_agent/tools/docling/patches.py b/ms_agent/tools/docling/patches.py index b2ec77afd..4191e706b 100644 --- a/ms_agent/tools/docling/patches.py +++ b/ms_agent/tools/docling/patches.py @@ -1,10 +1,10 @@ # flake8: noqa import sys -from pathlib import Path - from bs4 import Tag from docling_core.types import DoclingDocument from docling_core.types.doc import DocItemLabel, ImageRef +from pathlib import Path + from ms_agent.utils.logger import get_logger from ms_agent.utils.utils import (load_image_from_uri_to_pil, load_image_from_url_to_pil, validate_url) diff --git a/ms_agent/tools/fetch_playwright_fallback.py b/ms_agent/tools/fetch_playwright_fallback.py index 0ff89a91d..b173217df 100644 --- a/ms_agent/tools/fetch_playwright_fallback.py +++ b/ms_agent/tools/fetch_playwright_fallback.py @@ -100,7 +100,8 @@ def _thread_browser() -> object: except ImportError: logger.debug( 'playwright is not installed; skip headless fetch. ' - 'Install with: pip install playwright && playwright install chromium') + 'Install with: pip install playwright && playwright install chromium' + ) raise RuntimeError('playwright not installed') from None pw = sync_playwright().start() @@ -135,7 +136,8 @@ def try_playwright_inner_text( except ImportError: logger.debug( 'playwright is not installed; skip headless fetch. ' - 'Install with: pip install playwright && playwright install chromium') + 'Install with: pip install playwright && playwright install chromium' + ) return '' text = '' @@ -147,13 +149,11 @@ def try_playwright_inner_text( page.goto(url, wait_until='domcontentloaded', timeout=timeout_ms) if settle_ms: page.wait_for_timeout(settle_ms) - raw = page.evaluate( - """() => { + raw = page.evaluate("""() => { const b = document.body; if (!b) return ''; return b.innerText || ''; - }""" - ) + }""") if isinstance(raw, str): text = raw[:_MAX_INNER_TEXT_CHARS] finally: @@ -180,9 +180,8 @@ def looks_like_spa_shell_html(raw_html: str) -> bool: return False low = raw_html.lower() if any( - x in low - for x in ('enable javascript', 'javascript is required', - 'you need to enable javascript')): + x in low for x in ('enable javascript', 'javascript is required', + 'you need to enable javascript')): return True if re.search(r']+\bid=["\']root["\'][^>]*>\s*', low): return True diff --git a/ms_agent/tools/filesystem_tool.py b/ms_agent/tools/filesystem_tool.py index caa58c58b..0126f4f02 100644 --- a/ms_agent/tools/filesystem_tool.py +++ b/ms_agent/tools/filesystem_tool.py @@ -16,18 +16,58 @@ from ms_agent.utils import get_logger from ms_agent.utils.artifact_manager import ArtifactManager from ms_agent.utils.constants import DEFAULT_INDEX_DIR, DEFAULT_OUTPUT_DIR -from ms_agent.utils.workspace_policy import WorkspacePolicyError, WorkspacePolicyKernel +from ms_agent.utils.workspace_policy import (WorkspacePolicyError, + WorkspacePolicyKernel) logger = get_logger() -_FS_TOOL_ALIASES = {'read': 'read_file', 'edit': 'edit_file', 'write': 'write_file'} +_FS_TOOL_ALIASES = { + 'read': 'read_file', + 'edit': 'edit_file', + 'write': 'write_file' +} _TEXT_SUFFIXES = { - '.py', '.md', '.txt', '.yaml', '.yml', '.json', '.toml', '.cfg', '.ini', - '.sh', '.bash', '.js', '.ts', '.tsx', '.jsx', '.css', '.html', '.xml', - '.rs', '.go', '.java', '.c', '.h', '.cpp', '.hpp', '.cs', '.rb', '.php', - '.sql', '.vue', '.svelte', '.m', '.swift', '.kt', '.gradle', '.properties', - '.env', '.gitignore', '.dockerignore', 'Dockerfile', + '.py', + '.md', + '.txt', + '.yaml', + '.yml', + '.json', + '.toml', + '.cfg', + '.ini', + '.sh', + '.bash', + '.js', + '.ts', + '.tsx', + '.jsx', + '.css', + '.html', + '.xml', + '.rs', + '.go', + '.java', + '.c', + '.h', + '.cpp', + '.hpp', + '.cs', + '.rb', + '.php', + '.sql', + '.vue', + '.svelte', + '.m', + '.swift', + '.kt', + '.gradle', + '.properties', + '.env', + '.gitignore', + '.dockerignore', + 'Dockerfile', } @@ -38,8 +78,10 @@ class FileSystemTool(ToolBase): IMAGE_EXTENSIONS = frozenset({'png', 'jpg', 'jpeg', 'gif', 'webp'}) # Curly quote → straight quote mapping for fuzzy matching CURLY_QUOTE_MAP = { - '\u2018': "'", '\u2019': "'", # ' ' - '\u201c': '"', '\u201d': '"', # " " + '\u2018': "'", + '\u2019': "'", # ' ' + '\u201c': '"', + '\u201d': '"', # " " } SYSTEM_FOR_ABBREVIATIONS = """你是一个帮我简化文件信息并返回缩略的机器人,你需要根据输入文件内容来生成压缩过的文件内容。 @@ -88,7 +130,8 @@ def __init__(self, config, **kwargs): self._grep_timeout = int(getattr(fs_cfg, 'grep_timeout_s', 120) or 120) self._default_grep_head = int( getattr(fs_cfg, 'grep_head_limit', 250) or 250) - self._glob_max_files = int(getattr(fs_cfg, 'glob_max_files', 100) or 100) + self._glob_max_files = int( + getattr(fs_cfg, 'glob_max_files', 100) or 100) wp = getattr(getattr(config, 'tools', None), 'workspace_policy', None) extra = list(getattr(wp, 'allow_roots', []) or []) if wp else [] @@ -96,12 +139,13 @@ def __init__(self, config, **kwargs): shell_cfg = getattr( getattr(config.tools, 'code_executor', None), 'shell', None) - shell_mode = getattr(shell_cfg, 'default_mode', - 'workspace_write') if shell_cfg else 'workspace_write' - net = bool(getattr(shell_cfg, 'network_enabled', False) - ) if shell_cfg else False - max_cmd = int(getattr(shell_cfg, 'max_command_chars', 8192) - ) if shell_cfg else 8192 + shell_mode = getattr( + shell_cfg, 'default_mode', + 'workspace_write') if shell_cfg else 'workspace_write' + net = bool(getattr(shell_cfg, 'network_enabled', + False)) if shell_cfg else False + max_cmd = int(getattr(shell_cfg, 'max_command_chars', + 8192)) if shell_cfg else 8192 _out_p = Path(self.output_dir).expanduser().resolve() try: @@ -133,25 +177,29 @@ async def _get_tools_inner(self): Tool( tool_name='write_file', server_name='file_system', - description=( - 'Write content to a file. Creates the file if it does not exist, ' - 'or overwrites it if it does.\n\n' - 'Usage:\n' - '- Prefer `edit_file` for modifying existing files — it only changes the relevant section.\n' - '- Use this tool to create new files or perform a complete rewrite.\n' - '- Parent directories are created automatically if they do not exist.\n\n' - 'No prior `read_file` is required; the path is resolved and the given `content` is written as-is.' - ), + description= + ('Write content to a file. Creates the file if it does not exist, ' + 'or overwrites it if it does.\n\n' + 'Usage:\n' + '- Prefer `edit_file` for modifying existing files — it only changes the relevant section.\n' + '- Use this tool to create new files or perform a complete rewrite.\n' + '- Parent directories are created automatically if they do not exist.\n\n' + 'No prior `read_file` is required; the path is resolved and the given `content` is written as-is.' + ), parameters={ 'type': 'object', 'properties': { 'path': { - 'type': 'string', - 'description': 'The relative path of the file to write', + 'type': + 'string', + 'description': + 'The relative path of the file to write', }, 'content': { - 'type': 'string', - 'description': 'The full content to write into the file', + 'type': + 'string', + 'description': + 'The full content to write into the file', }, }, 'required': ['path', 'content'], @@ -160,48 +208,55 @@ async def _get_tools_inner(self): Tool( tool_name='read_file', server_name='file_system', - description=( - 'Read the content of one or more files.\n\n' - '- `paths`: list of relative file paths to read (preferred).\n' - '- `path`: single relative file path (alias when the model passes one file).\n' - '- For image files (png/jpg/jpeg/gif/webp), returns base64-encoded content.\n' - '- `offset`: line number to start reading from (1-based). ' - 'Only effective when paths has exactly one element. Omit to read from the beginning.\n' - '- `limit`: number of lines to read. ' - 'Only effective when paths has exactly one element. Omit to read to the end.\n' - '- `abbreviate`: if true, use an LLM to return a condensed summary of each file ' - 'instead of the raw content. Cached after first call. ' - 'Use this for a quick structural overview; read the full file if more detail is needed.' - ), + description= + ('Read the content of one or more files.\n\n' + '- `paths`: list of relative file paths to read (preferred).\n' + '- `path`: single relative file path (alias when the model passes one file).\n' + '- For image files (png/jpg/jpeg/gif/webp), returns base64-encoded content.\n' + '- `offset`: line number to start reading from (1-based). ' + 'Only effective when paths has exactly one element. Omit to read from the beginning.\n' + '- `limit`: number of lines to read. ' + 'Only effective when paths has exactly one element. Omit to read to the end.\n' + '- `abbreviate`: if true, use an LLM to return a condensed summary of each file ' + 'instead of the raw content. Cached after first call. ' + 'Use this for a quick structural overview; read the full file if more detail is needed.' + ), parameters={ 'type': 'object', 'properties': { 'paths': { - 'type': 'array', - 'items': {'type': 'string'}, + 'type': + 'array', + 'items': { + 'type': 'string' + }, 'description': 'List of relative file path(s) to read. ' 'Use this OR `path` (single file).', }, 'path': { - 'type': 'string', + 'type': + 'string', 'description': 'Single relative file path to read (alias for `paths` of length 1).', }, 'offset': { - 'type': 'integer', + 'type': + 'integer', 'description': 'Line number to start reading from (1-based). ' 'Only provide if the file is too large to read at once.', }, 'limit': { - 'type': 'integer', + 'type': + 'integer', 'description': 'Number of lines to read. ' 'Only provide if the file is too large to read at once.', }, 'abbreviate': { - 'type': 'boolean', + 'type': + 'boolean', 'description': 'If true, return an LLM-generated summary instead of raw content. ' 'Useful for large files or quick structural overview.', @@ -213,38 +268,44 @@ async def _get_tools_inner(self): Tool( tool_name='edit_file', server_name='file_system', - description=( - 'Edit an existing file by replacing an exact string with new content.\n\n' - 'The tool reads the file from disk and checks that `old_string` appears in the current ' - 'contents before applying the edit — you do not need a prior `read_file` call for ' - 'staleness. Still use `read_file` or `grep` when you need the exact snippet in the ' - 'conversation so you can form a correct `old_string`.\n\n' - 'You must provide the exact text to find (`old_string`) and the replacement (`new_string`).\n' - '`old_string` must match the file content EXACTLY — including whitespace and line breaks.\n' - 'If `old_string` appears multiple times and `replace_all` is false, the call will fail ' - 'with the match count so you can add more context to make it unique.\n\n' - 'Special case — `old_string=""`:\n' - '- File does not exist: creates the file with `new_string` as its content.\n' - '- File exists and is empty: fills it with `new_string`.\n' - '- File exists and has content: returns an error. Use `write_file` for a full rewrite.' - ), + description= + ('Edit an existing file by replacing an exact string with new content.\n\n' + 'The tool reads the file from disk and checks that `old_string` appears in the current ' + 'contents before applying the edit — you do not need a prior `read_file` call for ' + 'staleness. Still use `read_file` or `grep` when you need the exact snippet in the ' + 'conversation so you can form a correct `old_string`.\n\n' + 'You must provide the exact text to find (`old_string`) and the replacement (`new_string`).\n' + '`old_string` must match the file content EXACTLY — including whitespace and line breaks.\n' + 'If `old_string` appears multiple times and `replace_all` is false, the call will fail ' + 'with the match count so you can add more context to make it unique.\n\n' + 'Special case — `old_string=""`:\n' + '- File does not exist: creates the file with `new_string` as its content.\n' + '- File exists and is empty: fills it with `new_string`.\n' + '- File exists and has content: returns an error. Use `write_file` for a full rewrite.' + ), parameters={ 'type': 'object', 'properties': { 'path': { - 'type': 'string', - 'description': 'The relative path of the file to edit.', + 'type': + 'string', + 'description': + 'The relative path of the file to edit.', }, 'old_string': { - 'type': 'string', - 'description': 'The exact string to find and replace.', + 'type': + 'string', + 'description': + 'The exact string to find and replace.', }, 'new_string': { 'type': 'string', - 'description': 'The string to replace it with.', + 'description': + 'The string to replace it with.', }, 'replace_all': { - 'type': 'boolean', + 'type': + 'boolean', 'description': 'If true, replace all occurrences. Default is false (replace only the first).', }, @@ -255,40 +316,50 @@ async def _get_tools_inner(self): Tool( tool_name='grep', server_name='file_system', - description=( - 'Search file contents under the workspace using ripgrep when available, ' - 'otherwise a safe Python scan. Paths must stay under the configured output/workspace roots. ' - 'Read-only.' - ), + description= + ('Search file contents under the workspace using ripgrep when available, ' + 'otherwise a safe Python scan. Paths must stay under the configured output/workspace roots. ' + 'Read-only.'), parameters={ 'type': 'object', 'properties': { 'pattern': { - 'type': 'string', - 'description': 'Regular expression (Rust regex if rg is used).', + 'type': + 'string', + 'description': + 'Regular expression (Rust regex if rg is used).', }, 'path': { - 'type': 'string', + 'type': + 'string', 'description': 'Directory or file to search (relative to output_dir if not absolute). Default ".".', }, 'glob': { - 'type': 'string', - 'description': 'Optional glob filter for files, e.g. "*.py"', + 'type': + 'string', + 'description': + 'Optional glob filter for files, e.g. "*.py"', }, 'output_mode': { - 'type': 'string', - 'enum': ['content', 'files_with_matches', 'count'], + 'type': + 'string', + 'enum': + ['content', 'files_with_matches', 'count'], 'description': 'content: matching lines; files_with_matches: paths only; count: per-file counts', }, 'head_limit': { - 'type': 'integer', - 'description': 'Max lines (content) or paths/count entries to return', + 'type': + 'integer', + 'description': + 'Max lines (content) or paths/count entries to return', }, 'offset': { - 'type': 'integer', - 'description': 'Skip first N lines/entries after collect', + 'type': + 'integer', + 'description': + 'Skip first N lines/entries after collect', }, 'case_insensitive': { 'type': 'boolean', @@ -302,10 +373,10 @@ async def _get_tools_inner(self): Tool( tool_name='glob', server_name='file_system', - description=( - 'List files under a workspace directory matching a glob pattern ' - '(e.g. "**/*.py", "*.md"). Read-only; results are capped.' - ), + description= + ('List files under a workspace directory matching a glob pattern ' + '(e.g. "**/*.py", "*.md"). Read-only; results are capped.' + ), parameters={ 'type': 'object', 'properties': { @@ -314,7 +385,8 @@ async def _get_tools_inner(self): 'description': 'Glob pattern relative to path', }, 'path': { - 'type': 'string', + 'type': + 'string', 'description': 'Base directory (relative to output_dir if not absolute).', }, @@ -323,7 +395,6 @@ async def _get_tools_inner(self): 'additionalProperties': False, }, ), - ] } return tools @@ -343,8 +414,8 @@ async def grep( case_insensitive: bool = False, ) -> str: call_id = f'grep-{pattern[:40]}' - head_limit = (head_limit if head_limit is not None else - self._default_grep_head) + head_limit = ( + head_limit if head_limit is not None else self._default_grep_head) offset = offset or 0 path = path or '.' try: @@ -352,7 +423,8 @@ async def grep( except WorkspacePolicyError as e: return json.dumps({'success': False, 'error': str(e)}, indent=2) - if pattern is None or (isinstance(pattern, str) and not pattern.strip()): + if pattern is None or (isinstance(pattern, str) + and not pattern.strip()): return json.dumps( { 'success': False, @@ -363,13 +435,13 @@ async def grep( if isinstance(pattern, str) and ('\n' in pattern or '\r' in pattern): return json.dumps( { - 'success': False, - 'error': ( - 'grep pattern must not contain raw newline characters; ' - 'ripgrep rejects them unless multiline mode is enabled server-side. ' - 'Use several single-line patterns, escape newlines as needed, ' - 'or search with read_file for fixed multi-line text.' - ), + 'success': + False, + 'error': + ('grep pattern must not contain raw newline characters; ' + 'ripgrep rejects them unless multiline mode is enabled server-side. ' + 'Use several single-line patterns, escape newlines as needed, ' + 'or search with read_file for fixed multi-line text.'), }, indent=2, ) @@ -398,12 +470,8 @@ async def grep( except Exception as e: err = str(e) # Expected user/tooling failures (bad regex, rg rules) — log without traceback noise. - _quiet = ( - 'rg:' in err - or 'exited' in err.lower() - or 'regex' in err.lower() - or 'pattern' in err.lower() - ) + _quiet = ('rg:' in err or 'exited' in err.lower() + or 'regex' in err.lower() or 'pattern' in err.lower()) logger.warning('grep failed: %s', e, exc_info=not _quiet) return json.dumps({'success': False, 'error': str(e)}, indent=2) @@ -449,8 +517,8 @@ async def _grep_rg_file( stderr=asyncio.subprocess.PIPE, cwd=str(self._fs_policy.workspace_root), ) - out_b, err_b = await asyncio.wait_for(proc.communicate(), - timeout=self._grep_timeout) + out_b, err_b = await asyncio.wait_for( + proc.communicate(), timeout=self._grep_timeout) out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') if proc.returncode not in (0, 1): @@ -486,8 +554,8 @@ async def _grep_rg_dir( stderr=asyncio.subprocess.PIPE, cwd=str(self._fs_policy.workspace_root), ) - out_b, err_b = await asyncio.wait_for(proc.communicate(), - timeout=self._grep_timeout) + out_b, err_b = await asyncio.wait_for( + proc.communicate(), timeout=self._grep_timeout) out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') if proc.returncode not in (0, 1): @@ -516,8 +584,9 @@ def _grep_python( def consider_file(fp: Path) -> bool: if glob_pat: rel = str(fp.relative_to(root)) if root.is_dir() else fp.name - if not fnmatch.fnmatch(fp.name, glob_pat) and not fnmatch.fnmatch( - rel, glob_pat): + if not fnmatch.fnmatch(fp.name, + glob_pat) and not fnmatch.fnmatch( + rel, glob_pat): return False suf = fp.suffix.lower() if suf not in _TEXT_SUFFIXES and fp.suffix == '': @@ -539,9 +608,9 @@ def consider_file(fp: Path) -> bool: text = fp.read_text(encoding='utf-8', errors='replace') except OSError: continue - rel = str(fp.relative_to(self._fs_policy.workspace_root) - ) if _is_relative(fp, self._fs_policy.workspace_root) else str( - fp) + rel = str(fp.relative_to( + self._fs_policy.workspace_root)) if _is_relative( + fp, self._fs_policy.workspace_root) else str(fp) if output_mode == 'files_with_matches': if rx.search(text): lines_out.append(rel) @@ -589,9 +658,9 @@ async def glob(self, pattern: str, path: str = '') -> str: continue if _is_denied_path(rp, base, deny): continue - rel = str(p.relative_to(self._fs_policy.workspace_root) - ) if _is_relative(p, self._fs_policy.workspace_root - ) else str(p) + rel = str(p.relative_to( + self._fs_policy.workspace_root)) if _is_relative( + p, self._fs_policy.workspace_root) else str(p) matches.append(rel) if len(matches) >= self._glob_max_files: truncated = True @@ -629,7 +698,8 @@ def _normalize_quotes(self, s: str) -> str: s = s.replace(curly, straight) return s - def _preserve_quote_style(self, old_string: str, actual_old: str, new_string: str) -> str: + def _preserve_quote_style(self, old_string: str, actual_old: str, + new_string: str) -> str: """If old_string matched via quote normalization, apply the same curly quotes to new_string.""" if old_string == actual_old: return new_string @@ -727,7 +797,8 @@ def _normalize_read_paths(self, paths, path) -> List[str]: p.strip() for p in paths if isinstance(p, str) and p.strip() ] - if not out and path is not None and isinstance(path, str) and path.strip(): + if not out and path is not None and isinstance(path, + str) and path.strip(): out = [path.strip()] return out @@ -753,11 +824,11 @@ async def read_file(self, if not paths: return json.dumps( { - 'success': False, - 'error': ( - 'read_file requires `paths` (list of strings) or `path` (single string). ' - 'Example: {"paths": ["a.md"]} or {"path": "a.md"}.' - ), + 'success': + False, + 'error': + ('read_file requires `paths` (list of strings) or `path` (single string). ' + 'Example: {"paths": ["a.md"]} or {"path": "a.md"}.'), }, indent=2, ) @@ -802,8 +873,7 @@ async def read_file(self, # Dedup: return stub if file unchanged since last read mtime = os.path.getmtime(target_path_real) cached = self._read_cache.get(target_path_real) - if (cached - and cached['mtime'] == mtime + if (cached and cached['mtime'] == mtime and cached['offset'] == offset and cached['limit'] == limit): results[path] = { @@ -830,10 +900,13 @@ async def read_file(self, if use_line_range: actual_start = max(1, offset) if offset is not None else 1 - actual_end = min(actual_start + limit - 1, total_lines) if limit is not None else total_lines + actual_end = min( + actual_start + limit + - 1, total_lines) if limit is not None else total_lines if actual_start > total_lines: - results[path] = f'Error: offset {offset} exceeds file length ({total_lines} lines)' + results[ + path] = f'Error: offset {offset} exceeds file length ({total_lines} lines)' continue selected = lines[actual_start - 1:actual_end] start_lineno = actual_start @@ -841,10 +914,8 @@ async def read_file(self, selected = lines start_lineno = 1 - results[path] = ''.join( - f'{start_lineno + i}\t{line}' - for i, line in enumerate(selected) - ) + results[path] = ''.join(f'{start_lineno + i}\t{line}' + for i, line in enumerate(selected)) # Update dedup cache self._read_cache[target_path_real] = { @@ -881,7 +952,10 @@ def process_file(path): messages = [ Message(role='system', content=self.system), - Message(role='user', content='The content to be abbreviated:\n\n' + content), + Message( + role='user', + content='The content to be abbreviated:\n\n' + + content), ] response = self.llm.generate(messages=messages, stream=False) os.makedirs(os.path.dirname(index_file), exist_ok=True) @@ -894,7 +968,10 @@ def process_file(path): return path, f'Process file <{path}> failed, error: ' + str(e) with ThreadPoolExecutor(max_workers=4) as executor: - future_to_path = {executor.submit(process_file, p): p for p in paths} + future_to_path = { + executor.submit(process_file, p): p + for p in paths + } for future in as_completed(future_to_path): path, result = future.result() results[path] = result @@ -931,7 +1008,8 @@ async def edit_file(self, if old_string == '': if not os.path.exists(target_path_real): # Create new file - os.makedirs(os.path.dirname(target_path_real), exist_ok=True) + os.makedirs( + os.path.dirname(target_path_real), exist_ok=True) with open(target_path_real, 'w', encoding='utf-8') as f: f.write(new_string) return f'Created file <{path}> successfully.' @@ -989,10 +1067,12 @@ async def edit_file(self, ) # Apply quote style preservation to new_string - actual_new = self._preserve_quote_style(old_string, actual_old, new_string) + actual_new = self._preserve_quote_style(old_string, actual_old, + new_string) # --- Fallback 3: smart delete — strip trailing newline when deleting --- - if actual_new == '' and not actual_old.endswith('\n') and actual_old + '\n' in content: + if actual_new == '' and not actual_old.endswith( + '\n') and actual_old + '\n' in content: actual_old = actual_old + '\n' # Strip trailing whitespace from new_string (skip markdown files) diff --git a/ms_agent/tools/findata/akshare_source.py b/ms_agent/tools/findata/akshare_source.py index 6e68aaf41..b9fa3f7b4 100644 --- a/ms_agent/tools/findata/akshare_source.py +++ b/ms_agent/tools/findata/akshare_source.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import pandas as pd import re from typing import Any, Dict, List, Optional -import pandas as pd from ms_agent.tools.findata.data_source_base import (DataSourceError, FinancialDataSource, NoDataFoundError) diff --git a/ms_agent/tools/findata/baostock_source.py b/ms_agent/tools/findata/baostock_source.py index dd05b7a1f..9379b3ba2 100644 --- a/ms_agent/tools/findata/baostock_source.py +++ b/ms_agent/tools/findata/baostock_source.py @@ -1,10 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import pandas as pd import threading from contextlib import contextmanager from copy import deepcopy from typing import Any, Dict, List, Optional -import pandas as pd from ms_agent.tools.findata.data_source_base import (DataSourceError, FinancialDataSource, NoDataFoundError) diff --git a/ms_agent/tools/findata/data_source_base.py b/ms_agent/tools/findata/data_source_base.py index d287aebdb..e78b9fd29 100644 --- a/ms_agent/tools/findata/data_source_base.py +++ b/ms_agent/tools/findata/data_source_base.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import pandas as pd from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -import pandas as pd - class DataSourceError(Exception): """Base data source error class""" diff --git a/ms_agent/tools/findata/findata_fetcher.py b/ms_agent/tools/findata/findata_fetcher.py index 996f28781..2b5afdcfd 100644 --- a/ms_agent/tools/findata/findata_fetcher.py +++ b/ms_agent/tools/findata/findata_fetcher.py @@ -1,14 +1,15 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio +import json +import numpy as np +import pandas as pd from concurrent.futures import ThreadPoolExecutor from datetime import date, datetime from functools import partial +from omegaconf import DictConfig from pathlib import Path from typing import Any, Dict, Optional, Union -import json -import numpy as np -import pandas as pd from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.tools.findata.akshare_source import AKShareDataSource @@ -19,7 +20,6 @@ from ms_agent.tools.findata.hybrid_source import HybridDataSource from ms_agent.utils import get_logger from ms_agent.utils.rate_limiter import AdaptiveRateLimiter, RateLimiter -from omegaconf import DictConfig logger = get_logger() diff --git a/ms_agent/tools/findata/hybrid_source.py b/ms_agent/tools/findata/hybrid_source.py index 79380fd14..66b246c8e 100644 --- a/ms_agent/tools/findata/hybrid_source.py +++ b/ms_agent/tools/findata/hybrid_source.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import pandas as pd import re from typing import Any, Callable, Dict, List, Optional -import pandas as pd from ms_agent.tools.findata.akshare_source import AKShareDataSource from ms_agent.tools.findata.baostock_source import BaoStockDataSource from ms_agent.tools.findata.data_source_base import (DataSourceError, diff --git a/ms_agent/tools/image_generator/ds_image_gen.py b/ms_agent/tools/image_generator/ds_image_gen.py index b7286b218..ea83d1d06 100644 --- a/ms_agent/tools/image_generator/ds_image_gen.py +++ b/ms_agent/tools/image_generator/ds_image_gen.py @@ -1,7 +1,6 @@ import os import uuid from io import BytesIO - from PIL import Image diff --git a/ms_agent/tools/image_generator/ms_image_gen.py b/ms_agent/tools/image_generator/ms_image_gen.py index b121458e6..b61b6d7af 100644 --- a/ms_agent/tools/image_generator/ms_image_gen.py +++ b/ms_agent/tools/image_generator/ms_image_gen.py @@ -1,9 +1,8 @@ import asyncio +import json import os import uuid from io import BytesIO - -import json from PIL import Image diff --git a/ms_agent/tools/jina_reader.py b/ms_agent/tools/jina_reader.py index b3f663971..f60b39cd8 100644 --- a/ms_agent/tools/jina_reader.py +++ b/ms_agent/tools/jina_reader.py @@ -10,8 +10,8 @@ from urllib.parse import quote, urlparse from urllib.request import Request, urlopen -from ms_agent.tools.fetch_playwright_fallback import (looks_like_spa_shell_html, - try_playwright_inner_text) +from ms_agent.tools.fetch_playwright_fallback import ( + looks_like_spa_shell_html, try_playwright_inner_text) from ms_agent.utils.logger import get_logger logger = get_logger() @@ -194,8 +194,8 @@ def _fetch_via_jina(url: str, config: JinaReaderConfig) -> str: return '' -def fetch_single_text_with_meta(url: str, - config: JinaReaderConfig) -> Tuple[str, Dict[str, Any]]: +def fetch_single_text_with_meta( + url: str, config: JinaReaderConfig) -> Tuple[str, Dict[str, Any]]: """ Tiered fetch: Jina Reader → direct HTTP → optional Playwright (empty / short / SPA shell). @@ -209,15 +209,16 @@ def fetch_single_text_with_meta(url: str, return jina_text, {'content_source': 'jina_reader'} if not config.direct_fetch_fallback: return '', {'content_source': 'none'} - d_timeout = (float(config.timeout) if float(config.direct_fetch_timeout or 0) - <= 0 else float(config.direct_fetch_timeout)) + d_timeout = ( + float(config.timeout) if float(config.direct_fetch_timeout or 0) <= 0 + else float(config.direct_fetch_timeout)) direct_plain, raw_html = _fetch_direct_http_pair(url, d_timeout) direct_text = _postprocess_text(direct_plain) try_playwright = ( bool(config.playwright_fetch_fallback) and _is_direct_http_allowed(url) - and _should_try_playwright_after_direct(direct_text, raw_html, - config.playwright_retry_min_chars)) + and _should_try_playwright_after_direct( + direct_text, raw_html, config.playwright_retry_min_chars)) if try_playwright: pw_text = _postprocess_text( diff --git a/ms_agent/tools/mcp_client.py b/ms_agent/tools/mcp_client.py index d6b971fbb..7fbfe148c 100644 --- a/ms_agent/tools/mcp_client.py +++ b/ms_agent/tools/mcp_client.py @@ -2,18 +2,18 @@ import os from contextlib import AsyncExitStack from datetime import timedelta -from types import TracebackType -from typing import Any, Dict, Literal, Optional - from mcp import ClientSession, ListToolsResult, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client +from omegaconf import DictConfig +from types import TracebackType +from typing import Any, Dict, Literal, Optional + from ms_agent.config import Config from ms_agent.config.env import Env from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import enhance_error, get_logger -from omegaconf import DictConfig logger = get_logger() @@ -184,7 +184,8 @@ async def connect_to_server(self, 'Using streamable_http transport. To configure a different transport such as sse, please' 'set the `type` or `transport` variable to "sse".') try: - from mcp.client.streamable_http import streamablehttp_client + from mcp.client.streamable_http import \ + streamablehttp_client except ImportError: raise ImportError( 'Could not import streamablehttp_client. ' diff --git a/ms_agent/tools/mineru/pdf_parser.py b/ms_agent/tools/mineru/pdf_parser.py index d210fd9b1..284fbff69 100644 --- a/ms_agent/tools/mineru/pdf_parser.py +++ b/ms_agent/tools/mineru/pdf_parser.py @@ -1,5 +1,4 @@ import os - from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.data.data_reader_writer import (FileBasedDataReader, FileBasedDataWriter) diff --git a/ms_agent/tools/search/arxiv/schema.py b/ms_agent/tools/search/arxiv/schema.py index 154c603ae..2e9439904 100644 --- a/ms_agent/tools/search/arxiv/schema.py +++ b/ms_agent/tools/search/arxiv/schema.py @@ -1,10 +1,10 @@ # flake8: noqa -from dataclasses import dataclass, field -from typing import Any, Dict, Generator, List, Optional - import arxiv import json from arxiv import SortCriterion, SortOrder +from dataclasses import dataclass, field +from typing import Any, Dict, Generator, List, Optional + from ms_agent.tools.search.search_base import (BaseResult, SearchRequest, SearchResponse, SearchResult) from ms_agent.utils.logger import get_logger diff --git a/ms_agent/tools/search/arxiv/search.py b/ms_agent/tools/search/arxiv/search.py index 8f407d454..4888c2f3a 100644 --- a/ms_agent/tools/search/arxiv/search.py +++ b/ms_agent/tools/search/arxiv/search.py @@ -1,10 +1,10 @@ # flake8: noqa +import arxiv import os +from arxiv import SortCriterion, SortOrder from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING -import arxiv -from arxiv import SortCriterion, SortOrder from ms_agent.tools.search.arxiv.schema import (ArxivSearchRequest, ArxivSearchResult) from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType diff --git a/ms_agent/tools/search/content_optimizer.py b/ms_agent/tools/search/content_optimizer.py index 998b13b66..1e62798cb 100644 --- a/ms_agent/tools/search/content_optimizer.py +++ b/ms_agent/tools/search/content_optimizer.py @@ -1,19 +1,19 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio +import json import os import re from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from datetime import datetime +from omegaconf import DictConfig, OmegaConf from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse -import json from ms_agent.llm.openai_llm import OpenAI from ms_agent.llm.utils import Message from ms_agent.utils.logger import get_logger from ms_agent.utils.thread_util import DaemonThreadPoolExecutor -from omegaconf import DictConfig, OmegaConf logger = get_logger() diff --git a/ms_agent/tools/search/exa/schema.py b/ms_agent/tools/search/exa/schema.py index a80a1c401..48b51c06b 100644 --- a/ms_agent/tools/search/exa/schema.py +++ b/ms_agent/tools/search/exa/schema.py @@ -1,9 +1,8 @@ # flake8: noqa -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - import json +from dataclasses import dataclass, field from exa_py.api import SearchResponse +from typing import Any, Dict, List, Optional @dataclass diff --git a/ms_agent/tools/search/exa/search.py b/ms_agent/tools/search/exa/search.py index 08fa6fb71..d09ffa237 100644 --- a/ms_agent/tools/search/exa/search.py +++ b/ms_agent/tools/search/exa/search.py @@ -1,9 +1,9 @@ # flake8: noqa import os import threading +from exa_py import Exa from typing import TYPE_CHECKING, List, Optional, Set, Union -from exa_py import Exa from ms_agent.tools.search.exa.schema import ExaSearchRequest, ExaSearchResult from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType from ms_agent.utils.logger import get_logger diff --git a/ms_agent/tools/search/localsearch_tool.py b/ms_agent/tools/search/localsearch_tool.py index d8ab34ef7..2449d4453 100644 --- a/ms_agent/tools/search/localsearch_tool.py +++ b/ms_agent/tools/search/localsearch_tool.py @@ -5,12 +5,10 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from ms_agent.tools.search.sirchmunk_search import ( - SirchmunkSearch, - effective_localsearch_settings, -) from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase +from ms_agent.tools.search.sirchmunk_search import ( + SirchmunkSearch, effective_localsearch_settings) from ms_agent.utils.logger import get_logger logger = get_logger() @@ -60,9 +58,8 @@ def _resolved_localsearch_paths_from_config(config) -> List[str]: def _format_configured_roots(paths: List[str]) -> str: if not paths: - return ( - '(none — set tools.localsearch.paths in agent config, ' - 'or legacy knowledge_search.paths)') + return ('(none — set tools.localsearch.paths in agent config, ' + 'or legacy knowledge_search.paths)') return '\n'.join(f'- {p}' for p in paths) @@ -87,7 +84,8 @@ class LocalSearchTool(ToolBase): def __init__(self, config, **kwargs): super().__init__(config) tools_root = getattr(config, 'tools', None) - tool_cfg = getattr(tools_root, 'localsearch', None) if tools_root else None + tool_cfg = getattr(tools_root, 'localsearch', + None) if tools_root else None if tool_cfg is not None: self.exclude_func(tool_cfg) self._searcher: Optional[SirchmunkSearch] = None @@ -122,8 +120,7 @@ async def _get_tools_inner(self) -> Dict[str, List[Tool]]: server_name=_SERVER, description=self._tool_description(), parameters={ - 'type': - 'object', + 'type': 'object', 'properties': { 'query': { 'type': @@ -132,13 +129,11 @@ async def _get_tools_inner(self) -> Dict[str, List[Tool]]: 'Search keywords or natural-language question about local content.', }, 'paths': { - 'type': - 'array', + 'type': 'array', 'items': { 'type': 'string' }, - 'description': - self._paths_param_description(), + 'description': self._paths_param_description(), }, 'mode': { 'type': @@ -148,21 +143,28 @@ async def _get_tools_inner(self) -> Dict[str, List[Tool]]: 'Search mode; omit to use agent default (usually FAST).', }, 'max_depth': { - 'type': 'integer', - 'minimum': 1, - 'maximum': 20, + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 20, 'description': 'Max directory depth for filesystem search.', }, 'top_k_files': { - 'type': 'integer', - 'minimum': 1, - 'maximum': 20, + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 20, 'description': 'Max files for evidence / filename hits.', }, 'include': { - 'type': 'array', + 'type': + 'array', 'items': { 'type': 'string' }, @@ -170,7 +172,8 @@ async def _get_tools_inner(self) -> Dict[str, List[Tool]]: 'Glob patterns to include (e.g. *.py, *.md).', }, 'exclude': { - 'type': 'array', + 'type': + 'array', 'items': { 'type': 'string' }, @@ -219,8 +222,7 @@ async def call_tool(self, server_name: str, *, tool_name: str, if paths_arg: resolved_paths = searcher.resolve_tool_paths(paths_arg) if not resolved_paths: - roots = _format_configured_roots( - self._configured_roots) + roots = _format_configured_roots(self._configured_roots) return ( 'Error: `paths` are invalid. Each path must exist on disk and lie ' 'under one of these configured roots:\n' + roots) @@ -266,8 +268,7 @@ async def call_tool(self, server_name: str, *, tool_name: str, result_parts.append('\nSource paths:') for item in excerpts[:12]: meta = item.get('metadata') or {} - result_parts.append( - f'- {meta.get("source", "?")}') + result_parts.append(f'- {meta.get("source", "?")}') result_text = '\n'.join(result_parts) return { @@ -279,4 +280,3 @@ async def call_tool(self, server_name: str, *, tool_name: str, except Exception as exc: logger.warning(f'localsearch failed: {exc}') return f'Local search failed: {exc}' - diff --git a/ms_agent/tools/search/search_base.py b/ms_agent/tools/search/search_base.py index 8a10c9d20..3c125f11d 100644 --- a/ms_agent/tools/search/search_base.py +++ b/ms_agent/tools/search/search_base.py @@ -1,12 +1,11 @@ # flake8: noqa import enum +import json import os from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar -import json - if TYPE_CHECKING: from ms_agent.llm.utils import Tool diff --git a/ms_agent/tools/search/sirchmunk_search.py b/ms_agent/tools/search/sirchmunk_search.py index cd86f9d63..cd8aaacf5 100644 --- a/ms_agent/tools/search/sirchmunk_search.py +++ b/ms_agent/tools/search/sirchmunk_search.py @@ -7,11 +7,10 @@ import asyncio import json -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional - from loguru import logger from omegaconf import DictConfig +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional def _paths_from_block(block: Any) -> List[str]: @@ -142,8 +141,8 @@ def _validate_config(self, config: DictConfig): 'tools.localsearch.paths (or legacy knowledge_search.paths) ' 'must be specified and non-empty') - def resolve_tool_paths( - self, paths: Optional[List[str]]) -> Optional[List[str]]: + def resolve_tool_paths(self, + paths: Optional[List[str]]) -> Optional[List[str]]: """Restrict per-call paths to configured search roots.""" if not paths: return None @@ -154,10 +153,10 @@ def resolve_tool_paths( continue p = Path(str(raw).strip()).expanduser().resolve() if not p.exists(): - logger.warning(f'localsearch: path does not exist, skipped: {p}') + logger.warning( + f'localsearch: path does not exist, skipped: {p}') continue - allowed = any( - p == r or p.is_relative_to(r) for r in roots) + allowed = any(p == r or p.is_relative_to(r) for r in roots) if not allowed: logger.warning( f'localsearch: path outside configured search roots, ' @@ -402,8 +401,9 @@ async def query( self._last_search_result = [] for item in result[:20]: if isinstance(item, dict): - src = (item.get('path') or item.get('file_path') - or item.get('file') or '') + src = ( + item.get('path') or item.get('file_path') + or item.get('file') or '') self._last_search_result.append({ 'text': json.dumps(item, ensure_ascii=False), @@ -429,7 +429,7 @@ async def query( result, score_threshold=0.7, limit=5) if hasattr(result, 'answer') and getattr(result, 'answer', - None) is not None: + None) is not None: return result.answer if isinstance(result, str): diff --git a/ms_agent/tools/search/tavily/fetcher.py b/ms_agent/tools/search/tavily/fetcher.py index 4082907a2..d4ed4927a 100644 --- a/ms_agent/tools/search/tavily/fetcher.py +++ b/ms_agent/tools/search/tavily/fetcher.py @@ -45,7 +45,9 @@ def __init__( self._include_favicon = include_favicon self._include_usage = include_usage - def fetch(self, url: str, query: Optional[str] = None) -> Tuple[str, Dict[str, Any]]: + def fetch(self, + url: str, + query: Optional[str] = None) -> Tuple[str, Dict[str, Any]]: """ Extract one URL. Optional ``query`` enables chunk reranking (more relevant raw_content). """ @@ -64,7 +66,8 @@ def fetch(self, url: str, query: Optional[str] = None) -> Tuple[str, Dict[str, A body['chunks_per_source'] = self._chunks_per_source try: - data = post_json(TAVILY_EXTRACT_URL, body, timeout=self._timeout + 30.0) + data = post_json( + TAVILY_EXTRACT_URL, body, timeout=self._timeout + 30.0) except Exception as e: logger.warning(f'Tavily extract failed for {url[:80]}: {e}') return '', { diff --git a/ms_agent/tools/search/tavily/http.py b/ms_agent/tools/search/tavily/http.py index d4916d271..e2ad2cb85 100644 --- a/ms_agent/tools/search/tavily/http.py +++ b/ms_agent/tools/search/tavily/http.py @@ -44,7 +44,6 @@ def post_json( detail = json.loads(err_body) if err_body else {} except json.JSONDecodeError: detail = {'raw': err_body} - raise RuntimeError( - f'Tavily HTTP {e.code}: {detail}') from e + raise RuntimeError(f'Tavily HTTP {e.code}: {detail}') from e except URLError as e: raise RuntimeError(f'Tavily network error: {e}') from e diff --git a/ms_agent/tools/search/tavily/schema.py b/ms_agent/tools/search/tavily/schema.py index 75f3f0aed..112a41839 100644 --- a/ms_agent/tools/search/tavily/schema.py +++ b/ms_agent/tools/search/tavily/schema.py @@ -50,7 +50,8 @@ def to_api_body(self, api_key: str) -> Dict[str, Any]: } # chunks_per_source only meaningful for advanced (per Tavily docs) if self.search_depth == 'advanced': - body['chunks_per_source'] = max(1, min(3, int(self.chunks_per_source))) + body['chunks_per_source'] = max( + 1, min(3, int(self.chunks_per_source))) if self.time_range: body['time_range'] = self.time_range if self.start_date: diff --git a/ms_agent/tools/search/tavily/search.py b/ms_agent/tools/search/tavily/search.py index b4b7d3f3b..38c3c69cb 100644 --- a/ms_agent/tools/search/tavily/search.py +++ b/ms_agent/tools/search/tavily/search.py @@ -4,7 +4,8 @@ from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType from ms_agent.tools.search.tavily.http import post_json -from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult +from ms_agent.tools.search.tavily.schema import (TavilySearchRequest, + TavilySearchResult) from ms_agent.utils.logger import get_logger if TYPE_CHECKING: @@ -39,7 +40,8 @@ def __init__( self._api_key = key self._request_timeout = float(request_timeout) - def search(self, search_request: TavilySearchRequest) -> TavilySearchResult: + def search(self, + search_request: TavilySearchRequest) -> TavilySearchResult: body = search_request.to_api_body(self._api_key) try: data = post_json( @@ -64,7 +66,8 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'Search the web using Tavily (built for AI agents). ' 'Returns ranked results with optional full-page markdown via ' '`include_raw_content`. Use `search_depth` advanced for best ' - 'relevance and richer `content` chunks (higher API credit use).'), + 'relevance and richer `content` chunks (higher API credit use).' + ), parameters={ 'type': 'object', 'properties': { @@ -73,28 +76,36 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'description': 'Search query.', }, 'num_results': { - 'type': 'integer', - 'minimum': 1, - 'maximum': 20, - 'description': 'Max results (maps to Tavily max_results). Default 10.', + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 20, + 'description': + 'Max results (maps to Tavily max_results). Default 10.', }, 'search_depth': { - 'type': 'string', + 'type': + 'string', 'enum': ['advanced', 'basic', 'fast', 'ultra-fast'], 'description': ('advanced: best quality, 2 credits; ' 'basic/fast/ultra-fast: 1 credit (see Tavily docs).'), }, 'topic': { - 'type': 'string', + 'type': + 'string', 'enum': ['general', 'news', 'finance'], 'description': 'Search category (`news` / `finance` for focused verticals).', }, 'time_range': { - 'type': 'string', + 'type': + 'string', 'description': - ('Filter by recency: day, week, month, year or d,w,m,y.'), + ('Filter by recency: day, week, month, year or d,w,m,y.' + ), }, 'start_date': { 'type': 'string', @@ -105,24 +116,28 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'description': 'Results before YYYY-MM-DD.', }, 'include_answer': { - 'type': 'string', + 'type': + 'string', 'enum': ['false', 'true', 'basic', 'advanced'], 'description': ('LLM answer: true/basic for short, advanced for detailed. ' 'Use false to skip.'), }, 'include_raw_content': { - 'type': 'string', - 'enum': - ['false', 'true', 'markdown', 'text'], + 'type': + 'string', + 'enum': ['false', 'true', 'markdown', 'text'], 'description': ('full page text: markdown (recommended) or text; ' 'false to skip raw content.'), }, 'chunks_per_source': { - 'type': 'integer', - 'minimum': 1, - 'maximum': 3, + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 3, 'description': ('Relevant chunks per URL when search_depth=advanced. ' 'Each chunk up to ~500 chars in `content` field.'), @@ -142,13 +157,15 @@ def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': 'description': 'Exclude domains (max 150).', }, 'country': { - 'type': 'string', + 'type': + 'string', 'description': ('Boost results from country (e.g. united states). ' 'See Tavily docs for enum.'), }, 'exact_match': { - 'type': 'boolean', + 'type': + 'boolean', 'description': 'Only results with exact quoted phrases in query.', }, diff --git a/ms_agent/tools/search/web_search_spill.py b/ms_agent/tools/search/web_search_spill.py index 9c8da73c3..11a29bf16 100644 --- a/ms_agent/tools/search/web_search_spill.py +++ b/ms_agent/tools/search/web_search_spill.py @@ -206,9 +206,8 @@ def order_by_size() -> List[int]: work[idx]['content_path'] = rel_body work[idx]['content_chars_spilled'] = before_chars - preview_src = ( - item.get('content') or item.get('summary') or item.get('abstract') - or '')[:4000] + preview_src = (item.get('content') or item.get('summary') + or item.get('abstract') or '')[:4000] manifest_rows.append({ 'index': idx, @@ -251,8 +250,8 @@ def order_by_size() -> List[int]: 'rows': manifest_rows, } - rel_manifest = os.path.join(spill_subdir, run_key, 'manifest.json').replace( - '\\', '/') + rel_manifest = os.path.join(spill_subdir, run_key, + 'manifest.json').replace('\\', '/') abs_manifest = os.path.normpath( os.path.join(output_dir, rel_manifest.replace('/', os.sep))) with open(abs_manifest, 'w', encoding='utf-8') as mf: @@ -274,19 +273,12 @@ def order_by_size() -> List[int]: digest = '\n'.join(lines) spill_meta = { - 'spilled': - True, - 'run_key': - run_key, - 'artifact_dir': - f'{spill_subdir}/{run_key}'.replace('\\', '/'), - 'manifest_path': - rel_manifest, - 'digest': - digest, - 'inline_chars_before_spill': - total, - 'inline_chars_after_spill': - _total_inline_chars(work), + 'spilled': True, + 'run_key': run_key, + 'artifact_dir': f'{spill_subdir}/{run_key}'.replace('\\', '/'), + 'manifest_path': rel_manifest, + 'digest': digest, + 'inline_chars_before_spill': total, + 'inline_chars_after_spill': _total_inline_chars(work), } return work, spill_meta diff --git a/ms_agent/tools/search/websearch_tool.py b/ms_agent/tools/search/websearch_tool.py index 16d6005d2..167bbe3bb 100644 --- a/ms_agent/tools/search/websearch_tool.py +++ b/ms_agent/tools/search/websearch_tool.py @@ -16,7 +16,8 @@ ContentOptimizerConfig, SearchResultReranker) from ms_agent.tools.search.search_base import ENGINE_TOOL_NAMES, SearchEngine -from ms_agent.tools.search.web_search_spill import maybe_spill_web_search_payload +from ms_agent.tools.search.web_search_spill import \ + maybe_spill_web_search_payload from ms_agent.utils.logger import get_logger from ms_agent.utils.thread_util import DaemonThreadPoolExecutor @@ -199,8 +200,10 @@ def get_content_fetcher(fetcher_type: str = 'jina_reader', config = JinaReaderConfig( timeout=kwargs.get('timeout', 45.0), retries=kwargs.get('retries', 3), - direct_fetch_fallback=bool(kwargs.get('direct_fetch_fallback', True)), - direct_fetch_timeout=float(kwargs.get('direct_fetch_timeout', 15.0)), + direct_fetch_fallback=bool( + kwargs.get('direct_fetch_fallback', True)), + direct_fetch_timeout=float( + kwargs.get('direct_fetch_timeout', 15.0)), playwright_fetch_fallback=bool( kwargs.get('playwright_fetch_fallback', True)), playwright_retry_min_chars=int( @@ -217,10 +220,14 @@ def get_content_fetcher(fetcher_type: str = 'jina_reader', extract_depth=str(kwargs.get('tavily_extract_depth', 'advanced')), format=str(kwargs.get('tavily_extract_format', 'markdown')), timeout=float(kwargs.get('timeout', 45.0)), - chunks_per_source=int(kwargs.get('tavily_extract_chunks_per_source', 3)), - include_images=bool(kwargs.get('tavily_extract_include_images', False)), - include_favicon=bool(kwargs.get('tavily_extract_include_favicon', False)), - include_usage=bool(kwargs.get('tavily_extract_include_usage', False)), + chunks_per_source=int( + kwargs.get('tavily_extract_chunks_per_source', 3)), + include_images=bool( + kwargs.get('tavily_extract_include_images', False)), + include_favicon=bool( + kwargs.get('tavily_extract_include_favicon', False)), + include_usage=bool( + kwargs.get('tavily_extract_include_usage', False)), ) # Future: add more fetchers # elif fetcher_type == 'docling': @@ -233,8 +240,8 @@ def get_content_fetcher(fetcher_type: str = 'jina_reader', JinaReaderConfig( timeout=kwargs.get('timeout', 45.0), retries=kwargs.get('retries', 3), - direct_fetch_fallback=bool(kwargs.get('direct_fetch_fallback', - True)), + direct_fetch_fallback=bool( + kwargs.get('direct_fetch_fallback', True)), direct_fetch_timeout=float( kwargs.get('direct_fetch_timeout', 15.0)), playwright_fetch_fallback=bool( @@ -490,8 +497,8 @@ def __init__(self, config, **kwargs): or os.getenv('SERPAPI_API_KEY')) if tool_cfg else os.getenv('SERPAPI_API_KEY'), 'tavily': (getattr(tool_cfg, 'tavily_api_key', None) - or os.getenv('TAVILY_API_KEY')) if tool_cfg else - os.getenv('TAVILY_API_KEY'), + or os.getenv('TAVILY_API_KEY')) + if tool_cfg else os.getenv('TAVILY_API_KEY'), } # Tavily search defaults from optional `tavily:` sub-block in YAML @@ -530,23 +537,25 @@ def __init__(self, config, **kwargs): self._fetch_retries = int(getattr(tool_cfg, 'fetch_retries', 3) or 3) if tool_cfg else 3 self._jina_direct_fetch_fallback = bool( - getattr(tool_cfg, 'jina_direct_fetch_fallback', True) - ) if tool_cfg else True - if tool_cfg is not None and hasattr(tool_cfg, 'jina_direct_fetch_timeout'): + getattr(tool_cfg, 'jina_direct_fetch_fallback', + True)) if tool_cfg else True + if tool_cfg is not None and hasattr(tool_cfg, + 'jina_direct_fetch_timeout'): self._jina_direct_fetch_timeout = float( tool_cfg.jina_direct_fetch_timeout) else: self._jina_direct_fetch_timeout = 15.0 self._jina_playwright_fetch_fallback = bool( - getattr(tool_cfg, 'jina_playwright_fetch_fallback', True) - ) if tool_cfg else True + getattr(tool_cfg, 'jina_playwright_fetch_fallback', + True)) if tool_cfg else True self._jina_playwright_retry_min_chars = int( - getattr(tool_cfg, 'jina_playwright_retry_min_chars', 400) or 400 - ) if tool_cfg else 400 + getattr(tool_cfg, 'jina_playwright_retry_min_chars', 400) + or 400) if tool_cfg else 400 self._jina_playwright_timeout_ms = int( - getattr(tool_cfg, 'jina_playwright_timeout_ms', 30000) or 30000 - ) if tool_cfg else 30000 - if tool_cfg is not None and hasattr(tool_cfg, 'jina_playwright_settle_ms'): + getattr(tool_cfg, 'jina_playwright_timeout_ms', 30000) + or 30000) if tool_cfg else 30000 + if tool_cfg is not None and hasattr(tool_cfg, + 'jina_playwright_settle_ms'): self._jina_playwright_settle_ms = int( tool_cfg.jina_playwright_settle_ms) else: @@ -610,7 +619,8 @@ def __init__(self, config, **kwargs): # Large payload spill (write bodies to disk; keep JSON small) self._spill_enabled = bool( - getattr(tool_cfg, 'spill_large_results', True)) if tool_cfg else True + getattr(tool_cfg, 'spill_large_results', + True)) if tool_cfg else True self._spill_max_inline_chars = int( getattr(tool_cfg, 'spill_max_inline_chars', 120000) or 120000) if tool_cfg else 120000 @@ -689,7 +699,8 @@ async def connect(self) -> None: 'direct_fetch_fallback': self._jina_direct_fetch_fallback, 'direct_fetch_timeout': self._jina_direct_fetch_timeout, 'playwright_fetch_fallback': self._jina_playwright_fetch_fallback, - 'playwright_retry_min_chars': self._jina_playwright_retry_min_chars, + 'playwright_retry_min_chars': + self._jina_playwright_retry_min_chars, 'playwright_timeout_ms': self._jina_playwright_timeout_ms, 'playwright_settle_ms': self._jina_playwright_settle_ms, } @@ -967,10 +978,10 @@ async def _bounded_fetch(url: str) -> Dict[str, Any]: return await asyncio.gather(*tasks) def _do_search( - self, engine_type: str, engine: SearchEngine, - engine_cls: Type[SearchEngine], - tool_args: Dict[str, Any] - ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + self, engine_type: str, engine: SearchEngine, + engine_cls: Type[SearchEngine], + tool_args: Dict[str, + Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """Perform search; returns (result rows, extra top-level metadata e.g. Tavily).""" try: merged = dict(tool_args) @@ -1093,8 +1104,8 @@ async def _execute_search(self, engine_type: str, if urls: fetch_attempts = len(urls) fetch_results = await self._fetch_multiple_async(urls) - fetch_timeouts = sum( - 1 for r in fetch_results if r.get('fetch_timed_out')) + fetch_timeouts = sum(1 for r in fetch_results + if r.get('fetch_timed_out')) # Merge search metadata with fetched content url_to_fetch = {r['url']: r for r in fetch_results} @@ -1317,12 +1328,9 @@ async def _execute_search(self, engine_type: str, } if fetch_content and self._content_fetcher: response['fetch_stats'] = { - 'per_url_timeout_s': - self._per_url_fetch_timeout_s, - 'urls_fetched_this_call': - fetch_attempts, - 'urls_timed_out': - fetch_timeouts, + 'per_url_timeout_s': self._per_url_fetch_timeout_s, + 'urls_fetched_this_call': fetch_attempts, + 'urls_timed_out': fetch_timeouts, } if tavily_extra: response.update(tavily_extra) diff --git a/ms_agent/tools/search_engine.py b/ms_agent/tools/search_engine.py index 6bb0be4d6..66e05de8e 100644 --- a/ms_agent/tools/search_engine.py +++ b/ms_agent/tools/search_engine.py @@ -1,8 +1,8 @@ import os import threading +from dotenv import load_dotenv from typing import Any, Dict, Optional -from dotenv import load_dotenv from ms_agent.config.env import Env from ms_agent.tools.search.arxiv import ArxivSearch from ms_agent.tools.search.exa import ExaSearch diff --git a/ms_agent/tools/task_control_tool.py b/ms_agent/tools/task_control_tool.py index 01bd75f64..a0c781e94 100644 --- a/ms_agent/tools/task_control_tool.py +++ b/ms_agent/tools/task_control_tool.py @@ -1,11 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import json +from omegaconf import DictConfig from typing import Any, Dict, Optional from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.logger import get_logger -from omegaconf import DictConfig logger = get_logger() @@ -45,9 +45,10 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='list_tasks', server_name=_SERVER, - description=( - 'List all background tasks and their current status. ' - 'Returns task_id, tool_name, description, status, and duration.'), + description= + ('List all background tasks and their current status. ' + 'Returns task_id, tool_name, description, status, and duration.' + ), parameters={ 'type': 'object', 'properties': {}, @@ -58,13 +59,16 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='cancel_task', server_name=_SERVER, - description='Cancel a running background task by its task_id.', + description= + 'Cancel a running background task by its task_id.', parameters={ 'type': 'object', 'properties': { 'task_id': { - 'type': 'string', - 'description': 'The task_id returned by the async_launched response.', + 'type': + 'string', + 'description': + 'The task_id returned by the async_launched response.', } }, 'required': ['task_id'], diff --git a/ms_agent/tools/todolist_tool.py b/ms_agent/tools/todolist_tool.py index aee860134..a2ae94bb5 100644 --- a/ms_agent/tools/todolist_tool.py +++ b/ms_agent/tools/todolist_tool.py @@ -1,9 +1,9 @@ +import json import os import time from dataclasses import dataclass from typing import Any, Dict, List, Optional -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.utils import file_lock, render_markdown_todo diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 70703fbfb..161177088 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -2,6 +2,7 @@ import asyncio import importlib import inspect +import json import os import sys import uuid @@ -9,7 +10,6 @@ from types import TracebackType from typing import Any, Dict, List, Optional -import json from ms_agent.llm.utils import Tool, ToolCall from ms_agent.tools.agent_tool import AgentTool from ms_agent.tools.base import ToolBase @@ -18,7 +18,8 @@ from ms_agent.tools.image_generator import ImageGenerator from ms_agent.tools.mcp_client import MCPClient from ms_agent.tools.search.localsearch_tool import LocalSearchTool -from ms_agent.tools.search.sirchmunk_search import effective_localsearch_settings +from ms_agent.tools.search.sirchmunk_search import \ + effective_localsearch_settings from ms_agent.tools.search.websearch_tool import WebSearchTool from ms_agent.tools.todolist_tool import TodoListTool from ms_agent.tools.video_generator import VideoGenerator @@ -75,11 +76,12 @@ def __init__(self, self.extra_tools.append(CodeExecutionTool(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'financial_data_fetcher'): - from ms_agent.tools.findata.findata_fetcher import FinancialDataFetcher + from ms_agent.tools.findata.findata_fetcher import \ + FinancialDataFetcher self.extra_tools.append(FinancialDataFetcher(config)) - if hasattr(config, 'tools') and ( - getattr(config.tools, 'agent_tools', None) - or hasattr(config.tools, 'split_task')): + if hasattr(config, + 'tools') and (getattr(config.tools, 'agent_tools', None) + or hasattr(config.tools, 'split_task')): agent_tool = AgentTool( config, trust_remote_code=self.trust_remote_code) if agent_tool.enabled: @@ -231,10 +233,9 @@ async def single_call_tool(self, tool_info: ToolCall): call_args = dict(tool_args or {}) call_id = tool_info.get('id') or str(uuid.uuid4()) call_args['__call_id'] = call_id - elif isinstance( - tool_ins, - LocalCodeExecutionTool) and tool_name.endswith( - f'{self.TOOL_SPLITER}shell_executor'): + elif isinstance(tool_ins, + LocalCodeExecutionTool) and tool_name.endswith( + f'{self.TOOL_SPLITER}shell_executor'): call_args = dict(tool_args or {}) call_args['__call_id'] = tool_info.get('id') or str( uuid.uuid4()) diff --git a/ms_agent/utils/artifact_manager.py b/ms_agent/utils/artifact_manager.py index 9953f25c0..2d73457f0 100644 --- a/ms_agent/utils/artifact_manager.py +++ b/ms_agent/utils/artifact_manager.py @@ -49,8 +49,8 @@ def pack_text_result( out.update(extra) return out - safe_id = ''.join(c if c.isalnum() or c in '-_' else '_' for c in call_id - )[:120] or 'call' + safe_id = ''.join(c if c.isalnum() or c in '-_' else '_' + for c in call_id)[:120] or 'call' rel_dir = Path(tool_name) / safe_id out_dir = self._artifact_root / rel_dir out_dir.mkdir(parents=True, exist_ok=True) @@ -63,10 +63,10 @@ def pack_text_result( preview = _make_preview(body, self.preview_head_chars, self.preview_tail_chars) result = { - 'output': stdout[:self.preview_head_chars] + 'output': + stdout[:self.preview_head_chars] if len(stdout) > self.preview_head_chars else stdout, - 'error': - (stderr[:self.preview_head_chars] if stderr else None), + 'error': (stderr[:self.preview_head_chars] if stderr else None), 'truncated': True, 'artifact_path': @@ -102,16 +102,11 @@ def pack_json_shell_result( ) # pack_text_result merged extra into top level; rebuild standard shell shape out = { - 'success': - payload.get('success'), - 'output': - packed.get('output'), - 'error': - packed.get('error'), - 'return_code': - payload.get('return_code'), - 'truncated': - packed.get('truncated', False), + 'success': payload.get('success'), + 'output': packed.get('output'), + 'error': packed.get('error'), + 'return_code': payload.get('return_code'), + 'truncated': packed.get('truncated', False), } if packed.get('artifact_path'): out['artifact_path'] = packed['artifact_path'] diff --git a/ms_agent/utils/parser_utils.py b/ms_agent/utils/parser_utils.py index 5455dae52..af22034c6 100644 --- a/ms_agent/utils/parser_utils.py +++ b/ms_agent/utils/parser_utils.py @@ -1,11 +1,10 @@ +import json import os import re from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Dict, List, Optional -import json - @dataclass class ImportInfo: diff --git a/ms_agent/utils/push_to_hub.py b/ms_agent/utils/push_to_hub.py index a164603dd..2cc6d7b7a 100644 --- a/ms_agent/utils/push_to_hub.py +++ b/ms_agent/utils/push_to_hub.py @@ -1,15 +1,15 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 +import json import mimetypes import os import re +import requests import shutil from abc import ABC, abstractmethod from pathlib import Path from typing import List, Optional, Tuple -import json -import requests from ms_agent.utils.logger import get_logger from ms_agent.utils.utils import (get_files_from_dir, is_package_installed, text_hash) @@ -353,8 +353,7 @@ def __init__( 'ModelScope package is not installed. Please install it with `pip install modelscope`.' ) - from modelscope.hub.api import HubApi - from modelscope.hub.api import get_endpoint + from modelscope.hub.api import HubApi, get_endpoint self.api = HubApi() self.token = token diff --git a/ms_agent/utils/snapshot.py b/ms_agent/utils/snapshot.py index 2014b49da..d7362405a 100644 --- a/ms_agent/utils/snapshot.py +++ b/ms_agent/utils/snapshot.py @@ -8,8 +8,8 @@ All git commands are run with GIT_DIR and GIT_WORK_TREE explicitly set, so the snapshot repo is fully isolated from any surrounding repository. """ -import os import json +import os import subprocess from typing import Optional @@ -21,7 +21,9 @@ _META_FILE = 'snapshot_meta.json' -def _git(args: list[str], work_tree: str, git_dir: str, +def _git(args: list[str], + work_tree: str, + git_dir: str, check: bool = True) -> subprocess.CompletedProcess: env = os.environ.copy() env['GIT_DIR'] = git_dir @@ -42,7 +44,8 @@ def _snapshot_git_dir(output_dir: str) -> str: return os.path.join(output_dir, _SNAPSHOT_DIR_NAME) -def _configure_snapshot_repo_for_automation(work_tree: str, git_dir: str) -> None: +def _configure_snapshot_repo_for_automation(work_tree: str, + git_dir: str) -> None: """Disable hook execution for the nested snapshot repo. Without this, Git can inherit ``init.templateDir`` / global ``core.hooksPath`` @@ -68,9 +71,11 @@ def _ensure_repo(output_dir: str) -> str: # Do NOT pass a path argument; GIT_DIR env var points git at our custom dir. _git(['init'], work_tree=output_dir, git_dir=git_dir) _git(['config', 'user.email', 'ms-agent@snapshot'], - work_tree=output_dir, git_dir=git_dir) + work_tree=output_dir, + git_dir=git_dir) _git(['config', 'user.name', 'ms-agent'], - work_tree=output_dir, git_dir=git_dir) + work_tree=output_dir, + git_dir=git_dir) # Exclude the snapshot dir itself from tracking info_dir = os.path.join(git_dir, 'info') os.makedirs(info_dir, exist_ok=True) @@ -103,7 +108,8 @@ def _save_meta(output_dir: str, meta: dict) -> None: json.dump(meta, f, indent=2) -def take_snapshot(output_dir: str, message: str, +def take_snapshot(output_dir: str, + message: str, message_count: int = 0) -> Optional[str]: """ Stage all changes in output_dir and create a snapshot commit. @@ -128,14 +134,16 @@ def take_snapshot(output_dir: str, message: str, # Check if there's anything to commit status = _git(['status', '--porcelain'], - work_tree=output_dir, git_dir=git_dir) + work_tree=output_dir, + git_dir=git_dir) if not status.stdout.strip(): return None # Nothing changed # Truncate message to keep commit subject readable subject = message.strip().replace('\n', ' ')[:120] result = _git(['commit', '--no-verify', '-m', subject], - work_tree=output_dir, git_dir=git_dir) + work_tree=output_dir, + git_dir=git_dir) commit_hash = None for line in result.stdout.splitlines(): @@ -154,8 +162,7 @@ def take_snapshot(output_dir: str, message: str, return commit_hash except FileNotFoundError: - logger.warning_once( - '[snapshot] git not found — snapshots disabled.') + logger.warning_once('[snapshot] git not found — snapshots disabled.') return None except subprocess.CalledProcessError as e: logger.warning(f'[snapshot] git error: {e.stderr.strip()}') @@ -189,18 +196,21 @@ def list_snapshots(output_dir: str) -> list[dict]: if len(parts) == 3: h = parts[0] snapshots.append({ - 'hash': h, - 'date': parts[1], - 'message': parts[2], - 'message_count': meta.get(h, {}).get('message_count', 0), + 'hash': + h, + 'date': + parts[1], + 'message': + parts[2], + 'message_count': + meta.get(h, {}).get('message_count', 0), }) return snapshots except Exception: return [] -def restore_snapshot(output_dir: str, - commit_hash: str) -> tuple[bool, int]: +def restore_snapshot(output_dir: str, commit_hash: str) -> tuple[bool, int]: """ Restore output_dir to the state at commit_hash. @@ -213,7 +223,8 @@ def restore_snapshot(output_dir: str, return False, 0 try: _git(['checkout', commit_hash, '--', '.'], - work_tree=output_dir, git_dir=git_dir) + work_tree=output_dir, + git_dir=git_dir) logger.info(f'[snapshot] Restored to {commit_hash}') meta = _load_meta(output_dir) message_count = meta.get(commit_hash, {}).get('message_count', 0) diff --git a/ms_agent/utils/stats.py b/ms_agent/utils/stats.py index 7ed705a1d..beeccfc1f 100644 --- a/ms_agent/utils/stats.py +++ b/ms_agent/utils/stats.py @@ -1,13 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio +import json import os import time from datetime import datetime from typing import Any, Dict, Iterable, Optional -import json from ms_agent.llm.utils import Message - from .logger import get_logger logger = get_logger() diff --git a/ms_agent/utils/stream_writer.py b/ms_agent/utils/stream_writer.py index 46c5cf95b..cdeddec77 100644 --- a/ms_agent/utils/stream_writer.py +++ b/ms_agent/utils/stream_writer.py @@ -55,7 +55,8 @@ def __init__(self, output_dir: str, call_id: str, tool_name: str) -> None: subagents_dir = os.path.join(output_dir, 'subagents') os.makedirs(subagents_dir, exist_ok=True) safe_id = self._call_id.replace('/', '_').replace('\\', '_') - self._path: str = os.path.join(subagents_dir, f'{safe_id}.stream.jsonl') + self._path: str = os.path.join(subagents_dir, + f'{safe_id}.stream.jsonl') @property def stream_path(self) -> str: @@ -84,8 +85,8 @@ def on_start(self, agent_tag: Optional[str]) -> None: 'ts': _now_iso(), }) except Exception as exc: - logger.warning( - 'SubAgentStreamWriter: failed to open %s: %s', self._path, exc) + logger.warning('SubAgentStreamWriter: failed to open %s: %s', + self._path, exc) self._file = None def on_chunk(self, history: Any) -> None: @@ -137,7 +138,8 @@ def on_end(self, history: Any) -> None: self._file.close() except Exception as exc: logger.warning( - 'SubAgentStreamWriter: close error on %s: %s', self._path, exc) + 'SubAgentStreamWriter: close error on %s: %s', + self._path, exc) finally: self._file = None diff --git a/ms_agent/utils/task_manager.py b/ms_agent/utils/task_manager.py index a897e2fa9..6edc00280 100644 --- a/ms_agent/utils/task_manager.py +++ b/ms_agent/utils/task_manager.py @@ -14,11 +14,12 @@ @dataclass class BackgroundTask: task_id: str - task_type: str # 'agent' | 'shell' - tool_name: str # which tool spawned this + task_type: str # 'agent' | 'shell' + tool_name: str # which tool spawned this description: str status: str = 'running' # 'running' | 'completed' | 'failed' | 'killed' - proc: Optional[Any] = field(default=None, repr=False) # mp.Process or asyncio.Task + proc: Optional[Any] = field( + default=None, repr=False) # mp.Process or asyncio.Task result: Optional[str] = None error: Optional[str] = None started_at: float = field(default_factory=time.monotonic) @@ -54,7 +55,9 @@ def register( proc=proc, ) self._tasks[task_id] = task - logger.info(f'[TaskManager] registered {task_type} task {task_id}: {description}') + logger.info( + f'[TaskManager] registered {task_type} task {task_id}: {description}' + ) return task_id async def complete(self, task_id: str, result: str) -> None: @@ -85,7 +88,8 @@ def kill(self, task_id: str) -> None: try: if isinstance(task.proc, mp.Process): task.proc.terminate() - elif asyncio.isfuture(task.proc) or asyncio.iscoroutine(task.proc): + elif asyncio.isfuture(task.proc) or asyncio.iscoroutine( + task.proc): task.proc.cancel() except Exception as e: logger.warning(f'[TaskManager] kill {task_id} failed: {e}') @@ -120,13 +124,11 @@ def _format_notification(task: BackgroundTask) -> str: duration = '' if task.ended_at: duration = f'\n{task.ended_at - task.started_at:.1f}' - return ( - f'\n' - f'{task.task_id}\n' - f'{task.task_type}\n' - f'{task.tool_name}\n' - f'{task.description}\n' - f'{task.status}' - f'{result_line}{error_line}{duration}\n' - f'' - ) + return (f'\n' + f'{task.task_id}\n' + f'{task.task_type}\n' + f'{task.tool_name}\n' + f'{task.description}\n' + f'{task.status}' + f'{result_line}{error_line}{duration}\n' + f'') diff --git a/ms_agent/utils/thread_util.py b/ms_agent/utils/thread_util.py index 16e46eba9..380dd99b3 100644 --- a/ms_agent/utils/thread_util.py +++ b/ms_agent/utils/thread_util.py @@ -4,9 +4,9 @@ import weakref from concurrent.futures import ThreadPoolExecutor, as_completed from functools import wraps +from tqdm.auto import tqdm from ms_agent.utils.logger import get_logger -from tqdm.auto import tqdm logger = get_logger() @@ -102,7 +102,8 @@ def weakref_cb(_, q=self._work_queue): thread_name = '%s_%d' % (self._thread_name_prefix or self, num_threads) # Import internal helpers from stdlib to keep behavior consistent. - from concurrent.futures.thread import _worker, _threads_queues # type: ignore + from concurrent.futures.thread import ( # type: ignore + _threads_queues, _worker) t = threading.Thread( name=thread_name, diff --git a/ms_agent/utils/utils.py b/ms_agent/utils/utils.py index a6d0bc87d..91d9c6d42 100644 --- a/ms_agent/utils/utils.py +++ b/ms_agent/utils/utils.py @@ -5,21 +5,20 @@ import html import importlib import importlib.util +import json import os.path import re +import requests import subprocess import sys import time +import yaml from contextlib import contextmanager from io import BytesIO +from omegaconf import DictConfig, OmegaConf from pathlib import Path from typing import List, Optional, Tuple, Union -import json -import requests -import yaml -from omegaconf import DictConfig, OmegaConf - from .constants import DEFAULT_MEMORY_DIR from .logger import get_logger @@ -240,8 +239,8 @@ def read_history(output_dir: str, task: str): TypeError / AttributeError: If the deserialized JSON data lacks expected keys or structure for Message objects. """ - from ms_agent.llm import Message from ms_agent.config import Config + from ms_agent.llm import Message cache_dir = os.path.join(output_dir, DEFAULT_MEMORY_DIR) os.makedirs(cache_dir, exist_ok=True) config_file = os.path.join(cache_dir, f'{task}.yaml') diff --git a/ms_agent/utils/workspace_policy.py b/ms_agent/utils/workspace_policy.py index a8380b710..c11edf4db 100644 --- a/ms_agent/utils/workspace_policy.py +++ b/ms_agent/utils/workspace_policy.py @@ -35,7 +35,7 @@ def __init__( if p not in self._roots: self._roots.append(p) if deny_globs is None or len(tuple(deny_globs)) == 0: - self._deny_globs: tuple[str, ...] = ('**/.git/**',) + self._deny_globs: tuple[str, ...] = ('**/.git/**', ) else: self._deny_globs = tuple(deny_globs) self.shell_default_mode = shell_default_mode @@ -113,13 +113,13 @@ def assert_shell_command_allowed(self, command: str) -> None: mode = self.shell_default_mode if mode == 'read_only': - if _shell_looks_mutating_or_network(command, - allow_network=False): + if _shell_looks_mutating_or_network(command, allow_network=False): raise WorkspacePolicyError( 'Shell is in read_only mode: mutating or network commands are not allowed' ) elif mode == 'workspace_write': - if not self.shell_network_enabled and _shell_looks_network(command): + if not self.shell_network_enabled and _shell_looks_network( + command): raise WorkspacePolicyError( 'Network commands are disabled for shell (enable tools.code_executor.shell.network_enabled)' ) diff --git a/ms_agent/workflow/base.py b/ms_agent/workflow/base.py index 9d484118c..e2f8b0cec 100644 --- a/ms_agent/workflow/base.py +++ b/ms_agent/workflow/base.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from abc import ABC, abstractmethod +from omegaconf import DictConfig from typing import Dict, Optional from ms_agent.config import Config -from omegaconf import DictConfig class Workflow(ABC): diff --git a/ms_agent/workflow/chain_workflow.py b/ms_agent/workflow/chain_workflow.py index 2b2b739e1..994c98066 100644 --- a/ms_agent/workflow/chain_workflow.py +++ b/ms_agent/workflow/chain_workflow.py @@ -1,10 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os +from omegaconf import DictConfig from ms_agent.agent.loader import AgentLoader from ms_agent.utils import get_logger from ms_agent.workflow.base import Workflow -from omegaconf import DictConfig logger = get_logger() diff --git a/ms_agent/workflow/dag_workflow.py b/ms_agent/workflow/dag_workflow.py index 1419ccf46..14ea1d671 100644 --- a/ms_agent/workflow/dag_workflow.py +++ b/ms_agent/workflow/dag_workflow.py @@ -1,12 +1,12 @@ # Copyright (c) Alibaba, Inc. import os from collections import defaultdict, deque +from omegaconf import DictConfig from typing import Any, Dict, List, Set from ms_agent.agent.loader import AgentLoader from ms_agent.utils import get_logger from ms_agent.workflow.base import Workflow -from omegaconf import DictConfig logger = get_logger() diff --git a/ms_agent/workflow/deep_research/research_utils.py b/ms_agent/workflow/deep_research/research_utils.py index 7d5ccdb69..1bfdbc70a 100644 --- a/ms_agent/workflow/deep_research/research_utils.py +++ b/ms_agent/workflow/deep_research/research_utils.py @@ -1,8 +1,7 @@ -from typing import Dict, List, Optional - from pydantic import BaseModel, Field from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn +from typing import Dict, List, Optional console = Console() diff --git a/ms_agent/workflow/deep_research/research_workflow.py b/ms_agent/workflow/deep_research/research_workflow.py index 03d64bf80..23ec6957b 100644 --- a/ms_agent/workflow/deep_research/research_workflow.py +++ b/ms_agent/workflow/deep_research/research_workflow.py @@ -1,11 +1,11 @@ # flake8: noqa # yapf: disable import copy +import json import os import re from typing import Any, Dict, List, Optional, Union -import json from ms_agent.llm.openai import OpenAIChat from ms_agent.utils import get_logger @@ -222,7 +222,8 @@ def generate_todo(self, **kwargs) -> None: def search(self, search_request: 'SearchRequest') -> str: from ms_agent.tools.search.exa.schema import dump_batch_search_results - from ms_agent.tools.search.search_base import SearchRequest, SearchResult + from ms_agent.tools.search.search_base import (SearchRequest, + SearchResult) if self._reuse: # Load existing search results if they exist if os.path.exists(self.workdir_structure['search']): @@ -349,7 +350,8 @@ def run(self, from ms_agent.rag.extraction_manager import extract_key_information from ms_agent.rag.schema import KeyInformation from ms_agent.tools.search.search_base import SearchResult - from ms_agent.tools.search.search_request import get_search_request_generator + from ms_agent.tools.search.search_request import \ + get_search_request_generator from ms_agent.utils.utils import remove_resource_info, text_hash special_resources: List = [] if urls_or_files: diff --git a/ms_agent/workflow/deep_research/research_workflow_beta.py b/ms_agent/workflow/deep_research/research_workflow_beta.py index 24816ae52..b490f098d 100644 --- a/ms_agent/workflow/deep_research/research_workflow_beta.py +++ b/ms_agent/workflow/deep_research/research_workflow_beta.py @@ -1,13 +1,14 @@ # yapf: disable import asyncio +import click import os import re import threading from concurrent.futures import ThreadPoolExecutor from datetime import datetime +from rich.prompt import Confirm, Prompt from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import click from ms_agent.llm.openai import OpenAIChat from ms_agent.rag.extraction_manager import extract_key_information from ms_agent.tools.search.exa.schema import dump_batch_search_results @@ -21,7 +22,6 @@ ResearchProgress, ResearchResult) from ms_agent.workflow.deep_research.research_workflow import ResearchWorkflow -from rich.prompt import Confirm, Prompt logger = get_logger() diff --git a/ms_agent/workflow/loader.py b/ms_agent/workflow/loader.py index 7e0589cfa..b11a2763c 100644 --- a/ms_agent/workflow/loader.py +++ b/ms_agent/workflow/loader.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig, OmegaConf from typing import Dict, Optional from ms_agent.config.config import Config -from omegaconf import DictConfig, OmegaConf class WorkflowLoader: diff --git a/projects/code_genesis/workflow/api_search.py b/projects/code_genesis/workflow/api_search.py index 4e380d17a..ddc933475 100644 --- a/projects/code_genesis/workflow/api_search.py +++ b/projects/code_genesis/workflow/api_search.py @@ -1,9 +1,9 @@ +import json import os import re from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.constants import DEFAULT_INDEX_DIR diff --git a/projects/code_genesis/workflow/coding.py b/projects/code_genesis/workflow/coding.py index c6cfc398f..055e91418 100644 --- a/projects/code_genesis/workflow/coding.py +++ b/projects/code_genesis/workflow/coding.py @@ -1,14 +1,15 @@ import asyncio import dataclasses +import json import os import re import shutil from collections import OrderedDict from copy import deepcopy +from omegaconf import DictConfig from pathlib import Path from typing import List, Optional, Set -import json from ms_agent import LLMAgent from ms_agent.agent import CodeAgent from ms_agent.llm import Message @@ -19,7 +20,6 @@ DEFAULT_TAG) from ms_agent.utils.parser_utils import ImportInfo, parse_imports from ms_agent.utils.utils import extract_code_blocks, file_lock -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/code_genesis/workflow/file_design.py b/projects/code_genesis/workflow/file_design.py index b33f977a5..216a0aa18 100644 --- a/projects/code_genesis/workflow/file_design.py +++ b/projects/code_genesis/workflow/file_design.py @@ -1,7 +1,7 @@ +import json import os from typing import List -import json from ms_agent import LLMAgent from ms_agent.llm import Message diff --git a/projects/code_genesis/workflow/file_order.py b/projects/code_genesis/workflow/file_order.py index 592a68025..0a2ec17eb 100644 --- a/projects/code_genesis/workflow/file_order.py +++ b/projects/code_genesis/workflow/file_order.py @@ -1,7 +1,7 @@ +import json import os from typing import List -import json from ms_agent import LLMAgent from ms_agent.llm import Message diff --git a/projects/code_genesis/workflow/refine.py b/projects/code_genesis/workflow/refine.py index 17c9ffba4..6e07e3359 100644 --- a/projects/code_genesis/workflow/refine.py +++ b/projects/code_genesis/workflow/refine.py @@ -1,15 +1,15 @@ +import json import os import sys +from coding import CodingAgent +from omegaconf import DictConfig from typing import List, OrderedDict -import json -from coding import CodingAgent from ms_agent import LLMAgent from ms_agent.llm import Message from ms_agent.memory.condenser.refine_condenser import RefineCondenser from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/deep_research/v2/callbacks/quality_checker.py b/projects/deep_research/v2/callbacks/quality_checker.py index 36fadf902..bce1804fa 100644 --- a/projects/deep_research/v2/callbacks/quality_checker.py +++ b/projects/deep_research/v2/callbacks/quality_checker.py @@ -1,12 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json from abc import ABC, abstractmethod +from omegaconf import DictConfig, OmegaConf from typing import List, Optional -import json from ms_agent.llm.openai_llm import OpenAI as OpenAILLM from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig, OmegaConf logger = get_logger() diff --git a/projects/deep_research/v2/callbacks/reporter_callback.py b/projects/deep_research/v2/callbacks/reporter_callback.py index 477623a74..87d4f58ea 100644 --- a/projects/deep_research/v2/callbacks/reporter_callback.py +++ b/projects/deep_research/v2/callbacks/reporter_callback.py @@ -1,19 +1,19 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # yapf: disable +import json import os import re import shutil -from typing import Any, Dict, List, Optional, Set - -import json from callbacks.quality_checker import (ReportQualityChecker, build_quality_checkers) +from omegaconf import DictConfig +from typing import Any, Dict, List, Optional, Set + from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_MEMORY_DIR -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/deep_research/v2/callbacks/researcher_callback.py b/projects/deep_research/v2/callbacks/researcher_callback.py index 8306a6ba7..1b162f83e 100644 --- a/projects/deep_research/v2/callbacks/researcher_callback.py +++ b/projects/deep_research/v2/callbacks/researcher_callback.py @@ -3,16 +3,16 @@ import os import re import shutil -from typing import List, Optional - from callbacks.quality_checker import (ReportQualityChecker, build_quality_checkers) +from omegaconf import DictConfig, OmegaConf +from typing import List, Optional + from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.openai_llm import OpenAI as OpenAILLM from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig, OmegaConf logger = get_logger() diff --git a/projects/deep_research/v2/callbacks/searcher_callback.py b/projects/deep_research/v2/callbacks/searcher_callback.py index 735a2d47a..b93a8fd38 100644 --- a/projects/deep_research/v2/callbacks/searcher_callback.py +++ b/projects/deep_research/v2/callbacks/searcher_callback.py @@ -1,15 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json import os import re import uuid +from omegaconf import DictConfig from typing import Any, List, Optional -import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() @@ -31,7 +31,8 @@ def _parse_search_result_json(text: str) -> Optional[Any]: return json.loads(text) except (json.JSONDecodeError, TypeError): pass - m = re.search(r'```(?:json)?\s*\r?\n(.*?)```', text, flags=re.DOTALL | re.IGNORECASE) + m = re.search( + r'```(?:json)?\s*\r?\n(.*?)```', text, flags=re.DOTALL | re.IGNORECASE) if m: block = m.group(1).strip() if block: @@ -275,9 +276,13 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): try: with open(json_path, 'x', encoding='utf-8') as f: json.dump( - parsed_json, f, ensure_ascii=False, indent=2) + parsed_json, + f, + ensure_ascii=False, + indent=2) logger.info( - f'Searcher: Search result saved to {json_path}') + f'Searcher: Search result saved to {json_path}' + ) except FileExistsError: logger.info( f'Search result already exists at {json_path}') diff --git a/projects/deep_research/v2/eval/dr_bench_runner.py b/projects/deep_research/v2/eval/dr_bench_runner.py index 1917564bf..add3634ec 100644 --- a/projects/deep_research/v2/eval/dr_bench_runner.py +++ b/projects/deep_research/v2/eval/dr_bench_runner.py @@ -17,7 +17,9 @@ """ from __future__ import annotations + import argparse +import json import os import subprocess import sys @@ -28,8 +30,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple -import json - try: # Auto-load environment variables from a nearby `.env` (if present). from dotenv import find_dotenv, load_dotenv diff --git a/projects/deep_research/v2/reporter.py b/projects/deep_research/v2/reporter.py index d9c6f2507..a0227ff65 100644 --- a/projects/deep_research/v2/reporter.py +++ b/projects/deep_research/v2/reporter.py @@ -1,12 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +from omegaconf import DictConfig from typing import Any, AsyncGenerator, List, Union from ms_agent.agent.llm_agent import LLMAgent from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/deep_research/v2/researcher.py b/projects/deep_research/v2/researcher.py index 6af015717..2b2726368 100644 --- a/projects/deep_research/v2/researcher.py +++ b/projects/deep_research/v2/researcher.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from omegaconf import DictConfig from typing import Any, AsyncGenerator, List, Union from ms_agent.agent.llm_agent import LLMAgent @@ -8,7 +9,6 @@ from ms_agent.utils.stats import (append_stats, build_timing_record, get_stats_path, monotonic, now_iso, summarize_usage) -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/deep_research/v2/time_handler.py b/projects/deep_research/v2/time_handler.py index a36282b5f..2240ee515 100644 --- a/projects/deep_research/v2/time_handler.py +++ b/projects/deep_research/v2/time_handler.py @@ -1,9 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from datetime import datetime +from omegaconf import DictConfig from typing import Any from ms_agent.config.config import ConfigLifecycleHandler -from omegaconf import DictConfig class TimeHandler(ConfigLifecycleHandler): diff --git a/projects/deep_research/v2/tools/evidence_tool.py b/projects/deep_research/v2/tools/evidence_tool.py index 1379ce511..dd6866f4b 100644 --- a/projects/deep_research/v2/tools/evidence_tool.py +++ b/projects/deep_research/v2/tools/evidence_tool.py @@ -1,11 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json import os import re import time import uuid from typing import Any, Dict, List, Optional -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.utils import file_lock @@ -372,15 +372,15 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'Each note represents ONE piece of evidence: a claim/observation with supporting text. ' 'Returns the generated note_id.'), parameters={ - 'type': - 'object', + 'type': 'object', 'properties': { 'title': { 'type': 'string', 'description': ('Brief title describing this evidence (e.g., "Tesla Q3 revenue growth"). ' - 'Optional: if omitted, a title is derived from the first line of `content`.'), + 'Optional: if omitted, a title is derived from the first line of `content`.' + ), }, 'content': { 'type': @@ -466,8 +466,7 @@ async def _get_tools_inner(self) -> Dict[str, Any]: }, }, 'required': ['content'], - 'additionalProperties': - False, + 'additionalProperties': False, }, ), Tool( @@ -864,8 +863,10 @@ async def write_note( content = (content or '').strip() if not content: return _json_dumps({ - 'status': 'error', - 'message': 'write_note requires non-empty content.', + 'status': + 'error', + 'message': + 'write_note requires non-empty content.', }) if title is None or not str(title).strip(): diff --git a/projects/deep_research/v2/tools/report_tool.py b/projects/deep_research/v2/tools/report_tool.py index 819da108c..924b49804 100644 --- a/projects/deep_research/v2/tools/report_tool.py +++ b/projects/deep_research/v2/tools/report_tool.py @@ -1,11 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json import os import re import time import uuid from typing import Any, Dict, List, Optional -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.utils import file_lock, render_markdown_todo @@ -43,7 +43,8 @@ def _write_text(path: str, content: str) -> None: f.write(content) -def _coerce_chapters_argument(chapters: Any) -> tuple[List[Dict[str, Any]], Optional[str]]: +def _coerce_chapters_argument( + chapters: Any) -> tuple[List[Dict[str, Any]], Optional[str]]: """Normalize `chapters` from the model (list, JSON string, or nested strings).""" if chapters is None: return [], ( @@ -59,19 +60,20 @@ def _coerce_chapters_argument(chapters: Any) -> tuple[List[Dict[str, Any]], Opti f'or a JSON string of that array: {e}') if not isinstance(raw, list): return [], ( - f'commit_outline `chapters` must be a list, got {type(chapters).__name__}.') + f'commit_outline `chapters` must be a list, got {type(chapters).__name__}.' + ) out: List[Dict[str, Any]] = [] for i, ch in enumerate(raw): if isinstance(ch, str): try: ch = json.loads(ch.strip()) except json.JSONDecodeError: - return [], ( - f'commit_outline chapters[{i}] must be an object; ' - 'string entry is not valid JSON for an object.') + return [], (f'commit_outline chapters[{i}] must be an object; ' + 'string entry is not valid JSON for an object.') if not isinstance(ch, dict): return [], ( - f'commit_outline chapters[{i}] must be an object, got {type(ch).__name__}.') + f'commit_outline chapters[{i}] must be an object, got {type(ch).__name__}.' + ) out.append(ch) return out, None @@ -567,11 +569,11 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='load_index', server_name=self.SERVER_NAME, - description=( - 'Load the full evidence index (notes + analyses metadata). ' - 'Same data as evidence_store---load_index; provided here so calls ' - 'mistakenly prefixed with report_generator--- still work.' - ), + description= + ('Load the full evidence index (notes + analyses metadata). ' + 'Same data as evidence_store---load_index; provided here so calls ' + 'mistakenly prefixed with report_generator--- still work.' + ), parameters={ 'type': 'object', 'properties': {}, @@ -615,7 +617,8 @@ def _load_evidence_index(self, paths: Dict[str, str]) -> Dict[str, Any]: return {'notes': {}} return data - def _load_full_evidence_index(self, paths: Dict[str, str]) -> Dict[str, Any]: + def _load_full_evidence_index(self, paths: Dict[str, + str]) -> Dict[str, Any]: """Load evidence/index.json with the same defaults as EvidenceTool.""" data = _safe_read_json(paths['evidence_index']) if data is None or not isinstance(data, dict): @@ -627,7 +630,8 @@ def _load_full_evidence_index(self, paths: Dict[str, str]) -> Dict[str, Any]: } if 'notes' not in data or not isinstance(data.get('notes'), dict): data['notes'] = {} - if 'analyses' not in data or not isinstance(data.get('analyses'), dict): + if 'analyses' not in data or not isinstance( + data.get('analyses'), dict): data['analyses'] = {} legacy = data.get('conclusions') if isinstance(legacy, dict) and legacy and not data.get('analyses'): @@ -809,24 +813,26 @@ async def prepare_chapter_bundle( notes_meta = evidence_index.get('notes', {}) _known_sorted = sorted(notes_meta.keys()) _sample = _known_sorted[:48] - _note_id_hint = ( - 'Known note ids in evidence index (sample): ' - + (', '.join(_sample) if _sample else '(none)') - ) + _note_id_hint = ('Known note ids in evidence index (sample): ' + + (', '.join(_sample) if _sample else '(none)')) if len(_known_sorted) > len(_sample): _note_id_hint += f' … (+{len(_known_sorted) - len(_sample)} more)' - def _missing_note_entry(note_id: str, meta: Dict[str, Any]) -> Dict[str, Any]: + def _missing_note_entry(note_id: str, + meta: Dict[str, Any]) -> Dict[str, Any]: return { - 'note_id': note_id, - 'error': f'Note {note_id} not found', - 'title': meta.get('title', ''), - 'summary': meta.get('summary', ''), - 'hint': ( - f'{_note_id_hint}. ' - 'Align outline candidate_evidence with filenames under evidence_store, ' - 'or list existing notes before referencing ids.' - ), + 'note_id': + note_id, + 'error': + f'Note {note_id} not found', + 'title': + meta.get('title', ''), + 'summary': + meta.get('summary', ''), + 'hint': + (f'{_note_id_hint}. ' + 'Align outline candidate_evidence with filenames under evidence_store, ' + 'or list existing notes before referencing ids.'), } notes_content = [] @@ -913,20 +919,13 @@ def _missing_note_entry(note_id: str, meta: Dict[str, Any]) -> Dict[str, Any]: self._save_outline(paths, outline) out_bundle: Dict[str, Any] = { - 'status': - 'ok', - 'chapter_id': - chapter_id, - 'chapter_title': - chapter['title'], - 'chapter_goals': - chapter.get('goals', []), - 'evidence_count': - len(notes_content), - 'meta_path': - os.path.relpath(meta_path, self.output_dir), - 'notes_content': - notes_content, + 'status': 'ok', + 'chapter_id': chapter_id, + 'chapter_title': chapter['title'], + 'chapter_goals': chapter.get('goals', []), + 'evidence_count': len(notes_content), + 'meta_path': os.path.relpath(meta_path, self.output_dir), + 'notes_content': notes_content, } skipped: Dict[str, List[str]] = {} if cand_dropped: @@ -1116,8 +1115,7 @@ async def update_outline( out['invalid_candidate_evidence_removed'] = invalid_candidate_removed out['invalid_candidate_evidence_note'] = ( 'These ids were removed from candidate_evidence because no ' - 'matching evidence/notes/note_.md exists.' - ) + 'matching evidence/notes/note_.md exists.') return _json_dumps(out) async def assemble_draft( diff --git a/projects/fin_research/aggregator.py b/projects/fin_research/aggregator.py index 2380e58e3..b9e8fca5a 100644 --- a/projects/fin_research/aggregator.py +++ b/projects/fin_research/aggregator.py @@ -1,14 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os +from callbacks.file_parser import extract_code_blocks +from omegaconf import DictConfig from typing import Any, AsyncGenerator, List, Union -import json -from callbacks.file_parser import extract_code_blocks from ms_agent.agent.llm_agent import LLMAgent from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/fin_research/callbacks/aggregator_callback.py b/projects/fin_research/callbacks/aggregator_callback.py index b827f18a6..7faf6513f 100644 --- a/projects/fin_research/callbacks/aggregator_callback.py +++ b/projects/fin_research/callbacks/aggregator_callback.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os import re +from omegaconf import DictConfig from typing import List from ms_agent.agent.runtime import Runtime @@ -8,7 +9,6 @@ from ms_agent.llm.utils import Message from ms_agent.tools.filesystem_tool import FileSystemTool from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/fin_research/callbacks/analyst_callback.py b/projects/fin_research/callbacks/analyst_callback.py index 31af7306b..8b0b0f4e6 100644 --- a/projects/fin_research/callbacks/analyst_callback.py +++ b/projects/fin_research/callbacks/analyst_callback.py @@ -1,15 +1,15 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os import re +from omegaconf import DictConfig from pathlib import Path from typing import List -import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/fin_research/callbacks/collector_callback.py b/projects/fin_research/callbacks/collector_callback.py index caffbcb71..c8c0eadb4 100644 --- a/projects/fin_research/callbacks/collector_callback.py +++ b/projects/fin_research/callbacks/collector_callback.py @@ -1,14 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os +from omegaconf import DictConfig from pathlib import Path from typing import List -import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/fin_research/callbacks/orchestrator_callback.py b/projects/fin_research/callbacks/orchestrator_callback.py index d9b507dce..7b684841a 100644 --- a/projects/fin_research/callbacks/orchestrator_callback.py +++ b/projects/fin_research/callbacks/orchestrator_callback.py @@ -1,14 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os +from file_parser import extract_code_blocks +from omegaconf import DictConfig from typing import List -from file_parser import extract_code_blocks from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.tools.filesystem_tool import FileSystemTool from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/fin_research/searcher.py b/projects/fin_research/searcher.py index 23703fa64..543227ad1 100644 --- a/projects/fin_research/searcher.py +++ b/projects/fin_research/searcher.py @@ -1,8 +1,9 @@ +import json import os +from callbacks.file_parser import extract_code_blocks +from omegaconf import DictConfig from typing import List, Union -import json -from callbacks.file_parser import extract_code_blocks from ms_agent.agent.code_agent import CodeAgent from ms_agent.llm import Message from ms_agent.llm.openai import OpenAIChat @@ -10,7 +11,6 @@ from ms_agent.utils import get_logger from ms_agent.workflow.deep_research.research_workflow_beta import \ ResearchWorkflowBeta -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/fin_research/time_handler.py b/projects/fin_research/time_handler.py index dd01bc8ae..ef9d8af8b 100644 --- a/projects/fin_research/time_handler.py +++ b/projects/fin_research/time_handler.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from datetime import datetime +from omegaconf import DictConfig from typing import Any from ms_agent.config.config import ConfigLifecycleHandler -from omegaconf import DictConfig class TimeHandler(ConfigLifecycleHandler): diff --git a/projects/fin_research/tools/principle_skill.py b/projects/fin_research/tools/principle_skill.py index 19882f497..8c228de5f 100644 --- a/projects/fin_research/tools/principle_skill.py +++ b/projects/fin_research/tools/principle_skill.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # flake8: noqa +import json import os from typing import Any, Dict, List, Optional, Tuple -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger diff --git a/projects/fin_research/tools/spec_loader.py b/projects/fin_research/tools/spec_loader.py index 6ed8e1e21..d2e5b803e 100644 --- a/projects/fin_research/tools/spec_loader.py +++ b/projects/fin_research/tools/spec_loader.py @@ -1,14 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # flake8: noqa +import json import os +from spec_constant import (PRINCIPLE_ROUTING_GUIDE, PRINCIPLE_SPEC_GUIDE, + WRITING_ROUTING_GUIDE, WRITING_SPEC_GUIDE) from typing import Any, Dict, List, Tuple -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger -from spec_constant import (PRINCIPLE_ROUTING_GUIDE, PRINCIPLE_SPEC_GUIDE, - WRITING_ROUTING_GUIDE, WRITING_SPEC_GUIDE) logger = get_logger() diff --git a/projects/singularity_cinema/compose_video/agent.py b/projects/singularity_cinema/compose_video/agent.py index 09995c629..1fd2b0e95 100644 --- a/projects/singularity_cinema/compose_video/agent.py +++ b/projects/singularity_cinema/compose_video/agent.py @@ -1,18 +1,18 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 +import json import math +import moviepy as mp import os import shutil from copy import deepcopy - -import json -import moviepy as mp from moviepy import AudioClip +from omegaconf import DictConfig +from PIL import Image + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() diff --git a/projects/singularity_cinema/create_background/agent.py b/projects/singularity_cinema/create_background/agent.py index 44e510b13..fc1be2610 100644 --- a/projects/singularity_cinema/create_background/agent.py +++ b/projects/singularity_cinema/create_background/agent.py @@ -1,14 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import matplotlib.font_manager as fm import os import textwrap +from omegaconf import DictConfig +from PIL import Image, ImageDraw, ImageFont -import matplotlib.font_manager as fm from ms_agent.agent import CodeAgent from ms_agent.llm import LLM from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image, ImageDraw, ImageFont logger = get_logger() diff --git a/projects/singularity_cinema/generate_animation/agent.py b/projects/singularity_cinema/generate_animation/agent.py index 26837050b..780c5b925 100644 --- a/projects/singularity_cinema/generate_animation/agent.py +++ b/projects/singularity_cinema/generate_animation/agent.py @@ -2,9 +2,9 @@ import importlib.util import os import sys +from omegaconf import DictConfig from ms_agent.agent import CodeAgent -from omegaconf import DictConfig class GenerateAnimation(CodeAgent): diff --git a/projects/singularity_cinema/generate_animation/generate_manim_code.py b/projects/singularity_cinema/generate_animation/generate_manim_code.py index 990bab72d..e332959e3 100644 --- a/projects/singularity_cinema/generate_animation/generate_manim_code.py +++ b/projects/singularity_cinema/generate_animation/generate_manim_code.py @@ -1,14 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os from concurrent.futures import ThreadPoolExecutor, as_completed +from omegaconf import DictConfig +from PIL import Image from typing import List, Union -import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() diff --git a/projects/singularity_cinema/generate_animation/generate_remotion_code.py b/projects/singularity_cinema/generate_animation/generate_remotion_code.py index 1007a60b1..4d54585b7 100644 --- a/projects/singularity_cinema/generate_animation/generate_remotion_code.py +++ b/projects/singularity_cinema/generate_animation/generate_remotion_code.py @@ -1,16 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import glob +import json import os import re from concurrent.futures import ThreadPoolExecutor, as_completed +from omegaconf import DictConfig +from PIL import Image from typing import List, Union -import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() diff --git a/projects/singularity_cinema/generate_audio/agent.py b/projects/singularity_cinema/generate_audio/agent.py index 499779c73..e296bd4e9 100644 --- a/projects/singularity_cinema/generate_audio/agent.py +++ b/projects/singularity_cinema/generate_audio/agent.py @@ -1,21 +1,21 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio +import edge_tts +import json +import numpy as np import os import shutil from copy import deepcopy from dataclasses import dataclass, field +from moviepy import AudioClip, AudioFileClip +from omegaconf import DictConfig from typing import List -import edge_tts -import json -import numpy as np -from moviepy import AudioClip, AudioFileClip from ms_agent.agent import CodeAgent from ms_agent.llm import LLM from ms_agent.llm.openai_llm import OpenAI from ms_agent.tools.audio_generator import AudioGenerator from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/singularity_cinema/generate_illustration_prompts/agent.py b/projects/singularity_cinema/generate_illustration_prompts/agent.py index 5529dc67a..b7c774aba 100644 --- a/projects/singularity_cinema/generate_illustration_prompts/agent.py +++ b/projects/singularity_cinema/generate_illustration_prompts/agent.py @@ -1,16 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os import re import time from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field +from omegaconf import DictConfig from typing import List, Optional, Union -import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/singularity_cinema/generate_images/agent.py b/projects/singularity_cinema/generate_images/agent.py index e21eb325d..3e5efe03d 100644 --- a/projects/singularity_cinema/generate_images/agent.py +++ b/projects/singularity_cinema/generate_images/agent.py @@ -1,22 +1,22 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import aiohttp import asyncio +import json +import numpy as np import os import re import shutil from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from io import BytesIO +from omegaconf import DictConfig +from PIL import Image from typing import List, Union -import aiohttp -import json -import numpy as np from ms_agent.agent import CodeAgent from ms_agent.llm import Message from ms_agent.tools.image_generator import ImageGenerator from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() diff --git a/projects/singularity_cinema/generate_script/agent.py b/projects/singularity_cinema/generate_script/agent.py index fc0a726cc..0586aafc2 100644 --- a/projects/singularity_cinema/generate_script/agent.py +++ b/projects/singularity_cinema/generate_script/agent.py @@ -1,12 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from copy import deepcopy +from omegaconf import DictConfig from typing import List from ms_agent import LLMAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/singularity_cinema/generate_subtitle/agent.py b/projects/singularity_cinema/generate_subtitle/agent.py index efa873851..d956aa2c2 100644 --- a/projects/singularity_cinema/generate_subtitle/agent.py +++ b/projects/singularity_cinema/generate_subtitle/agent.py @@ -1,16 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json +import matplotlib.font_manager as fm import os import re +from omegaconf import DictConfig +from PIL import Image, ImageDraw, ImageFont from typing import List -import json -import matplotlib.font_manager as fm from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image, ImageDraw, ImageFont logger = get_logger() diff --git a/projects/singularity_cinema/generate_video/agent.py b/projects/singularity_cinema/generate_video/agent.py index aaecd7d5f..2393bb98e 100644 --- a/projects/singularity_cinema/generate_video/agent.py +++ b/projects/singularity_cinema/generate_video/agent.py @@ -1,18 +1,18 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import aiohttp import asyncio +import json import os import shutil from concurrent.futures import ThreadPoolExecutor from copy import deepcopy +from omegaconf import DictConfig from typing import List, Union -import aiohttp -import json from ms_agent.agent import CodeAgent from ms_agent.llm import Message from ms_agent.tools.video_generator import VideoGenerator from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/singularity_cinema/generate_video_prompts/agent.py b/projects/singularity_cinema/generate_video_prompts/agent.py index 2cd091f3b..3f70599ee 100644 --- a/projects/singularity_cinema/generate_video_prompts/agent.py +++ b/projects/singularity_cinema/generate_video_prompts/agent.py @@ -1,13 +1,13 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os from concurrent.futures import ThreadPoolExecutor, as_completed +from omegaconf import DictConfig from typing import List, Union -import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/singularity_cinema/parse_images/agent.py b/projects/singularity_cinema/parse_images/agent.py index 138803dc0..c42ce46dc 100644 --- a/projects/singularity_cinema/parse_images/agent.py +++ b/projects/singularity_cinema/parse_images/agent.py @@ -1,19 +1,19 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 import hashlib +import json import os import re from concurrent.futures import ThreadPoolExecutor from copy import deepcopy +from omegaconf import DictConfig +from PIL import Image from urllib.request import urlretrieve -import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() diff --git a/projects/singularity_cinema/render_animation/agent.py b/projects/singularity_cinema/render_animation/agent.py index 1b3007c66..e2d2bee86 100644 --- a/projects/singularity_cinema/render_animation/agent.py +++ b/projects/singularity_cinema/render_animation/agent.py @@ -2,9 +2,9 @@ import os import sys +from omegaconf import DictConfig from ms_agent.agent import CodeAgent -from omegaconf import DictConfig class RenderAnimation(CodeAgent): diff --git a/projects/singularity_cinema/render_animation/render_manim.py b/projects/singularity_cinema/render_animation/render_manim.py index 19b9abc27..17f62aa4a 100644 --- a/projects/singularity_cinema/render_animation/render_manim.py +++ b/projects/singularity_cinema/render_animation/render_manim.py @@ -1,21 +1,21 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 +import json import os import re import shutil import subprocess from concurrent.futures import ThreadPoolExecutor, as_completed from copy import deepcopy +from moviepy import VideoFileClip +from omegaconf import DictConfig from os import getcwd +from PIL import Image from typing import List, Union -import json -from moviepy import VideoFileClip from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() diff --git a/projects/singularity_cinema/render_animation/render_remotion.py b/projects/singularity_cinema/render_animation/render_remotion.py index e58e2927e..09426b91f 100644 --- a/projects/singularity_cinema/render_animation/render_remotion.py +++ b/projects/singularity_cinema/render_animation/render_remotion.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os import re import shutil @@ -6,14 +7,13 @@ import urllib.request import zipfile from collections import defaultdict +from moviepy import VideoFileClip +from omegaconf import DictConfig from typing import List, Optional, Tuple, Union -import json -from moviepy import VideoFileClip from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() @@ -718,8 +718,8 @@ def _check_edge_clipping(frame_path, threshold=10): Returns True if clipping detected (colored pixels at edges). """ try: - from PIL import Image import numpy as np + from PIL import Image img = Image.open(frame_path).convert('RGB') pixels = np.array(img) diff --git a/projects/singularity_cinema/segment/agent.py b/projects/singularity_cinema/segment/agent.py index 35df82260..f35718450 100644 --- a/projects/singularity_cinema/segment/agent.py +++ b/projects/singularity_cinema/segment/agent.py @@ -1,12 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os from copy import deepcopy +from omegaconf import DictConfig -import json from ms_agent.agent import LLMAgent from ms_agent.llm import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/setup.py b/setup.py index 1bded7e04..aec42aa2f 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # !/usr/bin/env python -import os -import shutil from setuptools import find_packages, setup from setuptools.command.build_py import build_py as _build_py + +import os +import shutil from typing import List diff --git a/shell-grep-glob-workspace-policy.md b/shell-grep-glob-workspace-policy.md deleted file mode 100644 index ac4e3f912..000000000 --- a/shell-grep-glob-workspace-policy.md +++ /dev/null @@ -1,225 +0,0 @@ -# Shell / Grep / Glob 与策略内核架构方案 - -本文档描述在 modelscope-agent 中为 **Shell**、**Grep**、**Glob** 提供统一的安全、权限、沙箱与产物管理的设计,以及与 **`feat/agent-tool-overhaul`** 分支中 **TaskManager**(后台 Agent、预留 Shell)的兼容方式。 - ---- - -## 1. 目标与边界 - -### 目标 - -- 在「同一工作区、同一沙箱视图」下,为 **Shell / Grep / Glob** 提供统一的: - - **安全**(命令与路径约束) - - **权限**(只读 / 写工作区 / 网络等分级) - - **沙箱**(本地子进程 vs Docker enclave 等与现有 `CodeExecutionTool` 对齐) - - **产物管理**(大 stdout/stderr 落盘、预览、配额) -- **默认 `allow_list`(允许根路径)包含 `output_dir`**(及其规范化的绝对路径),可配置追加其它根。 - -### 边界 - -- **不替代** `FileSystemTool` 的精确编辑与读缓存等语义;Shell 面向构建、包管理、复杂管道。 -- **Grep / Glob** 作为**只读发现面**的独立工具,减少对裸 shell 的依赖;复杂 `find -exec` 等仍可由受控 Shell 在更高权限模式下完成(若产品允许)。 - ---- - -## 2. 分层架构 - -``` -┌─────────────────────────────────────────────────────────────┐ -│ Tool Facade 层 │ -│ ShellTool │ GrepTool │ GlobTool (独立 JSON Schema) │ -└────────────┬───────────────────────────────┬────────────────┘ - │ │ -┌────────────▼───────────────────────────────▼────────────────┐ -│ WorkspacePolicyKernel(策略内核,纯逻辑、可单测) │ -│ - roots: 默认含 canonical(output_dir),可配置追加 │ -│ - allow_list / deny_list 合并与优先级 │ -│ - resolve_path(rel|abs) → 必须在 allow_roots 下 │ -│ - classify(op): read | search | mutate | exec | network_hint │ -└────────────┬────────────────────────────────────────────────┘ - │ -┌────────────▼────────────────────────────────────────────────┐ -│ SandboxRuntime(执行面,可替换实现) │ -│ - LocalProcessRuntime(asyncio subprocess,cwd=workspace) │ -│ - EnclaveRuntime(现有 ms_enclave / CodeExecutionTool 路径) │ -│ - 会话级 sandbox_id / working_dir 与挂载点一致 │ -└────────────┬────────────────────────────────────────────────┘ - │ -┌────────────▼────────────────────────────────────────────────┐ -│ ArtifactManager(产物管理) │ -│ - 超阈值 stdout/stderr → 落盘 + preview + 相对路径引用 │ -│ - 按 task_id / tool_call_id 分目录 │ -│ - TTL / 总配额(建议:output_dir/.ms_agent_artifacts/) │ -└─────────────────────────────────────────────────────────────┘ -``` - -**原则**:Grep/Glob 的**主路径**不是「拼一条 shell 给模型」;内部可调用 `rg` 或文件系统 walk,但必须经过 **PolicyKernel** 与 **SandboxRuntime**,输出经 **ArtifactManager**。 - ---- - -## 3. WorkspacePolicyKernel(共享策略内核) - -### 3.1 默认 allow_list(允许根集合) - -- 初始化:`allow_roots = { canonical_abs(output_dir) }`。 -- 配置可追加,例如:`tools.code_executor.extra_allow_roots` 或 `tools.workspace_policy.allow`(列表),合并去重。 -- Shell / Grep / Glob 涉及的 **`path`、`cwd`、搜索根目录** 均先执行 `resolve_under_allow_roots()`;失败则**拒绝**并返回结构化错误(不静默改路径到其它目录)。 - -### 3.2 权限与操作分类(建议) - -| 类别 | 示例 | Shell | Grep | Glob | -|------|------|-------|------|------| -| read | 读取工作区内文件 | 受模式 + 策略约束 | ✓ | ✓ | -| search | 内容/文件名发现 | 可引导至 Grep/Glob | ✓ | ✓ | -| mutate | rm、chmod、git 写入等 | 需 `workspace_write` | — | — | -| network | curl、pip 等 | 需显式 **network** 能力位 | — | — | - -Shell 在 **`read_only`** 模式下:仅允许白名单类命令(如 `git status`/`diff`/`log`、只读参数的 `rg` 等),并对重定向、写入工作区外等行为做拒绝或降级(可用前缀表 + 危险模式黑名单,必要时辅以轻量解析)。 - -### 3.3 Shell 安全补充 - -- **固定 cwd**:默认 `workspace_root`(与 `output_dir` 或沙箱内挂载点一致)。 -- **环境变量**:最小集或白名单继承;避免将宿主敏感变量原样传入。 -- **命令预处理**:与现有 `CodeExecutionTool.shell_executor` 思路一致——含 `| && ; > <` 等时使用 `sh -lc` 与安全 quoting;另加**命令长度上限**、**可配置的危险构造限制**(如嵌套命令替换,按产品分级)。 -- **(暂时不做)** 与 `FileSystemTool` 的「写前必读 / staleness」策略对齐:对会修改工作区文件的 Shell 子类共享元数据(若产品需要强一致)。 - ---- - -## 4. SandboxRuntime(共享沙箱) - -- **会话级**:每个 Agent 运行周期内一个 `SandboxSession`(或复用现有 `sandbox_id`)。 -- **Shell / Grep / Glob** 共用同一 **`working_dir` / 挂载视图** 与同一 **`SandboxRuntime` 实现**(本地 `asyncio` 子进程 vs Docker enclave,由 `implementation: sandbox | python_env` 等与现有一致)。 -- **Grep**:在 enclave 内调用 `rg` 或使用宿主 `ripgrep` 库(由部署二选一);**Glob**:在策略解析后的根上做目录遍历或 `pathspec`,避免默认可执行任意 `find -exec`。 - ---- - -## 5. ArtifactManager(产物管理) - -- **阈值**:例如 stdout+stderr 合计超过 N KB 则 spill 至 - `{output_dir}/.ms_agent_artifacts/{tool_name}/{task_or_call_id}.txt`(路径可配置)。 -- **返回**:JSON 中包含 `preview`(首尾若干字符/行)、`artifact_path`(相对 `output_dir`)、`truncated: true`。 -- **与 TaskManager 配合**:后台任务完成时,`TaskManager.complete(task_id, result)` 的 `result` 宜为「短摘要 + artifact 路径」,避免通知与下一轮上下文被撑爆。 - ---- - -## 6. GrepTool / GlobTool(独立工具、共享内核) - -- **输入**:结构化字段(如 pattern、path、glob、head_limit、offset、output_mode),不把「整条 shell」作为唯一 API。 -- **实现**:内部调用 `SandboxRuntime.exec_rg(...)` 或在策略内核限定根上的 glob 遍历;**禁止**由用户可控字符串直接拼接未校验的 shell。 -- **共享**:同一 `WorkspacePolicyKernel` + `SandboxRuntime` + `ArtifactManager`(由 `ToolManager` 或执行类工具在初始化时注入)。 -- **注册**:在 `ToolManager` 中作为独立 `ToolBase`(可一个 server 多个 tool,或两个 server);与 `file_system` 解耦,保持 `file_system` 精简。 - ---- - -## 7. 与 `feat/agent-tool-overhaul` 的 Task 体系兼容 - -### 7.1 分支中的现状(摘要) - -- **`TaskManager`**(`ms_agent/utils/task_manager.py`):进程级后台任务注册表;`BackgroundTask` 中 **`task_type` 注释已包含 `'agent' | 'shell'`**。 -- **`AgentTool`**:`run_in_background` 时 `register(task_type='agent', proc=mp.Process, ...)`,watcher 在子进程结束后调用 `complete` / `fail`;`LLMAgent` 通过 `set_task_manager` 注入同一 `TaskManager`,每轮 `drain_notifications()` 将完成事件注入对话。 - -### 7.2 Shell 后台(与 Agent 对称) - -**建议接口** - -- **同步**:`shell_executor(command, timeout)` → 行为与现网接近,但走 PolicyKernel + ArtifactManager。 -- **后台**:增加 `run_in_background: bool`(或等价命名), **`__call_id`**(与 `AgentTool` 注入一致,便于对账与「推后台」扩展)。 - -**后台行为** - -1. `task_id = task_manager.register(task_type='shell', tool_name='shell_executor', description=command[:200], proc=...)` -2. `proc` 可为 **`asyncio.create_subprocess_*` 返回的 `Process`**(与 Agent 的 `mp.Process` 不同,需在 **`TaskManager.kill` / `kill_all` 中扩展**:对 `asyncio.subprocess.Process` 调用 `kill()` / `terminate()`,并处理已结束进程)。 -3. `asyncio.create_task(watcher)`:等待结束 → `ArtifactManager.maybe_spill` → `await task_manager.complete(task_id, result_str)`(失败则 `fail`)。 - -**立即返回 JSON**(与 Agent 后台对齐,便于统一文档与客户端): - -```json -{ - "status": "async_launched", - "task_id": "", - "tool_name": "shell_executor" -} -``` - -### 7.3 LLMAgent 接线 - -- 与 overhaul 一致:构造 `TaskManager()`,遍历 `extra_tools`,若实现 **`set_task_manager(self.task_manager)`** 则注入。 -- **`LocalCodeExecutionTool` / 未来的 `SecureShellTool`** 实现 `set_task_manager`,与 `AgentTool` 共享同一 `TaskManager` 实例。 - -### 7.4 长同步 Shell → Escape 到后台 - -- 与 `AgentTool._run_sync_escapable` 类似:同步 Shell 带 `sync_timeout_s`,超时或显式信号后取消当前子进程并改为 `register(task_type='shell', ...)` 后台重跑或仅保留已产出部分(产品二选一)。 -- 若存在 **TaskControlTool** 类机制,可复用「`__call_id` + escape 事件」模式,Shell 侧维护 `call_id → Process` 映射以支持 **kill / escape**。 - -### 7.5 兼容对照表 - -| 能力 | overhaul 行为 | 本方案落点 | -|------|----------------|------------| -| 后台 Agent | `register(task_type='agent', proc=Process)` | 不变 | -| 预留 Shell | `task_type` 含 `'shell'` | `shell_executor(run_in_background=true)` 走同一 register / complete / fail | -| 回合内通知 | `drain_notifications()` | Shell 完成同样入队 | -| Kill / 清理 | `kill` / `kill_all` | 扩展支持 asyncio 子进程;watcher `finally` 释放资源 | - ---- - -## 8. 配置示例(OmegaConf / YAML 意向) - -```yaml -tools: - workspace_policy: - allow_roots: [] # 追加;默认已含 output_dir - deny_globs: ['**/.git/**'] - code_executor: - implementation: python_env # or sandbox - shell: - default_mode: workspace_write # read_only | workspace_write - max_output_kb: 256 - wall_time_s: 900 - grep: - default_head_limit: 250 - glob: - max_files: 100 -``` - ---- - -## 9. 实施顺序建议 - -1. 抽出 **`WorkspacePolicyKernel`** + 单元测试(路径解析、默认 `output_dir`、追加 allow)。 -2. 实现 **`ArtifactManager`**,接到现有 `shell_executor` 返回(先本地工具、后接沙箱)。 -3. 将 **`TaskManager`**(overhaul)合入主线并 **扩展 `kill` 支持 `asyncio.subprocess.Process`**。 -4. **`LocalCodeExecutionTool.set_task_manager` + `run_in_background` 的 `shell_executor`**。 -5. 新增 **GrepTool / GlobTool** façade,共享上述内核与运行时。 -6. 更新文档与系统提示:默认 **发现用 Grep/Glob,构建用 Shell,改文件用 file_system**。 - ---- - -## 10. 设计取舍小结 - -- **Shell**:强约束的通用执行面 + 后台,与 **TaskManager** 统一生命周期与通知。 -- **Grep / Glob**:独立 Schema、只读、易截断,与 Shell **共享策略与沙箱**,避免把一切搜索都绑在一条 shell 字符串上。 -- **默认 allow_roots 含 `output_dir`**:与现有 Agent 工作区模型一致,减少越权访问宿主路径的风险。 - ---- - -## 修订记录 - -| 日期 | 说明 | -|------|------| -| 2026-04-13 | 初版:根据设计与 `feat/agent-tool-overhaul` 中 TaskManager / AgentTool 后台模型整理成文。 | -| 2026-04-13 | 实现落地:见下文「实现映射」。 | - -## 11. 实现映射(代码位置) - -| 组件 | 路径 | -|------|------| -| WorkspacePolicyKernel | `ms_agent/utils/workspace_policy.py` | -| ArtifactManager | `ms_agent/utils/artifact_manager.py` | -| TaskManager | `ms_agent/utils/task_manager.py` | -| Shell 策略 / 产物 / 后台 | `ms_agent/tools/code/local_code_executor.py`(`set_task_manager`、`shell_executor`) | -| Grep / Glob | `ms_agent/tools/filesystem_tool.py` 中 `grep` / `glob` 工具(与 `read_file` / `edit_file` / `write_file` 同属 `file_system` server;用 `tools.file_system.include` / `exclude` 控制)。可选键:`grep_timeout_s`、`grep_head_limit`、`glob_max_files`;`include` 短名 `read` / `edit` / `write` 分别等价 `read_file` / `edit_file` / `write_file`。 | -| `__call_id` 注入 shell | `ms_agent/tools/tool_manager.py` | -| TaskManager 与通知 | `ms_agent/agent/llm_agent.py`(`prepare_tools` / `cleanup_tools` / `_append_task_notifications`) | -| 单测 | `tests/utils/test_workspace_policy.py` | - -**未在本阶段实现**:文档 §7.4 长同步 Shell escape 到后台;Docker `CodeExecutionTool` 侧 shell 与策略对齐(仍沿用原沙箱实现)。 diff --git a/webui/backend/agent_runner.py b/webui/backend/agent_runner.py index 1c122ff27..3f3879ac1 100644 --- a/webui/backend/agent_runner.py +++ b/webui/backend/agent_runner.py @@ -9,11 +9,10 @@ import signal import subprocess import sys +import yaml from datetime import datetime from typing import Any, Callable, Dict, Optional -import yaml - class AgentRunner: """Runs ms-agent as a subprocess with output streaming""" diff --git a/webui/backend/api.py b/webui/backend/api.py index 126ac4c02..fa68b849b 100644 --- a/webui/backend/api.py +++ b/webui/backend/api.py @@ -4,14 +4,13 @@ """ import mimetypes import os -from pathlib import Path -from typing import Any, Dict, List, Optional - from fastapi import APIRouter, HTTPException, Query from fastapi.responses import FileResponse +from pathlib import Path from pydantic import BaseModel, Field # Import shared instances from shared import config_manager, project_discovery, session_manager +from typing import Any, Dict, List, Optional router = APIRouter() diff --git a/webui/backend/config_manager.py b/webui/backend/config_manager.py index eddf94915..8d9c7f623 100644 --- a/webui/backend/config_manager.py +++ b/webui/backend/config_manager.py @@ -3,12 +3,11 @@ Configuration management for MS-Agent Web UI Handles global settings, LLM configuration, and MCP server configuration. """ +import json import os from threading import Lock from typing import Any, Dict, Optional -import json - class ConfigManager: """Manages global configuration for the Web UI""" diff --git a/webui/backend/deep_research_eventizer.py b/webui/backend/deep_research_eventizer.py index 84949674a..f1d5d6d74 100644 --- a/webui/backend/deep_research_eventizer.py +++ b/webui/backend/deep_research_eventizer.py @@ -1,6 +1,6 @@ +import json from typing import Any, Callable, Dict, List, Optional -import json from ms_agent.llm.utils import Message, ToolCall diff --git a/webui/backend/deep_research_worker.py b/webui/backend/deep_research_worker.py index e8a3480a3..5a0ce2d80 100644 --- a/webui/backend/deep_research_worker.py +++ b/webui/backend/deep_research_worker.py @@ -1,17 +1,17 @@ import argparse import asyncio +import json import os import signal import sys import traceback +from deep_research_eventizer import HistoryEventizer # noqa: E402 +from omegaconf import OmegaConf from pathlib import Path from typing import Any, Dict, Optional -import json -from deep_research_eventizer import HistoryEventizer # noqa: E402 from ms_agent.agent.loader import AgentLoader from ms_agent.tools.agent_tool import AgentTool -from omegaconf import OmegaConf BACKEND_DIR = Path(__file__).resolve().parent if str(BACKEND_DIR) not in sys.path: diff --git a/webui/backend/deep_research_worker_manager.py b/webui/backend/deep_research_worker_manager.py index a5eb29d05..30b666207 100644 --- a/webui/backend/deep_research_worker_manager.py +++ b/webui/backend/deep_research_worker_manager.py @@ -1,4 +1,5 @@ import asyncio +import json import os import signal import sys @@ -6,8 +7,6 @@ from pathlib import Path from typing import Any, Awaitable, Callable, Dict, Optional -import json - class DeepResearchWorkerManager: diff --git a/webui/backend/main.py b/webui/backend/main.py index 2afcebbe3..8eb584bc1 100644 --- a/webui/backend/main.py +++ b/webui/backend/main.py @@ -5,7 +5,6 @@ """ import os import sys - import uvicorn from api import router as api_router from fastapi import FastAPI diff --git a/webui/backend/shared.py b/webui/backend/shared.py index 1d3ce9754..fb7b7fe67 100644 --- a/webui/backend/shared.py +++ b/webui/backend/shared.py @@ -4,7 +4,6 @@ Ensures api.py and websocket_handler.py use the same manager instances. """ import os - from config_manager import ConfigManager from project_discovery import ProjectDiscovery from session_manager import SessionManager diff --git a/webui/backend/websocket_handler.py b/webui/backend/websocket_handler.py index 7b2830b17..549d1f41c 100644 --- a/webui/backend/websocket_handler.py +++ b/webui/backend/websocket_handler.py @@ -4,17 +4,16 @@ Handles agent execution, log streaming, and progress updates. """ import asyncio -import os -from datetime import datetime -from pathlib import Path -from typing import Any, Dict, Set - import json +import os from agent_runner import AgentRunner +from datetime import datetime from deep_research_worker_manager import DeepResearchWorkerManager from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from pathlib import Path # Import shared instances from shared import config_manager, project_discovery, session_manager +from typing import Any, Dict, Set router = APIRouter() @@ -188,6 +187,7 @@ async def start_agent(session_id: str, data: Dict[str, Any], if session_type == 'chat': # Create a virtual project for chat mode using the default agent.yaml import ms_agent + # Get ms_agent package installation path # Use __path__ which is always available for packages and gives real filesystem paths if hasattr(ms_agent, '__path__') and ms_agent.__path__: