Skip to content
12 changes: 12 additions & 0 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ async def reset(

self.stats = AgentStats()
self.stats.start_time = time.time()
self._wait_interrupted = False # 等待中断标记

async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
Expand Down Expand Up @@ -547,6 +548,17 @@ async def _handle_function_tools(
except Exception as e:
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
except Exception as e:
from astrbot.core.background_tool import WaitInterruptedException

if isinstance(e, WaitInterruptedException):
# 等待被中断,结束当前响应周期
logger.info(
f"Wait interrupted for task {e.task_id}, ending current response cycle"
)
self._wait_interrupted = True
self._transition_state(AgentState.DONE)
return

logger.warning(traceback.format_exc())
tool_call_result_blocks.append(
ToolCallMessageSegment(
Expand Down
88 changes: 17 additions & 71 deletions astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
from astrbot.core.cron.events import CronMessageEvent
from astrbot.core.message.message_event_result import (
CommandResult,
MessageChain,
MessageEventResult,
)
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.register import llm_tools
from astrbot.core.tool_execution.application.tool_executor import ToolExecutor
from astrbot.core.tool_execution.errors import MethodResolutionError
from astrbot.core.utils.history_saver import persist_agent_history


Expand Down Expand Up @@ -270,80 +271,25 @@ async def _execute_local(
tool_call_timeout: int | None = None,
**tool_args,
):
"""执行本地工具,使用 ToolExecutor 编排组件完成工具执行。"""
event = run_context.context.event
if not event:
raise ValueError("Event must be provided for local function tools.")

is_override_call = False
for ty in type(tool).mro():
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
is_override_call = True
break

# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
raise ValueError("Tool must have a valid handler or override 'run' method.")

awaitable = None
method_name = ""
if tool.handler:
awaitable = tool.handler
method_name = "decorator_handler"
elif is_override_call:
awaitable = tool.call
method_name = "call"
elif hasattr(tool, "run"):
awaitable = getattr(tool, "run")
method_name = "run"
if awaitable is None:
raise ValueError("Tool must have a valid handler or override 'run' method.")

wrapper = call_local_llm_tool(
context=run_context,
handler=awaitable,
method_name=method_name,
**tool_args,
)
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=tool_call_timeout or run_context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
if res := run_context.context.event.get_result():
if res.chain:
try:
await event.send(
MessageChain(
chain=res.chain,
type="tool_direct_result",
)
)
except Exception as e:
logger.error(
f"Tool 直接发送消息失败: {e}",
exc_info=True,
)
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {tool_call_timeout or run_context.tool_call_timeout} seconds.",
)
except StopAsyncIteration:
break
# 如果外部指定了超时(如后台任务用3600s),临时覆盖run_context的超时值
original_timeout = run_context.tool_call_timeout
if tool_call_timeout is not None:
run_context.tool_call_timeout = tool_call_timeout

executor = ToolExecutor()

try:
async for result in executor.execute(tool, run_context, **tool_args):
yield result
except MethodResolutionError as e:
raise ValueError(str(e)) from e
finally:
run_context.tool_call_timeout = original_timeout

@classmethod
async def _execute_mcp(
Expand Down
38 changes: 38 additions & 0 deletions astrbot/core/background_tool/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Background Tool Execution System
# 后台工具执行系统

from .callback_event_builder import CallbackEventBuilder
from .callback_publisher import CallbackPublisher
from .manager import BackgroundToolManager
from .output_buffer import OutputBuffer
from .task_executor import TaskExecutor
from .task_notifier import TaskNotifier
from .task_registry import TaskRegistry
from .task_state import BackgroundTask, TaskStatus


class WaitInterruptedException(Exception):
"""等待被新消息中断的异常

当wait_tool_result被用户新消息中断时抛出此异常,
用于通知框架结束当前LLM响应周期。
"""

def __init__(self, task_id: str, session_id: str):
self.task_id = task_id
self.session_id = session_id
super().__init__(f"Wait interrupted for task {task_id}")


__all__ = [
"BackgroundTask",
"TaskStatus",
"TaskRegistry",
"OutputBuffer",
"TaskExecutor",
"TaskNotifier",
"CallbackEventBuilder",
"CallbackPublisher",
"BackgroundToolManager",
"WaitInterruptedException",
]
152 changes: 152 additions & 0 deletions astrbot/core/background_tool/callback_event_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""回调事件构建器

将后台任务完成信息构建为可放入事件队列的回调事件。
遵循单一职责原则,只负责事件构建。
"""

import copy
import time
from typing import Any

from astrbot import logger
from astrbot.core.tool_execution.domain.config import DEFAULT_CONFIG

from .task_state import BackgroundTask, TaskStatus

# 状态文本映射
STATUS_TEXT_MAP = {
TaskStatus.COMPLETED: "completed successfully",
TaskStatus.FAILED: "failed",
TaskStatus.CANCELLED: "was cancelled",
}


class CallbackEventBuilder:
"""回调事件构建器

负责将后台任务构建为回调事件,不涉及队列操作。
"""

def __init__(self, config=None):
"""初始化构建器

Args:
config: 配置对象,默认使用 DEFAULT_CONFIG
"""
self._config = config or DEFAULT_CONFIG

def build_notification_text(self, task: BackgroundTask) -> str:
"""构建通知文本

Args:
task: 后台任务

Returns:
通知文本
"""
status = STATUS_TEXT_MAP.get(task.status, "unknown")

lines = [
"[Background Task Callback]",
f"Task ID: {task.task_id}",
f"Tool: {task.tool_name}",
f"Status: {status}",
]

if task.result:
lines.append(f"Result: {task.result}")

if task.error:
max_len = self._config.error_preview_max_length
error_preview = task.error[:max_len]
if len(task.error) > max_len:
error_preview += "..."
lines.append(f"Error: {error_preview}")

lines.append("")
lines.append(
"Please inform the user about this task completion and provide any relevant details."
)

return "\n".join(lines)

def build_message_object(self, task: BackgroundTask, text: str) -> Any:
"""构建消息对象

Args:
task: 后台任务
text: 通知文本

Returns:
AstrBotMessage 对象
"""
from astrbot.core.message.components import Plain
from astrbot.core.platform.astrbot_message import AstrBotMessage

original = task.event.message_obj

msg = AstrBotMessage()
msg.type = original.type
msg.self_id = original.self_id
msg.session_id = original.session_id
msg.message_id = f"bg_task_{task.task_id}"
msg.group = original.group
msg.sender = original.sender
msg.message = [Plain(text)]
msg.message_str = text
msg.raw_message = None
msg.timestamp = int(time.time())

return msg

def build_callback_event(self, task: BackgroundTask) -> Any | None:
"""构建完整的回调事件

Args:
task: 后台任务

Returns:
回调事件对象,构建失败返回 None
"""
if not task.event:
logger.warning(f"[CallbackEventBuilder] Task {task.task_id} has no event")
return None

try:
from astrbot.core.utils.trace import TraceSpan

text = self.build_notification_text(task)
msg_obj = self.build_message_object(task, text)

# 浅拷贝原事件,保留平台特定属性
new_event = copy.copy(task.event)
new_event.message_str = text
new_event.message_obj = msg_obj

# 重置状态
new_event._result = None
new_event._has_send_oper = False
new_event._extras = {}

# 初始化 trace
new_event.trace = TraceSpan(
name="BackgroundTaskCallback",
umo=new_event.unified_msg_origin,
sender_name=new_event.get_sender_name(),
message_outline=f"[Background Task {task.task_id}]",
)
new_event.span = new_event.trace

# 标记为回调事件
new_event.is_wake = True
new_event.is_at_or_wake_command = True
new_event.set_extra("is_background_task_callback", True)
new_event.set_extra("background_task_id", task.task_id)

return new_event

except Exception as e:
logger.error(
f"[CallbackEventBuilder] Failed to build event for task {task.task_id}: {e}"
)
return None
Loading