Skip to content

Commit e49f056

Browse files
committed
fix(cli): 优化消息接收逻辑,用 finalize 机制替代延迟响应
- CLIMessageEvent: 用 finalize() 替代 _delayed_response() 延迟机制, 管道完成后统一返回响应,解决工具调用响应截断问题 - CLIMessageEvent: 添加 send_streaming() 支持,采用收集后一次性发送策略 - SocketClientHandler: 超时从30s增加到120s,处理 finalize 返回 None 的情况 - PipelineScheduler: 管道完成后调用 event.finalize()(鸭子类型) - CLIAdapter: 添加 get_stats()/unified_webhook() 兼容 CLIConfig 数据类 - PlatformManager: 安全获取平台ID,兼容 dict 和 dataclass 类型 - PlatformRoute: 兼容 CLIConfig 的 webhook_uuid 获取方式 - MessageConverter: 补充 raw_message 字段
1 parent 9387d5e commit e49f056

7 files changed

Lines changed: 100 additions & 44 deletions

File tree

astrbot/core/pipeline/scheduler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,8 @@ async def execute(self, event: AstrMessageEvent) -> None:
8585
if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent):
8686
await event.send(None)
8787

88+
# 通知事件管道已完成(鸭子类型,供需要收集完整响应的适配器使用)
89+
if hasattr(event, "finalize"):
90+
await event.finalize()
91+
8892
logger.debug("pipeline 执行完毕。")

astrbot/core/platform/manager.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,16 @@ async def terminate(self) -> None:
292292
def get_insts(self):
293293
return self.platform_insts
294294

