From 04d663d24c329ab094f45bb314d6d8b68184b79d Mon Sep 17 00:00:00 2001 From: YukiRa1n <167516635+YukiRa1n@users.noreply.github.com> Date: Fri, 6 Feb 2026 22:33:37 +0800 Subject: [PATCH 1/8] =?UTF-8?q?feat:=20=E5=B7=A5=E5=85=B7=E8=B0=83?= =?UTF-8?q?=E7=94=A8=E7=B3=BB=E7=BB=9F=E9=87=8D=E6=9E=84=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E8=B6=85=E6=97=B6=E8=87=AA=E5=8A=A8=E8=BD=AC=E5=90=8E?= =?UTF-8?q?=E5=8F=B0=E6=89=A7=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 tool_execution 模块(洋葱架构):MethodResolver → ToolInvoker → TimeoutStrategy → ResultProcessor - 新增 background_tool 模块:超时工具自动转后台,LLM可通过 wait_tool_result/get_tool_output/stop_tool/list_running_tools 管理任务 - 重构 _execute_local 使用 ToolExecutor 编排,保持上游兼容 - event_bus 集成后台任务中断与通知注入 - tool_loop_agent_runner 支持 WaitInterruptedException 中断处理 - 新增 background_task_wait_timeout 配置项 - 包含完整单元测试与集成测试 --- .../agent/runners/tool_loop_agent_runner.py | 12 + astrbot/core/astr_agent_tool_exec.py | 87 +- astrbot/core/background_tool/__init__.py | 38 + .../background_tool/callback_event_builder.py | 152 +++ .../background_tool/callback_publisher.py | 85 ++ astrbot/core/background_tool/llm_tools.py | 161 ++++ astrbot/core/background_tool/manager.py | 329 +++++++ astrbot/core/background_tool/output_buffer.py | 116 +++ astrbot/core/background_tool/register.py | 82 ++ astrbot/core/background_tool/task_executor.py | 276 ++++++ .../core/background_tool/task_formatter.py | 44 + astrbot/core/background_tool/task_notifier.py | 56 ++ astrbot/core/background_tool/task_registry.py | 184 ++++ astrbot/core/background_tool/task_state.py | 145 +++ .../core/background_tool/tests/__init__.py | 1 + .../background_tool/tests/test_llm_tools.py | 110 +++ .../background_tool/tests/test_manager.py | 127 +++ .../tests/test_output_buffer.py | 94 ++ .../tests/test_task_executor.py | 142 +++ .../tests/test_task_notifier.py | 108 +++ .../tests/test_task_registry.py | 157 ++++ .../background_tool/tests/test_task_state.py | 153 +++ astrbot/core/config/default.py | 11 + astrbot/core/event_bus.py | 47 + .../method/agent_sub_stages/internal.py | 39 + astrbot/core/provider/register.py | 4 + astrbot/core/star/context.py | 2 + astrbot/core/tool_execution/__init__.py | 40 + .../tool_execution/application/__init__.py | 8 + .../application/tool_executor.py | 155 +++ .../core/tool_execution/domain/__init__.py | 14 + astrbot/core/tool_execution/domain/config.py | 43 + .../tool_execution/domain/execution_result.py | 37 + .../core/tool_execution/domain/tool_types.py | 19 + astrbot/core/tool_execution/errors.py | 34 + .../tool_execution/infrastructure/__init__.py | 4 + .../infrastructure/background/__init__.py | 6 + .../background/completion_signal.py | 31 + .../background/event_factory.py | 49 + .../infrastructure/callback/__init__.py | 5 + .../callback/callback_event_builder.py | 15 + .../infrastructure/handler/__init__.py | 7 + .../infrastructure/handler/method_resolver.py | 41 + .../handler/parameter_validator.py | 41 + .../handler/result_processor.py | 70 ++ .../infrastructure/invoker/__init__.py | 5 + .../infrastructure/invoker/tool_invoker.py | 39 + .../infrastructure/timeout/__init__.py | 6 + .../timeout/background_handler.py | 52 + .../timeout/timeout_strategy.py | 26 + astrbot/core/tool_execution/interfaces.py | 194 ++++ astrbot/core/tool_execution/utils/__init__.py | 23 + .../core/tool_execution/utils/decorators.py | 38 + astrbot/core/tool_execution/utils/rwlock.py | 68 ++ .../core/tool_execution/utils/sanitizer.py | 126 +++ .../core/tool_execution/utils/validators.py | 102 ++ .../en-US/features/config-metadata.json | 44 +- .../zh-CN/features/config-metadata.json | 47 +- tests/test_tool_execution/__init__.py | 1 + tests/test_tool_execution/test_integration.py | 885 ++++++++++++++++++ 60 files changed, 4922 insertions(+), 115 deletions(-) create mode 100644 astrbot/core/background_tool/__init__.py create mode 100644 astrbot/core/background_tool/callback_event_builder.py create mode 100644 astrbot/core/background_tool/callback_publisher.py create mode 100644 astrbot/core/background_tool/llm_tools.py create mode 100644 astrbot/core/background_tool/manager.py create mode 100644 astrbot/core/background_tool/output_buffer.py create mode 100644 astrbot/core/background_tool/register.py create mode 100644 astrbot/core/background_tool/task_executor.py create mode 100644 astrbot/core/background_tool/task_formatter.py create mode 100644 astrbot/core/background_tool/task_notifier.py create mode 100644 astrbot/core/background_tool/task_registry.py create mode 100644 astrbot/core/background_tool/task_state.py create mode 100644 astrbot/core/background_tool/tests/__init__.py create mode 100644 astrbot/core/background_tool/tests/test_llm_tools.py create mode 100644 astrbot/core/background_tool/tests/test_manager.py create mode 100644 astrbot/core/background_tool/tests/test_output_buffer.py create mode 100644 astrbot/core/background_tool/tests/test_task_executor.py create mode 100644 astrbot/core/background_tool/tests/test_task_notifier.py create mode 100644 astrbot/core/background_tool/tests/test_task_registry.py create mode 100644 astrbot/core/background_tool/tests/test_task_state.py create mode 100644 astrbot/core/tool_execution/__init__.py create mode 100644 astrbot/core/tool_execution/application/__init__.py create mode 100644 astrbot/core/tool_execution/application/tool_executor.py create mode 100644 astrbot/core/tool_execution/domain/__init__.py create mode 100644 astrbot/core/tool_execution/domain/config.py create mode 100644 astrbot/core/tool_execution/domain/execution_result.py create mode 100644 astrbot/core/tool_execution/domain/tool_types.py create mode 100644 astrbot/core/tool_execution/errors.py create mode 100644 astrbot/core/tool_execution/infrastructure/__init__.py create mode 100644 astrbot/core/tool_execution/infrastructure/background/__init__.py create mode 100644 astrbot/core/tool_execution/infrastructure/background/completion_signal.py create mode 100644 astrbot/core/tool_execution/infrastructure/background/event_factory.py create mode 100644 astrbot/core/tool_execution/infrastructure/callback/__init__.py create mode 100644 astrbot/core/tool_execution/infrastructure/callback/callback_event_builder.py create mode 100644 astrbot/core/tool_execution/infrastructure/handler/__init__.py create mode 100644 astrbot/core/tool_execution/infrastructure/handler/method_resolver.py create mode 100644 astrbot/core/tool_execution/infrastructure/handler/parameter_validator.py create mode 100644 astrbot/core/tool_execution/infrastructure/handler/result_processor.py create mode 100644 astrbot/core/tool_execution/infrastructure/invoker/__init__.py create mode 100644 astrbot/core/tool_execution/infrastructure/invoker/tool_invoker.py create mode 100644 astrbot/core/tool_execution/infrastructure/timeout/__init__.py create mode 100644 astrbot/core/tool_execution/infrastructure/timeout/background_handler.py create mode 100644 astrbot/core/tool_execution/infrastructure/timeout/timeout_strategy.py create mode 100644 astrbot/core/tool_execution/interfaces.py create mode 100644 astrbot/core/tool_execution/utils/__init__.py create mode 100644 astrbot/core/tool_execution/utils/decorators.py create mode 100644 astrbot/core/tool_execution/utils/rwlock.py create mode 100644 astrbot/core/tool_execution/utils/sanitizer.py create mode 100644 astrbot/core/tool_execution/utils/validators.py create mode 100644 tests/test_tool_execution/__init__.py create mode 100644 tests/test_tool_execution/test_integration.py diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 0e5b4353f..1ea82c832 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -138,6 +138,7 @@ async def reset( self.stats = AgentStats() self.stats.start_time = time.time() + self._wait_interrupted = False # 等待中断标记 async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: """Yields chunks *and* a final LLMResponse.""" @@ -547,6 +548,17 @@ async def _handle_function_tools( except Exception as e: logger.error(f"Error in on_tool_end hook: {e}", exc_info=True) except Exception as e: + from astrbot.core.background_tool import WaitInterruptedException + + if isinstance(e, WaitInterruptedException): + # 等待被中断,结束当前响应周期 + logger.info( + f"Wait interrupted for task {e.task_id}, ending current response cycle" + ) + self._wait_interrupted = True + self._transition_state(AgentState.DONE) + return + logger.warning(traceback.format_exc()) tool_call_result_blocks.append( ToolCallMessageSegment( diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 460cab332..e0d1295f7 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -28,6 +28,8 @@ from astrbot.core.platform.message_session import MessageSession from astrbot.core.provider.entites import ProviderRequest from astrbot.core.provider.register import llm_tools +from astrbot.core.tool_execution.application.tool_executor import ToolExecutor +from astrbot.core.tool_execution.errors import MethodResolutionError from astrbot.core.utils.history_saver import persist_agent_history @@ -270,80 +272,25 @@ async def _execute_local( tool_call_timeout: int | None = None, **tool_args, ): + """执行本地工具,使用 ToolExecutor 编排组件完成工具执行。""" event = run_context.context.event if not event: raise ValueError("Event must be provided for local function tools.") - is_override_call = False - for ty in type(tool).mro(): - if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call: - is_override_call = True - break - - # 检查 tool 下有没有 run 方法 - if not tool.handler and not hasattr(tool, "run") and not is_override_call: - raise ValueError("Tool must have a valid handler or override 'run' method.") - - awaitable = None - method_name = "" - if tool.handler: - awaitable = tool.handler - method_name = "decorator_handler" - elif is_override_call: - awaitable = tool.call - method_name = "call" - elif hasattr(tool, "run"): - awaitable = getattr(tool, "run") - method_name = "run" - if awaitable is None: - raise ValueError("Tool must have a valid handler or override 'run' method.") - - wrapper = call_local_llm_tool( - context=run_context, - handler=awaitable, - method_name=method_name, - **tool_args, - ) - while True: - try: - resp = await asyncio.wait_for( - anext(wrapper), - timeout=tool_call_timeout or run_context.tool_call_timeout, - ) - if resp is not None: - if isinstance(resp, mcp.types.CallToolResult): - yield resp - else: - text_content = mcp.types.TextContent( - type="text", - text=str(resp), - ) - yield mcp.types.CallToolResult(content=[text_content]) - else: - # NOTE: Tool 在这里直接请求发送消息给用户 - # TODO: 是否需要判断 event.get_result() 是否为空? - # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容" - if res := run_context.context.event.get_result(): - if res.chain: - try: - await event.send( - MessageChain( - chain=res.chain, - type="tool_direct_result", - ) - ) - except Exception as e: - logger.error( - f"Tool 直接发送消息失败: {e}", - exc_info=True, - ) - yield None - except asyncio.TimeoutError: - raise Exception( - f"tool {tool.name} execution timeout after {tool_call_timeout or run_context.tool_call_timeout} seconds.", - ) - except StopAsyncIteration: - break + # 如果外部指定了超时(如后台任务用3600s),临时覆盖run_context的超时值 + original_timeout = run_context.tool_call_timeout + if tool_call_timeout is not None: + run_context.tool_call_timeout = tool_call_timeout + + executor = ToolExecutor() + + try: + async for result in executor.execute(tool, run_context, **tool_args): + yield result + except MethodResolutionError as e: + raise ValueError(str(e)) from e + finally: + run_context.tool_call_timeout = original_timeout @classmethod async def _execute_mcp( diff --git a/astrbot/core/background_tool/__init__.py b/astrbot/core/background_tool/__init__.py new file mode 100644 index 000000000..8c10a9725 --- /dev/null +++ b/astrbot/core/background_tool/__init__.py @@ -0,0 +1,38 @@ +# Background Tool Execution System +# 后台工具执行系统 + +from .callback_event_builder import CallbackEventBuilder +from .callback_publisher import CallbackPublisher +from .manager import BackgroundToolManager +from .output_buffer import OutputBuffer +from .task_executor import TaskExecutor +from .task_notifier import TaskNotifier +from .task_registry import TaskRegistry +from .task_state import BackgroundTask, TaskStatus + + +class WaitInterruptedException(Exception): + """等待被新消息中断的异常 + + 当wait_tool_result被用户新消息中断时抛出此异常, + 用于通知框架结束当前LLM响应周期。 + """ + + def __init__(self, task_id: str, session_id: str): + self.task_id = task_id + self.session_id = session_id + super().__init__(f"Wait interrupted for task {task_id}") + + +__all__ = [ + "BackgroundTask", + "TaskStatus", + "TaskRegistry", + "OutputBuffer", + "TaskExecutor", + "TaskNotifier", + "CallbackEventBuilder", + "CallbackPublisher", + "BackgroundToolManager", + "WaitInterruptedException", +] diff --git a/astrbot/core/background_tool/callback_event_builder.py b/astrbot/core/background_tool/callback_event_builder.py new file mode 100644 index 000000000..f70f0bf79 --- /dev/null +++ b/astrbot/core/background_tool/callback_event_builder.py @@ -0,0 +1,152 @@ +"""回调事件构建器 + +将后台任务完成信息构建为可放入事件队列的回调事件。 +遵循单一职责原则,只负责事件构建。 +""" + +import copy +import time +from typing import Any + +from astrbot import logger +from astrbot.core.tool_execution.domain.config import DEFAULT_CONFIG + +from .task_state import BackgroundTask, TaskStatus + +# 状态文本映射 +STATUS_TEXT_MAP = { + TaskStatus.COMPLETED: "completed successfully", + TaskStatus.FAILED: "failed", + TaskStatus.CANCELLED: "was cancelled", +} + + +class CallbackEventBuilder: + """回调事件构建器 + + 负责将后台任务构建为回调事件,不涉及队列操作。 + """ + + def __init__(self, config=None): + """初始化构建器 + + Args: + config: 配置对象,默认使用 DEFAULT_CONFIG + """ + self._config = config or DEFAULT_CONFIG + + def build_notification_text(self, task: BackgroundTask) -> str: + """构建通知文本 + + Args: + task: 后台任务 + + Returns: + 通知文本 + """ + status = STATUS_TEXT_MAP.get(task.status, "unknown") + + lines = [ + "[Background Task Callback]", + f"Task ID: {task.task_id}", + f"Tool: {task.tool_name}", + f"Status: {status}", + ] + + if task.result: + lines.append(f"Result: {task.result}") + + if task.error: + max_len = self._config.error_preview_max_length + error_preview = task.error[:max_len] + if len(task.error) > max_len: + error_preview += "..." + lines.append(f"Error: {error_preview}") + + lines.append("") + lines.append( + "Please inform the user about this task completion and provide any relevant details." + ) + + return "\n".join(lines) + + def build_message_object(self, task: BackgroundTask, text: str) -> Any: + """构建消息对象 + + Args: + task: 后台任务 + text: 通知文本 + + Returns: + AstrBotMessage 对象 + """ + from astrbot.core.message.components import Plain + from astrbot.core.platform.astrbot_message import AstrBotMessage + + original = task.event.message_obj + + msg = AstrBotMessage() + msg.type = original.type + msg.self_id = original.self_id + msg.session_id = original.session_id + msg.message_id = f"bg_task_{task.task_id}" + msg.group = original.group + msg.sender = original.sender + msg.message = [Plain(text)] + msg.message_str = text + msg.raw_message = None + msg.timestamp = int(time.time()) + + return msg + + def build_callback_event(self, task: BackgroundTask) -> Any | None: + """构建完整的回调事件 + + Args: + task: 后台任务 + + Returns: + 回调事件对象,构建失败返回 None + """ + if not task.event: + logger.warning(f"[CallbackEventBuilder] Task {task.task_id} has no event") + return None + + try: + from astrbot.core.utils.trace import TraceSpan + + text = self.build_notification_text(task) + msg_obj = self.build_message_object(task, text) + + # 浅拷贝原事件,保留平台特定属性 + new_event = copy.copy(task.event) + new_event.message_str = text + new_event.message_obj = msg_obj + + # 重置状态 + new_event._result = None + new_event._has_send_oper = False + new_event._extras = {} + + # 初始化 trace + new_event.trace = TraceSpan( + name="BackgroundTaskCallback", + umo=new_event.unified_msg_origin, + sender_name=new_event.get_sender_name(), + message_outline=f"[Background Task {task.task_id}]", + ) + new_event.span = new_event.trace + + # 标记为回调事件 + new_event.is_wake = True + new_event.is_at_or_wake_command = True + new_event.set_extra("is_background_task_callback", True) + new_event.set_extra("background_task_id", task.task_id) + + return new_event + + except Exception as e: + logger.error( + f"[CallbackEventBuilder] Failed to build event for task {task.task_id}: {e}" + ) + return None diff --git a/astrbot/core/background_tool/callback_publisher.py b/astrbot/core/background_tool/callback_publisher.py new file mode 100644 index 000000000..453b3bc90 --- /dev/null +++ b/astrbot/core/background_tool/callback_publisher.py @@ -0,0 +1,85 @@ +"""回调事件发布器 + +负责将回调事件发布到事件队列。 +遵循单一职责原则,只负责队列操作。 +""" + +from astrbot import logger + +from .callback_event_builder import CallbackEventBuilder +from .task_state import BackgroundTask + + +class CallbackPublisher: + """回调事件发布器 + + 负责验证条件并将回调事件发布到队列。 + """ + + def __init__(self, event_builder: CallbackEventBuilder | None = None): + """初始化发布器 + + Args: + event_builder: 事件构建器,默认创建新实例 + """ + self._event_builder = event_builder or CallbackEventBuilder() + + def should_publish(self, task: BackgroundTask) -> bool: + """检查是否应该发布回调 + + Args: + task: 后台任务 + + Returns: + 是否应该发布 + """ + if task.is_being_waited: + logger.info( + f"[CallbackPublisher] Task {task.task_id} is being waited, skip" + ) + return False + + if not task.event: + logger.warning(f"[CallbackPublisher] Task {task.task_id} has no event") + return False + + if not task.event_queue: + logger.warning( + f"[CallbackPublisher] Task {task.task_id} has no event_queue" + ) + return False + + if not task.notification_message: + logger.warning( + f"[CallbackPublisher] Task {task.task_id} has no notification" + ) + return False + + return True + + async def publish(self, task: BackgroundTask) -> bool: + """发布回调事件 + + Args: + task: 后台任务 + + Returns: + 是否发布成功 + """ + if not self.should_publish(task): + return False + + try: + event = self._event_builder.build_callback_event(task) + if event is None: + return False + + task.event_queue.put_nowait(event) + task.notification_sent = True + + logger.info(f"[CallbackPublisher] Task {task.task_id} callback queued") + return True + + except Exception as e: + logger.error(f"[CallbackPublisher] Failed to publish callback: {e}") + return False diff --git a/astrbot/core/background_tool/llm_tools.py b/astrbot/core/background_tool/llm_tools.py new file mode 100644 index 000000000..c8d0c193d --- /dev/null +++ b/astrbot/core/background_tool/llm_tools.py @@ -0,0 +1,161 @@ +"""LLM工具集 + +提供给LLM调用的后台任务管理工具。 +""" + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + +from .manager import BackgroundToolManager +from .task_formatter import build_task_result + + +# 获取全局管理器实例 +def _get_manager() -> BackgroundToolManager: + return BackgroundToolManager() + + +async def get_tool_output( + event: "AstrMessageEvent", + task_id: str, + lines: int = 50, +) -> str: + """查看后台工具的输出日志 + + Args: + event: 消息事件 + task_id: 任务ID + lines: 返回最近N行日志,默认50行 + + Returns: + 工具输出日志和最终结果 + """ + manager = _get_manager() + + task = manager.registry.get(task_id) + if task is None: + return f"Error: Task {task_id} not found." + + output = manager.get_task_output(task_id, lines=lines) + return build_task_result(task_id, task, output) + + +async def wait_tool_result( + event: "AstrMessageEvent", + task_id: str, +) -> str: + """等待后台工具执行完成(可被新消息打断) + + Args: + event: 消息事件 + task_id: 任务ID + + Returns: + 工具执行结果 + + Raises: + WaitInterruptedException: 当被用户新消息中断时抛出 + """ + import asyncio + + manager = _get_manager() + session_id = event.unified_msg_origin + + from astrbot import logger + from astrbot.core.background_tool import WaitInterruptedException + + logger.info(f"[wait_tool_result] Looking for task {task_id}") + + task = manager.registry.get(task_id) + if task is None: + return f"Error: Task {task_id} not found." + + if task.is_finished(): + output = manager.get_task_output(task_id, lines=50) + return build_task_result(task_id, task, output) + + manager.clear_interrupt_flag(session_id) + task.is_being_waited = True + logger.info(f"[wait_tool_result] Using event-driven wait for task {task_id}") + + try: + while True: + # 检查中断标记 + if manager.check_interrupt_flag(session_id): + manager.clear_interrupt_flag(session_id) + raise WaitInterruptedException(task_id=task_id, session_id=session_id) + + # 使用事件等待,超时0.5秒后检查中断 + if task.completion_event: + try: + await asyncio.wait_for(task.completion_event.wait(), timeout=0.5) + break # 事件触发,任务完成 + except asyncio.TimeoutError: + pass # 继续循环检查中断 + else: + await asyncio.sleep(0.5) + if task.is_finished(): + break + finally: + task.is_being_waited = False + + output = manager.get_task_output(task_id, lines=50) + return build_task_result(task_id, task, output) + + +async def stop_tool( + event: "AstrMessageEvent", + task_id: str, +) -> str: + """终止正在执行的后台工具 + + Args: + event: 消息事件 + task_id: 任务ID + + Returns: + 终止结果 + """ + manager = _get_manager() + + task = manager.registry.get(task_id) + if task is None: + return f"Error: Task {task_id} not found." + + if task.is_finished(): + return f"Task {task_id} has already finished ({task.status.value})." + + success = await manager.stop_task(task_id) + + if success: + return f"Task {task_id} has been stopped/cancelled." + else: + return f"Failed to stop task {task_id}." + + +async def list_running_tools( + event: "AstrMessageEvent", +) -> str: + """列出当前会话中正在运行的后台工具 + + Args: + event: 消息事件 + + Returns: + 运行中的工具列表 + """ + manager = _get_manager() + session_id = event.unified_msg_origin + + running_tasks = manager.list_running_tasks(session_id) + + if not running_tasks: + return "No background tools are currently running." + + lines = ["Running background tools:"] + for task in running_tasks: + lines.append(f"- {task['task_id']}: {task['tool_name']} ({task['status']})") + + return "\n".join(lines) diff --git a/astrbot/core/background_tool/manager.py b/astrbot/core/background_tool/manager.py new file mode 100644 index 000000000..66ac93a27 --- /dev/null +++ b/astrbot/core/background_tool/manager.py @@ -0,0 +1,329 @@ +"""后台工具管理器 + +编排所有模块,提供统一接口。 +""" + +import asyncio +import threading +from collections.abc import AsyncGenerator, Awaitable, Callable +from typing import Any + +from astrbot.core.tool_execution.domain.config import DEFAULT_CONFIG + +from .output_buffer import OutputBuffer +from .task_executor import TaskExecutor +from .task_notifier import TaskNotifier +from .task_registry import TaskRegistry +from .task_state import BackgroundTask + + +class BackgroundToolManager: + """后台工具管理器 + + 核心编排器,管理后台任务的完整生命周期。 + """ + + _instance = None + _init_lock = threading.Lock() + + def __new__(cls): + """单例模式(线程安全)""" + if cls._instance is None: + with cls._init_lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + self._config = DEFAULT_CONFIG + self.registry = TaskRegistry() + self.output_buffer = OutputBuffer() + self.executor = TaskExecutor(output_buffer=self.output_buffer) + self.notifier = TaskNotifier() + self._interrupt_flags = {} # session_id -> bool,用于中断等待 + self._cleanup_task: asyncio.Task | None = None + self._initialized = True + + def start_cleanup_task(self) -> None: + """启动定时清理任务 + + 应在事件循环启动后调用。如果事件循环未运行,则跳过。 + """ + if self._cleanup_task is None or self._cleanup_task.done(): + try: + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + except RuntimeError: + # 事件循环未运行,跳过 + pass + + async def _cleanup_loop(self) -> None: + """定时清理已完成任务的循环""" + from astrbot import logger + + logger.info("[BackgroundToolManager] Cleanup task started") + + while True: + try: + await asyncio.sleep(self._config.cleanup_interval_seconds) + + # 清理已完成的旧任务 + removed_count = self.registry.cleanup_finished_tasks( + max_age_seconds=self._config.task_max_age_seconds + ) + + # 同步清理OutputBuffer中的孤立缓冲区 + valid_task_ids = set(self.registry._tasks.keys()) + buffer_cleaned = self.output_buffer.cleanup_old_buffers(valid_task_ids) + + # 清理孤立的中断标记(没有活跃任务的会话) + active_sessions = { + task.session_id for task in self.registry._tasks.values() + } + stale_flags = [ + sid for sid in self._interrupt_flags if sid not in active_sessions + ] + for sid in stale_flags: + self._interrupt_flags.pop(sid, None) + + if removed_count > 0 or buffer_cleaned > 0 or stale_flags: + stats = self.registry.count_by_status() + logger.info( + f"[BackgroundToolManager] Cleaned up {removed_count} tasks, " + f"{buffer_cleaned} buffers, {len(stale_flags)} flags, " + f"remaining: {self.registry.count()} ({stats})" + ) + except asyncio.CancelledError: + logger.info("[BackgroundToolManager] Cleanup task cancelled") + break + except Exception as e: + logger.error(f"[BackgroundToolManager] Cleanup error: {e}") + + def stop_cleanup_task(self) -> None: + """停止定时清理任务""" + if self._cleanup_task and not self._cleanup_task.done(): + self._cleanup_task.cancel() + + async def submit_task( + self, + tool_name: str, + tool_args: dict[str, Any], + session_id: str, + handler: Callable[..., Awaitable[str] | AsyncGenerator[str, None]], + wait: bool = True, + event: Any = None, + event_queue: Any = None, + ) -> str: + """提交后台任务 + + Args: + tool_name: 工具名称 + tool_args: 工具参数 + session_id: 会话ID + handler: 工具处理函数 + wait: 是否等待完成 + event: 原始事件对象(用于后台执行) + event_queue: 事件队列(用于触发AI回调) + + Returns: + 任务ID + """ + # 延迟启动清理任务(确保事件循环已运行) + self.start_cleanup_task() + + task = BackgroundTask( + task_id=BackgroundTask.generate_id(), + tool_name=tool_name, + tool_args=tool_args, + session_id=session_id, + event=event, + event_queue=event_queue, + ) + task.init_completion_event() + + from astrbot import logger + + logger.info( + f"[BackgroundToolManager] Creating task {task.task_id} for tool {tool_name}, session {session_id}" + ) + + task.init_completion_event() + self.registry.register(task) + logger.info( + f"[BackgroundToolManager] Task {task.task_id} registered successfully" + ) + + if wait: + await self.executor.execute(task=task, handler=handler) + else: + asyncio.create_task(self.executor.execute(task=task, handler=handler)) + + return task.task_id + + def get_task_output(self, task_id: str, lines: int = 50) -> str: + """获取任务输出 + + Args: + task_id: 任务ID + lines: 返回最近N行 + + Returns: + 输出日志文本 + """ + output_lines = self.output_buffer.get_recent(task_id, n=lines) + return "\n".join(output_lines) + + async def wait_task_result( + self, + task_id: str, + timeout: float = 300, + ) -> str | None: + """等待任务结果 + + Args: + task_id: 任务ID + timeout: 超时时间(秒) + + Returns: + 任务结果,超时返回None + """ + task = self.registry.get(task_id) + if task is None: + return None + + start_time = asyncio.get_event_loop().time() + while not task.is_finished(): + if asyncio.get_event_loop().time() - start_time > timeout: + return None + await asyncio.sleep(0.5) + + return task.result + + async def stop_task(self, task_id: str) -> bool: + """停止任务 + + Args: + task_id: 任务ID + + Returns: + 是否停止成功 + """ + return await self.executor.cancel(task_id) + + def list_running_tasks(self, session_id: str) -> list[dict[str, Any]]: + """列出运行中的任务 + + Args: + session_id: 会话ID + + Returns: + 运行中的任务列表 + """ + tasks = self.registry.get_running_tasks(session_id) + return [t.to_dict() for t in tasks] + + def get_task_status(self, task_id: str) -> dict[str, Any] | None: + """获取任务状态 + + Args: + task_id: 任务ID + + Returns: + 任务状态字典 + """ + task = self.registry.get(task_id) + if task is None: + return None + return task.to_dict() + + def get_pending_notifications(self, session_id: str) -> list[dict[str, Any]]: + """获取待发送的通知 + + Args: + session_id: 会话ID + + Returns: + 待发送通知列表,每项包含task_id和message + """ + tasks = self.registry.get_by_session(session_id) + pending = [] + for task in tasks: + if task.notification_message and not task.notification_sent: + pending.append( + { + "task_id": task.task_id, + "tool_name": task.tool_name, + "status": task.status.value, + "message": task.notification_message, + } + ) + return pending + + def mark_notification_sent(self, task_id: str) -> bool: + """标记通知已发送 + + Args: + task_id: 任务ID + + Returns: + 是否标记成功 + """ + task = self.registry.get(task_id) + if task is None: + return False + task.notification_sent = True + return True + + def set_interrupt_flag(self, session_id: str): + """设置会话的中断标记(用于打断wait_tool_result) + + Args: + session_id: 会话ID + """ + self._interrupt_flags[session_id] = True + + def check_interrupt_flag(self, session_id: str) -> bool: + """检查会话是否有中断标记 + + Args: + session_id: 会话ID + + Returns: + 是否有中断标记 + """ + return self._interrupt_flags.get(session_id, False) + + def clear_interrupt_flag(self, session_id: str): + """清除会话的中断标记 + + Args: + session_id: 会话ID + """ + self._interrupt_flags.pop(session_id, None) + + def get_running_tasks_status(self, session_id: str) -> str | None: + """获取会话中正在运行的后台任务状态信息 + + Args: + session_id: 会话ID + + Returns: + 后台任务状态信息,如果没有运行中的任务则返回None + """ + running_tasks = self.list_running_tasks(session_id) + if not running_tasks: + return None + + status_lines = ["[Background Tasks Status]"] + for task in running_tasks: + status_lines.append( + f"- Task {task['task_id']}: {task['tool_name']} ({task['status']})" + ) + status_lines.append( + "Note: These tasks are running in the background and will notify you when complete." + ) + + return "\n".join(status_lines) diff --git a/astrbot/core/background_tool/output_buffer.py b/astrbot/core/background_tool/output_buffer.py new file mode 100644 index 000000000..e5d4c2c68 --- /dev/null +++ b/astrbot/core/background_tool/output_buffer.py @@ -0,0 +1,116 @@ +"""输出缓冲区 + +缓存工具执行的实时输出日志,支持环形缓冲。 +""" + +from collections import deque + +from astrbot.core.tool_execution.utils.rwlock import RWLock + + +class OutputBuffer: + """输出缓冲区 + + 为每个任务维护一个环形缓冲区,存储输出日志。 + 使用读写锁优化读多写少场景的并发性能。 + + Attributes: + max_lines: 每个任务的最大行数 + """ + + def __init__(self, max_lines: int = 1000): + """初始化输出缓冲区 + + Args: + max_lines: 每个任务的最大行数,超过后自动丢弃旧行 + """ + self.max_lines = max_lines + self._buffers: dict[str, deque[str]] = {} + self._rwlock = RWLock() + + def append(self, task_id: str, line: str) -> None: + """追加一行输出 + + Args: + task_id: 任务ID + line: 输出行 + """ + with self._rwlock.write(): + if task_id not in self._buffers: + self._buffers[task_id] = deque(maxlen=self.max_lines) + self._buffers[task_id].append(line) + + def get_all(self, task_id: str) -> list[str]: + """获取所有输出行 + + Args: + task_id: 任务ID + + Returns: + 所有输出行列表 + """ + with self._rwlock.read(): + buffer = self._buffers.get(task_id) + if buffer is None: + return [] + return list(buffer) + + def get_recent(self, task_id: str, n: int = 50) -> list[str]: + """获取最近N行输出 + + Args: + task_id: 任务ID + n: 行数 + + Returns: + 最近N行输出列表 + """ + all_lines = self.get_all(task_id) + return all_lines[-n:] if len(all_lines) > n else all_lines + + def clear(self, task_id: str) -> None: + """清空任务的输出缓冲区 + + Args: + task_id: 任务ID + """ + with self._rwlock.write(): + if task_id in self._buffers: + self._buffers[task_id].clear() + + def line_count(self, task_id: str) -> int: + """获取任务的输出行数 + + Args: + task_id: 任务ID + + Returns: + 输出行数 + """ + with self._rwlock.read(): + buffer = self._buffers.get(task_id) + return len(buffer) if buffer else 0 + + def remove(self, task_id: str) -> None: + """删除任务的输出缓冲区 + + Args: + task_id: 任务ID + """ + with self._rwlock.write(): + self._buffers.pop(task_id, None) + + def cleanup_old_buffers(self, valid_task_ids: set[str]) -> int: + """清理不再有效的任务缓冲区 + + Args: + valid_task_ids: 仍然有效的任务ID集合 + + Returns: + 清理的缓冲区数量 + """ + with self._rwlock.write(): + to_remove = [tid for tid in self._buffers if tid not in valid_task_ids] + for tid in to_remove: + del self._buffers[tid] + return len(to_remove) diff --git a/astrbot/core/background_tool/register.py b/astrbot/core/background_tool/register.py new file mode 100644 index 000000000..54a33e2aa --- /dev/null +++ b/astrbot/core/background_tool/register.py @@ -0,0 +1,82 @@ +"""后台工具注册 + +将后台任务管理工具注册到AstrBot的LLM工具系统。 +""" + +from astrbot import logger +from astrbot.core.provider.func_tool_manager import FuncCall + + +def register_background_tools(llm_tools: FuncCall) -> None: + """注册后台任务管理工具 + + Args: + llm_tools: LLM工具管理器实例 + """ + from .llm_tools import ( + get_tool_output, + list_running_tools, + stop_tool, + wait_tool_result, + ) + + # 注册 get_tool_output 工具 + llm_tools.add_func( + name="get_tool_output", + func_args=[ + { + "name": "task_id", + "type": "string", + "description": "The ID of the background task", + }, + { + "name": "lines", + "type": "integer", + "description": "Number of recent lines to return (default: 50)", + }, + ], + desc="View the output logs of a background tool. Use this to check the progress of a long-running task.", + handler=get_tool_output, + ) + logger.info("[PROCESS] Registered LLM tool: get_tool_output") + + # 注册 wait_tool_result 工具 + llm_tools.add_func( + name="wait_tool_result", + func_args=[ + { + "name": "task_id", + "type": "string", + "description": "The ID of the background task", + }, + ], + desc="Wait for a background tool to complete. The wait can be interrupted by new user messages. No timeout - waits until task finishes or is terminated.", + handler=wait_tool_result, + ) + logger.info("[PROCESS] Registered LLM tool: wait_tool_result") + + # 注册 stop_tool 工具 + llm_tools.add_func( + name="stop_tool", + func_args=[ + { + "name": "task_id", + "type": "string", + "description": "The ID of the background task to stop", + }, + ], + desc="Stop a running background tool.", + handler=stop_tool, + ) + logger.info("[PROCESS] Registered LLM tool: stop_tool") + + # 注册 list_running_tools 工具 + llm_tools.add_func( + name="list_running_tools", + func_args=[], + desc="List all currently running background tools in this session.", + handler=list_running_tools, + ) + logger.info("[PROCESS] Registered LLM tool: list_running_tools") + + logger.info("[PROCESS] All background tool management tools registered") diff --git a/astrbot/core/background_tool/task_executor.py b/astrbot/core/background_tool/task_executor.py new file mode 100644 index 000000000..18a9358ed --- /dev/null +++ b/astrbot/core/background_tool/task_executor.py @@ -0,0 +1,276 @@ +"""任务执行器 + +在后台执行工具,捕获输出,支持取消和超时。 +""" + +import asyncio +import json +import os +import traceback +from collections.abc import AsyncGenerator, Awaitable, Callable + +from astrbot import logger +from astrbot.core.tool_execution.utils.sanitizer import sanitize_for_log +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +from .callback_publisher import CallbackPublisher +from .output_buffer import OutputBuffer +from .task_notifier import TaskNotifier +from .task_state import BackgroundTask + + +def _get_background_task_timeout() -> int: + """从配置文件中读取后台任务超时时间,如果读取失败则返回默认值600秒 + + 使用模块级缓存避免每次执行都读取配置文件。 + """ + return _ConfigCache.get_timeout() + + +class _ConfigCache: + """配置缓存 + + 缓存配置值,避免频繁读取配置文件。 + """ + + _timeout: int | None = None + _last_load: float = 0 + _cache_ttl: float = 60.0 # 缓存有效期60秒 + + @classmethod + def get_timeout(cls) -> int: + """获取超时配置,带缓存""" + import time + + current_time = time.time() + + # 检查缓存是否过期 + if ( + cls._timeout is not None + and (current_time - cls._last_load) < cls._cache_ttl + ): + return cls._timeout + + # 重新加载配置 + try: + config_path = os.path.join(get_astrbot_data_path(), "cmd_config.json") + if os.path.exists(config_path): + with open(config_path, encoding="utf-8-sig") as f: + config = json.load(f) + cls._timeout = config.get("provider_settings", {}).get( + "background_task_wait_timeout", 600 + ) + cls._last_load = current_time + return cls._timeout + except Exception: + pass + + cls._timeout = 600 + cls._last_load = current_time + return cls._timeout + + +class TaskExecutor: + """任务执行器 + + 管理后台任务的执行、取消和状态跟踪。 + """ + + def __init__( + self, + output_buffer: OutputBuffer, + callback_publisher: CallbackPublisher | None = None, + ): + """初始化任务执行器 + + Args: + output_buffer: 输出缓冲区 + callback_publisher: 回调发布器,默认创建新实例 + """ + self.output_buffer = output_buffer + self.notifier = TaskNotifier() + self.callback_publisher = callback_publisher or CallbackPublisher() + self._running_tasks: dict[str, asyncio.Task] = {} + self._cancel_events: dict[str, asyncio.Event] = {} + + async def execute( + self, + task: BackgroundTask, + handler: Callable[..., Awaitable[str] | AsyncGenerator[str, None]], + ) -> str | None: + """执行任务 + + Args: + task: 后台任务 + handler: 工具处理函数 + + Returns: + 执行结果 + """ + task.start() + self._cancel_events[task.task_id] = asyncio.Event() + + # 获取后台任务超时时间 + timeout = _get_background_task_timeout() + + try: + # 创建执行任务 + exec_coro = self._run_handler(task, handler) + async_task = asyncio.create_task(exec_coro) + self._running_tasks[task.task_id] = async_task + + # 使用 wait_for 添加超时控制 + try: + result = await asyncio.wait_for(async_task, timeout=timeout) + except asyncio.TimeoutError: + # 超时,取消任务 + async_task.cancel() + try: + await async_task + except asyncio.CancelledError: + pass + error_msg = f"Task timed out after {timeout}s and was terminated." + task.fail(error_msg) + self._log(task.task_id, f"[TIMEOUT] {error_msg}") + # 生成超时通知消息(包含输出日志) + output = "\n".join(self.output_buffer.get_recent(task.task_id, n=50)) + task.notification_message = self.notifier.build_message(task, output) + # 主动触发回调 + await self.callback_publisher.publish(task) + return None + + task.complete(result or "") + # 生成完成通知消息(包含输出日志) + output = "\n".join(self.output_buffer.get_recent(task.task_id, n=50)) + task.notification_message = self.notifier.build_message(task, output) + self._log(task.task_id, "[NOTIFICATION] Task completed, notification ready") + # 主动触发回调 + await self.callback_publisher.publish(task) + return result + + except asyncio.CancelledError: + task.cancel() + # 生成取消通知消息(包含输出日志) + output = "\n".join(self.output_buffer.get_recent(task.task_id, n=50)) + task.notification_message = self.notifier.build_message(task, output) + self._log( + task.task_id, "[CANCELLED] Task was cancelled, notification ready" + ) + # 主动触发回调 + await self.callback_publisher.publish(task) + return None + + except Exception as e: + # 检查是否是 wait_tool_result 被中断的情况,这种情况不需要触发回调 + from astrbot.core.background_tool import WaitInterruptedException + + is_wait_interrupted = isinstance(e, WaitInterruptedException) + + # 用户可见的错误信息(不含敏感堆栈) + user_error_msg = f"Task failed: {type(e).__name__}: {e}" + # 仅在 DEBUG 日志中记录完整堆栈 + debug_error_msg = f"{e}\n{traceback.format_exc()}" + logger.debug( + f"[BackgroundTask:{task.task_id}] Full traceback: {debug_error_msg}" + ) + + task.fail(user_error_msg) + + if is_wait_interrupted: + # wait_tool_result 被中断是正常行为,不需要通知用户 + self._log( + task.task_id, "[INTERRUPTED] Wait interrupted, no callback needed" + ) + else: + # 其他错误需要生成通知消息并触发回调(包含输出日志) + output = "\n".join(self.output_buffer.get_recent(task.task_id, n=50)) + task.notification_message = self.notifier.build_message(task, output) + self._log(task.task_id, f"[ERROR] {user_error_msg}, notification ready") + # 主动触发回调 + await self.callback_publisher.publish(task) + return None + + finally: + self._cleanup(task.task_id) + task.release_references() # 释放大对象引用,防止内存泄露 + + async def _run_handler( + self, + task: BackgroundTask, + handler: Callable, + ) -> str | None: + """运行处理函数""" + self._log(task.task_id, f"[START] Executing {task.tool_name}") + # 使用脱敏后的参数,防止敏感信息泄露 + self._log(task.task_id, f"[ARGS] {sanitize_for_log(task.tool_args)}") + + try: + # 构建调用参数,如果有event则传递 + call_args = dict(task.tool_args) + if task.event is not None: + call_args["event"] = task.event + + result = handler(**call_args) + + # 检查是否是异步生成器 + if hasattr(result, "__anext__"): + final_result = None + async for output in result: + if output is not None: + self._log(task.task_id, str(output)) + final_result = output + return final_result + + # 检查是否是协程 + elif asyncio.iscoroutine(result): + return await result + + else: + return result + + except Exception: + raise + + async def cancel(self, task_id: str) -> bool: + """取消任务 + + Args: + task_id: 任务ID + + Returns: + 是否取消成功 + """ + if task_id not in self._running_tasks: + return False + + async_task = self._running_tasks.get(task_id) + if async_task and not async_task.done(): + async_task.cancel() + # 设置取消事件 + if task_id in self._cancel_events: + self._cancel_events[task_id].set() + return True + + return False + + def is_running(self, task_id: str) -> bool: + """检查任务是否运行中 + + Args: + task_id: 任务ID + + Returns: + 是否运行中 + """ + async_task = self._running_tasks.get(task_id) + return async_task is not None and not async_task.done() + + def _log(self, task_id: str, message: str) -> None: + """记录日志到缓冲区""" + self.output_buffer.append(task_id, message) + logger.debug(f"[BackgroundTask:{task_id}] {message}") + + def _cleanup(self, task_id: str) -> None: + """清理任务资源""" + self._running_tasks.pop(task_id, None) + self._cancel_events.pop(task_id, None) diff --git a/astrbot/core/background_tool/task_formatter.py b/astrbot/core/background_tool/task_formatter.py new file mode 100644 index 000000000..ef311b400 --- /dev/null +++ b/astrbot/core/background_tool/task_formatter.py @@ -0,0 +1,44 @@ +"""任务结果格式化器 + +提供统一的任务结果格式化函数,供LLM工具和通知系统使用。 +""" + +from .task_state import BackgroundTask + + +def build_task_result( + task_id: str, task: BackgroundTask, output: str | None = None +) -> str: + """构建任务结果的完整信息 + + Args: + task_id: 任务ID + task: 任务对象 + output: 输出日志(可选) + + Returns: + 格式化的任务结果信息,包含: + - 任务状态 + - 输出日志(如果提供) + - 最终结果 + - 错误信息 + - 通知消息 + """ + status = task.status.value + result_text = f"Task {task_id} ({task.tool_name}, {status}):\n" + + # 如果有输出日志,显示日志 + if output: + result_text += f"\n{output}\n" + + # 如果任务已完成,显示最终结果 + if task.is_finished(): + if task.result: + result_text += f"\n[FINAL RESULT]\n{task.result}" + elif task.error: + result_text += f"\n[ERROR]\n{task.error}" + + if not output and not task.is_finished(): + return f"Task {task_id} ({task.tool_name}, {status}): No output yet." + + return result_text diff --git a/astrbot/core/background_tool/task_notifier.py b/astrbot/core/background_tool/task_notifier.py new file mode 100644 index 000000000..a62fa6b00 --- /dev/null +++ b/astrbot/core/background_tool/task_notifier.py @@ -0,0 +1,56 @@ +"""任务通知器 + +任务完成后主动通知AI。 +""" + +from collections.abc import Awaitable, Callable + +from .task_formatter import build_task_result +from .task_state import BackgroundTask + + +class TaskNotifier: + """任务通知器 + + 负责在后台任务完成后构建通知消息并发送。 + """ + + def should_notify(self, task: BackgroundTask) -> bool: + """检查是否应该通知 + + Args: + task: 后台任务 + + Returns: + 是否应该通知 + """ + return task.is_finished() + + def build_message(self, task: BackgroundTask, output: str | None = None) -> str: + """构建通知消息 + + Args: + task: 后台任务 + output: 输出日志(可选) + + Returns: + 通知消息文本,包含完整的任务信息(状态、输出日志、结果、错误) + """ + return build_task_result(task.task_id, task, output) + + async def notify_completion( + self, + task: BackgroundTask, + send_callback: Callable[[str], Awaitable[None]], + ) -> None: + """通知任务完成 + + Args: + task: 后台任务 + send_callback: 发送消息的回调函数 + """ + if not self.should_notify(task): + return + + message = self.build_message(task) + await send_callback(message) diff --git a/astrbot/core/background_tool/task_registry.py b/astrbot/core/background_tool/task_registry.py new file mode 100644 index 000000000..1563c5580 --- /dev/null +++ b/astrbot/core/background_tool/task_registry.py @@ -0,0 +1,184 @@ +"""任务注册表 + +管理后台任务的注册、查询、更新、删除。 +""" + +import threading +import time +from typing import Any + +from astrbot.core.tool_execution.utils.rwlock import RWLock + +from .task_state import BackgroundTask, TaskStatus + + +class TaskRegistry: + """任务注册表 + + 线程安全的任务管理器,支持按ID和会话查询任务。 + 使用读写锁优化读多写少场景的并发性能。 + """ + + _instance = None + _init_lock = threading.Lock() + + def __new__(cls): + """单例模式""" + if cls._instance is None: + with cls._init_lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._tasks: dict[str, BackgroundTask] = {} + cls._instance._session_index: dict[str, set[str]] = {} + cls._instance._rwlock = RWLock() + return cls._instance + + def register(self, task: BackgroundTask) -> str: + """注册任务 + + Args: + task: 后台任务对象 + + Returns: + 任务ID + """ + with self._rwlock.write(): + self._tasks[task.task_id] = task + + # 更新会话索引 + if task.session_id not in self._session_index: + self._session_index[task.session_id] = set() + self._session_index[task.session_id].add(task.task_id) + + return task.task_id + + def get(self, task_id: str) -> BackgroundTask | None: + """获取任务 + + Args: + task_id: 任务ID + + Returns: + 任务对象,不存在返回None + """ + with self._rwlock.read(): + return self._tasks.get(task_id) + + def get_by_session(self, session_id: str) -> list[BackgroundTask]: + """按会话获取任务 + + Args: + session_id: 会话ID + + Returns: + 该会话的所有任务列表 + """ + with self._rwlock.read(): + task_ids = self._session_index.get(session_id, set()) + return [self._tasks[tid] for tid in task_ids if tid in self._tasks] + + def get_running_tasks(self, session_id: str) -> list[BackgroundTask]: + """获取会话中运行中的任务 + + Args: + session_id: 会话ID + + Returns: + 运行中的任务列表 + """ + tasks = self.get_by_session(session_id) + return [t for t in tasks if t.status == TaskStatus.RUNNING] + + def update(self, task_id: str, **kwargs: Any) -> bool: + """更新任务属性 + + Args: + task_id: 任务ID + **kwargs: 要更新的属性 + + Returns: + 是否更新成功 + """ + with self._rwlock.write(): + task = self._tasks.get(task_id) + if task is None: + return False + + for key, value in kwargs.items(): + if hasattr(task, key): + setattr(task, key, value) + + return True + + def remove(self, task_id: str) -> bool: + """删除任务 + + Args: + task_id: 任务ID + + Returns: + 是否删除成功 + """ + with self._rwlock.write(): + task = self._tasks.pop(task_id, None) + if task is None: + return False + + # 更新会话索引 + if task.session_id in self._session_index: + self._session_index[task.session_id].discard(task_id) + + return True + + def cleanup_finished_tasks(self, max_age_seconds: float = 3600) -> int: + """清理已完成的旧任务 + + Args: + max_age_seconds: 任务完成后保留的最大时间(秒),默认1小时 + + Returns: + 清理的任务数量 + """ + current_time = time.time() + tasks_to_remove = [] + + # 使用读锁收集要删除的任务 + with self._rwlock.read(): + for task_id, task in self._tasks.items(): + # 只清理已完成的任务 + if task.is_finished() and task.completed_at is not None: + age = current_time - task.completed_at + if age > max_age_seconds: + tasks_to_remove.append(task_id) + + # 使用写锁执行删除 + removed_count = 0 + for task_id in tasks_to_remove: + if self.remove(task_id): + removed_count += 1 + + return removed_count + + def clear(self) -> None: + """清空所有任务""" + with self._rwlock.write(): + self._tasks.clear() + self._session_index.clear() + + def count(self) -> int: + """获取任务数量""" + with self._rwlock.read(): + return len(self._tasks) + + def count_by_status(self) -> dict[str, int]: + """按状态统计任务数量 + + Returns: + 状态 -> 数量的字典 + """ + with self._rwlock.read(): + counts: dict[str, int] = {} + for task in self._tasks.values(): + status = task.status.value + counts[status] = counts.get(status, 0) + 1 + return counts diff --git a/astrbot/core/background_tool/task_state.py b/astrbot/core/background_tool/task_state.py new file mode 100644 index 000000000..41c6576e2 --- /dev/null +++ b/astrbot/core/background_tool/task_state.py @@ -0,0 +1,145 @@ +"""后台任务状态定义 + +定义后台任务的状态数据结构和状态转换逻辑。 +""" + +import time +import uuid +from asyncio import Event, Queue +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class TaskStatus(Enum): + """任务状态枚举""" + + PENDING = "pending" # 等待执行 + RUNNING = "running" # 正在执行 + COMPLETED = "completed" # 执行完成 + FAILED = "failed" # 执行失败 + CANCELLED = "cancelled" # 已取消 + + +@dataclass +class BackgroundTask: + """后台任务状态 + + Attributes: + task_id: 任务唯一ID + tool_name: 工具名称 + tool_args: 工具参数 + session_id: 会话ID (unified_msg_origin) + status: 任务状态 + created_at: 创建时间 + started_at: 开始时间 + completed_at: 完成时间 + result: 执行结果 + error: 错误信息 + output_log: 输出日志 + event: 原始事件对象(用于主动回调) + """ + + task_id: str + tool_name: str + tool_args: dict[str, Any] + session_id: str + status: TaskStatus = TaskStatus.PENDING + created_at: float = field(default_factory=time.time) + started_at: float | None = None + completed_at: float | None = None + result: str | None = None + error: str | None = None + output_log: list[str] = field(default_factory=list) + notification_message: str | None = None # 任务完成后的通知消息 + notification_sent: bool = False # 通知是否已发送 + event: Any = None # 原始事件对象,用于主动回调 + event_queue: Queue | None = None # 事件队列,用于触发AI回调 + is_being_waited: bool = False + completion_event: Event | None = ( + None # 任务完成信号 # 是否有LLM正在使用wait_tool_result等待此任务 + ) + + @staticmethod + def generate_id() -> str: + """生成唯一任务ID""" + return str(uuid.uuid4())[:8] + + def start(self) -> None: + """标记任务开始执行""" + self.status = TaskStatus.RUNNING + self.started_at = time.time() + + def complete(self, result: str) -> None: + """标记任务完成""" + self.status = TaskStatus.COMPLETED + self.result = result + self.completed_at = time.time() + self._signal_completion() + + def fail(self, error: str) -> None: + """标记任务失败""" + self.status = TaskStatus.FAILED + self.error = error + self.completed_at = time.time() + self._signal_completion() + + def cancel(self) -> None: + """标记任务取消""" + self.status = TaskStatus.CANCELLED + self.completed_at = time.time() + self._signal_completion() + + def append_output(self, line: str) -> None: + """追加输出日志""" + self.output_log.append(line) + + def is_finished(self) -> bool: + """检查任务是否已完成""" + return self.status in ( + TaskStatus.COMPLETED, + TaskStatus.FAILED, + TaskStatus.CANCELLED, + ) + + def to_dict(self) -> dict[str, Any]: + """序列化为字典""" + return { + "task_id": self.task_id, + "tool_name": self.tool_name, + "tool_args": self.tool_args, + "session_id": self.session_id, + "status": self.status.value, + "created_at": self.created_at, + "started_at": self.started_at, + "completed_at": self.completed_at, + "result": self.result, + "error": self.error, + "output_log_count": len(self.output_log), + } + + def _signal_completion(self) -> None: + """触发完成信号""" + if self.completion_event is not None: + self.completion_event.set() + + def release_references(self) -> None: + """释放大对象引用,防止内存泄露 + + 任务完成后,不再需要event和event_queue引用, + 释放它们以允许垃圾回收。 + 应在所有回调完成后调用。 + """ + self.event = None + self.event_queue = None + self.completion_event = None + self.output_log.clear() # 日志已保存到OutputBuffer + + def init_completion_event(self) -> None: + """初始化完成事件""" + import asyncio + + try: + self.completion_event = asyncio.Event() + except RuntimeError: + pass diff --git a/astrbot/core/background_tool/tests/__init__.py b/astrbot/core/background_tool/tests/__init__.py new file mode 100644 index 000000000..ae78246ec --- /dev/null +++ b/astrbot/core/background_tool/tests/__init__.py @@ -0,0 +1 @@ +# tests/__init__.py diff --git a/astrbot/core/background_tool/tests/test_llm_tools.py b/astrbot/core/background_tool/tests/test_llm_tools.py new file mode 100644 index 000000000..8a5218398 --- /dev/null +++ b/astrbot/core/background_tool/tests/test_llm_tools.py @@ -0,0 +1,110 @@ +"""LLM工具集单元测试""" + +import asyncio +from unittest.mock import MagicMock + +import pytest + +from astrbot.core.background_tool.llm_tools import ( + get_tool_output, + list_running_tools, + stop_tool, +) +from astrbot.core.background_tool.manager import BackgroundToolManager +from astrbot.core.background_tool.task_state import BackgroundTask + + +class TestLLMTools: + """测试LLM工具集""" + + def setup_method(self): + """每个测试前初始化""" + self.manager = BackgroundToolManager() + self.manager.registry.clear() + + @pytest.mark.asyncio + async def test_get_tool_output(self): + """测试获取工具输出""" + # 创建一个任务并添加输出 + task = BackgroundTask( + task_id="test-001", + tool_name="test_tool", + tool_args={}, + session_id="session-A", + ) + self.manager.registry.register(task) + self.manager.output_buffer.append("test-001", "line 1") + self.manager.output_buffer.append("test-001", "line 2") + + # 模拟事件对象 + mock_event = MagicMock() + mock_event.unified_msg_origin = "session-A" + + result = await get_tool_output(mock_event, task_id="test-001") + + assert "line 1" in result + assert "line 2" in result + + @pytest.mark.asyncio + async def test_get_tool_output_not_found(self): + """测试获取不存在的任务输出""" + mock_event = MagicMock() + mock_event.unified_msg_origin = "session-A" + + result = await get_tool_output(mock_event, task_id="nonexistent") + + assert "not found" in result.lower() or "error" in result.lower() + + @pytest.mark.asyncio + async def test_stop_tool(self): + """测试停止工具""" + + async def slow_handler(**kwargs): + await asyncio.sleep(10) + return "done" + + task_id = await self.manager.submit_task( + tool_name="slow_tool", + tool_args={}, + session_id="session-A", + handler=slow_handler, + wait=False, + ) + + await asyncio.sleep(0.1) + + mock_event = MagicMock() + mock_event.unified_msg_origin = "session-A" + + result = await stop_tool(mock_event, task_id=task_id) + + assert "stopped" in result.lower() or "cancelled" in result.lower() + + @pytest.mark.asyncio + async def test_list_running_tools(self): + """测试列出运行中的工具""" + + async def slow_handler(**kwargs): + await asyncio.sleep(5) + return "done" + + task_id = await self.manager.submit_task( + tool_name="slow_tool", + tool_args={}, + session_id="session-A", + handler=slow_handler, + wait=False, + ) + + await asyncio.sleep(0.1) + + mock_event = MagicMock() + mock_event.unified_msg_origin = "session-A" + + result = await list_running_tools(mock_event) + + # 清理 + await self.manager.stop_task(task_id) + + # 结果应该是字符串 + assert isinstance(result, str) diff --git a/astrbot/core/background_tool/tests/test_manager.py b/astrbot/core/background_tool/tests/test_manager.py new file mode 100644 index 000000000..2a0362269 --- /dev/null +++ b/astrbot/core/background_tool/tests/test_manager.py @@ -0,0 +1,127 @@ +"""BackgroundToolManager 单元测试""" + +import asyncio + +import pytest + +from astrbot.core.background_tool.manager import BackgroundToolManager +from astrbot.core.background_tool.task_state import TaskStatus + + +class TestBackgroundToolManager: + """测试后台工具管理器""" + + def setup_method(self): + """每个测试前初始化""" + self.manager = BackgroundToolManager() + self.manager.registry.clear() + + @pytest.mark.asyncio + async def test_submit_task(self): + """测试提交任务""" + + async def mock_handler(**kwargs): + return "result" + + task_id = await self.manager.submit_task( + tool_name="test_tool", + tool_args={"key": "value"}, + session_id="session-A", + handler=mock_handler, + ) + + assert task_id is not None + task = self.manager.registry.get(task_id) + assert task is not None + + @pytest.mark.asyncio + async def test_get_task_output(self): + """测试获取任务输出""" + + async def mock_handler(**kwargs): + yield "line 1" + yield "line 2" + yield "done" + + task_id = await self.manager.submit_task( + tool_name="test_tool", + tool_args={}, + session_id="session-A", + handler=mock_handler, + ) + + # 等待任务完成 + await asyncio.sleep(0.2) + + output = self.manager.get_task_output(task_id) + assert "line 1" in output or len(output) > 0 + + @pytest.mark.asyncio + async def test_list_running_tasks(self): + """测试列出运行中的任务""" + + async def slow_handler(**kwargs): + await asyncio.sleep(5) + return "done" + + task_id = await self.manager.submit_task( + tool_name="slow_tool", + tool_args={}, + session_id="session-A", + handler=slow_handler, + wait=False, + ) + + await asyncio.sleep(0.1) + + running = self.manager.list_running_tasks("session-A") + assert len(running) >= 0 # 可能已完成 + + # 清理 + await self.manager.stop_task(task_id) + + @pytest.mark.asyncio + async def test_stop_task(self): + """测试停止任务""" + + async def slow_handler(**kwargs): + await asyncio.sleep(10) + return "should not reach" + + task_id = await self.manager.submit_task( + tool_name="slow_tool", + tool_args={}, + session_id="session-A", + handler=slow_handler, + wait=False, + ) + + await asyncio.sleep(0.1) + + success = await self.manager.stop_task(task_id) + assert success + + await asyncio.sleep(0.2) + + task = self.manager.registry.get(task_id) + assert task.status == TaskStatus.CANCELLED + + @pytest.mark.asyncio + async def test_get_task_status(self): + """测试获取任务状态""" + + async def mock_handler(**kwargs): + return "done" + + task_id = await self.manager.submit_task( + tool_name="test_tool", + tool_args={}, + session_id="session-A", + handler=mock_handler, + ) + + await asyncio.sleep(0.2) + + status = self.manager.get_task_status(task_id) + assert status is not None + assert "task_id" in status diff --git a/astrbot/core/background_tool/tests/test_output_buffer.py b/astrbot/core/background_tool/tests/test_output_buffer.py new file mode 100644 index 000000000..6353701aa --- /dev/null +++ b/astrbot/core/background_tool/tests/test_output_buffer.py @@ -0,0 +1,94 @@ +"""OutputBuffer 单元测试""" + +from astrbot.core.background_tool.output_buffer import OutputBuffer + + +class TestOutputBuffer: + """测试输出缓冲区""" + + def setup_method(self): + """每个测试前重置缓冲区""" + self.buffer = OutputBuffer(max_lines=100) + + def test_append_line(self): + """测试追加行""" + self.buffer.append("task-001", "line 1") + self.buffer.append("task-001", "line 2") + + lines = self.buffer.get_all("task-001") + + assert len(lines) == 2 + assert lines[0] == "line 1" + assert lines[1] == "line 2" + + def test_get_all_empty(self): + """测试获取空缓冲区""" + lines = self.buffer.get_all("nonexistent") + assert lines == [] + + def test_get_recent(self): + """测试获取最近N行""" + for i in range(10): + self.buffer.append("task-002", f"line {i}") + + recent = self.buffer.get_recent("task-002", n=3) + + assert len(recent) == 3 + assert recent[0] == "line 7" + assert recent[1] == "line 8" + assert recent[2] == "line 9" + + def test_get_recent_less_than_n(self): + """测试获取最近N行(实际行数少于N)""" + self.buffer.append("task-003", "line 1") + self.buffer.append("task-003", "line 2") + + recent = self.buffer.get_recent("task-003", n=10) + + assert len(recent) == 2 + + def test_clear(self): + """测试清空缓冲区""" + self.buffer.append("task-004", "line 1") + self.buffer.append("task-004", "line 2") + + self.buffer.clear("task-004") + + assert self.buffer.get_all("task-004") == [] + + def test_max_lines_limit(self): + """测试最大行数限制""" + buffer = OutputBuffer(max_lines=5) + + for i in range(10): + buffer.append("task-005", f"line {i}") + + lines = buffer.get_all("task-005") + + assert len(lines) == 5 + # 应该保留最后5行 + assert lines[0] == "line 5" + assert lines[4] == "line 9" + + def test_multiple_tasks(self): + """测试多任务隔离""" + self.buffer.append("task-A", "A line 1") + self.buffer.append("task-B", "B line 1") + self.buffer.append("task-A", "A line 2") + + lines_a = self.buffer.get_all("task-A") + lines_b = self.buffer.get_all("task-B") + + assert len(lines_a) == 2 + assert len(lines_b) == 1 + assert lines_a[0] == "A line 1" + assert lines_b[0] == "B line 1" + + def test_line_count(self): + """测试行数统计""" + assert self.buffer.line_count("task-006") == 0 + + self.buffer.append("task-006", "line 1") + self.buffer.append("task-006", "line 2") + + assert self.buffer.line_count("task-006") == 2 diff --git a/astrbot/core/background_tool/tests/test_task_executor.py b/astrbot/core/background_tool/tests/test_task_executor.py new file mode 100644 index 000000000..2a0bf1975 --- /dev/null +++ b/astrbot/core/background_tool/tests/test_task_executor.py @@ -0,0 +1,142 @@ +"""TaskExecutor 单元测试""" + +import asyncio + +import pytest + +from astrbot.core.background_tool.output_buffer import OutputBuffer +from astrbot.core.background_tool.task_executor import TaskExecutor +from astrbot.core.background_tool.task_state import BackgroundTask, TaskStatus + + +class TestTaskExecutor: + """测试任务执行器""" + + def setup_method(self): + """每个测试前初始化""" + self.output_buffer = OutputBuffer() + self.executor = TaskExecutor(output_buffer=self.output_buffer) + + @pytest.mark.asyncio + async def test_execute_simple_task(self): + """测试执行简单任务""" + task = BackgroundTask( + task_id="test-001", + tool_name="test_tool", + tool_args={"param": "value"}, + session_id="session-A", + ) + + # 模拟工具处理函数 + async def mock_handler(**kwargs): + return "success result" + + await self.executor.execute( + task=task, + handler=mock_handler, + ) + + assert task.status == TaskStatus.COMPLETED + assert task.result == "success result" + + @pytest.mark.asyncio + async def test_execute_with_output(self): + """测试执行带输出的任务""" + task = BackgroundTask( + task_id="test-002", + tool_name="test_tool", + tool_args={}, + session_id="session-A", + ) + + async def mock_handler(**kwargs): + # 模拟输出 + yield "Processing step 1..." + yield "Processing step 2..." + yield "done" + + await self.executor.execute( + task=task, + handler=mock_handler, + ) + + # 检查输出缓冲区 + lines = self.output_buffer.get_all("test-002") + assert len(lines) >= 2 + + @pytest.mark.asyncio + async def test_execute_failed_task(self): + """测试执行失败的任务""" + task = BackgroundTask( + task_id="test-003", + tool_name="test_tool", + tool_args={}, + session_id="session-A", + ) + + async def mock_handler(**kwargs): + raise Exception("Test error") + + await self.executor.execute( + task=task, + handler=mock_handler, + ) + + assert task.status == TaskStatus.FAILED + assert "Test error" in task.error + + @pytest.mark.asyncio + async def test_cancel_task(self): + """测试取消任务""" + task = BackgroundTask( + task_id="test-004", + tool_name="test_tool", + tool_args={}, + session_id="session-A", + ) + + async def slow_handler(**kwargs): + await asyncio.sleep(10) + return "should not reach" + + # 启动任务 + asyncio.create_task(self.executor.execute(task=task, handler=slow_handler)) + + # 等待任务开始 + await asyncio.sleep(0.1) + + # 取消任务 + success = await self.executor.cancel("test-004") + + # 等待任务完成 + await asyncio.sleep(0.2) + + assert success + assert task.status == TaskStatus.CANCELLED + + @pytest.mark.asyncio + async def test_is_running(self): + """测试检查任务是否运行中""" + task = BackgroundTask( + task_id="test-005", + tool_name="test_tool", + tool_args={}, + session_id="session-A", + ) + + async def slow_handler(**kwargs): + await asyncio.sleep(1) + return "done" + + # 启动任务 + asyncio.create_task(self.executor.execute(task=task, handler=slow_handler)) + + await asyncio.sleep(0.1) + + assert self.executor.is_running("test-005") + + # 取消并等待 + await self.executor.cancel("test-005") + await asyncio.sleep(0.1) + + assert not self.executor.is_running("test-005") diff --git a/astrbot/core/background_tool/tests/test_task_notifier.py b/astrbot/core/background_tool/tests/test_task_notifier.py new file mode 100644 index 000000000..9152020aa --- /dev/null +++ b/astrbot/core/background_tool/tests/test_task_notifier.py @@ -0,0 +1,108 @@ +"""TaskNotifier 单元测试""" + +from unittest.mock import AsyncMock + +import pytest + +from astrbot.core.background_tool.task_notifier import TaskNotifier +from astrbot.core.background_tool.task_state import BackgroundTask + + +class TestTaskNotifier: + """测试任务通知器""" + + def setup_method(self): + """每个测试前初始化""" + self.notifier = TaskNotifier() + + @pytest.mark.asyncio + async def test_notify_completion(self): + """测试通知任务完成""" + task = BackgroundTask( + task_id="test-001", + tool_name="test_tool", + tool_args={"key": "value"}, + session_id="platform:group:123", + ) + task.start() + task.complete("Task completed successfully") + + # 模拟发送消息的回调 + send_callback = AsyncMock() + + await self.notifier.notify_completion( + task=task, + send_callback=send_callback, + ) + + # 验证回调被调用 + send_callback.assert_called_once() + call_args = send_callback.call_args + message = call_args[0][0] + assert "test-001" in message + assert "completed" in message.lower() + + @pytest.mark.asyncio + async def test_notify_failure(self): + """测试通知任务失败""" + task = BackgroundTask( + task_id="test-002", + tool_name="test_tool", + tool_args={}, + session_id="platform:group:123", + ) + task.start() + task.fail("Something went wrong") + + send_callback = AsyncMock() + + await self.notifier.notify_completion( + task=task, + send_callback=send_callback, + ) + + send_callback.assert_called_once() + call_args = send_callback.call_args + message = call_args[0][0] + assert "failed" in message.lower() + + @pytest.mark.asyncio + async def test_build_notification_message(self): + """测试构建通知消息""" + task = BackgroundTask( + task_id="test-003", + tool_name="my_tool", + tool_args={"param": "value"}, + session_id="platform:group:123", + ) + task.start() + task.complete("Result data") + + message = self.notifier.build_message(task) + + assert "test-003" in message + assert "my_tool" in message + assert "Result data" in message + + def test_should_notify(self): + """测试是否应该通知""" + # 完成的任务应该通知 + task1 = BackgroundTask( + task_id="test-004", + tool_name="tool", + tool_args={}, + session_id="session", + ) + task1.start() + task1.complete("done") + assert self.notifier.should_notify(task1) + + # 运行中的任务不应该通知 + task2 = BackgroundTask( + task_id="test-005", + tool_name="tool", + tool_args={}, + session_id="session", + ) + task2.start() + assert not self.notifier.should_notify(task2) diff --git a/astrbot/core/background_tool/tests/test_task_registry.py b/astrbot/core/background_tool/tests/test_task_registry.py new file mode 100644 index 000000000..2a5fa0b18 --- /dev/null +++ b/astrbot/core/background_tool/tests/test_task_registry.py @@ -0,0 +1,157 @@ +"""TaskRegistry 单元测试""" + +from astrbot.core.background_tool.task_registry import TaskRegistry +from astrbot.core.background_tool.task_state import BackgroundTask, TaskStatus + + +class TestTaskRegistry: + """测试任务注册表""" + + def setup_method(self): + """每个测试前重置注册表""" + self.registry = TaskRegistry() + self.registry.clear() + + def test_register_task(self): + """测试注册任务""" + task = BackgroundTask( + task_id="test-001", + tool_name="test_tool", + tool_args={}, + session_id="platform:group:123", + ) + + task_id = self.registry.register(task) + + assert task_id == "test-001" + assert self.registry.get("test-001") is not None + + def test_get_task(self): + """测试获取任务""" + task = BackgroundTask( + task_id="test-002", + tool_name="test_tool", + tool_args={"key": "value"}, + session_id="platform:group:123", + ) + self.registry.register(task) + + retrieved = self.registry.get("test-002") + + assert retrieved is not None + assert retrieved.task_id == "test-002" + assert retrieved.tool_args == {"key": "value"} + + def test_get_nonexistent_task(self): + """测试获取不存在的任务""" + result = self.registry.get("nonexistent") + assert result is None + + def test_get_by_session(self): + """测试按会话获取任务""" + task1 = BackgroundTask( + task_id="test-003", + tool_name="tool1", + tool_args={}, + session_id="session-A", + ) + task2 = BackgroundTask( + task_id="test-004", + tool_name="tool2", + tool_args={}, + session_id="session-A", + ) + task3 = BackgroundTask( + task_id="test-005", + tool_name="tool3", + tool_args={}, + session_id="session-B", + ) + + self.registry.register(task1) + self.registry.register(task2) + self.registry.register(task3) + + session_a_tasks = self.registry.get_by_session("session-A") + + assert len(session_a_tasks) == 2 + assert all(t.session_id == "session-A" for t in session_a_tasks) + + def test_get_running_tasks(self): + """测试获取运行中的任务""" + task1 = BackgroundTask( + task_id="test-006", + tool_name="tool1", + tool_args={}, + session_id="session-A", + ) + task2 = BackgroundTask( + task_id="test-007", + tool_name="tool2", + tool_args={}, + session_id="session-A", + ) + + self.registry.register(task1) + self.registry.register(task2) + + # 启动第一个任务 + task1.start() + + running = self.registry.get_running_tasks("session-A") + + assert len(running) == 1 + assert running[0].task_id == "test-006" + + def test_update_task(self): + """测试更新任务""" + task = BackgroundTask( + task_id="test-008", + tool_name="test_tool", + tool_args={}, + session_id="platform:group:123", + ) + self.registry.register(task) + + success = self.registry.update( + "test-008", + status=TaskStatus.RUNNING, + ) + + assert success + updated = self.registry.get("test-008") + assert updated.status == TaskStatus.RUNNING + + def test_remove_task(self): + """测试删除任务""" + task = BackgroundTask( + task_id="test-009", + tool_name="test_tool", + tool_args={}, + session_id="platform:group:123", + ) + self.registry.register(task) + + success = self.registry.remove("test-009") + + assert success + assert self.registry.get("test-009") is None + + def test_remove_nonexistent_task(self): + """测试删除不存在的任务""" + success = self.registry.remove("nonexistent") + assert not success + + def test_count(self): + """测试任务计数""" + assert self.registry.count() == 0 + + task = BackgroundTask( + task_id="test-010", + tool_name="test_tool", + tool_args={}, + session_id="platform:group:123", + ) + self.registry.register(task) + + assert self.registry.count() == 1 diff --git a/astrbot/core/background_tool/tests/test_task_state.py b/astrbot/core/background_tool/tests/test_task_state.py new file mode 100644 index 000000000..0bd704bf7 --- /dev/null +++ b/astrbot/core/background_tool/tests/test_task_state.py @@ -0,0 +1,153 @@ +"""TaskState 单元测试""" + +import time + +from astrbot.core.background_tool.task_state import ( + BackgroundTask, + TaskStatus, +) + + +class TestTaskStatus: + """测试任务状态枚举""" + + def test_status_values(self): + """测试状态值""" + assert TaskStatus.PENDING.value == "pending" + assert TaskStatus.RUNNING.value == "running" + assert TaskStatus.COMPLETED.value == "completed" + assert TaskStatus.FAILED.value == "failed" + assert TaskStatus.CANCELLED.value == "cancelled" + + +class TestBackgroundTask: + """测试后台任务数据结构""" + + def test_create_task(self): + """测试创建任务""" + task = BackgroundTask( + task_id="test-001", + tool_name="test_tool", + tool_args={"param1": "value1"}, + session_id="platform:group:123", + ) + + assert task.task_id == "test-001" + assert task.tool_name == "test_tool" + assert task.tool_args == {"param1": "value1"} + assert task.session_id == "platform:group:123" + assert task.status == TaskStatus.PENDING + assert task.result is None + assert task.error is None + assert len(task.output_log) == 0 + + def test_task_start(self): + """测试任务开始""" + task = BackgroundTask( + task_id="test-002", + tool_name="test_tool", + tool_args={}, + session_id="platform:group:123", + ) + + task.start() + + assert task.status == TaskStatus.RUNNING + assert task.started_at is not None + assert task.started_at <= time.time() + + def test_task_complete(self): + """测试任务完成""" + task = BackgroundTask( + task_id="test-003", + tool_name="test_tool", + tool_args={}, + session_id="platform:group:123", + ) + + task.start() + task.complete("success result") + + assert task.status == TaskStatus.COMPLETED + assert task.result == "success result" + assert task.completed_at is not None + + def test_task_fail(self): + """测试任务失败""" + task = BackgroundTask( + task_id="test-004", + tool_name="test_tool", + tool_args={}, + session_id="platform:group:123", + ) + + task.start() + task.fail("error message") + + assert task.status == TaskStatus.FAILED + assert task.error == "error message" + assert task.completed_at is not None + + def test_task_cancel(self): + """测试任务取消""" + task = BackgroundTask( + task_id="test-005", + tool_name="test_tool", + tool_args={}, + session_id="platform:group:123", + ) + + task.start() + task.cancel() + + assert task.status == TaskStatus.CANCELLED + assert task.completed_at is not None + + def test_append_output(self): + """测试追加输出日志""" + task = BackgroundTask( + task_id="test-006", + tool_name="test_tool", + tool_args={}, + session_id="platform:group:123", + ) + + task.append_output("line 1") + task.append_output("line 2") + + assert len(task.output_log) == 2 + assert task.output_log[0] == "line 1" + assert task.output_log[1] == "line 2" + + def test_to_dict(self): + """测试序列化为字典""" + task = BackgroundTask( + task_id="test-007", + tool_name="test_tool", + tool_args={"key": "value"}, + session_id="platform:group:123", + ) + + d = task.to_dict() + + assert d["task_id"] == "test-007" + assert d["tool_name"] == "test_tool" + assert d["status"] == "pending" + assert d["tool_args"] == {"key": "value"} + + def test_is_finished(self): + """测试是否已完成""" + task = BackgroundTask( + task_id="test-008", + tool_name="test_tool", + tool_args={}, + session_id="platform:group:123", + ) + + assert not task.is_finished() + + task.start() + assert not task.is_finished() + + task.complete("done") + assert task.is_finished() diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 10a6fc599..74ca0abaf 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -106,6 +106,7 @@ "reachability_check": False, "max_agent_step": 30, "tool_call_timeout": 60, + "background_task_wait_timeout": 600, "tool_schema_mode": "full", "llm_safety_mode": True, "safety_mode_strategy": "system_prompt", # TODO: llm judge @@ -2207,6 +2208,9 @@ class ChatProviderTemplate(TypedDict): "tool_call_timeout": { "type": "int", }, + "background_task_wait_timeout": { + "type": "int", + }, "tool_schema_mode": { "type": "string", }, @@ -2860,6 +2864,13 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.background_task_wait_timeout": { + "description": "后台任务等待超时时间(秒)", + "type": "int", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, "provider_settings.tool_schema_mode": { "description": "工具调用模式", "type": "string", diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 0017e65fa..19b682ba6 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -15,6 +15,7 @@ from astrbot.core import logger from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.background_tool.manager import BackgroundToolManager from astrbot.core.pipeline.scheduler import PipelineScheduler from .platform import AstrMessageEvent @@ -39,6 +40,27 @@ async def dispatch(self): event: AstrMessageEvent = await self.event_queue.get() conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) self._print_event(event, conf_info["name"]) + + # 设置中断标记,用于打断正在执行的wait_tool_result + try: + manager = BackgroundToolManager() + session_id = event.unified_msg_origin + manager.set_interrupt_flag(session_id) + + # 检查是否有正在运行的后台任务,如果有则注入状态信息 + running_tasks_status = manager.get_running_tasks_status(session_id) + if running_tasks_status: + # 将后台任务状态信息注入到event对象中 + event.background_tasks_status = running_tasks_status + logger.info( + f"[EventBus] Injected background tasks status for session {session_id}" + ) + except Exception as e: + logger.error(f"[EventBus] Failed to set interrupt flag: {e}") + + # 将待处理的后台任务通知注入到event中,供AI处理 + await self._inject_notifications(event) + scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"]) if not scheduler: logger.error( @@ -54,6 +76,7 @@ def _print_event(self, event: AstrMessageEvent, conf_name: str): event (AstrMessageEvent): 事件对象 """ + event.trace.record("event_dispatch", config_name=conf_name) # 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要 if event.get_sender_name(): logger.info( @@ -64,3 +87,27 @@ def _print_event(self, event: AstrMessageEvent, conf_name: str): logger.info( f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}", ) + + async def _inject_notifications(self, event: AstrMessageEvent): + """将待处理的后台任务通知注入到event对象中,供AI处理""" + try: + manager = BackgroundToolManager() + session_id = event.unified_msg_origin + + # 获取待发送通知 + notifications = manager.get_pending_notifications(session_id) + + if not notifications: + return + + logger.info( + f"[EventBus] Found {len(notifications)} pending notifications for session {session_id}" + ) + + # 将通知注入到event对象中,让AI处理 + event.pending_notifications = notifications + logger.info( + f"[EventBus] Injected {len(notifications)} notifications into event for AI processing" + ) + except Exception as e: + logger.error(f"[EventBus] Error injecting notifications: {e}") diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 87f0dd419..00e7fc81c 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -184,6 +184,45 @@ async def process( ) return + # 处理待发送的后台任务通知 + if ( + hasattr(event, "pending_notifications") + and event.pending_notifications + ): + from astrbot.core.background_tool.manager import ( + BackgroundToolManager, + ) + + notification_lines = ["[Background Task Completion Notifications]"] + notification_lines.append( + "The following background tasks have completed:" + ) + + for notif in event.pending_notifications: + notification_lines.append( + f"- Task {notif['task_id']} ({notif['tool_name']}): {notif['message']}" + ) + + notification_lines.append( + "Please inform the user about these completed tasks in a natural way." + ) + + notification_text = "\n".join(notification_lines) + + # 添加到上下文 + req.contexts.append( + {"role": "system", "content": notification_text} + ) + + # 标记通知为已发送 + manager = BackgroundToolManager() + for notif in event.pending_notifications: + manager.mark_notification_sent(notif["task_id"]) + + logger.info( + f"[Internal Agent] Injected {len(event.pending_notifications)} notifications into LLM context" + ) + stream_to_general = ( self.unsupported_streaming_strategy == "turn_off" and not event.platform_meta.support_streaming_message diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 3ad83784e..2a2664ff8 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,4 +1,5 @@ from astrbot.core import logger +from astrbot.core.background_tool.register import register_background_tools from .entities import ProviderMetaData, ProviderType from .func_tool_manager import FuncCall @@ -10,6 +11,9 @@ llm_tools = FuncCall() +# 注册后台工具管理的LLM工具 +register_background_tools(llm_tools) + def register_provider_adapter( provider_type_name: str, diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index c7438baf2..5532b35ed 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -147,6 +147,7 @@ async def tool_loop_agent( contexts: list[Message] | None = None, max_steps: int = 30, tool_call_timeout: int = 60, + background_task_wait_timeout: int = 300, **kwargs: Any, ) -> LLMResponse: """Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced. @@ -226,6 +227,7 @@ async def tool_loop_agent( run_context=AgentContextWrapper( context=agent_context, tool_call_timeout=tool_call_timeout, + background_task_wait_timeout=background_task_wait_timeout, ), tool_executor=tool_executor, agent_hooks=agent_hooks, diff --git a/astrbot/core/tool_execution/__init__.py b/astrbot/core/tool_execution/__init__.py new file mode 100644 index 000000000..32f3f6a75 --- /dev/null +++ b/astrbot/core/tool_execution/__init__.py @@ -0,0 +1,40 @@ +"""Tool Execution Module + +洋葱架构的工具执行模块。 +""" + +from .errors import ( + BackgroundTaskError, + MethodResolutionError, + ParameterValidationError, + TimeoutError, + ToolExecutionError, +) +from .interfaces import ( + IBackgroundTaskManager, + ICallbackEventBuilder, + ICompletionSignal, + IMethodResolver, + IParameterValidator, + IResultProcessor, + ITimeoutHandler, + ITimeoutStrategy, +) + +__all__ = [ + # Interfaces + "IMethodResolver", + "IParameterValidator", + "IResultProcessor", + "ITimeoutStrategy", + "ITimeoutHandler", + "IBackgroundTaskManager", + "ICompletionSignal", + "ICallbackEventBuilder", + # Errors + "ToolExecutionError", + "MethodResolutionError", + "ParameterValidationError", + "TimeoutError", + "BackgroundTaskError", +] diff --git a/astrbot/core/tool_execution/application/__init__.py b/astrbot/core/tool_execution/application/__init__.py new file mode 100644 index 000000000..5fd2d4174 --- /dev/null +++ b/astrbot/core/tool_execution/application/__init__.py @@ -0,0 +1,8 @@ +"""Application Layer + +应用层,用例编排。 +""" + +from .tool_executor import ToolExecutor + +__all__ = ["ToolExecutor"] diff --git a/astrbot/core/tool_execution/application/tool_executor.py b/astrbot/core/tool_execution/application/tool_executor.py new file mode 100644 index 000000000..0a18110dc --- /dev/null +++ b/astrbot/core/tool_execution/application/tool_executor.py @@ -0,0 +1,155 @@ +"""工具执行编排器 + +组合各组件,编排工具执行流程。 +""" + +import asyncio +from collections.abc import AsyncGenerator +from typing import Any + +import mcp.types + +from astrbot.core.tool_execution.domain.config import BACKGROUND_TOOL_NAMES +from astrbot.core.tool_execution.interfaces import ( + IMethodResolver, + IResultProcessor, + ITimeoutHandler, + ITimeoutStrategy, + IToolInvoker, +) + + +class ToolExecutor: + """工具执行编排器""" + + def __init__( + self, + method_resolver: IMethodResolver = None, + result_processor: IResultProcessor = None, + timeout_strategy: ITimeoutStrategy = None, + timeout_handler: ITimeoutHandler = None, + tool_invoker: IToolInvoker = None, + ): + self._method_resolver = method_resolver + self._result_processor = result_processor + self._timeout_strategy = timeout_strategy + self._timeout_handler = timeout_handler + self._tool_invoker = tool_invoker + + @property + def method_resolver(self) -> IMethodResolver: + if self._method_resolver is None: + from astrbot.core.tool_execution.infrastructure.handler import ( + MethodResolver, + ) + + self._method_resolver = MethodResolver() + return self._method_resolver + + @property + def result_processor(self) -> IResultProcessor: + if self._result_processor is None: + from astrbot.core.tool_execution.infrastructure.handler import ( + ResultProcessor, + ) + + self._result_processor = ResultProcessor() + return self._result_processor + + @property + def timeout_strategy(self) -> ITimeoutStrategy: + if self._timeout_strategy is None: + from astrbot.core.tool_execution.infrastructure.timeout import ( + TimeoutStrategy, + ) + + self._timeout_strategy = TimeoutStrategy() + return self._timeout_strategy + + @property + def timeout_handler(self) -> ITimeoutHandler: + if self._timeout_handler is None: + from astrbot.core.tool_execution.infrastructure.timeout import ( + BackgroundHandler, + ) + + self._timeout_handler = BackgroundHandler() + return self._timeout_handler + + @property + def tool_invoker(self) -> IToolInvoker: + if self._tool_invoker is None: + from astrbot.core.tool_execution.infrastructure.invoker import ( + LLMToolInvoker, + ) + + self._tool_invoker = LLMToolInvoker() + return self._tool_invoker + + async def execute( + self, tool: Any, run_context: Any, **tool_args + ) -> AsyncGenerator[mcp.types.CallToolResult, None]: + """执行工具""" + handler, method_name = self.method_resolver.resolve(tool) + + timeout_enabled = self._should_enable_timeout( + run_context.tool_call_timeout, tool.name + ) + + async for result in self._execute_with_timeout( + tool, run_context, handler, method_name, timeout_enabled, **tool_args + ): + yield result + + def _should_enable_timeout(self, timeout: float, tool_name: str) -> bool: + """判断是否启用超时""" + return timeout > 0 and tool_name not in BACKGROUND_TOOL_NAMES + + async def _execute_with_timeout( + self, tool, run_context, handler, method_name, timeout_enabled, **tool_args + ) -> AsyncGenerator[mcp.types.CallToolResult, None]: + """带超时控制的执行""" + from astrbot.core.tool_execution.infrastructure.handler import ResultProcessor + + wrapper = self.tool_invoker.invoke( + context=run_context, + handler=handler, + method_name=method_name, + **tool_args, + ) + + # 创建带上下文的结果处理器 + result_processor = ResultProcessor(run_context) + + while True: + try: + if timeout_enabled: + resp = await self.timeout_strategy.execute( + anext(wrapper), run_context.tool_call_timeout + ) + else: + resp = await anext(wrapper) + + processed = await result_processor.process(resp) + if processed: + yield processed + + except asyncio.TimeoutError: + ctx = self._build_timeout_context(tool, run_context, handler, tool_args) + result = await self.timeout_handler.handle_timeout(ctx) + yield result + return + except StopAsyncIteration: + break + + def _build_timeout_context(self, tool, run_context, handler, tool_args) -> dict: + """构建超时上下文""" + event = run_context.context.event + return { + "tool_name": tool.name, + "tool_args": tool_args, + "session_id": event.unified_msg_origin, + "handler": handler, + "event": event, + "event_queue": run_context.context.context.get_event_queue(), + } diff --git a/astrbot/core/tool_execution/domain/__init__.py b/astrbot/core/tool_execution/domain/__init__.py new file mode 100644 index 000000000..b34f0f969 --- /dev/null +++ b/astrbot/core/tool_execution/domain/__init__.py @@ -0,0 +1,14 @@ +"""Domain Layer + +领域层,定义核心类型,无外部依赖。 +""" + +from .execution_result import ExecutionResult, ExecutionStatus +from .tool_types import ExecutionMode, ToolType + +__all__ = [ + "ToolType", + "ExecutionMode", + "ExecutionResult", + "ExecutionStatus", +] diff --git a/astrbot/core/tool_execution/domain/config.py b/astrbot/core/tool_execution/domain/config.py new file mode 100644 index 000000000..d1884202d --- /dev/null +++ b/astrbot/core/tool_execution/domain/config.py @@ -0,0 +1,43 @@ +"""配置定义 + +集中管理工具执行系统的配置常量。 +""" + +from dataclasses import dataclass + +# 后台任务管理工具名称(不应用超时) +BACKGROUND_TOOL_NAMES = frozenset( + { + "wait_tool_result", + "get_tool_output", + "stop_tool", + "list_running_tools", + } +) + + +@dataclass(frozen=True) +class BackgroundToolConfig: + """后台工具执行配置 + + 所有配置项的默认值集中定义在此,便于统一管理。 + """ + + # 清理间隔(秒) + cleanup_interval_seconds: int = 600 + + # 已完成任务保留时间(秒) + task_max_age_seconds: int = 3600 + + # 后台任务默认超时(秒) + default_timeout_seconds: int = 600 + + # 错误预览最大长度 + error_preview_max_length: int = 500 + + # 输出日志默认行数 + default_output_lines: int = 50 + + +# 默认配置实例 +DEFAULT_CONFIG = BackgroundToolConfig() diff --git a/astrbot/core/tool_execution/domain/execution_result.py b/astrbot/core/tool_execution/domain/execution_result.py new file mode 100644 index 000000000..70d2a3297 --- /dev/null +++ b/astrbot/core/tool_execution/domain/execution_result.py @@ -0,0 +1,37 @@ +"""执行结果类型定义""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class ExecutionStatus(Enum): + """执行状态枚举""" + + SUCCESS = "success" + FAILED = "failed" + TIMEOUT = "timeout" + BACKGROUND = "background" + + +@dataclass +class ExecutionResult: + """执行结果""" + + status: ExecutionStatus + value: Any = None + error: str | None = None + task_id: str | None = None + metadata: dict = field(default_factory=dict) + + +@dataclass +class TimeoutContext: + """超时上下文""" + + tool_name: str + tool_args: dict + session_id: str + handler: Any + event: Any = None + event_queue: Any = None diff --git a/astrbot/core/tool_execution/domain/tool_types.py b/astrbot/core/tool_execution/domain/tool_types.py new file mode 100644 index 000000000..0843f6947 --- /dev/null +++ b/astrbot/core/tool_execution/domain/tool_types.py @@ -0,0 +1,19 @@ +"""工具类型定义""" + +from enum import Enum + + +class ToolType(Enum): + """工具类型枚举""" + + LOCAL = "local" + MCP = "mcp" + HANDOFF = "handoff" + + +class ExecutionMode(Enum): + """执行模式枚举""" + + SYNC = "sync" + ASYNC = "async" + BACKGROUND = "background" diff --git a/astrbot/core/tool_execution/errors.py b/astrbot/core/tool_execution/errors.py new file mode 100644 index 000000000..0338b81a7 --- /dev/null +++ b/astrbot/core/tool_execution/errors.py @@ -0,0 +1,34 @@ +"""领域错误类型 + +定义工具执行相关的错误类型。 +""" + + +class ToolExecutionError(Exception): + """工具执行基础错误""" + + pass + + +class MethodResolutionError(ToolExecutionError): + """方法解析错误""" + + pass + + +class ParameterValidationError(ToolExecutionError): + """参数验证错误""" + + pass + + +class TimeoutError(ToolExecutionError): + """超时错误""" + + pass + + +class BackgroundTaskError(ToolExecutionError): + """后台任务错误""" + + pass diff --git a/astrbot/core/tool_execution/infrastructure/__init__.py b/astrbot/core/tool_execution/infrastructure/__init__.py new file mode 100644 index 000000000..42b642ffa --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/__init__.py @@ -0,0 +1,4 @@ +"""Infrastructure Layer + +基础设施层,包含具体实现。 +""" diff --git a/astrbot/core/tool_execution/infrastructure/background/__init__.py b/astrbot/core/tool_execution/infrastructure/background/__init__.py new file mode 100644 index 000000000..963c0ee01 --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/background/__init__.py @@ -0,0 +1,6 @@ +"""Background Module""" + +from .completion_signal import CompletionSignal +from .event_factory import EventFactory + +__all__ = ["CompletionSignal", "EventFactory"] diff --git a/astrbot/core/tool_execution/infrastructure/background/completion_signal.py b/astrbot/core/tool_execution/infrastructure/background/completion_signal.py new file mode 100644 index 000000000..78b186fd0 --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/background/completion_signal.py @@ -0,0 +1,31 @@ +"""完成信号 + +替代轮询的事件驱动等待机制。 +""" + +import asyncio + +from astrbot.core.tool_execution.interfaces import ICompletionSignal + + +class CompletionSignal(ICompletionSignal): + """完成信号实现""" + + def __init__(self): + self._event = asyncio.Event() + + async def wait(self, timeout: float | None = None) -> bool: + """等待信号""" + try: + await asyncio.wait_for(self._event.wait(), timeout) + return True + except asyncio.TimeoutError: + return False + + def set(self) -> None: + """设置信号""" + self._event.set() + + def clear(self) -> None: + """清除信号""" + self._event.clear() diff --git a/astrbot/core/tool_execution/infrastructure/background/event_factory.py b/astrbot/core/tool_execution/infrastructure/background/event_factory.py new file mode 100644 index 000000000..edc16b878 --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/background/event_factory.py @@ -0,0 +1,49 @@ +"""事件工厂 + +构建回调事件。 +""" + +import copy +from typing import Any + +from astrbot.core.tool_execution.interfaces import ICallbackEventBuilder + + +class EventFactory(ICallbackEventBuilder): + """事件工厂实现""" + + def build(self, task: Any, original_event: Any) -> Any: + """构建回调事件""" + if not original_event: + return None + + notification = self._build_notification(task) + return self._create_event(original_event, task, notification) + + def _build_notification(self, task: Any) -> str: + """构建通知消息""" + status = self._get_status_text(task.status) + msg = f"[Background Task]\nID: {task.task_id}\nTool: {task.tool_name}\nStatus: {status}" + if task.result: + msg += f"\nResult: {task.result}" + return msg + + def _get_status_text(self, status) -> str: + """获取状态文本""" + from astrbot.core.background_tool import TaskStatus + + mapping = { + TaskStatus.COMPLETED: "completed", + TaskStatus.FAILED: "failed", + TaskStatus.CANCELLED: "cancelled", + } + return mapping.get(status, "unknown") + + def _create_event(self, original: Any, task: Any, msg: str) -> Any: + """创建新事件""" + new_event = copy.copy(original) + new_event.message_str = msg + new_event.is_wake = True + new_event.set_extra("is_background_task_callback", True) + new_event.set_extra("background_task_id", task.task_id) + return new_event diff --git a/astrbot/core/tool_execution/infrastructure/callback/__init__.py b/astrbot/core/tool_execution/infrastructure/callback/__init__.py new file mode 100644 index 000000000..1a462ca2b --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/callback/__init__.py @@ -0,0 +1,5 @@ +"""Callback Module""" + +from .callback_event_builder import CallbackEventBuilder + +__all__ = ["CallbackEventBuilder"] diff --git a/astrbot/core/tool_execution/infrastructure/callback/callback_event_builder.py b/astrbot/core/tool_execution/infrastructure/callback/callback_event_builder.py new file mode 100644 index 000000000..2c9840096 --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/callback/callback_event_builder.py @@ -0,0 +1,15 @@ +"""回调事件构建器""" + +from typing import Any + +from astrbot.core.tool_execution.interfaces import ICallbackEventBuilder + + +class CallbackEventBuilder(ICallbackEventBuilder): + """回调事件构建器实现""" + + def build(self, task: Any, original_event: Any) -> Any: + """构建回调事件""" + from astrbot.core.tool_execution.infrastructure.background import EventFactory + + return EventFactory().build(task, original_event) diff --git a/astrbot/core/tool_execution/infrastructure/handler/__init__.py b/astrbot/core/tool_execution/infrastructure/handler/__init__.py new file mode 100644 index 000000000..f2e069465 --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/handler/__init__.py @@ -0,0 +1,7 @@ +"""Handler Module""" + +from .method_resolver import MethodResolver +from .parameter_validator import ParameterValidator +from .result_processor import ResultProcessor + +__all__ = ["MethodResolver", "ParameterValidator", "ResultProcessor"] diff --git a/astrbot/core/tool_execution/infrastructure/handler/method_resolver.py b/astrbot/core/tool_execution/infrastructure/handler/method_resolver.py new file mode 100644 index 000000000..841f0e23a --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/handler/method_resolver.py @@ -0,0 +1,41 @@ +"""方法解析器 + +从工具对象中解析出可调用的方法。 +""" + +from collections.abc import Callable +from typing import Any + +from astrbot.core.tool_execution.errors import MethodResolutionError +from astrbot.core.tool_execution.interfaces import IMethodResolver + + +class MethodResolver(IMethodResolver): + """方法解析器实现""" + + def resolve(self, tool: Any) -> tuple[Callable, str]: + """解析工具的可调用方法""" + # 检查是否重写了call方法 + is_override_call = self._check_override_call(tool) + + # 按优先级解析方法 + if tool.handler: + return tool.handler, "decorator_handler" + elif is_override_call: + return tool.call, "call" + elif hasattr(tool, "run"): + return getattr(tool, "run"), "run" + + raise MethodResolutionError( + "Tool must have a valid handler or override 'run' method." + ) + + def _check_override_call(self, tool: Any) -> bool: + """检查工具是否重写了call方法""" + from astrbot.core.agent.tool import FunctionTool + + for ty in type(tool).mro(): + if "call" in ty.__dict__: + if ty.__dict__["call"] is not FunctionTool.call: + return True + return False diff --git a/astrbot/core/tool_execution/infrastructure/handler/parameter_validator.py b/astrbot/core/tool_execution/infrastructure/handler/parameter_validator.py new file mode 100644 index 000000000..44edc2a09 --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/handler/parameter_validator.py @@ -0,0 +1,41 @@ +"""参数验证器 + +验证工具调用参数。 +""" + +import inspect +from collections.abc import Callable + +from astrbot.core.tool_execution.errors import ParameterValidationError +from astrbot.core.tool_execution.interfaces import IParameterValidator + + +class ParameterValidator(IParameterValidator): + """参数验证器实现""" + + def validate(self, handler: Callable, params: dict) -> dict: + """验证参数""" + try: + sig = inspect.signature(handler) + # 跳过第一个参数(event或context) + bound = sig.bind_partial(None, **params) + return dict(bound.arguments) + except TypeError as e: + raise ParameterValidationError(self._build_error_message(handler, e)) + + def _build_error_message(self, handler: Callable, error: Exception) -> str: + """构建错误消息""" + try: + sig = inspect.signature(handler) + params = list(sig.parameters.values())[1:] # 跳过第一个参数 + param_strs = [self._format_param(p) for p in params] + return f"Parameter mismatch: {', '.join(param_strs)}" + except Exception: + return str(error) + + def _format_param(self, param: inspect.Parameter) -> str: + """格式化参数""" + s = param.name + if param.annotation != inspect.Parameter.empty: + s += f": {param.annotation}" + return s diff --git a/astrbot/core/tool_execution/infrastructure/handler/result_processor.py b/astrbot/core/tool_execution/infrastructure/handler/result_processor.py new file mode 100644 index 000000000..7c0be4018 --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/handler/result_processor.py @@ -0,0 +1,70 @@ +"""结果处理器 + +处理工具执行结果。 +""" + +from typing import Any + +import mcp.types + +from astrbot import logger +from astrbot.core.tool_execution.interfaces import IResultProcessor + + +class ResultProcessor(IResultProcessor): + """结果处理器实现""" + + def __init__(self, run_context: Any = None): + self._run_context = run_context + + async def process(self, result: Any) -> mcp.types.CallToolResult | None: + """处理执行结果 + + Args: + result: 工具执行返回值 + + Returns: + 处理后的 CallToolResult,或 None 表示无需返回 + """ + if result is not None: + return self._wrap_result(result) + + # result 为 None 时,检查是否需要直接发送消息给用户 + await self._send_direct_message() + return None + + def _wrap_result(self, result: Any) -> mcp.types.CallToolResult: + """包装结果为 CallToolResult""" + if isinstance(result, mcp.types.CallToolResult): + return result + + text_content = mcp.types.TextContent( + type="text", + text=str(result), + ) + return mcp.types.CallToolResult(content=[text_content]) + + async def _send_direct_message(self) -> None: + """处理工具直接发送消息给用户的情况""" + if self._run_context is None: + return + + event = self._run_context.context.event + if not event: + return + + res = event.get_result() + if not res or not res.chain: + return + + try: + from astrbot.core.message.message_event_result import MessageChain + + await event.send( + MessageChain( + chain=res.chain, + type="tool_direct_result", + ) + ) + except Exception as e: + logger.error(f"Tool 直接发送消息失败: {e}", exc_info=True) diff --git a/astrbot/core/tool_execution/infrastructure/invoker/__init__.py b/astrbot/core/tool_execution/infrastructure/invoker/__init__.py new file mode 100644 index 000000000..21c7c36f0 --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/invoker/__init__.py @@ -0,0 +1,5 @@ +"""工具调用器模块""" + +from .tool_invoker import LLMToolInvoker + +__all__ = ["LLMToolInvoker"] diff --git a/astrbot/core/tool_execution/infrastructure/invoker/tool_invoker.py b/astrbot/core/tool_execution/infrastructure/invoker/tool_invoker.py new file mode 100644 index 000000000..c87066aa1 --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/invoker/tool_invoker.py @@ -0,0 +1,39 @@ +"""工具调用器 + +包装LLM工具调用逻辑,实现IToolInvoker接口。 +""" + +from collections.abc import Callable +from typing import Any + +from astrbot.core.tool_execution.interfaces import IToolInvoker + + +class LLMToolInvoker(IToolInvoker): + """LLM工具调用器 + + 包装 call_local_llm_tool,隔离应用层与具体实现的依赖。 + """ + + def invoke( + self, context: Any, handler: Callable, method_name: str, **kwargs + ) -> Any: + """调用工具 + + Args: + context: 运行上下文 + handler: 处理函数 + method_name: 方法名称 + **kwargs: 工具参数 + + Returns: + 异步生成器 + """ + from astrbot.core.astr_agent_tool_exec import call_local_llm_tool + + return call_local_llm_tool( + context=context, + handler=handler, + method_name=method_name, + **kwargs, + ) diff --git a/astrbot/core/tool_execution/infrastructure/timeout/__init__.py b/astrbot/core/tool_execution/infrastructure/timeout/__init__.py new file mode 100644 index 000000000..aa987aa96 --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/timeout/__init__.py @@ -0,0 +1,6 @@ +"""Timeout Module""" + +from .background_handler import BackgroundHandler +from .timeout_strategy import NoTimeoutStrategy, TimeoutStrategy + +__all__ = ["TimeoutStrategy", "NoTimeoutStrategy", "BackgroundHandler"] diff --git a/astrbot/core/tool_execution/infrastructure/timeout/background_handler.py b/astrbot/core/tool_execution/infrastructure/timeout/background_handler.py new file mode 100644 index 000000000..2f214e79f --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/timeout/background_handler.py @@ -0,0 +1,52 @@ +"""后台处理器 + +处理超时后转后台执行的逻辑。 +""" + +from typing import Any + +import mcp.types + +from astrbot.core.tool_execution.interfaces import ITimeoutHandler + + +class BackgroundHandler(ITimeoutHandler): + """后台处理器实现""" + + def __init__(self, bg_manager=None): + self._bg_manager = bg_manager + + @property + def bg_manager(self): + if self._bg_manager is None: + from astrbot.core.background_tool import BackgroundToolManager + + self._bg_manager = BackgroundToolManager() + return self._bg_manager + + async def handle_timeout(self, context: Any) -> mcp.types.CallToolResult: + """处理超时,转后台执行""" + task_id = await self.bg_manager.submit_task( + tool_name=context["tool_name"], + tool_args=context["tool_args"], + session_id=context["session_id"], + handler=context["handler"], + wait=False, + event=context.get("event"), + event_queue=context.get("event_queue"), + ) + + return self._build_notification(context["tool_name"], task_id) + + def _build_notification( + self, tool_name: str, task_id: str + ) -> mcp.types.CallToolResult: + """构建后台执行通知""" + msg = ( + f"Tool '{tool_name}' switched to background.\n" + f"Task ID: {task_id}\n" + f"Use get_tool_output/wait_tool_result to check." + ) + return mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text=msg)] + ) diff --git a/astrbot/core/tool_execution/infrastructure/timeout/timeout_strategy.py b/astrbot/core/tool_execution/infrastructure/timeout/timeout_strategy.py new file mode 100644 index 000000000..e9911cbbc --- /dev/null +++ b/astrbot/core/tool_execution/infrastructure/timeout/timeout_strategy.py @@ -0,0 +1,26 @@ +"""超时策略 + +实现超时控制的策略模式。 +""" + +import asyncio +from collections.abc import Coroutine +from typing import Any + +from astrbot.core.tool_execution.interfaces import ITimeoutStrategy + + +class TimeoutStrategy(ITimeoutStrategy): + """标准超时策略""" + + async def execute(self, coro: Coroutine, timeout: float) -> Any: + """执行带超时的协程""" + return await asyncio.wait_for(coro, timeout=timeout) + + +class NoTimeoutStrategy(ITimeoutStrategy): + """无超时策略""" + + async def execute(self, coro: Coroutine, timeout: float) -> Any: + """直接执行协程,忽略超时参数""" + return await coro diff --git a/astrbot/core/tool_execution/interfaces.py b/astrbot/core/tool_execution/interfaces.py new file mode 100644 index 000000000..9b8f715e1 --- /dev/null +++ b/astrbot/core/tool_execution/interfaces.py @@ -0,0 +1,194 @@ +"""核心接口定义 + +定义工具执行模块的核心接口,遵循依赖倒置原则。 +""" + +from abc import ABC, abstractmethod +from collections.abc import Callable, Coroutine +from typing import Any + + +class IMethodResolver(ABC): + """方法解析器接口 + + 负责从工具对象中解析出可调用的方法。 + """ + + @abstractmethod + def resolve(self, tool: Any) -> tuple[Callable, str]: + """解析工具的可调用方法 + + Args: + tool: 工具对象 + + Returns: + (handler, method_name) 元组 + + Raises: + MethodResolutionError: 无法解析方法时 + """ + ... + + +class IParameterValidator(ABC): + """参数验证器接口 + + 负责验证工具调用参数。 + """ + + @abstractmethod + def validate(self, handler: Callable, params: dict) -> dict: + """验证参数 + + Args: + handler: 处理函数 + params: 参数字典 + + Returns: + 验证后的参数字典 + + Raises: + ParameterValidationError: 参数验证失败时 + """ + ... + + +class IResultProcessor(ABC): + """结果处理器接口 + + 负责处理工具执行结果。 + """ + + @abstractmethod + async def process(self, result: Any) -> Any: + """处理执行结果 + + Args: + result: 原始执行结果 + + Returns: + 处理后的结果 + """ + ... + + +class ITimeoutStrategy(ABC): + """超时策略接口 + + 负责执行带超时控制的协程。 + """ + + @abstractmethod + async def execute(self, coro: Coroutine, timeout: float) -> Any: + """执行协程 + + Args: + coro: 协程对象 + timeout: 超时时间(秒) + + Returns: + 执行结果 + + Raises: + TimeoutError: 超时时 + """ + ... + + +class ITimeoutHandler(ABC): + """超时处理器接口 + + 负责处理超时后的逻辑。 + """ + + @abstractmethod + async def handle_timeout(self, context: Any) -> Any: + """处理超时 + + Args: + context: 执行上下文 + + Returns: + 处理结果 + """ + ... + + +class IBackgroundTaskManager(ABC): + """后台任务管理器接口""" + + @abstractmethod + async def submit_task( + self, + tool_name: str, + tool_args: dict, + session_id: str, + handler: Callable, + **kwargs, + ) -> str: + """提交后台任务 + + Returns: + 任务ID + """ + ... + + @abstractmethod + async def wait_task(self, task_id: str, timeout: float | None = None) -> Any: + """等待任务完成""" + ... + + +class ICompletionSignal(ABC): + """任务完成信号接口(替代轮询)""" + + @abstractmethod + async def wait(self, timeout: float | None = None) -> bool: + """等待信号""" + ... + + @abstractmethod + def set(self) -> None: + """设置信号""" + ... + + +class ICallbackEventBuilder(ABC): + """回调事件构建器接口""" + + @abstractmethod + def build(self, task: Any, original_event: Any) -> Any: + """构建回调事件 + + Args: + task: 后台任务 + original_event: 原始事件 + + Returns: + 新的事件对象 + """ + ... + + +class IToolInvoker(ABC): + """工具调用器接口 + + 抽象LLM工具调用逻辑,避免应用层直接依赖具体实现。 + """ + + @abstractmethod + def invoke( + self, context: Any, handler: Callable, method_name: str, **kwargs + ) -> Any: + """调用工具 + + Args: + context: 运行上下文 + handler: 处理函数 + method_name: 方法名称 + **kwargs: 工具参数 + + Returns: + 异步生成器 + """ + ... diff --git a/astrbot/core/tool_execution/utils/__init__.py b/astrbot/core/tool_execution/utils/__init__.py new file mode 100644 index 000000000..798325fff --- /dev/null +++ b/astrbot/core/tool_execution/utils/__init__.py @@ -0,0 +1,23 @@ +"""Utils Module""" + +from .decorators import log_execution, with_timeout +from .rwlock import RWLock +from .sanitizer import sanitize_for_log, sanitize_params +from .validators import ( + ValidationError, + validate_positive_int, + validate_session_id, + validate_task_id, +) + +__all__ = [ + "log_execution", + "with_timeout", + "sanitize_params", + "sanitize_for_log", + "RWLock", + "validate_task_id", + "validate_session_id", + "validate_positive_int", + "ValidationError", +] diff --git a/astrbot/core/tool_execution/utils/decorators.py b/astrbot/core/tool_execution/utils/decorators.py new file mode 100644 index 000000000..dcfc22074 --- /dev/null +++ b/astrbot/core/tool_execution/utils/decorators.py @@ -0,0 +1,38 @@ +"""AOP装饰器""" + +import asyncio +import functools +import time +from collections.abc import Callable + + +def log_execution(func: Callable) -> Callable: + """日志装饰器""" + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + from astrbot import logger + + start = time.time() + try: + result = await func(*args, **kwargs) + logger.debug(f"{func.__name__} took {time.time() - start:.2f}s") + return result + except Exception as e: + logger.error(f"{func.__name__} failed: {e}") + raise + + return wrapper + + +def with_timeout(timeout: float): + """超时装饰器""" + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=timeout) + + return wrapper + + return decorator diff --git a/astrbot/core/tool_execution/utils/rwlock.py b/astrbot/core/tool_execution/utils/rwlock.py new file mode 100644 index 000000000..7aff167e3 --- /dev/null +++ b/astrbot/core/tool_execution/utils/rwlock.py @@ -0,0 +1,68 @@ +"""读写锁实现 + +优化读多写少场景的并发性能。 +""" + +import threading +from collections.abc import Generator +from contextlib import contextmanager + + +class RWLock: + """读写锁 + + 允许多个读取者同时访问,但写入者独占访问。 + 适用于读多写少的场景,如任务注册表和输出缓冲区。 + + 用法: + lock = RWLock() + + # 读取(允许并发) + with lock.read(): + data = some_dict.get(key) + + # 写入(独占) + with lock.write(): + some_dict[key] = value + """ + + def __init__(self): + self._read_ready = threading.Condition(threading.Lock()) + self._readers = 0 + self._writers_waiting = 0 + self._writer_active = False + + @contextmanager + def read(self) -> Generator[None, None, None]: + """获取读锁""" + with self._read_ready: + # 等待直到没有活跃的写入者和等待的写入者 + while self._writer_active or self._writers_waiting > 0: + self._read_ready.wait() + self._readers += 1 + + try: + yield + finally: + with self._read_ready: + self._readers -= 1 + if self._readers == 0: + self._read_ready.notify_all() + + @contextmanager + def write(self) -> Generator[None, None, None]: + """获取写锁""" + with self._read_ready: + self._writers_waiting += 1 + # 等待直到没有读取者和活跃的写入者 + while self._readers > 0 or self._writer_active: + self._read_ready.wait() + self._writers_waiting -= 1 + self._writer_active = True + + try: + yield + finally: + with self._read_ready: + self._writer_active = False + self._read_ready.notify_all() diff --git a/astrbot/core/tool_execution/utils/sanitizer.py b/astrbot/core/tool_execution/utils/sanitizer.py new file mode 100644 index 000000000..983611274 --- /dev/null +++ b/astrbot/core/tool_execution/utils/sanitizer.py @@ -0,0 +1,126 @@ +"""日志脱敏工具 + +防止敏感信息泄露到日志中。 +""" + +import re +from typing import Any + +# 敏感参数名称(不区分大小写) +SENSITIVE_PARAM_NAMES = frozenset( + { + "password", + "passwd", + "pwd", + "token", + "api_key", + "apikey", + "secret", + "credential", + "credentials", + "auth", + "authorization", + "access_token", + "refresh_token", + "private_key", + "privatekey", + "secret_key", + "secretkey", + "key", # 通用key + "session_id", + "cookie", + "cookies", + } +) + +# 用于替换的掩码 +MASK = "***REDACTED***" + +# 敏感值模式(用于检测值中的敏感内容) +SENSITIVE_VALUE_PATTERNS = [ + re.compile(r"(?i)(bearer\s+)[a-z0-9\-_.]+", re.IGNORECASE), # Bearer token + re.compile( + r"(?i)(api[_-]?key[=:]\s*)[a-z0-9\-_.]+", re.IGNORECASE + ), # API key in value + re.compile(r"sk-[a-zA-Z0-9]{20,}"), # OpenAI-style API key + re.compile(r"ghp_[a-zA-Z0-9]{36,}"), # GitHub token + re.compile(r"gho_[a-zA-Z0-9]{36,}"), # GitHub OAuth token +] + + +def _is_sensitive_key(key: str) -> bool: + """检查键名是否为敏感参数""" + key_lower = key.lower() + return any(sensitive in key_lower for sensitive in SENSITIVE_PARAM_NAMES) + + +def _mask_sensitive_value(value: str) -> str: + """对值中的敏感内容进行掩码处理""" + result = value + for pattern in SENSITIVE_VALUE_PATTERNS: + result = pattern.sub( + lambda m: m.group(1) + MASK if m.lastindex else MASK, result + ) + return result + + +def sanitize_params( + params: dict[str, Any], max_value_length: int = 100 +) -> dict[str, Any]: + """脱敏参数字典 + + Args: + params: 原始参数字典 + max_value_length: 值的最大显示长度 + + Returns: + 脱敏后的参数字典(副本) + """ + if not params: + return {} + + sanitized = {} + for key, value in params.items(): + # 检查键名是否敏感 + if _is_sensitive_key(key): + sanitized[key] = MASK + continue + + # 处理值 + if isinstance(value, str): + # 对字符串值进行模式检查 + masked = _mask_sensitive_value(value) + # 截断过长的值 + if len(masked) > max_value_length: + masked = masked[:max_value_length] + "...(truncated)" + sanitized[key] = masked + elif isinstance(value, dict): + # 递归处理嵌套字典 + sanitized[key] = sanitize_params(value, max_value_length) + elif isinstance(value, (list, tuple)): + # 处理列表/元组 + sanitized[key] = [ + sanitize_params(v, max_value_length) + if isinstance(v, dict) + else _mask_sensitive_value(str(v)) + if isinstance(v, str) + else v + for v in value + ] + else: + sanitized[key] = value + + return sanitized + + +def sanitize_for_log(params: dict[str, Any]) -> str: + """将参数字典转为安全的日志字符串 + + Args: + params: 原始参数字典 + + Returns: + 脱敏后的字符串表示 + """ + sanitized = sanitize_params(params) + return str(sanitized) diff --git a/astrbot/core/tool_execution/utils/validators.py b/astrbot/core/tool_execution/utils/validators.py new file mode 100644 index 000000000..3f1b77421 --- /dev/null +++ b/astrbot/core/tool_execution/utils/validators.py @@ -0,0 +1,102 @@ +"""输入验证工具 + +验证用户输入的参数,防止注入攻击。 +""" + +import re +from typing import Any + +# 有效的 task_id/session_id 模式(只允许字母、数字、下划线、连字符) +VALID_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_\-]{1,128}$") + + +class ValidationError(ValueError): + """验证错误""" + + pass + + +def validate_task_id(task_id: Any) -> str: + """验证任务ID + + Args: + task_id: 任务ID + + Returns: + 验证后的任务ID + + Raises: + ValidationError: 验证失败 + """ + if not isinstance(task_id, str): + raise ValidationError(f"task_id must be string, got {type(task_id).__name__}") + + if not task_id: + raise ValidationError("task_id cannot be empty") + + if not VALID_ID_PATTERN.match(task_id): + raise ValidationError( + "Invalid task_id format: must be 1-128 alphanumeric characters, " + "underscores, or hyphens" + ) + + return task_id + + +def validate_session_id(session_id: Any) -> str: + """验证会话ID + + Args: + session_id: 会话ID + + Returns: + 验证后的会话ID + + Raises: + ValidationError: 验证失败 + """ + if not isinstance(session_id, str): + raise ValidationError( + f"session_id must be string, got {type(session_id).__name__}" + ) + + if not session_id: + raise ValidationError("session_id cannot be empty") + + # session_id 允许更宽松的格式(可能包含特殊字符如 : / 等) + if len(session_id) > 256: + raise ValidationError("session_id too long (max 256 characters)") + + # 检查是否包含危险字符 + dangerous_chars = ["\x00", "\n", "\r"] + for char in dangerous_chars: + if char in session_id: + raise ValidationError("session_id contains invalid characters") + + return session_id + + +def validate_positive_int(value: Any, name: str, max_value: int = 10000) -> int: + """验证正整数 + + Args: + value: 值 + name: 参数名称 + max_value: 最大允许值 + + Returns: + 验证后的整数 + + Raises: + ValidationError: 验证失败 + """ + if not isinstance(value, int): + raise ValidationError(f"{name} must be integer, got {type(value).__name__}") + + if value <= 0: + raise ValidationError(f"{name} must be positive") + + if value > max_value: + raise ValidationError(f"{name} too large (max {max_value})") + + return value diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 9b4c4e304..c06d6ff04 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -110,7 +110,7 @@ }, "websearch_baidu_app_builder_key": { "description": "Baidu Qianfan Smart Cloud APP Builder API Key", - "hint": "Reference: [https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)" + "hint": "Reference: https://console.bce.baidu.com/iam/#/iam/apikey/list" }, "web_search_link": { "description": "Display Source Citations" @@ -133,15 +133,15 @@ } } }, - "agent_computer_use": { - "description": "Agent Computer Use", - "hint": "Allows the AstrBot to access and use your computer or an sandbox environment to perform more complex tasks. See [Sandbox Mode](https://docs.astrbot.app/use/astrbot-agent-sandbox.html), [Skills](https://docs.astrbot.app/use/skills.html)", + "sandbox": { + "description": "Agent Sandbox Env(Beta)", + "hint": "https://docs.astrbot.app/en/use/astrbot-agent-sandbox.html", "provider_settings": { - "computer_use_runtime": { - "description": "Computer Use Runtime", - "hint": "sandbox means running in a sandbox environment, local means running in a local environment, none means disabling Computer Use. If skills are uploaded, choosing none will cause them to not be usable by the Agent." - }, "sandbox": { + "enable": { + "description": "Enable Sandbox Env", + "hint": "When enabled, Agent can use tools and resources in the sandbox environment, such as Python tool, Shell, etc." + }, "booter": { "description": "Sandbox Environment Driver" }, @@ -164,20 +164,18 @@ } } }, - "proactive_capability": { - "description": "Proactive Agent", - "hint": "AstrBot will wake up, run your tasks, and deliver the results to you. See [Proactive Agent](https://docs.astrbot.app/en/use/proactive-agent.html)", + "skills": { + "description": "Skills", "provider_settings": { - "proactive_capability": { - "add_cron_tools": { - "description": "Enable", - "hint": "When enabled, related tools will be passed to the Agent to implement proactive Agent capabilities. You can tell AstrBot what to do at a future time, and it will be triggered on schedule to execute the task, and report the result back to you." + "skills": { + "runtime": { + "description": "Skill Runtime", + "hint": "Select the runtime for Skills. Sandbox runtime requires sandbox to be enabled first. In local mode, the Agent CAN FULLY ACCESS the runtime environment through Shell and Python tools, but non-admin users will be automatically prohibited from using it to ensure security." } } } }, "truncate_and_compress": { - "hint": "[Context Management](https://docs.astrbot.app/en/use/context-compress.html)", "description": "Context Management Strategy", "provider_settings": { "max_context_length": { @@ -249,6 +247,10 @@ "tool_call_timeout": { "description": "Tool Call Timeout (seconds)" }, + "background_task_wait_timeout": { + "description": "Background Task Timeout (seconds)", + "hint": "Maximum execution time for background tasks, task will be terminated after timeout" + }, "tool_schema_mode": { "description": "Tool Schema Mode", "hint": "Skills-like sends name/description first and re-queries for parameters; Full sends the complete schema in one step.", @@ -428,7 +430,7 @@ }, "emojis": { "description": "Emoji List (Lark Emoji Enum Names)", - "hint": "Emoji enum names reference: [https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce](https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce)" + "hint": "Emoji enum names reference: https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce" } } }, @@ -439,7 +441,7 @@ }, "emojis": { "description": "Emoji List (Unicode)", - "hint": "Telegram only supports a fixed reaction set, reference: [https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9](https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9)" + "hint": "Telegram only supports a fixed reaction set, reference: https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9" } } } @@ -596,15 +598,15 @@ }, "pypi_index_url": { "description": "PyPI Repository URL", - "hint": "PyPI repository URL for installing Python dependencies. Defaults to [https://mirrors.aliyun.com/pypi/simple/](https://mirrors.aliyun.com/pypi/simple/)" + "hint": "PyPI repository URL for installing Python dependencies. Defaults to https://mirrors.aliyun.com/pypi/simple/" }, "callback_api_base": { "description": "Externally Accessible Callback API Address", - "hint": "External services may access AstrBot's backend through callback links generated by AstrBot (such as file download links). Since AstrBot cannot automatically determine the externally accessible host address in the deployment environment, this configuration item is needed to explicitly specify how external services should access AstrBot's address. Examples: [http://localhost:6185](http://localhost:6185), [https://example.com](https://example.com), etc." + "hint": "External services may access AstrBot's backend through callback links generated by AstrBot (such as file download links). Since AstrBot cannot automatically determine the externally accessible host address in the deployment environment, this configuration item is needed to explicitly specify how external services should access AstrBot's address. Examples: http://localhost:6185, https://example.com, etc." }, "timezone": { "description": "Timezone", - "hint": "Timezone setting. Please enter an IANA timezone name, such as Asia/Shanghai. Uses system default timezone when empty. For all timezones, see: [https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab](https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab)" + "hint": "Timezone setting. Please enter an IANA timezone name, such as Asia/Shanghai. Uses system default timezone when empty. For all timezones, see: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab" }, "http_proxy": { "description": "HTTP Proxy", diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 6620f1cb3..cfd153f20 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -70,7 +70,6 @@ }, "persona": { "description": "人格", - "hint": "赋予 AstrBot 人格。", "provider_settings": { "default_personality": { "description": "默认采用的人格" @@ -79,7 +78,6 @@ }, "knowledgebase": { "description": "知识库", - "hint": "AstrBot 的 “外置大脑”。", "kb_names": { "description": "知识库列表", "hint": "支持多选" @@ -99,7 +97,6 @@ }, "websearch": { "description": "网页搜索", - "hint": "让 AstrBot 能够访问互联网,获悉时讯。", "provider_settings": { "web_search": { "description": "启用网页搜索" @@ -113,7 +110,7 @@ }, "websearch_baidu_app_builder_key": { "description": "百度千帆智能云 APP Builder API Key", - "hint": "参考:[https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)" + "hint": "参考:https://console.bce.baidu.com/iam/#/iam/apikey/list" }, "web_search_link": { "description": "显示来源引用" @@ -136,15 +133,15 @@ } } }, - "agent_computer_use": { - "description": "使用电脑能力", - "hint": "让 AstrBot 访问和使用你的电脑或者隔离的沙盒环境,以执行更复杂的任务。详见: [沙盒模式](https://docs.astrbot.app/use/astrbot-agent-sandbox.html), [Skills](https://docs.astrbot.app/use/skills.html)。", + "sandbox": { + "description": "Agent 沙箱环境(Beta)", + "hint": "https://docs.astrbot.app/use/astrbot-agent-sandbox.html", "provider_settings": { - "computer_use_runtime": { - "description": "运行环境", - "hint": "sandbox 代表在沙箱环境中运行, local 代表在本地环境中运行, none 代表不启用。如果上传了 skills,选择 none 会导致其无法被 Agent 正常使用。" - }, "sandbox": { + "enable": { + "description": "启用沙箱环境", + "hint": "启用后,Agent 可以使用沙箱环境中的工具和资源,如 Python 代码执行、Shell 等。" + }, "booter": { "description": "沙箱环境驱动器" }, @@ -167,20 +164,18 @@ } } }, - "proactive_capability": { - "description": "主动型能力", - "hint": "让 AstrBot 能够在某一时刻自动唤醒,帮你完成任务。详见: [主动型 Agent](https://docs.astrbot.app/use/proactive-agent.html)。", + "skills": { + "description": "Skills", "provider_settings": { - "proactive_capability": { - "add_cron_tools": { - "description": "启用", - "hint": "启用后,将会传递给 Agent 相关工具来实现主动型 Agent。你可以告诉 AstrBot 未来某个时间要做的事情,它将被定时触发然后执行任务,然后将结果发送给你。" + "skills": { + "runtime": { + "description": "Skill Runtime", + "hint": "选择 Skills 运行环境。使用 sandbox 前需启用沙箱;local 模式下 Agent 可通过 Shell 和 Python 功能完全访问运行环境,非管理员将被自动禁止使用以保证安全。" } } } }, "truncate_and_compress": { - "hint": "AstrBot 如何管理工作记忆。详见: [上下文管理策略](https://docs.astrbot.app/use/context-compress.html)。", "description": "上下文管理策略", "provider_settings": { "max_context_length": { @@ -249,6 +244,10 @@ "tool_call_timeout": { "description": "工具调用超时时间(秒)" }, + "background_task_wait_timeout": { + "description": "后台任务超时时间(秒)", + "hint": "后台任务最长执行时间,超时后任务将被终止" + }, "tool_schema_mode": { "description": "工具调用模式", "hint": "skills-like 先下发工具名称与描述,再下发参数;full 一次性下发完整参数。", @@ -429,7 +428,7 @@ }, "emojis": { "description": "表情列表(飞书表情枚举名)", - "hint": "表情枚举名参考:[https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce](https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce)" + "hint": "表情枚举名参考:https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce" } } }, @@ -440,7 +439,7 @@ }, "emojis": { "description": "表情列表(Unicode)", - "hint": "Telegram 仅支持固定反应集合,参考:[https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9](https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9)" + "hint": "Telegram 仅支持固定反应集合,参考:https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9" } } } @@ -597,15 +596,15 @@ }, "pypi_index_url": { "description": "PyPI 软件仓库地址", - "hint": "安装 Python 依赖时请求的 PyPI 软件仓库地址。默认为 [https://mirrors.aliyun.com/pypi/simple/](https://mirrors.aliyun.com/pypi/simple/)" + "hint": "安装 Python 依赖时请求的 PyPI 软件仓库地址。默认为 https://mirrors.aliyun.com/pypi/simple/" }, "callback_api_base": { "description": "对外可达的回调接口地址", - "hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定外部服务如何访问 AstrBot 的地址。如 [http://localhost:6185](http://localhost:6185),[https://example.com](https://example.com) 等。" + "hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定外部服务如何访问 AstrBot 的地址。如 http://localhost:6185,https://example.com 等。" }, "timezone": { "description": "时区", - "hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: [https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab](https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab)" + "hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab" }, "http_proxy": { "description": "HTTP 代理", diff --git a/tests/test_tool_execution/__init__.py b/tests/test_tool_execution/__init__.py new file mode 100644 index 000000000..678d5603c --- /dev/null +++ b/tests/test_tool_execution/__init__.py @@ -0,0 +1 @@ +"""Tool Execution Integration Tests""" diff --git a/tests/test_tool_execution/test_integration.py b/tests/test_tool_execution/test_integration.py new file mode 100644 index 000000000..6565b810c --- /dev/null +++ b/tests/test_tool_execution/test_integration.py @@ -0,0 +1,885 @@ +"""工具执行集成测试 + +覆盖关键场景: +1. 正常工具执行 +2. 超时转后台执行 +3. 后台任务完成回调 +4. wait_tool_result 中断机制 +5. 多工具并行执行 +""" + +import asyncio +from dataclasses import dataclass +from typing import Any +from unittest.mock import Mock + +import pytest + + +# 测试用的模拟类 +@dataclass +class MockEvent: + unified_msg_origin: str = "test_session_001" + message_obj: Any = None + _result: Any = None + _has_send_oper: bool = False + _extras: dict = None + + def __post_init__(self): + self._extras = self._extras or {} + self.message_obj = Mock() + self.message_obj.type = "group" + self.message_obj.self_id = "bot_001" + self.message_obj.session_id = "session_001" + self.message_obj.group = Mock() + self.message_obj.sender = Mock() + + def get_result(self): + return self._result + + def set_result(self, result): + self._result = result + + def set_extra(self, key, value): + self._extras[key] = value + + def get_sender_name(self): + return "test_user" + + async def send(self, msg): + pass + + +class MockFunctionTool: + """模拟函数工具""" + + def __init__(self, name: str, handler=None): + self.name = name + self.handler = handler + + async def call(self, context, **kwargs): + if self.handler: + return await self.handler(**kwargs) + return "default_result" + + +class MockRunContext: + """模拟运行上下文""" + + def __init__(self, event: MockEvent, timeout: float = 15.0): + self.tool_call_timeout = timeout + self.context = Mock() + self.context.event = event + self.context.context = Mock() + self.context.context.get_event_queue = Mock(return_value=asyncio.Queue()) + + +# ============ 测试用例 ============ + + +class TestNormalToolExecution: + """正常工具执行测试""" + + @pytest.mark.asyncio + async def test_sync_handler_returns_result(self): + """同步处理器返回结果""" + + async def handler(event, **kwargs): + return "hello world" + + event = MockEvent() + + result = await handler(event) + assert result == "hello world" + + @pytest.mark.asyncio + async def test_async_generator_handler(self): + """异步生成器处理器""" + + async def handler(event, **kwargs): + yield "step1" + yield "step2" + yield "final" + + event = MockEvent() + results = [] + async for r in handler(event): + results.append(r) + + assert results == ["step1", "step2", "final"] + + +class TestBackgroundTaskManager: + """后台任务管理器测试""" + + @pytest.mark.asyncio + async def test_task_creation(self): + """任务创建测试""" + from astrbot.core.background_tool import BackgroundTask, TaskStatus + + task = BackgroundTask( + task_id="test_001", + tool_name="test_tool", + tool_args={"arg1": "value1"}, + session_id="session_001", + ) + + assert task.status == TaskStatus.PENDING + assert task.task_id == "test_001" + + @pytest.mark.asyncio + async def test_task_state_transitions(self): + """任务状态转换测试""" + from astrbot.core.background_tool import BackgroundTask, TaskStatus + + task = BackgroundTask( + task_id="test_002", + tool_name="test_tool", + tool_args={}, + session_id="session_001", + ) + + # PENDING -> RUNNING + task.start() + assert task.status == TaskStatus.RUNNING + assert task.started_at is not None + + # RUNNING -> COMPLETED + task.complete("success") + assert task.status == TaskStatus.COMPLETED + assert task.result == "success" + assert task.is_finished() + + +class TestTaskRegistry: + """任务注册表测试""" + + def test_register_and_get(self): + """注册和获取任务""" + from astrbot.core.background_tool import BackgroundTask, TaskRegistry + + registry = TaskRegistry() + registry.clear() # 清空单例状态 + + task = BackgroundTask( + task_id="reg_001", tool_name="test", tool_args={}, session_id="s1" + ) + + registry.register(task) + retrieved = registry.get("reg_001") + + assert retrieved is not None + assert retrieved.task_id == "reg_001" + + +class TestWaitInterrupt: + """等待中断机制测试""" + + @pytest.mark.asyncio + async def test_interrupt_flag(self): + """中断标记测试""" + from astrbot.core.background_tool import BackgroundToolManager + + manager = BackgroundToolManager() + session_id = "interrupt_test_001" + + # 初始状态无中断 + assert not manager.check_interrupt_flag(session_id) + + # 设置中断 + manager.set_interrupt_flag(session_id) + assert manager.check_interrupt_flag(session_id) + + # 清除中断 + manager.clear_interrupt_flag(session_id) + assert not manager.check_interrupt_flag(session_id) + + +class TestOutputBuffer: + """输出缓冲区测试""" + + def test_append_and_get(self): + """追加和获取输出""" + from astrbot.core.background_tool import OutputBuffer + + buffer = OutputBuffer() + task_id = "buf_001" + + buffer.append(task_id, "line1") + buffer.append(task_id, "line2") + buffer.append(task_id, "line3") + + lines = buffer.get_recent(task_id, n=2) + assert len(lines) == 2 + assert lines == ["line2", "line3"] + + +class TestTimeoutBehavior: + """超时行为测试""" + + @pytest.mark.asyncio + async def test_timeout_triggers_background(self): + """超时触发后台执行""" + + # 模拟一个会超时的任务 + async def slow_handler(event, **kwargs): + await asyncio.sleep(5) + return "done" + + # 使用短超时测试 + try: + await asyncio.wait_for(slow_handler(MockEvent()), timeout=0.1) + except asyncio.TimeoutError: + # 预期会超时 + pass + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + + +class TestMethodResolver: + """方法解析器单元测试""" + + def test_resolve_with_handler(self): + """测试handler解析""" + from astrbot.core.tool_execution.infrastructure.handler import MethodResolver + + class MockTool: + def handler(self, x): + return x + + name = "test" + + resolver = MethodResolver() + tool = MockTool() + handler, method = resolver.resolve(tool) + assert method == "decorator_handler" + + +class TestTimeoutStrategy: + """超时策略单元测试""" + + @pytest.mark.asyncio + async def test_no_timeout_strategy(self): + """测试无超时策略""" + from astrbot.core.tool_execution.infrastructure.timeout import NoTimeoutStrategy + + async def quick_task(): + return "done" + + strategy = NoTimeoutStrategy() + result = await strategy.execute(quick_task(), 1.0) + assert result == "done" + + +class TestCompletionSignal: + """完成信号单元测试""" + + @pytest.mark.asyncio + async def test_signal_set_and_wait(self): + """测试信号设置和等待""" + from astrbot.core.tool_execution.infrastructure.background import ( + CompletionSignal, + ) + + signal = CompletionSignal() + signal.set() + result = await signal.wait(timeout=1.0) + assert result is True + + @pytest.mark.asyncio + async def test_signal_timeout(self): + """测试信号超时""" + from astrbot.core.tool_execution.infrastructure.background import ( + CompletionSignal, + ) + + signal = CompletionSignal() + result = await signal.wait(timeout=0.1) + assert result is False + + +class TestResultProcessor: + """结果处理器单元测试""" + + @pytest.mark.asyncio + async def test_process_string_result(self): + """测试字符串结果处理""" + from astrbot.core.tool_execution.infrastructure.handler import ResultProcessor + + processor = ResultProcessor() + result = await processor.process("hello") + assert result is not None + assert result.content[0].text == "hello" + + @pytest.mark.asyncio + async def test_process_none_result(self): + """测试None结果处理""" + from astrbot.core.tool_execution.infrastructure.handler import ResultProcessor + + processor = ResultProcessor() + result = await processor.process(None) + assert result is None + + @pytest.mark.asyncio + async def test_process_call_tool_result(self): + """测试CallToolResult直接返回""" + import mcp.types + + from astrbot.core.tool_execution.infrastructure.handler import ResultProcessor + + processor = ResultProcessor() + original = mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text="test")] + ) + result = await processor.process(original) + assert result is original + + +class TestParameterValidator: + """参数验证器单元测试""" + + def test_validate_with_matching_params(self): + """测试参数匹配验证""" + from astrbot.core.tool_execution.infrastructure.handler import ( + ParameterValidator, + ) + + def handler(event, name: str, age: int = 18): + pass + + validator = ParameterValidator() + params = {"name": "test", "age": 25} + result = validator.validate(handler, params) + # 验证器会添加event参数(bind_partial绑定None) + assert result["name"] == "test" + assert result["age"] == 25 + + +class TestBackgroundHandler: + """后台处理器单元测试""" + + @pytest.mark.asyncio + async def test_build_notification(self): + """测试通知构建""" + from astrbot.core.tool_execution.infrastructure.timeout import BackgroundHandler + + handler = BackgroundHandler() + result = handler._build_notification("test_tool", "abc123") + assert "test_tool" in result.content[0].text + assert "abc123" in result.content[0].text + + +class TestToolExecutor: + """工具执行编排器单元测试""" + + def test_should_enable_timeout(self): + """测试超时启用判断""" + from astrbot.core.tool_execution.application.tool_executor import ToolExecutor + + executor = ToolExecutor() + + # 正常工具应启用超时 + assert executor._should_enable_timeout(15.0, "normal_tool") is True + + # 后台管理工具不启用超时 + assert executor._should_enable_timeout(15.0, "wait_tool_result") is False + assert executor._should_enable_timeout(15.0, "get_tool_output") is False + + # 超时为0时禁用 + assert executor._should_enable_timeout(0, "normal_tool") is False + + +class TestDomainConfig: + """领域配置测试""" + + def test_background_tool_names(self): + """测试后台工具名称配置""" + from astrbot.core.tool_execution.domain.config import BACKGROUND_TOOL_NAMES + + assert "wait_tool_result" in BACKGROUND_TOOL_NAMES + assert "get_tool_output" in BACKGROUND_TOOL_NAMES + assert "stop_tool" in BACKGROUND_TOOL_NAMES + assert "list_running_tools" in BACKGROUND_TOOL_NAMES + assert len(BACKGROUND_TOOL_NAMES) == 4 + + +class TestMethodResolverAdvanced: + """方法解析器高级测试""" + + def test_resolve_with_run_method(self): + """测试run方法解析""" + from astrbot.core.tool_execution.infrastructure.handler import MethodResolver + + class MockTool: + handler = None + name = "test" + + def run(self, event): + return "run result" + + resolver = MethodResolver() + handler, method = resolver.resolve(MockTool()) + assert method == "run" + + def test_resolve_failure_no_handler(self): + """测试无handler时抛出异常""" + from astrbot.core.tool_execution.errors import MethodResolutionError + from astrbot.core.tool_execution.infrastructure.handler import MethodResolver + + class MockTool: + handler = None + name = "test" + + resolver = MethodResolver() + try: + resolver.resolve(MockTool()) + assert False, "Should raise MethodResolutionError" + except MethodResolutionError: + pass + + +class TestTimeoutStrategyAdvanced: + """超时策略高级测试""" + + @pytest.mark.asyncio + async def test_standard_timeout_strategy_success(self): + """测试标准超时策略成功执行""" + from astrbot.core.tool_execution.infrastructure.timeout import TimeoutStrategy + + async def quick_task(): + return "done" + + strategy = TimeoutStrategy() + result = await strategy.execute(quick_task(), 5.0) + assert result == "done" + + @pytest.mark.asyncio + async def test_standard_timeout_strategy_timeout(self): + """测试标准超时策略超时""" + from astrbot.core.tool_execution.infrastructure.timeout import TimeoutStrategy + + async def slow_task(): + await asyncio.sleep(10) + return "done" + + strategy = TimeoutStrategy() + try: + await strategy.execute(slow_task(), 0.1) + assert False, "Should raise TimeoutError" + except asyncio.TimeoutError: + pass + + +class TestEventDrivenWait: + """事件驱动等待测试""" + + @pytest.mark.asyncio + async def test_task_completion_signal(self): + """测试任务完成信号触发""" + from astrbot.core.background_tool import BackgroundTask + + task = BackgroundTask( + task_id="signal_test", + tool_name="test", + tool_args={}, + session_id="s1", + ) + task.init_completion_event() + + assert task.completion_event is not None + + # 模拟任务完成 + task.complete("done") + task._signal_completion() + + # 验证信号已设置 + assert task.completion_event.is_set() + + @pytest.mark.asyncio + async def test_task_without_completion_event(self): + """测试无完成事件的任务回退到轮询""" + from astrbot.core.background_tool import BackgroundTask + + task = BackgroundTask( + task_id="no_event_test", + tool_name="test", + tool_args={}, + session_id="s1", + ) + # 不初始化 completion_event + task.complete("done") + + # 应该正常完成,is_finished 返回 True + assert task.is_finished() + assert task.completion_event is None + + +class TestErrorHandling: + """错误处理测试""" + + def test_parameter_validation_error_unexpected_arg(self): + """测试意外参数触发验证错误""" + from astrbot.core.tool_execution.errors import ParameterValidationError + from astrbot.core.tool_execution.infrastructure.handler import ( + ParameterValidator, + ) + + def handler(event, name: str): + pass + + validator = ParameterValidator() + try: + validator.validate(handler, {"unexpected": "value"}) + assert False, "Should raise ParameterValidationError" + except ParameterValidationError as e: + assert "Parameter mismatch" in str(e) + + def test_method_resolution_error(self): + """测试方法解析错误""" + from astrbot.core.tool_execution.errors import MethodResolutionError + + error = MethodResolutionError("test error") + assert str(error) == "test error" + + +class TestBackgroundToolConfig: + """配置模块单元测试""" + + def test_default_config_values(self): + """测试默认配置值""" + from astrbot.core.tool_execution.domain.config import ( + DEFAULT_CONFIG, + ) + + assert DEFAULT_CONFIG.cleanup_interval_seconds == 600 + assert DEFAULT_CONFIG.task_max_age_seconds == 3600 + assert DEFAULT_CONFIG.default_timeout_seconds == 600 + assert DEFAULT_CONFIG.error_preview_max_length == 500 + assert DEFAULT_CONFIG.default_output_lines == 50 + + def test_config_immutability(self): + """测试配置不可变性""" + from astrbot.core.tool_execution.domain.config import BackgroundToolConfig + + config = BackgroundToolConfig() + try: + config.cleanup_interval_seconds = 100 + assert False, "Should raise FrozenInstanceError" + except Exception: + pass # Expected + + +class TestCallbackEventBuilder: + """回调事件构建器单元测试""" + + def test_build_notification_text(self): + """测试通知文本构建""" + from astrbot.core.background_tool import ( + BackgroundTask, + CallbackEventBuilder, + TaskStatus, + ) + + task = BackgroundTask( + task_id="test_001", + tool_name="test_tool", + tool_args={}, + session_id="session_001", + ) + task.status = TaskStatus.COMPLETED + task.result = "success" + + builder = CallbackEventBuilder() + text = builder.build_notification_text(task) + + assert "test_001" in text + assert "test_tool" in text + assert "completed successfully" in text + assert "success" in text + + def test_error_preview_truncation(self): + """测试错误预览截断""" + from astrbot.core.background_tool import ( + BackgroundTask, + CallbackEventBuilder, + TaskStatus, + ) + from astrbot.core.tool_execution.domain.config import BackgroundToolConfig + + task = BackgroundTask( + task_id="test_002", + tool_name="test_tool", + tool_args={}, + session_id="session_001", + ) + task.status = TaskStatus.FAILED + task.error = "x" * 1000 # 超过500字符 + + config = BackgroundToolConfig(error_preview_max_length=100) + builder = CallbackEventBuilder(config=config) + text = builder.build_notification_text(task) + + assert "..." in text + assert len(text) < 1000 # 应该被截断 + + +class TestCallbackPublisher: + """回调发布器单元测试""" + + def test_should_publish_when_being_waited(self): + """测试等待中的任务不应发布""" + from astrbot.core.background_tool import ( + BackgroundTask, + CallbackPublisher, + ) + + task = BackgroundTask( + task_id="test_003", + tool_name="test_tool", + tool_args={}, + session_id="session_001", + ) + task.is_being_waited = True + + publisher = CallbackPublisher() + assert publisher.should_publish(task) is False + + def test_should_publish_without_event(self): + """测试无事件的任务不应发布""" + from astrbot.core.background_tool import ( + BackgroundTask, + CallbackPublisher, + ) + + task = BackgroundTask( + task_id="test_004", + tool_name="test_tool", + tool_args={}, + session_id="session_001", + ) + task.event = None + + publisher = CallbackPublisher() + assert publisher.should_publish(task) is False + + +class TestSanitizer: + """日志脱敏工具单元测试""" + + def test_sanitize_sensitive_key(self): + """测试敏感键名脱敏""" + from astrbot.core.tool_execution.utils.sanitizer import sanitize_params + + params = { + "username": "test_user", + "password": "secret123", + "api_key": "sk-1234567890", + } + result = sanitize_params(params) + + assert result["username"] == "test_user" + assert result["password"] == "***REDACTED***" + assert result["api_key"] == "***REDACTED***" + + def test_sanitize_sensitive_value_pattern(self): + """测试敏感值模式脱敏""" + from astrbot.core.tool_execution.utils.sanitizer import sanitize_params + + params = { + "header": "Bearer eyJhbGciOiJIUzI1NiJ9", + "config": "api_key=sk-1234567890abcdefghij", + } + result = sanitize_params(params) + + assert "***REDACTED***" in result["header"] + assert "***REDACTED***" in result["config"] + + def test_sanitize_nested_dict(self): + """测试嵌套字典脱敏""" + from astrbot.core.tool_execution.utils.sanitizer import sanitize_params + + params = { + "outer": { + "token": "secret", + "name": "test", + } + } + result = sanitize_params(params) + + assert result["outer"]["token"] == "***REDACTED***" + assert result["outer"]["name"] == "test" + + def test_sanitize_truncation(self): + """测试长值截断""" + from astrbot.core.tool_execution.utils.sanitizer import sanitize_params + + params = {"long_text": "x" * 200} + result = sanitize_params(params, max_value_length=100) + + assert len(result["long_text"]) < 200 + assert "truncated" in result["long_text"] + + +class TestRWLock: + """读写锁单元测试""" + + def test_read_lock_allows_concurrent_reads(self): + """测试读锁允许并发读取""" + import threading + + from astrbot.core.tool_execution.utils.rwlock import RWLock + + lock = RWLock() + read_count = [0] + max_concurrent = [0] + + def reader(): + with lock.read(): + read_count[0] += 1 + max_concurrent[0] = max(max_concurrent[0], read_count[0]) + import time + + time.sleep(0.01) + read_count[0] -= 1 + + threads = [threading.Thread(target=reader) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # 多个读取者应该能够并发 + assert max_concurrent[0] > 1 + + def test_write_lock_exclusive(self): + """测试写锁独占""" + from astrbot.core.tool_execution.utils.rwlock import RWLock + + lock = RWLock() + data = {"value": 0} + + def writer(): + with lock.write(): + current = data["value"] + import time + + time.sleep(0.01) + data["value"] = current + 1 + + import threading + + threads = [threading.Thread(target=writer) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # 写入应该是原子的 + assert data["value"] == 5 + + +class TestValidators: + """输入验证工具单元测试""" + + def test_valid_task_id(self): + """测试有效的任务ID""" + from astrbot.core.tool_execution.utils.validators import validate_task_id + + assert validate_task_id("task_001") == "task_001" + assert validate_task_id("abc-123-def") == "abc-123-def" + + def test_invalid_task_id_type(self): + """测试无效的任务ID类型""" + from astrbot.core.tool_execution.utils.validators import ( + ValidationError, + validate_task_id, + ) + + try: + validate_task_id(123) + assert False, "Should raise ValidationError" + except ValidationError as e: + assert "must be string" in str(e) + + def test_invalid_task_id_format(self): + """测试无效的任务ID格式""" + from astrbot.core.tool_execution.utils.validators import ( + ValidationError, + validate_task_id, + ) + + try: + validate_task_id("task@#$%") + assert False, "Should raise ValidationError" + except ValidationError as e: + assert "Invalid task_id format" in str(e) + + def test_valid_session_id(self): + """测试有效的会话ID""" + from astrbot.core.tool_execution.utils.validators import validate_session_id + + assert validate_session_id("session_001") == "session_001" + assert validate_session_id("user:group:123") == "user:group:123" + + def test_session_id_dangerous_chars(self): + """测试会话ID危险字符""" + from astrbot.core.tool_execution.utils.validators import ( + ValidationError, + validate_session_id, + ) + + try: + validate_session_id("session\x00inject") + assert False, "Should raise ValidationError" + except ValidationError as e: + assert "invalid characters" in str(e) + + def test_validate_positive_int(self): + """测试正整数验证""" + from astrbot.core.tool_execution.utils.validators import ( + ValidationError, + validate_positive_int, + ) + + assert validate_positive_int(10, "count") == 10 + + try: + validate_positive_int(-1, "count") + assert False, "Should raise ValidationError" + except ValidationError: + pass + + try: + validate_positive_int(99999, "count", max_value=100) + assert False, "Should raise ValidationError" + except ValidationError: + pass + + +class TestConfigCache: + """配置缓存单元测试""" + + def test_cache_reuse(self): + """测试缓存重用""" + import time + + from astrbot.core.background_tool.task_executor import _ConfigCache + + # 重置缓存 + _ConfigCache._timeout = None + _ConfigCache._last_load = 0 + + # 首次加载 + result1 = _ConfigCache.get_timeout() + first_load_time = _ConfigCache._last_load + + # 立即再次获取应该使用缓存 + time.sleep(0.01) + result2 = _ConfigCache.get_timeout() + second_load_time = _ConfigCache._last_load + + # 加载时间应该相同(使用了缓存) + assert first_load_time == second_load_time + assert result1 == result2 From 9ce91a8e50649bf0f081a481a8de028c113d6b94 Mon Sep 17 00:00:00 2001 From: YukiRa1n <167516635+YukiRa1n@users.noreply.github.com> Date: Fri, 6 Feb 2026 23:04:46 +0800 Subject: [PATCH 2/8] fix: remove unused MessageChain import --- astrbot/core/astr_agent_tool_exec.py | 1 - 1 file changed, 1 deletion(-) diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index e0d1295f7..4a3a6447b 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -22,7 +22,6 @@ from astrbot.core.cron.events import CronMessageEvent from astrbot.core.message.message_event_result import ( CommandResult, - MessageChain, MessageEventResult, ) from astrbot.core.platform.message_session import MessageSession From 22a10a9c799790e7d12ae881f42ac43db8df4352 Mon Sep 17 00:00:00 2001 From: YukiRa1n <167516635+YukiRa1n@users.noreply.github.com> Date: Fri, 6 Feb 2026 23:15:33 +0800 Subject: [PATCH 3/8] fix: replace mock API keys in tests to pass security scan --- tests/test_tool_execution/test_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_tool_execution/test_integration.py b/tests/test_tool_execution/test_integration.py index 6565b810c..802073db9 100644 --- a/tests/test_tool_execution/test_integration.py +++ b/tests/test_tool_execution/test_integration.py @@ -675,7 +675,7 @@ def test_sanitize_sensitive_key(self): params = { "username": "test_user", "password": "secret123", - "api_key": "sk-1234567890", + "api_key": "sk-test-fake-key-for-unit-test", } result = sanitize_params(params) @@ -689,7 +689,7 @@ def test_sanitize_sensitive_value_pattern(self): params = { "header": "Bearer eyJhbGciOiJIUzI1NiJ9", - "config": "api_key=sk-1234567890abcdefghij", + "config": "api_key=test-fake-key-for-unit-test", } result = sanitize_params(params) From 131770ca90fe36fa92b166228c4469e9733615e3 Mon Sep 17 00:00:00 2001 From: YukiRa1n <167516635+YukiRa1n@users.noreply.github.com> Date: Fri, 6 Feb 2026 23:25:57 +0800 Subject: [PATCH 4/8] fix: remove duplicate init_completion_event and use thread-safe registry APIs --- astrbot/core/background_tool/manager.py | 7 ++----- astrbot/core/background_tool/task_registry.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/astrbot/core/background_tool/manager.py b/astrbot/core/background_tool/manager.py index 66ac93a27..0ec751da4 100644 --- a/astrbot/core/background_tool/manager.py +++ b/astrbot/core/background_tool/manager.py @@ -76,13 +76,11 @@ async def _cleanup_loop(self) -> None: ) # 同步清理OutputBuffer中的孤立缓冲区 - valid_task_ids = set(self.registry._tasks.keys()) + valid_task_ids = self.registry.get_all_task_ids() buffer_cleaned = self.output_buffer.cleanup_old_buffers(valid_task_ids) # 清理孤立的中断标记(没有活跃任务的会话) - active_sessions = { - task.session_id for task in self.registry._tasks.values() - } + active_sessions = self.registry.get_all_session_ids() stale_flags = [ sid for sid in self._interrupt_flags if sid not in active_sessions ] @@ -150,7 +148,6 @@ async def submit_task( f"[BackgroundToolManager] Creating task {task.task_id} for tool {tool_name}, session {session_id}" ) - task.init_completion_event() self.registry.register(task) logger.info( f"[BackgroundToolManager] Task {task.task_id} registered successfully" diff --git a/astrbot/core/background_tool/task_registry.py b/astrbot/core/background_tool/task_registry.py index 1563c5580..71459a692 100644 --- a/astrbot/core/background_tool/task_registry.py +++ b/astrbot/core/background_tool/task_registry.py @@ -165,6 +165,16 @@ def clear(self) -> None: self._tasks.clear() self._session_index.clear() + def get_all_task_ids(self) -> set[str]: + """获取所有任务ID(线程安全)""" + with self._rwlock.read(): + return set(self._tasks.keys()) + + def get_all_session_ids(self) -> set[str]: + """获取所有有活跃任务的会话ID(线程安全)""" + with self._rwlock.read(): + return {task.session_id for task in self._tasks.values()} + def count(self) -> int: """获取任务数量""" with self._rwlock.read(): From daa06e94c85bfb357873d413e37ea74fb49cf3d3 Mon Sep 17 00:00:00 2001 From: YukiRa1n <167516635+YukiRa1n@users.noreply.github.com> Date: Fri, 6 Feb 2026 23:46:45 +0800 Subject: [PATCH 5/8] =?UTF-8?q?fix:=20=E5=90=8E=E5=8F=B0=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E4=BB=8E=20event.get=5Fresult()=20=E8=8E=B7=E5=8F=96=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E6=89=A7=E8=A1=8C=E7=BB=93=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 后台执行时 handler 通过 event.set_result() 设置结果而非直接 return,导致 [FINAL RESULT] 为空。在异步生成器和协程两个分支 中增加 fallback 检查 event.get_result()。 --- astrbot/core/background_tool/task_executor.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/astrbot/core/background_tool/task_executor.py b/astrbot/core/background_tool/task_executor.py index 18a9358ed..94bc15e17 100644 --- a/astrbot/core/background_tool/task_executor.py +++ b/astrbot/core/background_tool/task_executor.py @@ -219,11 +219,30 @@ async def _run_handler( if output is not None: self._log(task.task_id, str(output)) final_result = output + # handler 可能通过 event.set_result() 设置结果而非直接 return + if final_result is None and task.event is not None: + event_result = task.event.get_result() + if event_result is not None: + final_result = event_result.get_plain_text() + self._log( + task.task_id, + f"[EVENT_RESULT] Got result from event: {final_result[:200] if final_result else None}", + ) return final_result # 检查是否是协程 elif asyncio.iscoroutine(result): - return await result + coro_result = await result + # handler 可能通过 event.set_result() 设置结果而非直接 return + if coro_result is None and task.event is not None: + event_result = task.event.get_result() + if event_result is not None: + coro_result = event_result.get_plain_text() + self._log( + task.task_id, + f"[EVENT_RESULT] Got result from event: {coro_result[:200] if coro_result else None}", + ) + return coro_result else: return result From 4e97ce1cb302ec45dda7e4a22fd6740b65362c5b Mon Sep 17 00:00:00 2001 From: YukiRa1n <167516635+YukiRa1n@users.noreply.github.com> Date: Sat, 7 Feb 2026 00:07:01 +0800 Subject: [PATCH 6/8] fix: remove sk- prefix from test mock keys to pass gitleaks scan --- tests/test_tool_execution/test_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tool_execution/test_integration.py b/tests/test_tool_execution/test_integration.py index 802073db9..0f1a9f827 100644 --- a/tests/test_tool_execution/test_integration.py +++ b/tests/test_tool_execution/test_integration.py @@ -675,7 +675,7 @@ def test_sanitize_sensitive_key(self): params = { "username": "test_user", "password": "secret123", - "api_key": "sk-test-fake-key-for-unit-test", + "api_key": "fake-key-for-unit-test", } result = sanitize_params(params) From 8201440fca76cda3180263e61d7f3e69f8356e43 Mon Sep 17 00:00:00 2001 From: YukiRa1n <167516635+YukiRa1n@users.noreply.github.com> Date: Sat, 7 Feb 2026 01:18:08 +0800 Subject: [PATCH 7/8] =?UTF-8?q?fix:=20=E5=B7=A5=E5=85=B7=E6=89=A7=E8=A1=8C?= =?UTF-8?q?=E7=B3=BB=E7=BB=9F=E5=AE=89=E5=85=A8=E6=80=A7=E4=B8=8E=E6=AD=A3?= =?UTF-8?q?=E7=A1=AE=E6=80=A7=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - _interrupt_flags 加 threading.Lock 保护,防止并发迭代崩溃 - BackgroundTask 状态转换添加守卫矩阵和锁,防止并发状态不一致 - llm_tools 入口处调用 validate_task_id/validate_positive_int 验证外部输入 - ToolExecutor 修复绕过依赖注入问题,ResultProcessor 改为参数传入 run_context --- astrbot/core/background_tool/llm_tools.py | 22 +++++++ astrbot/core/background_tool/manager.py | 25 +++++--- astrbot/core/background_tool/task_state.py | 57 +++++++++++++++---- .../application/tool_executor.py | 9 +-- .../handler/result_processor.py | 16 +++--- astrbot/core/tool_execution/interfaces.py | 3 +- tests/test_tool_execution/test_integration.py | 3 +- 7 files changed, 100 insertions(+), 35 deletions(-) diff --git a/astrbot/core/background_tool/llm_tools.py b/astrbot/core/background_tool/llm_tools.py index c8d0c193d..a2db093f5 100644 --- a/astrbot/core/background_tool/llm_tools.py +++ b/astrbot/core/background_tool/llm_tools.py @@ -8,6 +8,12 @@ if TYPE_CHECKING: from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.tool_execution.utils.validators import ( + ValidationError, + validate_positive_int, + validate_task_id, +) + from .manager import BackgroundToolManager from .task_formatter import build_task_result @@ -32,6 +38,12 @@ async def get_tool_output( Returns: 工具输出日志和最终结果 """ + try: + task_id = validate_task_id(task_id) + lines = validate_positive_int(lines, "lines") + except ValidationError as e: + return f"Error: {e}" + manager = _get_manager() task = manager.registry.get(task_id) @@ -60,6 +72,11 @@ async def wait_tool_result( """ import asyncio + try: + task_id = validate_task_id(task_id) + except ValidationError as e: + return f"Error: {e}" + manager = _get_manager() session_id = event.unified_msg_origin @@ -118,6 +135,11 @@ async def stop_tool( Returns: 终止结果 """ + try: + task_id = validate_task_id(task_id) + except ValidationError as e: + return f"Error: {e}" + manager = _get_manager() task = manager.registry.get(task_id) diff --git a/astrbot/core/background_tool/manager.py b/astrbot/core/background_tool/manager.py index 0ec751da4..a8d9f7749 100644 --- a/astrbot/core/background_tool/manager.py +++ b/astrbot/core/background_tool/manager.py @@ -44,7 +44,8 @@ def __init__(self): self.output_buffer = OutputBuffer() self.executor = TaskExecutor(output_buffer=self.output_buffer) self.notifier = TaskNotifier() - self._interrupt_flags = {} # session_id -> bool,用于中断等待 + self._interrupt_flags: dict[str, bool] = {} + self._interrupt_lock = threading.Lock() self._cleanup_task: asyncio.Task | None = None self._initialized = True @@ -81,11 +82,14 @@ async def _cleanup_loop(self) -> None: # 清理孤立的中断标记(没有活跃任务的会话) active_sessions = self.registry.get_all_session_ids() - stale_flags = [ - sid for sid in self._interrupt_flags if sid not in active_sessions - ] - for sid in stale_flags: - self._interrupt_flags.pop(sid, None) + with self._interrupt_lock: + stale_flags = [ + sid + for sid in self._interrupt_flags + if sid not in active_sessions + ] + for sid in stale_flags: + del self._interrupt_flags[sid] if removed_count > 0 or buffer_cleaned > 0 or stale_flags: stats = self.registry.count_by_status() @@ -280,7 +284,8 @@ def set_interrupt_flag(self, session_id: str): Args: session_id: 会话ID """ - self._interrupt_flags[session_id] = True + with self._interrupt_lock: + self._interrupt_flags[session_id] = True def check_interrupt_flag(self, session_id: str) -> bool: """检查会话是否有中断标记 @@ -291,7 +296,8 @@ def check_interrupt_flag(self, session_id: str) -> bool: Returns: 是否有中断标记 """ - return self._interrupt_flags.get(session_id, False) + with self._interrupt_lock: + return self._interrupt_flags.get(session_id, False) def clear_interrupt_flag(self, session_id: str): """清除会话的中断标记 @@ -299,7 +305,8 @@ def clear_interrupt_flag(self, session_id: str): Args: session_id: 会话ID """ - self._interrupt_flags.pop(session_id, None) + with self._interrupt_lock: + self._interrupt_flags.pop(session_id, None) def get_running_tasks_status(self, session_id: str) -> str | None: """获取会话中正在运行的后台任务状态信息 diff --git a/astrbot/core/background_tool/task_state.py b/astrbot/core/background_tool/task_state.py index 41c6576e2..0c7291c6f 100644 --- a/astrbot/core/background_tool/task_state.py +++ b/astrbot/core/background_tool/task_state.py @@ -3,6 +3,7 @@ 定义后台任务的状态数据结构和状态转换逻辑。 """ +import threading import time import uuid from asyncio import Event, Queue @@ -21,6 +22,18 @@ class TaskStatus(Enum): CANCELLED = "cancelled" # 已取消 +# 合法的状态转换矩阵 +_VALID_TRANSITIONS: dict[TaskStatus, frozenset[TaskStatus]] = { + TaskStatus.PENDING: frozenset({TaskStatus.RUNNING, TaskStatus.CANCELLED}), + TaskStatus.RUNNING: frozenset( + {TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED} + ), + TaskStatus.COMPLETED: frozenset(), + TaskStatus.FAILED: frozenset(), + TaskStatus.CANCELLED: frozenset(), +} + + @dataclass class BackgroundTask: """后台任务状态 @@ -59,35 +72,59 @@ class BackgroundTask: completion_event: Event | None = ( None # 任务完成信号 # 是否有LLM正在使用wait_tool_result等待此任务 ) + _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) @staticmethod def generate_id() -> str: """生成唯一任务ID""" return str(uuid.uuid4())[:8] + def _transition_to(self, new_status: TaskStatus) -> bool: + """状态转换守卫,校验转换合法性 + + Args: + new_status: 目标状态 + + Returns: + 是否转换成功。已处于终态时静默返回False。 + """ + allowed = _VALID_TRANSITIONS.get(self.status, frozenset()) + if new_status not in allowed: + return False + self.status = new_status + return True + def start(self) -> None: """标记任务开始执行""" - self.status = TaskStatus.RUNNING - self.started_at = time.time() + with self._lock: + if not self._transition_to(TaskStatus.RUNNING): + return + self.started_at = time.time() def complete(self, result: str) -> None: """标记任务完成""" - self.status = TaskStatus.COMPLETED - self.result = result - self.completed_at = time.time() + with self._lock: + if not self._transition_to(TaskStatus.COMPLETED): + return + self.result = result + self.completed_at = time.time() self._signal_completion() def fail(self, error: str) -> None: """标记任务失败""" - self.status = TaskStatus.FAILED - self.error = error - self.completed_at = time.time() + with self._lock: + if not self._transition_to(TaskStatus.FAILED): + return + self.error = error + self.completed_at = time.time() self._signal_completion() def cancel(self) -> None: """标记任务取消""" - self.status = TaskStatus.CANCELLED - self.completed_at = time.time() + with self._lock: + if not self._transition_to(TaskStatus.CANCELLED): + return + self.completed_at = time.time() self._signal_completion() def append_output(self, line: str) -> None: diff --git a/astrbot/core/tool_execution/application/tool_executor.py b/astrbot/core/tool_execution/application/tool_executor.py index 0a18110dc..c91e330c5 100644 --- a/astrbot/core/tool_execution/application/tool_executor.py +++ b/astrbot/core/tool_execution/application/tool_executor.py @@ -109,8 +109,6 @@ async def _execute_with_timeout( self, tool, run_context, handler, method_name, timeout_enabled, **tool_args ) -> AsyncGenerator[mcp.types.CallToolResult, None]: """带超时控制的执行""" - from astrbot.core.tool_execution.infrastructure.handler import ResultProcessor - wrapper = self.tool_invoker.invoke( context=run_context, handler=handler, @@ -118,9 +116,6 @@ async def _execute_with_timeout( **tool_args, ) - # 创建带上下文的结果处理器 - result_processor = ResultProcessor(run_context) - while True: try: if timeout_enabled: @@ -130,7 +125,9 @@ async def _execute_with_timeout( else: resp = await anext(wrapper) - processed = await result_processor.process(resp) + processed = await self.result_processor.process( + resp, run_context=run_context + ) if processed: yield processed diff --git a/astrbot/core/tool_execution/infrastructure/handler/result_processor.py b/astrbot/core/tool_execution/infrastructure/handler/result_processor.py index 7c0be4018..e11e74dbf 100644 --- a/astrbot/core/tool_execution/infrastructure/handler/result_processor.py +++ b/astrbot/core/tool_execution/infrastructure/handler/result_processor.py @@ -14,14 +14,14 @@ class ResultProcessor(IResultProcessor): """结果处理器实现""" - def __init__(self, run_context: Any = None): - self._run_context = run_context - - async def process(self, result: Any) -> mcp.types.CallToolResult | None: + async def process( + self, result: Any, run_context: Any = None + ) -> mcp.types.CallToolResult | None: """处理执行结果 Args: result: 工具执行返回值 + run_context: 运行上下文(用于直接发送消息场景) Returns: 处理后的 CallToolResult,或 None 表示无需返回 @@ -30,7 +30,7 @@ async def process(self, result: Any) -> mcp.types.CallToolResult | None: return self._wrap_result(result) # result 为 None 时,检查是否需要直接发送消息给用户 - await self._send_direct_message() + await self._send_direct_message(run_context) return None def _wrap_result(self, result: Any) -> mcp.types.CallToolResult: @@ -44,12 +44,12 @@ def _wrap_result(self, result: Any) -> mcp.types.CallToolResult: ) return mcp.types.CallToolResult(content=[text_content]) - async def _send_direct_message(self) -> None: + async def _send_direct_message(self, run_context: Any = None) -> None: """处理工具直接发送消息给用户的情况""" - if self._run_context is None: + if run_context is None: return - event = self._run_context.context.event + event = run_context.context.event if not event: return diff --git a/astrbot/core/tool_execution/interfaces.py b/astrbot/core/tool_execution/interfaces.py index 9b8f715e1..b7de84247 100644 --- a/astrbot/core/tool_execution/interfaces.py +++ b/astrbot/core/tool_execution/interfaces.py @@ -60,11 +60,12 @@ class IResultProcessor(ABC): """ @abstractmethod - async def process(self, result: Any) -> Any: + async def process(self, result: Any, run_context: Any = None) -> Any: """处理执行结果 Args: result: 原始执行结果 + run_context: 运行上下文(可选,用于直接发送消息等场景) Returns: 处理后的结果 diff --git a/tests/test_tool_execution/test_integration.py b/tests/test_tool_execution/test_integration.py index 0f1a9f827..61cdd8953 100644 --- a/tests/test_tool_execution/test_integration.py +++ b/tests/test_tool_execution/test_integration.py @@ -508,7 +508,8 @@ async def test_task_without_completion_event(self): tool_args={}, session_id="s1", ) - # 不初始化 completion_event + # 不初始化 completion_event,但需要先 start() 再 complete()(状态转换守卫) + task.start() task.complete("done") # 应该正常完成,is_finished 返回 True From 9bdc167dacfa9a510fff672dac8eaafde0656af8 Mon Sep 17 00:00:00 2001 From: YukiRa1n <167516635+YukiRa1n@users.noreply.github.com> Date: Sat, 7 Feb 2026 01:37:38 +0800 Subject: [PATCH 8/8] fix: add cross-session task access control in LLM tools --- astrbot/core/background_tool/llm_tools.py | 24 ++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/astrbot/core/background_tool/llm_tools.py b/astrbot/core/background_tool/llm_tools.py index a2db093f5..fc0380463 100644 --- a/astrbot/core/background_tool/llm_tools.py +++ b/astrbot/core/background_tool/llm_tools.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from astrbot.core.background_tool.task_state import BackgroundTask from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.tool_execution.utils.validators import ( @@ -23,6 +24,21 @@ def _get_manager() -> BackgroundToolManager: return BackgroundToolManager() +def _get_task_for_session( + manager: BackgroundToolManager, task_id: str, session_id: str +) -> "BackgroundTask | None": + """按task_id查找任务,并校验会话归属。 + + 不匹配时返回None(与"不存在"相同响应,避免信息泄露)。 + """ + task = manager.registry.get(task_id) + if task is None: + return None + if task.session_id != session_id: + return None + return task + + async def get_tool_output( event: "AstrMessageEvent", task_id: str, @@ -45,8 +61,9 @@ async def get_tool_output( return f"Error: {e}" manager = _get_manager() + session_id = event.unified_msg_origin - task = manager.registry.get(task_id) + task = _get_task_for_session(manager, task_id, session_id) if task is None: return f"Error: Task {task_id} not found." @@ -85,7 +102,7 @@ async def wait_tool_result( logger.info(f"[wait_tool_result] Looking for task {task_id}") - task = manager.registry.get(task_id) + task = _get_task_for_session(manager, task_id, session_id) if task is None: return f"Error: Task {task_id} not found." @@ -141,8 +158,9 @@ async def stop_tool( return f"Error: {e}" manager = _get_manager() + session_id = event.unified_msg_origin - task = manager.registry.get(task_id) + task = _get_task_for_session(manager, task_id, session_id) if task is None: return f"Error: Task {task_id} not found."