295+
@staticmethod
296+
def _get_platform_id(inst) -> str:
297+
"""安全获取平台ID,兼容dict和dataclass类型的config"""
298+
config = getattr(inst, "config", None)
299+
if config is None:
300+
return "unknown"
301+
if isinstance(config, dict):
302+
return config.get("id", "unknown")
303+
return getattr(config, "platform_id", getattr(config, "id", "unknown"))
304+
295305
def get_all_stats(self) -> dict:
296306
"""获取所有平台的统计信息
297307
@@ -317,7 +327,7 @@ def get_all_stats(self) -> dict:
317327
logger.warning(f"获取平台统计信息失败: {e}")
318328
stats_list.append(
319329
{
320-
"id": getattr(inst, "config", {}).get("id", "unknown"),
330+
"id": self._get_platform_id(inst),
321331
"type": "unknown",
322332
"status": "unknown",
323333
"error_count": 0,

astrbot/core/platform/sources/cli/cli_adapter.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,39 @@ def meta(self) -> PlatformMetadata:
195195
"""获取平台元数据"""
196196
return self.metadata
197197

198+
def unified_webhook(self) -> bool:
199+
"""CLI不使用webhook"""
200+
return False
201+
202+
def get_stats(self) -> dict:
203+
"""获取平台统计信息(兼容CLIConfig数据类)"""
204+
meta = self.meta()
205+
meta_info = {
206+
"id": meta.id,
207+
"name": meta.name,
208+
"display_name": meta.adapter_display_name or meta.name,
209+
"description": meta.description,
210+
"support_streaming_message": meta.support_streaming_message,
211+
"support_proactive_message": meta.support_proactive_message,
212+
}
213+
return {
214+
"id": meta.id or self.config.platform_id,
215+
"type": meta.name,
216+
"display_name": meta.adapter_display_name or meta.name,
217+
"status": self._status.value,
218+
"started_at": self._started_at.isoformat() if self._started_at else None,
219+
"error_count": len(self._errors),
220+
"last_error": {
221+
"message": self.last_error.message,
222+
"timestamp": self.last_error.timestamp.isoformat(),
223+
"traceback": self.last_error.traceback,
224+
}
225+
if self.last_error
226+
else None,
227+
"unified_webhook": False,
228+
"meta": meta_info,
229+
}
230+
198231
async def terminate(self) -> None:
199232
"""终止平台运行"""
200233
self._running = False

astrbot/core/platform/sources/cli/cli_event.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import asyncio
9+
from collections.abc import AsyncGenerator
910
from typing import Any
1011

1112
from astrbot import logger
@@ -21,11 +22,9 @@ class CLIMessageEvent(AstrMessageEvent):
2122
"""CLI消息事件
2223
2324
处理命令行模拟器的消息事件。
25+
Socket模式下收集管道中所有send()调用的消息,在管道完成(finalize)后统一返回。
2426
"""
2527

26-
# 延迟配置
27-
INITIAL_DELAY = 5.0 # 首次发送延迟
28-
EXTENDED_DELAY = 10.0 # 后续发送延迟
2928
MAX_BUFFER_SIZE = 100 # 缓冲区最大消息组件数
3029

3130
def __init__(
@@ -48,29 +47,21 @@ def __init__(
4847
self.output_queue = output_queue
4948
self.response_future = response_future
5049

51-
# 多次回复收集
50+
# 多次回复收集(Socket模式)
5251
self.send_buffer = None
53-
self._response_delay_task = None
54-
self._response_delay = self.INITIAL_DELAY
5552

5653
async def send(self, message_chain: MessageChain) -> dict[str, Any]:
5754
"""发送消息到CLI"""
58-
# 调用父类方法以设置 _has_send_oper 标志
59-
# 这告诉 ProcessStage 已经有发送操作,避免触发 LLM
6055
await super().send(message_chain)
6156

62-
# Socket模式:收集多次回复
57+
# Socket模式:收集所有回复到buffer,等待finalize()统一返回
6358
if self.response_future is not None and not self.response_future.done():
64-
# 使用 ImageProcessor 预处理图片(避免临时文件被删除)
6559
ImageProcessor.preprocess_chain(message_chain)
6660

67-
# 收集多次回复到buffer
6861
if not self.send_buffer:
6962
self.send_buffer = message_chain
70-
self._response_delay = self.INITIAL_DELAY
7163
logger.debug("[CLI] First send: buffer initialized")
7264
else:
73-
# 检查缓冲区大小限制
7465
current_size = len(self.send_buffer.chain)
7566
new_size = len(message_chain.chain)
7667
if current_size + new_size > self.MAX_BUFFER_SIZE:
@@ -80,47 +71,61 @@ async def send(self, message_chain: MessageChain) -> dict[str, Any]:
8071
new_size,
8172
self.MAX_BUFFER_SIZE,
8273
)
83-
# 只添加能容纳的部分
8474
available = self.MAX_BUFFER_SIZE - current_size
8575
if available > 0:
8676
self.send_buffer.chain.extend(message_chain.chain[:available])
8777
else:
8878
self.send_buffer.chain.extend(message_chain.chain)
89-
self._response_delay = self.EXTENDED_DELAY
9079
logger.debug(
9180
"[CLI] Appended to buffer, total: %d", len(self.send_buffer.chain)
9281
)
93-
94-
# 重置延迟任务
95-
if self._response_delay_task and not self._response_delay_task.done():
96-
self._response_delay_task.cancel()
97-
98-
self._response_delay_task = asyncio.create_task(self._delayed_response())
9982
else:
100-
# 其他模式:直接放入输出队列
83+
# 非Socket模式或future已完成:直接放入输出队列
10184
await self.output_queue.put(message_chain)
10285

10386
return {"success": True}
10487

88+
async def send_streaming(
89+
self,
90+
generator: AsyncGenerator[MessageChain, None],
91+
use_fallback: bool = False,
92+
) -> None:
93+
"""处理流式LLM响应
94+
95+
CLI不支持真正的流式输出,采用收集后一次性发送的策略。
96+
与aiocqhttp的非fallback模式一致。
97+
"""
98+
buffer = None
99+
async for chain in generator:
100+
if not buffer:
101+
buffer = chain
102+
else:
103+
buffer.chain.extend(chain.chain)
104+
105+
if not buffer:
106+
return
107+
108+
buffer.squash_plain()
109+
await self.send(buffer)
110+
await super().send_streaming(generator, use_fallback)
111+
105112
async def reply(self, message_chain: MessageChain) -> dict[str, Any]:
106113
"""回复消息"""
107114
return await self.send(message_chain)
108115

109-
async def _delayed_response(self) -> None:
110-
"""延迟响应:收集所有回复后统一返回"""
111-
try:
112-
await asyncio.sleep(self._response_delay)
116+
async def finalize(self) -> None:
117+
"""管道完成后调用,将收集的所有回复统一返回给Socket客户端。
113118
114-
if self.response_future and not self.response_future.done():
119+
由PipelineScheduler.execute()在所有阶段执行完毕后调用。
120+
"""
121+
if self.response_future and not self.response_future.done():
122+
if self.send_buffer:
115123
self.response_future.set_result(self.send_buffer)
116124
logger.debug(
117-
"[CLI] Delayed response set, %d components",
125+
"[CLI] Pipeline done, response set with %d components",
118126
len(self.send_buffer.chain),
119127
)
120-
121-
except asyncio.CancelledError:
122-
pass
123-
except Exception as e:
124-
logger.error("[CLI] Delayed response error: %s", e)
125-
if self.response_future and not self.response_future.done():
126-
self.response_future.set_exception(e)
128+
else:
129+
# 管道完成但没有任何发送操作(如被白名单/频率限制拦截)
130+
self.response_future.set_result(None)
131+
logger.debug("[CLI] Pipeline done, no response to send")

astrbot/core/platform/sources/cli/handlers/socket_handler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import TYPE_CHECKING
1313

1414
from astrbot import logger
15+
from astrbot.core.message.message_event_result import MessageChain
1516

1617
from ..interfaces import IHandler, IMessageConverter, ISessionManager, ITokenValidator
1718
from ..message.response_builder import ResponseBuilder
@@ -34,7 +35,7 @@ class SocketClientHandler:
3435

3536
RECV_BUFFER_SIZE = 4096
3637
MAX_REQUEST_SIZE = 1024 * 1024 # 1MB 最大请求大小
37-
RESPONSE_TIMEOUT = 30.0
38+
RESPONSE_TIMEOUT = 120.0
3839

3940
def __init__(
4041
self,
@@ -181,14 +182,13 @@ async def _process_message(self, message_text: str, request_id: str) -> str:
181182
message_chain = await asyncio.wait_for(
182183
response_future, timeout=self.RESPONSE_TIMEOUT
183184
)
185+
if message_chain is None:
186+
# 管道完成但没有产生任何回复(被白名单/频率限制等拦截)
187+
return ResponseBuilder.build_success(
188+
MessageChain([]), request_id
189+
)
184190
return ResponseBuilder.build_success(message_chain, request_id)
185191
except asyncio.TimeoutError:
186-
# 超时时取消延迟响应任务,防止资源泄露
187-
if (
188-
hasattr(message_event, "_response_delay_task")
189-
and message_event._response_delay_task
190-
):
191-
message_event._response_delay_task.cancel()
192192
return ResponseBuilder.build_error("Request timeout", request_id, "TIMEOUT")
193193

194194
async def _get_logs(self, request: dict, request_id: str) -> str:

astrbot/core/platform/sources/cli/message/converter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,6 @@ def convert(
7171
nickname=self.user_nickname,
7272
)
7373

74+
message.raw_message = None
75+
7476
return message

astrbot/dashboard/routes/platform.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def _find_platform_by_uuid(self, webhook_uuid: str) -> Platform | None:
8181
平台适配器实例,未找到则返回 None
8282
"""
8383
for platform in self.platform_manager.platform_insts:
84-
if platform.config.get("webhook_uuid") == webhook_uuid:
84+
config = platform.config
85+
uuid_val = config.get("webhook_uuid") if isinstance(config, dict) else getattr(config, "webhook_uuid", None)
86+
if uuid_val == webhook_uuid:
8587
if platform.unified_webhook():
8688
return platform
8789
return None

0 commit comments

Comments
 (0)