diff --git a/astrbot/api/star/__init__.py b/astrbot/api/star/__init__.py index 63db07a72..65d7ccae5 100644 --- a/astrbot/api/star/__init__.py +++ b/astrbot/api/star/__init__.py @@ -1,7 +1,7 @@ -from astrbot.core.star import Context, Star, StarTools +from astrbot.core.star import Context, NodeResult, NodeStar, Star, StarTools from astrbot.core.star.config import * from astrbot.core.star.register import ( register_star as register, # 注册插件(Star) ) -__all__ = ["Context", "Star", "StarTools", "register"] +__all__ = ["Context", "NodeResult", "NodeStar", "Star", "StarTools", "register"] diff --git a/astrbot/builtin_stars/agent/_node_config_schema.json b/astrbot/builtin_stars/agent/_node_config_schema.json new file mode 100644 index 000000000..401f72366 --- /dev/null +++ b/astrbot/builtin_stars/agent/_node_config_schema.json @@ -0,0 +1,59 @@ +{ + "agent_runner_type": { + "type": "string", + "default": "local", + "description": "此节点使用的 Agent 运行器类型", + "options": ["local", "dify", "coze", "dashscope"], + "labels": ["内置 Agent", "Dify", "Coze", "DashScope"] + }, + "coze_agent_runner_provider_id": { + "type": "string", + "default": "", + "description": "Coze Agent 运行器的提供商 ID", + "_special": "select_agent_runner_provider:coze", + "condition": { + "agent_runner_type": "coze" + } + }, + "dify_agent_runner_provider_id": { + "type": "string", + "default": "", + "description": "Dify Agent 运行器的提供商 ID", + "_special": "select_agent_runner_provider:dify", + "condition": { + "agent_runner_type": "dify" + } + }, + "dashscope_agent_runner_provider_id": { + "type": "string", + "default": "", + "description": "DashScope Agent 运行器的提供商 ID", + "_special": "select_agent_runner_provider:dashscope", + "condition": { + "agent_runner_type": "dashscope" + } + }, + "provider_id": { + "type": "string", + "default": "", + "description": "覆盖此节点的对话模型提供商 ID", + "_special": "select_provider" + }, + "image_caption_provider_id": { + "type": "string", + "default": "", + "description": "覆盖此节点的图像描述模型提供商 ID", + "_special": "select_provider" + }, + "image_caption_prompt": { + "type": "string", + "default": "", + "description": "覆盖此节点的图像描述提示词" + }, + "persona_id": { + "type": "string", + "default": "", + "description": "覆盖此节点使用的人格 ID", + "_special": "select_persona" + } +} diff --git a/astrbot/builtin_stars/agent/main.py b/astrbot/builtin_stars/agent/main.py new file mode 100644 index 000000000..bf391a7c3 --- /dev/null +++ b/astrbot/builtin_stars/agent/main.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from astrbot.core import logger +from astrbot.core.pipeline.engine.node_context import NodePacket +from astrbot.core.star.node_star import NodeResult, NodeStar + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +class AgentNode(NodeStar): + """Agent execution node (local + third-party).""" + + async def process(self, event: AstrMessageEvent) -> NodeResult: + ctx = event.node_context + + # 合并上游输出作为 agent 输入 + if ctx: + merged_input = await event.get_node_input(strategy="text_concat") + if isinstance(merged_input, str): + if merged_input.strip(): + ctx.input = NodePacket.create(merged_input) + elif merged_input is not None: + ctx.input = NodePacket.create(merged_input) + + if event.get_extra("_provider_request_consumed", False): + return NodeResult.SKIP + + has_provider_request = event.get_extra("has_provider_request", False) + if not has_provider_request: + has_upstream_input = bool(ctx and ctx.input is not None) + should_wake = ( + not event._has_send_oper + and event.is_at_or_wake_command + and not event.call_llm + ) + if not (has_upstream_input or should_wake): + return NodeResult.SKIP + + # 从 event 获取 AgentExecutor + agent_executor = event.agent_executor + if not agent_executor: + logger.warning("AgentExecutor missing in event services.") + return NodeResult.SKIP + + outcome = await agent_executor.run(event) + + if outcome.result: + event.set_node_output(outcome.result) + + if outcome.stopped or event.is_stopped(): + return NodeResult.STOP + + return NodeResult.CONTINUE diff --git a/astrbot/builtin_stars/agent/metadata.yaml b/astrbot/builtin_stars/agent/metadata.yaml new file mode 100644 index 000000000..5e9558673 --- /dev/null +++ b/astrbot/builtin_stars/agent/metadata.yaml @@ -0,0 +1,4 @@ +name: agent +desc: Builtin agent pipeline node +author: AstrBot +version: 1.0.0 diff --git a/astrbot/builtin_stars/astrbot/long_term_memory.py b/astrbot/builtin_stars/astrbot/long_term_memory.py index e08cdc515..eeda26e4e 100644 --- a/astrbot/builtin_stars/astrbot/long_term_memory.py +++ b/astrbot/builtin_stars/astrbot/long_term_memory.py @@ -24,7 +24,9 @@ def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None: """记录群成员的群聊记录""" def cfg(self, event: AstrMessageEvent): - cfg = self.context.get_config(umo=event.unified_msg_origin) + cfg = self.context.get_config_by_id( + event.chain_config.config_id if event.chain_config else None, + ) try: max_cnt = int(cfg["provider_ltm_settings"]["group_message_max_cnt"]) except BaseException as e: @@ -68,9 +70,15 @@ async def get_image_caption( image_url: str, image_caption_provider_id: str, image_caption_prompt: str, + event: AstrMessageEvent | None = None, ) -> str: if not image_caption_provider_id: - provider = self.context.get_using_provider() + provider = ( + self.context.get_chat_provider_for_event(event) if event else None + ) + if provider is None: + providers = self.context.get_all_providers() + provider = providers[0] if providers else None else: provider = self.context.get_provider_by_id(image_caption_provider_id) if not provider: @@ -133,6 +141,7 @@ async def handle_message(self, event: AstrMessageEvent) -> None: url, cfg["image_caption_provider_id"], cfg["image_caption_prompt"], + event, ) parts.append(f" [Image: {caption}]") except Exception as e: diff --git a/astrbot/builtin_stars/astrbot/main.py b/astrbot/builtin_stars/astrbot/main.py index da2a00835..ce76f6d3f 100644 --- a/astrbot/builtin_stars/astrbot/main.py +++ b/astrbot/builtin_stars/astrbot/main.py @@ -19,9 +19,8 @@ def __init__(self, context: star.Context) -> None: logger.error(f"聊天增强 err: {e}") def ltm_enabled(self, event: AstrMessageEvent): - ltmse = self.context.get_config(umo=event.unified_msg_origin)[ - "provider_ltm_settings" - ] + chain_config_id = event.chain_config.config_id if event.chain_config else None + ltmse = self.context.get_config_by_id(chain_config_id)["provider_ltm_settings"] return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"] @filter.platform_adapter_type(filter.PlatformAdapterType.ALL) @@ -36,9 +35,12 @@ async def on_message(self, event: AstrMessageEvent): if self.ltm_enabled(event) and self.ltm and has_image_or_plain: need_active = await self.ltm.need_active_reply(event) - group_icl_enable = self.context.get_config()["provider_ltm_settings"][ - "group_icl_enable" - ] + chain_config_id = ( + event.chain_config.config_id if event.chain_config else None + ) + group_icl_enable = self.context.get_config_by_id(chain_config_id)[ + "provider_ltm_settings" + ]["group_icl_enable"] if group_icl_enable: """记录对话""" try: @@ -48,7 +50,25 @@ async def on_message(self, event: AstrMessageEvent): if need_active: """主动回复""" - provider = self.context.get_using_provider(event.unified_msg_origin) + chain_config_id = ( + event.chain_config.config_id if event.chain_config else None + ) + runtime_cfg = self.context.get_config_by_id(chain_config_id) + default_provider_id = str( + runtime_cfg.get("provider_settings", {}).get( + "default_provider_id", + "", + ) + or "" + ).strip() + provider = ( + self.context.get_provider_by_id(default_provider_id) + if default_provider_id + else None + ) + if not provider: + all_providers = self.context.get_all_providers() + provider = all_providers[0] if all_providers else None if not provider: logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复") return diff --git a/astrbot/builtin_stars/builtin_commands/commands/__init__.py b/astrbot/builtin_stars/builtin_commands/commands/__init__.py index 46d255965..6d319ff93 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/__init__.py +++ b/astrbot/builtin_stars/builtin_commands/commands/__init__.py @@ -10,6 +10,7 @@ from .provider import ProviderCommands from .setunset import SetUnsetCommands from .sid import SIDCommand +from .stt import STTCommand from .t2i import T2ICommand from .tts import TTSCommand @@ -24,6 +25,7 @@ "ProviderCommands", "SIDCommand", "SetUnsetCommands", + "STTCommand", "T2ICommand", "TTSCommand", ] diff --git a/astrbot/builtin_stars/builtin_commands/commands/_node_binding.py b/astrbot/builtin_stars/builtin_commands/commands/_node_binding.py new file mode 100644 index 000000000..ccfe90698 --- /dev/null +++ b/astrbot/builtin_stars/builtin_commands/commands/_node_binding.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from astrbot.api.event import AstrMessageEvent +from astrbot.core.config.node_config import AstrBotNodeConfig +from astrbot.core.pipeline.engine.chain_config import ChainNodeConfig +from astrbot.core.star.context import Context + + +@dataclass +class NodeTarget: + node: ChainNodeConfig + config: AstrBotNodeConfig + + +def _get_node_schema(context: Context, node_name: str) -> dict | None: + meta = context.get_registered_star(node_name) + if meta: + return meta.node_schema + return None + + +def get_chain_nodes(event: AstrMessageEvent, node_name: str) -> list[ChainNodeConfig]: + chain_config = event.chain_config + if not chain_config: + return [] + return [node for node in chain_config.nodes if node.name == node_name] + + +def resolve_node_selector( + nodes: list[ChainNodeConfig], selector: str +) -> ChainNodeConfig | None: + selector = (selector or "").strip() + if not selector: + return None + if selector.isdigit(): + idx = int(selector) + if idx < 1 or idx > len(nodes): + return None + return nodes[idx - 1] + for node in nodes: + if node.uuid == selector: + return node + return None + + +def get_node_target( + context: Context, + event: AstrMessageEvent, + node_name: str, + selector: str | None = None, +) -> NodeTarget | None: + chain_config = event.chain_config + if not chain_config: + return None + + nodes = get_chain_nodes(event, node_name) + if not nodes: + return None + + target: ChainNodeConfig | None = None + if selector: + target = resolve_node_selector(nodes, selector) + elif len(nodes) == 1: + target = nodes[0] + + if target is None: + return None + + schema = _get_node_schema(context, node_name) + cfg = AstrBotNodeConfig.get_cached( + node_name=node_name, + chain_id=chain_config.chain_id, + node_uuid=target.uuid, + schema=schema, + ) + return NodeTarget(node=target, config=cfg) + + +def list_nodes_with_config( + context: Context, + event: AstrMessageEvent, + node_name: str, +) -> list[NodeTarget]: + chain_config = event.chain_config + if not chain_config: + return [] + + schema = _get_node_schema(context, node_name) + ret: list[NodeTarget] = [] + for node in get_chain_nodes(event, node_name): + cfg = AstrBotNodeConfig.get_cached( + node_name=node_name, + chain_id=chain_config.chain_id, + node_uuid=node.uuid, + schema=schema, + ) + ret.append(NodeTarget(node=node, config=cfg)) + return ret diff --git a/astrbot/builtin_stars/builtin_commands/commands/admin.py b/astrbot/builtin_stars/builtin_commands/commands/admin.py index a4f46b603..8daec81e7 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/admin.py +++ b/astrbot/builtin_stars/builtin_commands/commands/admin.py @@ -48,7 +48,9 @@ async def wl(self, event: AstrMessageEvent, sid: str = "") -> None: ), ) return - cfg = self.context.get_config(umo=event.unified_msg_origin) + cfg = self.context.get_config_by_id( + event.chain_config.config_id if event.chain_config else None, + ) cfg["platform_settings"]["id_whitelist"].append(str(sid)) cfg.save_config() event.set_result(MessageEventResult().message("添加白名单成功。")) @@ -63,7 +65,9 @@ async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None: ) return try: - cfg = self.context.get_config(umo=event.unified_msg_origin) + cfg = self.context.get_config_by_id( + event.chain_config.config_id if event.chain_config else None, + ) cfg["platform_settings"]["id_whitelist"].remove(str(sid)) cfg.save_config() event.set_result(MessageEventResult().message("删除白名单成功。")) diff --git a/astrbot/builtin_stars/builtin_commands/commands/conversation.py b/astrbot/builtin_stars/builtin_commands/commands/conversation.py index eb8cfdefa..41e31c00f 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/conversation.py +++ b/astrbot/builtin_stars/builtin_commands/commands/conversation.py @@ -2,9 +2,11 @@ from astrbot.api import sp, star from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.core.pipeline.agent.runner_config import resolve_agent_runner_config from astrbot.core.platform.astr_message_event import MessageSession from astrbot.core.platform.message_type import MessageType +from ._node_binding import list_nodes_with_config from .utils.rst_scene import RstScene THIRD_PARTY_AGENT_RUNNER_KEY = { @@ -19,6 +21,17 @@ class ConversationCommands: def __init__(self, context: star.Context) -> None: self.context = context + def _resolve_agent_runner_type(self, message: AstrMessageEvent) -> str: + agent_node_config: dict = {} + targets = list_nodes_with_config(self.context, message, "agent") + if targets: + target = targets[0] + if isinstance(target.config, dict): + agent_node_config = dict(target.config) + + runner_type, _ = resolve_agent_runner_config(agent_node_config) + return runner_type + async def _get_current_persona_id(self, session_id): curr = await self.context.conversation_manager.get_curr_conversation_id( session_id, @@ -36,7 +49,10 @@ async def _get_current_persona_id(self, session_id): async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" umo = message.unified_msg_origin - cfg = self.context.get_config(umo=message.unified_msg_origin) + chain_config_id = ( + message.chain_config.config_id if message.chain_config else None + ) + cfg = self.context.get_config_by_id(chain_config_id) is_unique_session = cfg["platform_settings"]["unique_session"] is_group = bool(message.get_group_id()) @@ -60,7 +76,7 @@ async def reset(self, message: AstrMessageEvent) -> None: ) return - agent_runner_type = cfg["provider_settings"]["agent_runner_type"] + agent_runner_type = self._resolve_agent_runner_type(message) if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: await sp.remove_async( scope="umo", @@ -70,7 +86,7 @@ async def reset(self, message: AstrMessageEvent) -> None: message.set_result(MessageEventResult().message("重置对话成功。")) return - if not self.context.get_using_provider(umo): + if not self.context.get_chat_provider_for_event(message): message.set_result( MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"), ) @@ -100,7 +116,7 @@ async def reset(self, message: AstrMessageEvent) -> None: async def his(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话记录""" - if not self.context.get_using_provider(message.unified_msg_origin): + if not self.context.get_chat_provider_for_event(message): message.set_result( MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"), ) @@ -143,8 +159,11 @@ async def his(self, message: AstrMessageEvent, page: int = 1) -> None: async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话列表""" - cfg = self.context.get_config(umo=message.unified_msg_origin) - agent_runner_type = cfg["provider_settings"]["agent_runner_type"] + chain_config_id = ( + message.chain_config.config_id if message.chain_config else None + ) + cfg = self.context.get_config_by_id(chain_config_id) + agent_runner_type = self._resolve_agent_runner_type(message) if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: message.set_result( MessageEventResult().message( @@ -181,8 +200,11 @@ async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: for conv in conversations_paged: persona_id = conv.persona_id if not persona_id or persona_id == "[%None]": + chain_config_id = ( + message.chain_config.config_id if message.chain_config else None + ) persona = await self.context.persona_manager.get_default_persona_v3( - umo=message.unified_msg_origin, + config_id=chain_config_id, ) persona_id = persona["name"] title = _titles.get(conv.cid, "新对话") @@ -203,7 +225,6 @@ async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: else: ret += "\n当前对话: 无" - cfg = self.context.get_config(umo=message.unified_msg_origin) unique_session = cfg["platform_settings"]["unique_session"] if unique_session: ret += "\n会话隔离粒度: 个人" @@ -218,8 +239,7 @@ async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: async def new_conv(self, message: AstrMessageEvent) -> None: """创建新对话""" - cfg = self.context.get_config(umo=message.unified_msg_origin) - agent_runner_type = cfg["provider_settings"]["agent_runner_type"] + agent_runner_type = self._resolve_agent_runner_type(message) if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: await sp.remove_async( scope="umo", @@ -321,7 +341,10 @@ async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> No async def del_conv(self, message: AstrMessageEvent) -> None: """删除当前对话""" - cfg = self.context.get_config(umo=message.unified_msg_origin) + chain_config_id = ( + message.chain_config.config_id if message.chain_config else None + ) + cfg = self.context.get_config_by_id(chain_config_id) is_unique_session = cfg["platform_settings"]["unique_session"] if message.get_group_id() and not is_unique_session and message.role != "admin": # 群聊,没开独立会话,发送人不是管理员 @@ -332,7 +355,7 @@ async def del_conv(self, message: AstrMessageEvent) -> None: ) return - agent_runner_type = cfg["provider_settings"]["agent_runner_type"] + agent_runner_type = self._resolve_agent_runner_type(message) if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: await sp.remove_async( scope="umo", diff --git a/astrbot/builtin_stars/builtin_commands/commands/llm.py b/astrbot/builtin_stars/builtin_commands/commands/llm.py index ba9ba5c9b..57a5f7a00 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/llm.py +++ b/astrbot/builtin_stars/builtin_commands/commands/llm.py @@ -1,20 +1,25 @@ from astrbot.api import star -from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.core.pipeline.engine.chain_runtime_flags import ( + FEATURE_LLM, + toggle_chain_runtime_flag, +) class LLMCommands: def __init__(self, context: star.Context) -> None: self.context = context - async def llm(self, event: AstrMessageEvent) -> None: - """开启/关闭 LLM""" - cfg = self.context.get_config(umo=event.unified_msg_origin) - enable = cfg["provider_settings"].get("enable", True) - if enable: - cfg["provider_settings"]["enable"] = False - status = "关闭" - else: - cfg["provider_settings"]["enable"] = True - status = "开启" - cfg.save_config() - await event.send(MessageChain().message(f"{status} LLM 聊天功能。")) + async def llm(self, event: AstrMessageEvent): + chain_config = event.chain_config + if not chain_config: + event.set_result(MessageEventResult().message("未找到已路由的 Chain。")) + return + + enabled = await toggle_chain_runtime_flag(chain_config.chain_id, FEATURE_LLM) + status = "开启" if enabled else "关闭" + event.set_result( + MessageEventResult().message( + f"Chain `{chain_config.chain_id}` 的 LLM 功能已{status}。" + ) + ) diff --git a/astrbot/builtin_stars/builtin_commands/commands/persona.py b/astrbot/builtin_stars/builtin_commands/commands/persona.py index cf99988a2..5c1fbddc4 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/persona.py +++ b/astrbot/builtin_stars/builtin_commands/commands/persona.py @@ -1,204 +1,241 @@ import builtins -from typing import TYPE_CHECKING -from astrbot.api import sp, star +from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult -if TYPE_CHECKING: - from astrbot.core.db.po import Persona +from ._node_binding import get_node_target, list_nodes_with_config class PersonaCommands: def __init__(self, context: star.Context) -> None: self.context = context - def _build_tree_output( - self, - folder_tree: list[dict], - all_personas: list["Persona"], - depth: int = 0, - ) -> list[str]: - """递归构建树状输出,使用短线条表示层级""" - lines: list[str] = [] - # 使用短线条作为缩进前缀,每层只用 "│" 加一个空格 - prefix = "│ " * depth - - for folder in folder_tree: - # 输出文件夹 - lines.append(f"{prefix}├ 📁 {folder['name']}/") - - # 获取该文件夹下的人格 - folder_personas = [ - p for p in all_personas if p.folder_id == folder["folder_id"] - ] - child_prefix = "│ " * (depth + 1) - - # 输出该文件夹下的人格 - for persona in folder_personas: - lines.append(f"{child_prefix}├ 👤 {persona.persona_id}") - - # 递归处理子文件夹 - children = folder.get("children", []) - if children: - lines.extend( - self._build_tree_output( - children, - all_personas, - depth + 1, - ) - ) + @staticmethod + def _split_tokens(message: str) -> list[str]: + parts = [p for p in message.strip().split() if p] + if parts and parts[0].startswith("/"): + parts = parts[1:] + if parts and parts[0] == "persona": + parts = parts[1:] + return parts + + def _find_persona(self, persona_name: str): + return next( + builtins.filter( + lambda persona: persona["name"] == persona_name, + self.context.provider_manager.personas, + ), + None, + ) + + def _render_agent_nodes(self, event: AstrMessageEvent) -> str: + targets = list_nodes_with_config(self.context, event, "agent") + if not targets: + return "当前 Chain 中没有 agent 节点。" + lines = [] + for idx, target in enumerate(targets, start=1): + persona_id = target.config.get("persona_id") or "<继承>" + provider_id = target.config.get("provider_id") or "<继承>" + lines.append( + f"{idx}. 节点={target.node.uuid[:8]} persona={persona_id} provider={provider_id}" + ) + return "\n".join(lines) + + async def _render_persona_tree(self) -> str: + folders = await self.context.persona_manager.get_folder_tree() + personas = await self.context.persona_manager.get_all_personas() + + personas_by_folder: dict[str | None, list] = {} + for persona in personas: + personas_by_folder.setdefault(persona.folder_id, []).append(persona) + + for folder_personas in personas_by_folder.values(): + folder_personas.sort(key=lambda p: (p.sort_order, p.persona_id)) + + lines: list[str] = ["人格树:"] + + def append_personas(folder_id: str | None, indent: str) -> None: + for persona in personas_by_folder.get(folder_id, []): + lines.append(f"{indent}- {persona.persona_id}") + + def append_folders(folder_nodes: list[dict], indent: str) -> None: + for folder in folder_nodes: + folder_name = folder.get("name") or "<未命名文件夹>" + lines.append(f"{indent}[{folder_name}]") + + folder_id = folder.get("folder_id") + append_personas(folder_id, indent + " ") + + children = folder.get("children") or [] + if children: + append_folders(children, indent + " ") + + append_personas(None, "") + if folders: + if len(lines) > 1: + lines.append("") + append_folders(folders, "") + + if len(lines) == 1: + lines.append("(空)") - return lines + return "\n".join(lines) - async def persona(self, message: AstrMessageEvent) -> None: - l = message.message_str.split(" ") # noqa: E741 - umo = message.unified_msg_origin + async def persona(self, message: AstrMessageEvent): + chain = message.chain_config + if not chain: + message.set_result(MessageEventResult().message("未找到已路由的 Chain。")) + return - curr_persona_name = "无" - cid = await self.context.conversation_manager.get_curr_conversation_id(umo) + tokens = self._split_tokens(message.message_str) + chain_config_id = chain.config_id if chain else None default_persona = await self.context.persona_manager.get_default_persona_v3( - umo=umo, + config_id=chain_config_id, ) - force_applied_persona_id = ( - await sp.get_async( - scope="umo", scope_id=umo, key="session_service_config", default={} - ) - ).get("persona_id") - - curr_cid_title = "无" - if cid: - conv = await self.context.conversation_manager.get_conversation( - unified_msg_origin=umo, - conversation_id=cid, - create_if_not_exists=True, + if not tokens: + help_text = [ + f"当前 Chain: {chain.chain_id}", + f"默认人格: {default_persona['name']}", + "", + self._render_agent_nodes(message), + "", + "用法:", + "/persona list", + "/persona view ", + "/persona # 兼容单 agent 绑定", + "/persona unset # 兼容单 agent 解绑", + "/persona node ls", + "/persona node set ", + "/persona node unset ", + ] + message.set_result( + MessageEventResult().message("\n".join(help_text)).use_t2i(False) ) - if conv is None: - message.set_result( - MessageEventResult().message( - "当前对话不存在,请先使用 /new 新建一个对话。", - ), - ) - return - if not conv.persona_id and conv.persona_id != "[%None]": - curr_persona_name = default_persona["name"] - else: - curr_persona_name = conv.persona_id - - if force_applied_persona_id: - curr_persona_name = f"{curr_persona_name} (自定义规则)" + return - curr_cid_title = conv.title if conv.title else "新对话" - curr_cid_title += f"({cid[:4]})" - - if len(l) == 1: + if tokens[0] == "list": message.set_result( MessageEventResult() - .message( - f"""[Persona] - -- 人格情景列表: `/persona list` -- 设置人格情景: `/persona 人格` -- 人格情景详细信息: `/persona view 人格` -- 取消人格: `/persona unset` - -默认人格情景: {default_persona["name"]} -当前对话 {curr_cid_title} 的人格情景: {curr_persona_name} + .message(await self._render_persona_tree()) + .use_t2i(False) + ) + return -配置人格情景请前往管理面板-配置页 -""", + if tokens[0] == "view": + if len(tokens) < 2: + message.set_result( + MessageEventResult().message("请输入 persona 名称。") ) - .use_t2i(False), - ) - elif l[1] == "list": - # 获取文件夹树和所有人格 - folder_tree = await self.context.persona_manager.get_folder_tree() - all_personas = self.context.persona_manager.personas - - lines = ["📂 人格列表:\n"] - - # 构建树状输出 - tree_lines = self._build_tree_output(folder_tree, all_personas) - lines.extend(tree_lines) - - # 输出根目录下的人格(没有文件夹的) - root_personas = [p for p in all_personas if p.folder_id is None] - if root_personas: - if tree_lines: # 如果有文件夹内容,加个空行 - lines.append("") - for persona in root_personas: - lines.append(f"👤 {persona.persona_id}") - - # 统计信息 - total_count = len(all_personas) - lines.append(f"\n共 {total_count} 个人格") - lines.append("\n*使用 `/persona <人格名>` 设置人格") - lines.append("*使用 `/persona view <人格名>` 查看详细信息") - - msg = "\n".join(lines) - message.set_result(MessageEventResult().message(msg).use_t2i(False)) - elif l[1] == "view": - if len(l) == 2: - message.set_result(MessageEventResult().message("请输入人格情景名")) return - ps = l[2].strip() - if persona := next( - builtins.filter( - lambda persona: persona["name"] == ps, - self.context.provider_manager.personas, - ), - None, - ): - msg = f"人格{ps}的详细信息:\n" - msg += f"{persona['prompt']}\n" - else: - msg = f"人格{ps}不存在" - message.set_result(MessageEventResult().message(msg)) - elif l[1] == "unset": - if not cid: + persona_name = tokens[1] + persona = self._find_persona(persona_name) + if not persona: message.set_result( - MessageEventResult().message("当前没有对话,无法取消人格。"), + MessageEventResult().message(f"未找到 persona `{persona_name}`。") ) return - await self.context.conversation_manager.update_conversation_persona_id( - message.unified_msg_origin, - "[%None]", + message.set_result( + MessageEventResult().message( + f"persona {persona_name}:\n{persona['prompt']}" + ) ) - message.set_result(MessageEventResult().message("取消人格成功。")) - else: - ps = "".join(l[1:]).strip() - if not cid: + return + + if tokens[0] == "node": + if len(tokens) >= 2 and tokens[1] == "ls": + message.set_result( + MessageEventResult() + .message(self._render_agent_nodes(message)) + .use_t2i(False) + ) + return + if len(tokens) >= 4 and tokens[1] == "set": + selector = tokens[2] + persona_name = " ".join(tokens[3:]).strip() + persona = self._find_persona(persona_name) + if not persona: + message.set_result( + MessageEventResult().message( + f"未找到 persona `{persona_name}`。" + ) + ) + return + target = get_node_target( + self.context, message, "agent", selector=selector + ) + if not target: + message.set_result( + MessageEventResult().message("agent 节点选择器无效。") + ) + return + target.config.save_config({"persona_id": persona_name}) message.set_result( MessageEventResult().message( - "当前没有对话,请先开始对话或使用 /new 创建一个对话。", - ), + f"已将 persona `{persona_name}` 绑定到 agent 节点 `{target.node.uuid[:8]}`。" + ) ) return - if persona := next( - builtins.filter( - lambda persona: persona["name"] == ps, - self.context.provider_manager.personas, - ), - None, - ): - await self.context.conversation_manager.update_conversation_persona_id( - message.unified_msg_origin, - ps, + if len(tokens) >= 3 and tokens[1] == "unset": + selector = tokens[2] + target = get_node_target( + self.context, message, "agent", selector=selector ) - force_warn_msg = "" - if force_applied_persona_id: - force_warn_msg = ( - "提醒:由于自定义规则,您现在切换的人格将不会生效。" + if not target: + message.set_result( + MessageEventResult().message("agent 节点选择器无效。") ) - + return + target.config.save_config({"persona_id": ""}) message.set_result( MessageEventResult().message( - f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。{force_warn_msg}", - ), + f"已清除 agent 节点 `{target.node.uuid[:8]}` 的 persona 绑定。" + ) + ) + return + + message.set_result( + MessageEventResult().message( + "用法: /persona node ls | /persona node set | /persona node unset " ) - else: + ) + return + + if tokens[0] == "unset": + targets = list_nodes_with_config(self.context, message, "agent") + if len(targets) != 1: message.set_result( MessageEventResult().message( - "不存在该人格情景。使用 /persona list 查看所有。", - ), + "检测到多个 agent 节点,请使用 /persona node unset 。" + ) + ) + return + targets[0].config.save_config({"persona_id": ""}) + message.set_result(MessageEventResult().message("已清除 persona 绑定。")) + return + + persona_name = " ".join(tokens).strip() + persona = self._find_persona(persona_name) + if not persona: + message.set_result( + MessageEventResult().message( + f"未找到 persona `{persona_name}`,请使用 /persona list 查看。" ) + ) + return + + targets = list_nodes_with_config(self.context, message, "agent") + if len(targets) != 1: + message.set_result( + MessageEventResult().message( + "检测到多个 agent 节点,请使用 /persona node set 。" + ) + ) + return + + targets[0].config.save_config({"persona_id": persona_name}) + message.set_result( + MessageEventResult().message( + f"已将 persona `{persona_name}` 绑定到当前 agent 节点。" + ) + ) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index ae20eb8e1..60c52efbc 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -1,329 +1,268 @@ -import asyncio import re -from astrbot import logger from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult -from astrbot.core.provider.entities import ProviderType + +from ._node_binding import get_node_target, list_nodes_with_config class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context - def _log_reachability_failure( - self, - provider, - provider_capability_type: ProviderType | None, - err_code: str, - err_reason: str, - ) -> None: - """记录不可达原因到日志。""" - meta = provider.meta() - logger.warning( - "Provider reachability check failed: id=%s type=%s code=%s reason=%s", - meta.id, - provider_capability_type.name if provider_capability_type else "unknown", - err_code, - err_reason, - ) + @staticmethod + def _split_tokens(message: str) -> list[str]: + parts = [p for p in message.strip().split() if p] + if parts and parts[0].startswith("/"): + parts = parts[1:] + if parts and parts[0] == "provider": + parts = parts[1:] + return parts - async def _test_provider_capability(self, provider): - """测试单个 provider 的可用性""" - meta = provider.meta() - provider_capability_type = meta.provider_type + @staticmethod + def _parse_kind(token: str | None) -> str: + token = (token or "").strip().lower() + if token in ("", "llm", "agent", "chat"): + return "llm" + if token in ("tts", "stt"): + return token + return "llm" - try: - await provider.test() - return True, None, None - except Exception as e: - err_code = "TEST_FAILED" - err_reason = str(e) - self._log_reachability_failure( - provider, provider_capability_type, err_code, err_reason - ) - return False, err_code, err_reason + @staticmethod + def _kind_to_node_name(kind: str) -> str: + return {"llm": "agent", "tts": "tts", "stt": "stt"}[kind] + + def _providers_by_kind(self, kind: str): + if kind == "llm": + return list(self.context.get_all_providers()) + if kind == "tts": + return list(self.context.get_all_tts_providers()) + return list(self.context.get_all_stt_providers()) + + def _resolve_provider(self, kind: str, token: str): + providers = self._providers_by_kind(kind) + if token.isdigit(): + idx = int(token) + if 1 <= idx <= len(providers): + return providers[idx - 1] + return None + for prov in providers: + if prov.meta().id == token: + return prov + return None + + def _render_node_bindings(self, event: AstrMessageEvent) -> str: + rows: list[str] = [] + mapping = {"llm": "agent", "tts": "tts", "stt": "stt"} + for kind, node_name in mapping.items(): + targets = list_nodes_with_config(self.context, event, node_name) + if not targets: + continue + rows.append(f"[{kind}] 节点绑定:") + for idx, target in enumerate(targets, start=1): + bound = target.config.get("provider_id", "") or "<继承>" + rows.append(f" {idx}. 节点={target.node.uuid[:8]} provider={bound}") + return "\n".join(rows) if rows else "当前 Chain 没有可绑定 provider 的节点。" + + def _render_provider_list(self) -> str: + parts: list[str] = [] + for kind in ("llm", "tts", "stt"): + providers = self._providers_by_kind(kind) + if not providers: + continue + parts.append(f"[{kind}] 可用 provider:") + for idx, prov in enumerate(providers, start=1): + meta = prov.meta() + model = getattr(meta, "model", "") + suffix = f" ({model})" if model else "" + parts.append(f" {idx}. {meta.id}{suffix}") + return "\n".join(parts) if parts else "当前没有已加载的 provider。" async def provider( self, event: AstrMessageEvent, idx: str | int | None = None, idx2: int | None = None, - ) -> None: - """查看或者切换 LLM Provider""" - umo = event.unified_msg_origin - cfg = self.context.get_config(umo).get("provider_settings", {}) - reachability_check_enabled = cfg.get("reachability_check", True) + ): + del idx, idx2 + chain = event.chain_config + if not chain: + event.set_result(MessageEventResult().message("未找到已路由的 Chain。")) + return - if idx is None: - parts = ["## 载入的 LLM 提供商\n"] + tokens = self._split_tokens(event.message_str) + if not tokens: + msg = [ + f"当前 Chain: {chain.chain_id}", + self._render_provider_list(), + "", + self._render_node_bindings(event), + "", + "用法:", + "/provider # 兼容单 agent 绑定", + "/provider # 兼容单节点绑定", + "/provider ", + "/provider node ls", + ] + event.set_result( + MessageEventResult().message("\n".join(msg)).use_t2i(False) + ) + return - # 获取所有类型的提供商 - llms = list(self.context.get_all_providers()) - ttss = self.context.get_all_tts_providers() - stts = self.context.get_all_stt_providers() + if tokens[0] == "node" and len(tokens) >= 2 and tokens[1] == "ls": + event.set_result( + MessageEventResult() + .message(self._render_node_bindings(event)) + .use_t2i(False) + ) + return - # 构造待检测列表: [(provider, type_label), ...] - all_providers = [] - all_providers.extend([(p, "llm") for p in llms]) - all_providers.extend([(p, "tts") for p in ttss]) - all_providers.extend([(p, "stt") for p in stts]) + kind = "llm" + remaining = tokens + if tokens[0] in ("llm", "agent", "chat", "tts", "stt"): + kind = self._parse_kind(tokens[0]) + remaining = tokens[1:] - # 并发测试连通性 - if reachability_check_enabled: - if all_providers: - await event.send( - MessageEventResult().message( - "正在进行提供商可达性测试,请稍候..." - ) - ) - check_results = await asyncio.gather( - *[self._test_provider_capability(p) for p, _ in all_providers], - return_exceptions=True, + node_name = self._kind_to_node_name(kind) + node_targets = list_nodes_with_config(self.context, event, node_name) + if not node_targets: + event.set_result( + MessageEventResult().message( + f"当前 Chain 中没有可用于 {kind} 绑定的 `{node_name}` 节点。" ) - else: - # 用 None 表示未检测 - check_results = [None for _ in all_providers] + ) + return - # 整合结果 - display_data = [] - for (p, p_type), reachable in zip(all_providers, check_results): - meta = p.meta() - id_ = meta.id - error_code = None + selector: str | None = None + provider_token: str | None = None - if isinstance(reachable, Exception): - # 异常情况下兜底处理,避免单个 provider 导致列表失败 - self._log_reachability_failure( - p, - None, - reachable.__class__.__name__, - str(reachable), + if len(remaining) == 1: + provider_token = remaining[0] + if len(node_targets) > 1: + event.set_result( + MessageEventResult().message( + f"检测到多个 `{node_name}` 节点,请使用 `/provider {kind} ` 指定节点。" ) - reachable_flag = False - error_code = reachable.__class__.__name__ - elif isinstance(reachable, tuple): - reachable_flag, error_code, _ = reachable - else: - reachable_flag = reachable - - # 根据类型构建显示名称 - if p_type == "llm": - info = f"{id_} ({meta.model})" - else: - info = f"{id_}" - - # 确定状态标记 - if reachable_flag is True: - mark = " ✅" - elif reachable_flag is False: - if error_code: - mark = f" ❌(错误码: {error_code})" - else: - mark = " ❌" - else: - mark = "" # 不支持检测时不显示标记 - - display_data.append( - { - "type": p_type, - "info": info, - "mark": mark, - "provider": p, - } ) + return + elif len(remaining) >= 2: + selector = remaining[0] + provider_token = remaining[1] - # 分组输出 - # 1. LLM - llm_data = [d for d in display_data if d["type"] == "llm"] - for i, d in enumerate(llm_data): - line = f"{i + 1}. {d['info']}{d['mark']}" - provider_using = self.context.get_using_provider(umo=umo) - if ( - provider_using - and provider_using.meta().id == d["provider"].meta().id - ): - line += " (当前使用)" - parts.append(line + "\n") - - # 2. TTS - tts_data = [d for d in display_data if d["type"] == "tts"] - if tts_data: - parts.append("\n## 载入的 TTS 提供商\n") - for i, d in enumerate(tts_data): - line = f"{i + 1}. {d['info']}{d['mark']}" - tts_using = self.context.get_using_tts_provider(umo=umo) - if tts_using and tts_using.meta().id == d["provider"].meta().id: - line += " (当前使用)" - parts.append(line + "\n") - - # 3. STT - stt_data = [d for d in display_data if d["type"] == "stt"] - if stt_data: - parts.append("\n## 载入的 STT 提供商\n") - for i, d in enumerate(stt_data): - line = f"{i + 1}. {d['info']}{d['mark']}" - stt_using = self.context.get_using_stt_provider(umo=umo) - if stt_using and stt_using.meta().id == d["provider"].meta().id: - line += " (当前使用)" - parts.append(line + "\n") + if not provider_token: + event.set_result(MessageEventResult().message("缺少 provider 参数。")) + return - parts.append("\n使用 /provider <序号> 切换 LLM 提供商。") - ret = "".join(parts) + provider = self._resolve_provider(kind, provider_token) + if not provider: + event.set_result(MessageEventResult().message("provider 序号或 ID 无效。")) + return - if ttss: - ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。" - if stts: - ret += "\n使用 /provider stt <序号> 切换 STT 提供商。" - if not reachability_check_enabled: - ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启。" + target = get_node_target( + self.context, + event, + node_name, + selector=selector, + ) + if not target: + if selector: + event.set_result(MessageEventResult().message("节点选择器无效。")) + else: + event.set_result( + MessageEventResult().message( + f"检测到多个 `{node_name}` 节点,请显式指定节点。" + ) + ) + return - event.set_result(MessageEventResult().message(ret)) - elif idx == "tts": - if idx2 is None: - event.set_result(MessageEventResult().message("请输入序号。")) - return - if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1: - event.set_result(MessageEventResult().message("无效的提供商序号。")) - return - provider = self.context.get_all_tts_providers()[idx2 - 1] - id_ = provider.meta().id - await self.context.provider_manager.set_provider( - provider_id=id_, - provider_type=ProviderType.TEXT_TO_SPEECH, - umo=umo, - ) - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) - elif idx == "stt": - if idx2 is None: - event.set_result(MessageEventResult().message("请输入序号。")) - return - if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1: - event.set_result(MessageEventResult().message("无效的提供商序号。")) - return - provider = self.context.get_all_stt_providers()[idx2 - 1] - id_ = provider.meta().id - await self.context.provider_manager.set_provider( - provider_id=id_, - provider_type=ProviderType.SPEECH_TO_TEXT, - umo=umo, - ) - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) - elif isinstance(idx, int): - if idx > len(self.context.get_all_providers()) or idx < 1: - event.set_result(MessageEventResult().message("无效的提供商序号。")) - return - provider = self.context.get_all_providers()[idx - 1] - id_ = provider.meta().id - await self.context.provider_manager.set_provider( - provider_id=id_, - provider_type=ProviderType.CHAT_COMPLETION, - umo=umo, + target.config.save_config({"provider_id": provider.meta().id}) + event.set_result( + MessageEventResult().message( + f"已将 {kind} provider `{provider.meta().id}` 绑定到 Chain `{chain.chain_id}` 的节点 `{target.node.uuid[:8]}`。" ) - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) - else: - event.set_result(MessageEventResult().message("无效的参数。")) + ) async def model_ls( self, message: AstrMessageEvent, idx_or_name: int | str | None = None, - ) -> None: - """查看或者切换模型""" - prov = self.context.get_using_provider(message.unified_msg_origin) + ): + prov = self.context.get_chat_provider_for_event(message) if not prov: message.set_result( - MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"), + MessageEventResult().message("当前没有可用的 LLM provider。") ) return - # 定义正则表达式匹配 API 密钥 api_key_pattern = re.compile(r"key=[^&'\" ]+") if idx_or_name is None: - models = [] try: models = await prov.get_models() except BaseException as e: err_msg = api_key_pattern.sub("key=***", str(e)) message.set_result( MessageEventResult() - .message("获取模型列表失败: " + err_msg) - .use_t2i(False), + .message("获取模型列表失败:" + err_msg) + .use_t2i(False) ) return - parts = ["下面列出了此模型提供商可用模型:"] + + parts = ["模型列表:"] for i, model in enumerate(models, 1): parts.append(f"\n{i}. {model}") - - curr_model = prov.get_model() or "无" - parts.append(f"\n当前模型: [{curr_model}]") - parts.append( - "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" + parts.append(f"\n当前模型:[{prov.get_model() or '-'}]") + parts.append("\n使用 /model 切换模型。") + message.set_result( + MessageEventResult().message("".join(parts)).use_t2i(False) ) - - ret = "".join(parts) - message.set_result(MessageEventResult().message(ret).use_t2i(False)) elif isinstance(idx_or_name, int): - models = [] try: models = await prov.get_models() except BaseException as e: message.set_result( - MessageEventResult().message("获取模型列表失败: " + str(e)), + MessageEventResult().message("获取模型列表失败:" + str(e)) ) return if idx_or_name > len(models) or idx_or_name < 1: - message.set_result(MessageEventResult().message("模型序号错误。")) - else: - try: - new_model = models[idx_or_name - 1] - prov.set_model(new_model) - except BaseException as e: - message.set_result( - MessageEventResult().message("切换模型未知错误: " + str(e)), - ) - message.set_result( - MessageEventResult().message( - f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", - ), - ) + message.set_result(MessageEventResult().message("模型序号无效。")) + return + new_model = models[idx_or_name - 1] + prov.set_model(new_model) + message.set_result( + MessageEventResult().message(f"已切换到模型 {prov.get_model()}。") + ) else: prov.set_model(idx_or_name) message.set_result( - MessageEventResult().message(f"切换模型到 {prov.get_model()}。"), + MessageEventResult().message(f"已切换到模型 {prov.get_model()}。") ) - async def key(self, message: AstrMessageEvent, index: int | None = None) -> None: - prov = self.context.get_using_provider(message.unified_msg_origin) + async def key(self, message: AstrMessageEvent, index: int | None = None): + prov = self.context.get_chat_provider_for_event(message) if not prov: message.set_result( - MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"), + MessageEventResult().message("当前没有可用的 LLM provider。") ) return if index is None: keys_data = prov.get_keys() curr_key = prov.get_current_key() - parts = ["Key:"] + parts = ["可用密钥:"] for i, k in enumerate(keys_data, 1): parts.append(f"\n{i}. {k[:8]}") + parts.append(f"\n当前密钥:{curr_key[:8]}") + parts.append(f"\n当前模型:{prov.get_model()}") + parts.append("\n使用 /key 切换密钥。") + message.set_result( + MessageEventResult().message("".join(parts)).use_t2i(False) + ) + return - parts.append(f"\n当前 Key: {curr_key[:8]}") - parts.append("\n当前模型: " + prov.get_model()) - parts.append("\n使用 /key 切换 Key。") - - ret = "".join(parts) - message.set_result(MessageEventResult().message(ret).use_t2i(False)) - else: - keys_data = prov.get_keys() - if index > len(keys_data) or index < 1: - message.set_result(MessageEventResult().message("Key 序号错误。")) - else: - try: - new_key = keys_data[index - 1] - prov.set_key(new_key) - except BaseException as e: - message.set_result( - MessageEventResult().message(f"切换 Key 未知错误: {e!s}"), - ) - message.set_result(MessageEventResult().message("切换 Key 成功。")) + keys_data = prov.get_keys() + if index > len(keys_data) or index < 1: + message.set_result(MessageEventResult().message("密钥序号无效。")) + return + new_key = keys_data[index - 1] + prov.set_key(new_key) + message.set_result(MessageEventResult().message("密钥切换成功。")) diff --git a/astrbot/builtin_stars/builtin_commands/commands/stt.py b/astrbot/builtin_stars/builtin_commands/commands/stt.py new file mode 100644 index 000000000..cb452f44b --- /dev/null +++ b/astrbot/builtin_stars/builtin_commands/commands/stt.py @@ -0,0 +1,38 @@ +"""Speech-to-text command.""" + +from astrbot.api import star +from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.core.pipeline.engine.chain_runtime_flags import ( + FEATURE_STT, + toggle_chain_runtime_flag, +) + +from ._node_binding import get_chain_nodes + + +class STTCommand: + """Toggle speech-to-text for the current routed chain.""" + + def __init__(self, context: star.Context): + self.context = context + + async def stt(self, event: AstrMessageEvent): + chain_config = event.chain_config + if not chain_config: + event.set_result(MessageEventResult().message("未找到已路由的 Chain。")) + return + + nodes = get_chain_nodes(event, "stt") + if not nodes: + event.set_result( + MessageEventResult().message("当前 Chain 中没有 STT 节点。") + ) + return + + enabled = await toggle_chain_runtime_flag(chain_config.chain_id, FEATURE_STT) + status = "开启" if enabled else "关闭" + event.set_result( + MessageEventResult().message( + f"Chain `{chain_config.chain_id}` 的 STT 功能已{status}(共 {len(nodes)} 个节点)。" + ) + ) diff --git a/astrbot/builtin_stars/builtin_commands/commands/t2i.py b/astrbot/builtin_stars/builtin_commands/commands/t2i.py index 78d6b0df7..0a4616341 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/t2i.py +++ b/astrbot/builtin_stars/builtin_commands/commands/t2i.py @@ -1,23 +1,38 @@ -"""文本转图片命令""" +"""Text-to-image command.""" from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.core.pipeline.engine.chain_runtime_flags import ( + FEATURE_T2I, + toggle_chain_runtime_flag, +) + +from ._node_binding import get_chain_nodes class T2ICommand: - """文本转图片命令类""" + """Toggle text-to-image output for the current routed chain.""" def __init__(self, context: star.Context) -> None: self.context = context - async def t2i(self, event: AstrMessageEvent) -> None: - """开关文本转图片""" - config = self.context.get_config(umo=event.unified_msg_origin) - if config["t2i"]: - config["t2i"] = False - config.save_config() - event.set_result(MessageEventResult().message("已关闭文本转图片模式。")) + async def t2i(self, event: AstrMessageEvent): + chain_config = event.chain_config + if not chain_config: + event.set_result(MessageEventResult().message("未找到已路由的 Chain。")) + return + + nodes = get_chain_nodes(event, "t2i") + if not nodes: + event.set_result( + MessageEventResult().message("当前 Chain 中没有 T2I 节点。") + ) return - config["t2i"] = True - config.save_config() - event.set_result(MessageEventResult().message("已开启文本转图片模式。")) + + enabled = await toggle_chain_runtime_flag(chain_config.chain_id, FEATURE_T2I) + status = "开启" if enabled else "关闭" + event.set_result( + MessageEventResult().message( + f"Chain `{chain_config.chain_id}` 的 T2I 功能已{status}(共 {len(nodes)} 个节点)。" + ) + ) diff --git a/astrbot/builtin_stars/builtin_commands/commands/tts.py b/astrbot/builtin_stars/builtin_commands/commands/tts.py index 13049ac22..ba6453c6b 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/tts.py +++ b/astrbot/builtin_stars/builtin_commands/commands/tts.py @@ -1,36 +1,38 @@ -"""文本转语音命令""" +"""Text-to-speech command.""" from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult -from astrbot.core.star.session_llm_manager import SessionServiceManager +from astrbot.core.pipeline.engine.chain_runtime_flags import ( + FEATURE_TTS, + toggle_chain_runtime_flag, +) + +from ._node_binding import get_chain_nodes class TTSCommand: - """文本转语音命令类""" + """Toggle text-to-speech for the current routed chain.""" def __init__(self, context: star.Context) -> None: self.context = context - async def tts(self, event: AstrMessageEvent) -> None: - """开关文本转语音(会话级别)""" - umo = event.unified_msg_origin - ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo) - cfg = self.context.get_config(umo=umo) - tts_enable = cfg["provider_tts_settings"]["enable"] - - # 切换状态 - new_status = not ses_tts - await SessionServiceManager.set_tts_status_for_session(umo, new_status) - - status_text = "已开启" if new_status else "已关闭" + async def tts(self, event: AstrMessageEvent): + chain_config = event.chain_config + if not chain_config: + event.set_result(MessageEventResult().message("未找到已路由的 Chain。")) + return - if new_status and not tts_enable: + nodes = get_chain_nodes(event, "tts") + if not nodes: event.set_result( - MessageEventResult().message( - f"{status_text}当前会话的文本转语音。但 TTS 功能在配置中未启用,请前往 WebUI 开启。", - ), + MessageEventResult().message("当前 Chain 中没有 TTS 节点。") ) - else: - event.set_result( - MessageEventResult().message(f"{status_text}当前会话的文本转语音。"), + return + + enabled = await toggle_chain_runtime_flag(chain_config.chain_id, FEATURE_TTS) + status = "开启" if enabled else "关闭" + event.set_result( + MessageEventResult().message( + f"Chain `{chain_config.chain_id}` 的 TTS 功能已{status}(共 {len(nodes)} 个节点)。" ) + ) diff --git a/astrbot/builtin_stars/builtin_commands/main.py b/astrbot/builtin_stars/builtin_commands/main.py index 9b839ca88..8eeefbda2 100644 --- a/astrbot/builtin_stars/builtin_commands/main.py +++ b/astrbot/builtin_stars/builtin_commands/main.py @@ -12,6 +12,7 @@ ProviderCommands, SetUnsetCommands, SIDCommand, + STTCommand, T2ICommand, TTSCommand, ) @@ -30,15 +31,31 @@ def __init__(self, context: star.Context) -> None: self.persona_c = PersonaCommands(self.context) self.alter_cmd_c = AlterCmdCommands(self.context) self.setunset_c = SetUnsetCommands(self.context) + self.sid_c = SIDCommand(self.context) self.t2i_c = T2ICommand(self.context) self.tts_c = TTSCommand(self.context) - self.sid_c = SIDCommand(self.context) + self.stt_c = STTCommand(self.context) @filter.command("help") async def help(self, event: AstrMessageEvent) -> None: """查看帮助""" await self.help_c.help(event) + @filter.command("t2i") + async def t2i(self, event: AstrMessageEvent): + """开关文本转图片""" + await self.t2i_c.t2i(event) + + @filter.command("tts") + async def tts(self, event: AstrMessageEvent): + """开关文本转语音""" + await self.tts_c.tts(event) + + @filter.command("stt") + async def stt(self, event: AstrMessageEvent): + """开关语音转文本""" + await self.stt_c.stt(event) + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("llm") async def llm(self, event: AstrMessageEvent) -> None: @@ -77,15 +94,6 @@ async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> N """获取插件帮助""" await self.plugin_c.plugin_help(event, plugin_name) - @filter.command("t2i") - async def t2i(self, event: AstrMessageEvent) -> None: - """开关文本转图片""" - await self.t2i_c.t2i(event) - - @filter.command("tts") - async def tts(self, event: AstrMessageEvent) -> None: - """开关文本转语音(会话级别)""" - await self.tts_c.tts(event) @filter.command("sid") async def sid(self, event: AstrMessageEvent) -> None: diff --git a/astrbot/builtin_stars/content_safety/_node_config_schema.json b/astrbot/builtin_stars/content_safety/_node_config_schema.json new file mode 100644 index 000000000..cbc588130 --- /dev/null +++ b/astrbot/builtin_stars/content_safety/_node_config_schema.json @@ -0,0 +1,47 @@ +{ + "internal_keywords": { + "type": "object", + "description": "内置关键词检测策略配置", + "items": { + "enable": { + "type": "bool", + "default": true, + "description": "是否启用内置关键词检测" + }, + "extra_keywords": { + "type": "list", + "description": "追加关键词列表", + "items": { + "type": "string" + }, + "default": [] + } + } + }, + "baidu_aip": { + "type": "object", + "description": "百度内容安全策略配置", + "items": { + "enable": { + "type": "bool", + "default": false, + "description": "是否启用百度内容安全检测" + }, + "app_id": { + "type": "string", + "default": "", + "description": "百度内容安全 APP ID" + }, + "api_key": { + "type": "string", + "default": "", + "description": "百度内容安全 API Key" + }, + "secret_key": { + "type": "string", + "default": "", + "description": "百度内容安全 Secret Key" + } + } + } +} diff --git a/astrbot/builtin_stars/content_safety/main.py b/astrbot/builtin_stars/content_safety/main.py new file mode 100644 index 000000000..d535a7430 --- /dev/null +++ b/astrbot/builtin_stars/content_safety/main.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import hashlib +import json +from typing import TYPE_CHECKING + +from astrbot.core import logger +from astrbot.core.message.components import Plain +from astrbot.core.message.message_event_result import ( + MessageChain, + MessageEventResult, + ResultContentType, +) +from astrbot.core.star.node_star import NodeResult, NodeStar + +from .strategies import StrategySelector + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +class ContentSafetyStar(NodeStar): + """Content safety checks for input/output text.""" + + def __init__(self, context, config: dict | None = None): + super().__init__(context, config) + self._strategy_selector: StrategySelector | None = None + self._config_signature: str | None = None + + def _ensure_strategy_selector(self, event: AstrMessageEvent) -> None: + config = event.node_config or {} + signature = hashlib.sha256( + json.dumps(config, sort_keys=True, ensure_ascii=False).encode() + ).hexdigest() + + if signature != self._config_signature: + self._strategy_selector = StrategySelector(config) + self._config_signature = signature + + def _check_content(self, text: str) -> tuple[bool, str]: + if not self._strategy_selector: + return True, "" + return self._strategy_selector.check(text) + + @staticmethod + def _block_event(event: AstrMessageEvent, reason: str) -> NodeResult: + if event.is_at_or_wake_command: + event.set_result( + MessageEventResult().message( + "你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。" + ) + ) + event.stop_event() + logger.info(f"内容安全检查不通过,原因:{reason}") + return NodeResult.STOP + + async def process(self, event: AstrMessageEvent) -> NodeResult: + self._ensure_strategy_selector(event) + + # 输入检查使用当前 message_str(已反映 STT/文件提取结果)。 + # get_node_input 仅包含已执行节点输出,受链路顺序影响。 + text = event.get_message_str() + if text: + ok, info = self._check_content(text) + if not ok: + return self._block_event(event, info) + + upstream_output = await event.get_node_input(strategy="last") + output_text = "" + if isinstance(upstream_output, MessageEventResult): + if ( + upstream_output.result_content_type + == ResultContentType.STREAMING_RESULT + ): + await self.collect_stream(event, upstream_output) + result = upstream_output + else: + result = upstream_output + if result.chain: + output_text = "".join( + comp.text for comp in result.chain if isinstance(comp, Plain) + ) + elif isinstance(upstream_output, MessageChain): + output_text = "".join( + comp.text for comp in upstream_output.chain if isinstance(comp, Plain) + ) + elif isinstance(upstream_output, str): + output_text = upstream_output + elif upstream_output is not None: + output_text = str(upstream_output) + + if output_text: + ok, info = self._check_content(output_text) + if not ok: + return self._block_event(event, info) + + # 向下游传递上游输出 + if upstream_output is not None: + event.set_node_output(upstream_output) + + return NodeResult.CONTINUE diff --git a/astrbot/builtin_stars/content_safety/metadata.yaml b/astrbot/builtin_stars/content_safety/metadata.yaml new file mode 100644 index 000000000..c36f6ef62 --- /dev/null +++ b/astrbot/builtin_stars/content_safety/metadata.yaml @@ -0,0 +1,4 @@ +name: content_safety +desc: Builtin content safety checks for pipeline chains +author: AstrBot +version: 1.0.0 diff --git a/astrbot/builtin_stars/content_safety/requirements.txt b/astrbot/builtin_stars/content_safety/requirements.txt new file mode 100644 index 000000000..257518dcc --- /dev/null +++ b/astrbot/builtin_stars/content_safety/requirements.txt @@ -0,0 +1 @@ +baidu-aip \ No newline at end of file diff --git a/astrbot/builtin_stars/content_safety/strategies.py b/astrbot/builtin_stars/content_safety/strategies.py new file mode 100644 index 000000000..8b2d05253 --- /dev/null +++ b/astrbot/builtin_stars/content_safety/strategies.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import abc +import re + +from astrbot.core import logger + + +class ContentSafetyStrategy(abc.ABC): + @abc.abstractmethod + def check(self, content: str) -> tuple[bool, str]: + raise NotImplementedError + + +class KeywordsStrategy(ContentSafetyStrategy): + def __init__(self, extra_keywords: list) -> None: + self.keywords = [] + if extra_keywords is None: + extra_keywords = [] + self.keywords.extend(extra_keywords) + + def check(self, content: str) -> tuple[bool, str]: + for keyword in self.keywords: + if re.search(keyword, content): + return False, "内容安全检查不通过,匹配到敏感词。" + return True, "" + + +class BaiduAipStrategy(ContentSafetyStrategy): + def __init__(self, appid: str, ak: str, sk: str) -> None: + from aip import AipContentCensor + + self.app_id = appid + self.api_key = ak + self.secret_key = sk + self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key) + + def check(self, content: str) -> tuple[bool, str]: + res = self.client.textCensorUserDefined(content) + if "conclusionType" not in res: + return False, "" + if res["conclusionType"] == 1: + return True, "" + if "data" not in res: + return False, "" + count = len(res["data"]) + parts = [f"百度审核服务发现 {count} 处违规:\n"] + for i in res["data"]: + parts.append(f"{i['msg']};\n") + parts.append("\n判断结果:" + res["conclusion"]) + info = "".join(parts) + return False, info + + +class StrategySelector: + def __init__(self, config: dict) -> None: + self.enabled_strategies: list[ContentSafetyStrategy] = [] + if config["internal_keywords"]["enable"]: + self.enabled_strategies.append( + KeywordsStrategy(config["internal_keywords"]["extra_keywords"]) + ) + if config["baidu_aip"]["enable"]: + try: + self.enabled_strategies.append( + BaiduAipStrategy( + config["baidu_aip"]["app_id"], + config["baidu_aip"]["api_key"], + config["baidu_aip"]["secret_key"], + ) + ) + except ImportError: + logger.warning("使用百度内容审核应该先 pip install baidu-aip") + + def check(self, content: str) -> tuple[bool, str]: + for strategy in self.enabled_strategies: + ok, info = strategy.check(content) + if not ok: + return False, info + return True, "" diff --git a/astrbot/builtin_stars/file_extract/_node_config_schema.json b/astrbot/builtin_stars/file_extract/_node_config_schema.json new file mode 100644 index 000000000..eee0433a9 --- /dev/null +++ b/astrbot/builtin_stars/file_extract/_node_config_schema.json @@ -0,0 +1,15 @@ +{ + "provider": { + "type": "string", + "default": "moonshotai", + "description": "文件提取服务提供方", + "options": ["local", "moonshotai"], + "labels": ["本地解析", "Moonshot AI"] + }, + "moonshotai_api_key": { + "type": "string", + "default": "", + "description": "Moonshot AI 的 API Key", + "hint": "当提供方为 moonshotai 时必填" + } +} diff --git a/astrbot/builtin_stars/file_extract/main.py b/astrbot/builtin_stars/file_extract/main.py new file mode 100644 index 000000000..9b3d5ba58 --- /dev/null +++ b/astrbot/builtin_stars/file_extract/main.py @@ -0,0 +1,137 @@ +"""文件提取节点 - 将消息中的 File 组件转换为文本""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +from astrbot.core import logger +from astrbot.core.message.components import File, Plain, Reply +from astrbot.core.star.node_star import NodeResult, NodeStar + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +class FileExtractNode(NodeStar): + """文件提取节点 + + 用户可手动添加到 Chain 中,在 Agent 节点之前运行。 + 负责: + 1. 提取消息中 File 组件的文本内容 + 2. 将 File 替换为 Plain(提取的文本) + 3. 后续节点无需感知文件提取,只看到纯文本 + + 支持两种提取模式(通过 provider 配置): + - local: 使用本地解析器(无需 API,支持 pdf/docx/xlsx/md/txt) + - moonshotai: 使用 Moonshot AI API + """ + + async def process(self, event: AstrMessageEvent) -> NodeResult: + node_config = event.node_config or {} + provider = node_config.get("provider", "moonshotai") + moonshotai_api_key = node_config.get("moonshotai_api_key", "") + + message = event.message_obj.message + replaced = await self._replace_files(message, provider, moonshotai_api_key) + + # 处理引用消息中的文件 + for comp in message: + if isinstance(comp, Reply) and comp.chain: + replaced += await self._replace_files( + comp.chain, provider, moonshotai_api_key + ) + + if replaced: + # 重建 message_str + event.rebuild_message_str_from_plain() + logger.debug(f"File extraction: replaced {replaced} File component(s)") + + # Write output to ctx for downstream nodes + event.set_node_output(event.message_str) + + return NodeResult.CONTINUE + + return NodeResult.SKIP + + async def _replace_files( + self, components: list, provider: str, moonshotai_api_key: str + ) -> int: + """遍历组件列表,将 File 替换为 Plain,返回替换数量""" + replaced = 0 + for idx, comp in enumerate(components): + if not isinstance(comp, File): + continue + try: + file_path = await comp.get_file() + file_name = comp.name or os.path.basename(file_path) + content = await self._extract_content( + file_path, provider, moonshotai_api_key + ) + if content: + components[idx] = Plain(f"[File: {file_name}]\n{content}\n[/File]") + replaced += 1 + except Exception as e: + logger.warning(f"Failed to extract file {comp.name}: {e}") + return replaced + + async def _extract_content( + self, file_path: str, provider: str, moonshotai_api_key: str + ) -> str | None: + """提取单个文件的文本内容""" + if provider == "local": + return await self._extract_local(file_path) + elif provider == "moonshotai": + return await self._extract_moonshotai(file_path, moonshotai_api_key) + else: + logger.error(f"Unsupported file extract provider: {provider}") + return None + + async def _extract_local(self, file_path: str) -> str | None: + """使用本地解析器提取文件内容""" + ext = os.path.splitext(file_path)[1].lower() + + try: + parser = await self._select_parser(ext) + except ValueError as e: + logger.warning(f"Local parser not available for {ext}: {e}") + return None + + try: + with open(file_path, "rb") as f: + file_content = f.read() + result = await parser.parse(file_content, os.path.basename(file_path)) + return result.text + except Exception as e: + logger.warning(f"Local parsing failed for {file_path}: {e}") + return None + + @staticmethod + async def _select_parser(ext: str): + """根据文件扩展名选择解析器""" + if ext == ".pdf": + from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser + + return PDFParser() + else: + from astrbot.core.knowledge_base.parsers.markitdown_parser import ( + MarkitdownParser, + ) + + return MarkitdownParser() + + @staticmethod + async def _extract_moonshotai( + file_path: str, moonshotai_api_key: str + ) -> str | None: + """使用 Moonshot AI API 提取文件内容""" + from astrbot.core.utils.file_extract import extract_file_moonshotai + + if not moonshotai_api_key: + logger.error("Moonshot AI API key for file extract is not set") + return None + try: + return await extract_file_moonshotai(file_path, moonshotai_api_key) + except Exception as e: + logger.warning(f"Moonshot AI extraction failed: {e}") + return None diff --git a/astrbot/builtin_stars/file_extract/metadata.yaml b/astrbot/builtin_stars/file_extract/metadata.yaml new file mode 100644 index 000000000..2a951d544 --- /dev/null +++ b/astrbot/builtin_stars/file_extract/metadata.yaml @@ -0,0 +1,4 @@ +name: file_extract +desc: Builtin file extraction node for pipeline chains +author: AstrBot +version: 1.0.0 diff --git a/astrbot/builtin_stars/knowledge_base/_node_config_schema.json b/astrbot/builtin_stars/knowledge_base/_node_config_schema.json new file mode 100644 index 000000000..8ce22f99f --- /dev/null +++ b/astrbot/builtin_stars/knowledge_base/_node_config_schema.json @@ -0,0 +1,29 @@ +{ + "use_global_kb": { + "type": "bool", + "description": "当 kb_names 为空时,是否使用全局知识库配置", + "hint": "若关闭且 kb_names 为空,则此节点会跳过知识库检索", + "default": true + }, + "kb_names": { + "type": "list", + "description": "知识库名称列表", + "_special": "select_knowledgebase", + "items": { + "type": "string" + }, + "hint": "填写后将覆盖全局知识库;留空则在 use_global_kb 为 true 时使用全局配置" + }, + "top_k": { + "type": "int", + "description": "最终返回的知识块数量(Top K)", + "hint": "仅在配置了 kb_names 时生效。", + "default": 5 + }, + "fusion_top_k": { + "type": "int", + "description": "融合阶段召回数量(Fusion Top K)", + "hint": "设置后将覆盖全局 kb_fusion_top_k。", + "default": 20 + } +} diff --git a/astrbot/builtin_stars/knowledge_base/main.py b/astrbot/builtin_stars/knowledge_base/main.py new file mode 100644 index 000000000..7f8a965ea --- /dev/null +++ b/astrbot/builtin_stars/knowledge_base/main.py @@ -0,0 +1,174 @@ +"""知识库检索节点 - 在Agent调用之前检索相关知识并注入上下文""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from astrbot.core import logger +from astrbot.core.star.node_star import NodeResult, NodeStar + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +class KnowledgeBaseNode(NodeStar): + """知识库检索节点 + + 用户可手动添加到 Chain 中,在 Agent 节点之前运行。 + 负责: + 1. 根据用户消息检索知识库 + 2. 将检索结果注入到 ProviderRequest 的 system_prompt 中 + 3. 后续 Agent 节点可以使用增强后的上下文 + + 注意:此节点实现的是非Agentic模式的知识库检索。 + 如需Agentic模式(LLM主动调用知识库工具),请在provider_settings中启用kb_agentic_mode。 + """ + + async def process(self, event: AstrMessageEvent) -> NodeResult: + # 检查是否有消息内容需要检索 + merged_query = await event.get_node_input(strategy="text_concat") + if isinstance(merged_query, str) and merged_query.strip(): + query = merged_query + elif merged_query is not None: + query = str(merged_query) + else: + # 无上游输出时回退使用原始消息。 + query = event.message_str + if not query or not query.strip(): + return NodeResult.SKIP + + chain_config_id = event.chain_config.config_id if event.chain_config else None + + try: + kb_result = await self._retrieve_knowledge_base( + query, + event.node_config, + event.chain_config.chain_id if event.chain_config else "unknown", + chain_config_id, + ) + if kb_result: + event.set_node_output(kb_result) + logger.debug("[知识库节点] 检索到知识库上下文") + except Exception as e: + logger.error(f"[知识库节点] 检索知识库时发生错误: {e}") + + return NodeResult.CONTINUE + + async def _retrieve_knowledge_base( + self, + query: str, + node_config, + chain_id: str, + config_id: str | None, + ) -> str | None: + """检索知识库 + + Args: + query: 查询文本 + node_config: Node config for this node + + Returns: + 检索到的知识库内容,如果没有则返回 None + """ + kb_mgr = self.context.kb_manager + config = self.context.get_config_by_id(config_id) + node_config = node_config or {} + use_global_kb = node_config.get("use_global_kb", True) + kb_names = node_config.get("kb_names", []) or [] + + if kb_names: + top_k = node_config.get("top_k", 5) + logger.debug( + f"[知识库节点] 使用节点配置,知识库数量: {len(kb_names)}", + ) + + valid_kb_names = [] + invalid_kb_names = [] + for kb_name in kb_names: + kb_helper = await kb_mgr.get_kb_by_name(kb_name) + if kb_helper: + valid_kb_names.append(kb_helper.kb.kb_name) + else: + logger.warning(f"[知识库节点] 知识库不存在或未加载: {kb_name}") + invalid_kb_names.append(kb_name) + + if invalid_kb_names: + logger.warning( + f"[知识库节点] 配置的以下知识库名称无效: {invalid_kb_names}", + ) + + kb_names = valid_kb_names + elif use_global_kb: + kb_names = config.get("kb_names", []) + top_k = config.get("kb_final_top_k", 5) + logger.debug( + f"[知识库节点] 使用全局配置,知识库数量: {len(kb_names)}", + ) + if not kb_names: + return None + return await self._do_retrieve( + kb_mgr, + query, + kb_names, + top_k, + None, + config, + ) + else: + logger.info(f"[知识库节点] 节点已禁用知识库: {chain_id}") + return None + + if not kb_names: + return None + + fusion_top_k = node_config.get("fusion_top_k") + if fusion_top_k is not None: + try: + fusion_top_k = int(fusion_top_k) + except (TypeError, ValueError): + fusion_top_k = None + + return await self._do_retrieve( + kb_mgr, + query, + kb_names, + top_k, + fusion_top_k, + config, + ) + + @staticmethod + async def _do_retrieve( + kb_mgr, + query: str, + kb_names: list[str], + top_k: int, + fusion_top_k: int | None, + config: dict, + ) -> str | None: + """执行知识库检索""" + if fusion_top_k is None: + top_k_fusion = config.get("kb_fusion_top_k", 20) + else: + top_k_fusion = fusion_top_k + + logger.debug( + f"[知识库节点] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}" + ) + kb_context = await kb_mgr.retrieve( + query=query, + kb_names=kb_names, + top_k_fusion=top_k_fusion, + top_m_final=top_k, + ) + + if not kb_context: + return None + + formatted = kb_context.get("context_text", "") + if formatted: + results = kb_context.get("results", []) + logger.debug(f"[知识库节点] 检索到 {len(results)} 条相关知识块") + return formatted + + return None diff --git a/astrbot/builtin_stars/knowledge_base/metadata.yaml b/astrbot/builtin_stars/knowledge_base/metadata.yaml new file mode 100644 index 000000000..ff4e1bc4f --- /dev/null +++ b/astrbot/builtin_stars/knowledge_base/metadata.yaml @@ -0,0 +1,4 @@ +name: knowledge_base +desc: Builtin knowledge base retrieval node for pipeline chains +author: AstrBot +version: 1.0.0 diff --git a/astrbot/builtin_stars/session_controller/main.py b/astrbot/builtin_stars/session_controller/main.py index 70081e03a..20988f74a 100644 --- a/astrbot/builtin_stars/session_controller/main.py +++ b/astrbot/builtin_stars/session_controller/main.py @@ -34,7 +34,9 @@ async def handle_empty_mention(self, event: AstrMessageEvent): """实现了对只有一个 @ 的消息内容的处理""" try: messages = event.get_messages() - cfg = self.context.get_config(umo=event.unified_msg_origin) + cfg = self.context.get_config_by_id( + event.chain_config.config_id if event.chain_config else None, + ) p_settings = cfg["platform_settings"] wake_prefix = cfg.get("wake_prefix", []) if len(messages) == 1: diff --git a/astrbot/builtin_stars/stt/_node_config_schema.json b/astrbot/builtin_stars/stt/_node_config_schema.json new file mode 100644 index 000000000..10111cb38 --- /dev/null +++ b/astrbot/builtin_stars/stt/_node_config_schema.json @@ -0,0 +1,8 @@ +{ + "provider_id": { + "type": "string", + "default": "", + "description": "覆盖此节点的语音转文字(STT)提供商 ID", + "_special": "select_provider_stt" + } +} diff --git a/astrbot/builtin_stars/stt/main.py b/astrbot/builtin_stars/stt/main.py new file mode 100644 index 000000000..8847ccd28 --- /dev/null +++ b/astrbot/builtin_stars/stt/main.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +from astrbot.core import logger +from astrbot.core.message.components import Plain, Record +from astrbot.core.pipeline.engine.chain_runtime_flags import ( + FEATURE_STT, + is_chain_runtime_feature_enabled, +) +from astrbot.core.star.node_star import NodeResult, NodeStar + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +class STTStar(NodeStar): + """Speech-to-text.""" + + async def process(self, event: AstrMessageEvent) -> NodeResult: + chain_id = event.chain_config.chain_id if event.chain_config else None + if not await is_chain_runtime_feature_enabled(chain_id, FEATURE_STT): + return NodeResult.SKIP + + stt_provider = self.context.get_stt_provider_for_event(event) + if not stt_provider: + logger.warning( + f"Session {event.unified_msg_origin} has no STT provider configured." + ) + return NodeResult.SKIP + + message_chain = event.get_messages() + transcribed_texts = [] + + for idx, component in enumerate(message_chain): + if isinstance(component, Record) and component.url: + path = component.url.removeprefix("file://") + retry = 5 + for i in range(retry): + try: + result = await stt_provider.get_text(audio_url=path) + if result: + logger.info("STT result: " + result) + message_chain[idx] = Plain(result) + event.append_message_str(result) + transcribed_texts.append(result) + break + except FileNotFoundError as e: + logger.warning(f"STT retry {i + 1}/{retry}: {e}") + await asyncio.sleep(0.5) + continue + except Exception as e: + logger.error(f"STT failed: {e}") + break + + if transcribed_texts: + event.set_node_output("\n".join(transcribed_texts)) + + return NodeResult.CONTINUE diff --git a/astrbot/builtin_stars/stt/metadata.yaml b/astrbot/builtin_stars/stt/metadata.yaml new file mode 100644 index 000000000..c63434eb7 --- /dev/null +++ b/astrbot/builtin_stars/stt/metadata.yaml @@ -0,0 +1,4 @@ +name: stt +desc: Builtin speech-to-text pipeline node +author: AstrBot +version: 1.0.0 diff --git a/astrbot/builtin_stars/t2i/_node_config_schema.json b/astrbot/builtin_stars/t2i/_node_config_schema.json new file mode 100644 index 000000000..70153ede4 --- /dev/null +++ b/astrbot/builtin_stars/t2i/_node_config_schema.json @@ -0,0 +1,28 @@ +{ + "word_threshold": { + "type": "int", + "default": 150, + "description": "触发文转图的文本长度阈值", + "hint": "纯文本长度超过该值时才会触发文转图" + }, + "strategy": { + "type": "string", + "description": "渲染策略", + "options": ["remote", "local"], + "labels": ["远程渲染", "本地渲染"], + "default": "remote", + "hint": "remote 使用 t2i 端点;local 由本地渲染器处理" + }, + "active_template": { + "type": "string", + "description": "渲染模板名称", + "default": "", + "hint": "留空则使用全局激活模板" + }, + "use_file_service": { + "type": "bool", + "default": false, + "description": "是否使用文件服务分发图片", + "hint": "开启后,生成图片将通过文件服务对外提供访问链接" + } +} diff --git a/astrbot/builtin_stars/t2i/main.py b/astrbot/builtin_stars/t2i/main.py new file mode 100644 index 000000000..ad33d8659 --- /dev/null +++ b/astrbot/builtin_stars/t2i/main.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import time +import traceback +from typing import TYPE_CHECKING + +from astrbot.core import file_token_service, html_renderer, logger +from astrbot.core.message.components import Image, Plain +from astrbot.core.message.message_event_result import MessageEventResult +from astrbot.core.pipeline.engine.chain_runtime_flags import ( + FEATURE_T2I, + is_chain_runtime_feature_enabled, +) +from astrbot.core.star.node_star import NodeResult, NodeStar + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +class T2IStar(NodeStar): + """Text-to-image.""" + + async def process(self, event: AstrMessageEvent) -> NodeResult: + chain_id = event.chain_config.chain_id if event.chain_config else None + if not await is_chain_runtime_feature_enabled(chain_id, FEATURE_T2I): + return NodeResult.SKIP + + chain_config_id = event.chain_config.config_id if event.chain_config else None + runtime_cfg = self.context.get_config_by_id(chain_config_id) + t2i_active_template = runtime_cfg.get("t2i_active_template", "base") + callback_api_base = runtime_cfg.get("callback_api_base", "") + + node_config = event.node_config or {} + word_threshold = node_config.get("word_threshold", 150) + strategy = node_config.get("strategy", "remote") + active_template = node_config.get("active_template", "") + use_file_service = node_config.get("use_file_service", False) + + upstream_output = await event.get_node_input(strategy="last") + if not isinstance(upstream_output, MessageEventResult): + logger.warning( + "T2I upstream output is not MessageEventResult. type=%s", + type(upstream_output).__name__, + ) + return NodeResult.SKIP + result = upstream_output + await self.collect_stream(event, result) + + if not result.chain: + return NodeResult.SKIP + + if result.use_t2i_ is False: + return NodeResult.SKIP + + parts = [] + for comp in result.chain: + if isinstance(comp, Plain): + parts.append("\n\n" + comp.text) + else: + break + plain_str = "".join(parts) + + if not plain_str: + return NodeResult.SKIP + + if result.use_t2i_ is None: + try: + threshold = max(int(word_threshold), 50) + except Exception: + threshold = 150 + + if len(plain_str) <= threshold: + return NodeResult.SKIP + + render_start = time.time() + try: + if not active_template: + active_template = t2i_active_template + url = await html_renderer.render_t2i( + plain_str, + return_url=True, + use_network=strategy == "remote", + template_name=active_template, + ) + except Exception: + logger.error(traceback.format_exc()) + logger.error("T2I render failed, fallback to text output.") + return NodeResult.SKIP + + if time.time() - render_start > 3: + logger.warning("T2I render took longer than 3s.") + + if url: + if url.startswith("http"): + result.chain = [Image.fromURL(url)] + elif use_file_service and callback_api_base: + token = await file_token_service.register_file(url) + url = f"{callback_api_base}/api/file/{token}" + logger.debug(f"Registered file service url: {url}") + result.chain = [Image.fromURL(url)] + else: + result.chain = [Image.fromFileSystem(url)] + + event.set_node_output(result) + return NodeResult.CONTINUE + + return NodeResult.SKIP diff --git a/astrbot/builtin_stars/t2i/metadata.yaml b/astrbot/builtin_stars/t2i/metadata.yaml new file mode 100644 index 000000000..27f540e7f --- /dev/null +++ b/astrbot/builtin_stars/t2i/metadata.yaml @@ -0,0 +1,4 @@ +name: t2i +desc: Builtin text-to-image pipeline node +author: AstrBot +version: 1.0.0 diff --git a/astrbot/builtin_stars/tts/_node_config_schema.json b/astrbot/builtin_stars/tts/_node_config_schema.json new file mode 100644 index 000000000..d9cb887e5 --- /dev/null +++ b/astrbot/builtin_stars/tts/_node_config_schema.json @@ -0,0 +1,26 @@ +{ + "provider_id": { + "type": "string", + "default": "", + "description": "覆盖此节点的文字转语音(TTS)提供商 ID", + "_special": "select_provider_tts" + }, + "trigger_probability": { + "type": "float", + "default": 1.0, + "description": "触发概率。", + "hint": "文本转换为语音的概率,范围为 0.0 到 1.0" + }, + "use_file_service": { + "type": "bool", + "default": false, + "description": "是否使用文件服务分发语音。", + "hint": "开启后,生成音频将通过文件服务对外提供访问链接" + }, + "dual_output": { + "type": "bool", + "default": false, + "description": "是否双输出。", + "hint": "开启后将同时输出语音和文本" + } +} diff --git a/astrbot/builtin_stars/tts/main.py b/astrbot/builtin_stars/tts/main.py new file mode 100644 index 000000000..48ce134a8 --- /dev/null +++ b/astrbot/builtin_stars/tts/main.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import random +import traceback +from typing import TYPE_CHECKING + +from astrbot.core import file_token_service, logger +from astrbot.core.message.components import Plain, Record +from astrbot.core.message.message_event_result import MessageEventResult +from astrbot.core.pipeline.engine.chain_runtime_flags import ( + FEATURE_TTS, + is_chain_runtime_feature_enabled, +) +from astrbot.core.star.node_star import NodeResult, NodeStar + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +class TTSStar(NodeStar): + """Text-to-speech.""" + + def __init__(self, context, config: dict | None = None): + super().__init__(context, config) + + async def process(self, event: AstrMessageEvent) -> NodeResult: + chain_id = event.chain_config.chain_id if event.chain_config else None + if not await is_chain_runtime_feature_enabled(chain_id, FEATURE_TTS): + return NodeResult.SKIP + + node_config = event.node_config or {} + chain_config_id = event.chain_config.config_id if event.chain_config else None + runtime_cfg = self.context.get_config_by_id(chain_config_id) + callback_api_base = runtime_cfg.get("callback_api_base", "") + + use_file_service = node_config.get("use_file_service", False) + dual_output = node_config.get("dual_output", False) + trigger_probability = node_config.get("trigger_probability", 1.0) + try: + trigger_probability = max(0.0, min(float(trigger_probability), 1.0)) + except (TypeError, ValueError): + trigger_probability = 1.0 + + upstream_output = await event.get_node_input(strategy="last") + if not isinstance(upstream_output, MessageEventResult): + logger.warning( + "TTS upstream output is not MessageEventResult. type=%s", + type(upstream_output).__name__, + ) + return NodeResult.SKIP + result = upstream_output + await self.collect_stream(event, result) + + if not result.chain: + return NodeResult.SKIP + + if not result.is_llm_result(): + return NodeResult.SKIP + + if random.random() > trigger_probability: + return NodeResult.SKIP + + tts_provider = self.context.get_tts_provider_for_event(event) + if not tts_provider: + logger.warning( + f"Session {event.unified_msg_origin} has no TTS provider configured." + ) + return NodeResult.SKIP + + new_chain = [] + + for comp in result.chain: + if isinstance(comp, Plain) and len(comp.text) > 1: + try: + logger.info(f"TTS request: {comp.text}") + audio_path = await tts_provider.get_audio(comp.text) + logger.info(f"TTS result: {audio_path}") + + if not audio_path: + logger.error(f"TTS audio not found: {comp.text}") + new_chain.append(comp) + continue + + url = None + if use_file_service and callback_api_base: + token = await file_token_service.register_file(audio_path) + url = f"{callback_api_base}/api/file/{token}" + logger.debug(f"Registered file service url: {url}") + + new_chain.append( + Record( + file=url or audio_path, + url=url or audio_path, + ) + ) + + if dual_output: + new_chain.append(comp) + + except Exception: + logger.error(traceback.format_exc()) + logger.error("TTS failed, fallback to text output.") + new_chain.append(comp) + else: + new_chain.append(comp) + + result.chain = new_chain + event.set_node_output(result) + + return NodeResult.CONTINUE diff --git a/astrbot/builtin_stars/tts/metadata.yaml b/astrbot/builtin_stars/tts/metadata.yaml new file mode 100644 index 000000000..4bfdf3609 --- /dev/null +++ b/astrbot/builtin_stars/tts/metadata.yaml @@ -0,0 +1,4 @@ +name: tts +desc: Builtin text-to-speech pipeline node +author: AstrBot +version: 1.0.0 diff --git a/astrbot/builtin_stars/web_searcher/main.py b/astrbot/builtin_stars/web_searcher/main.py index 85eeffd94..d75802b65 100644 --- a/astrbot/builtin_stars/web_searcher/main.py +++ b/astrbot/builtin_stars/web_searcher/main.py @@ -222,7 +222,9 @@ async def search_from_search_engine( """ logger.info(f"web_searcher - search_from_search_engine: {query}") - cfg = self.context.get_config(umo=event.unified_msg_origin) + cfg = self.context.get_config_by_id( + event.chain_config.config_id if event.chain_config else None, + ) websearch_link = cfg["provider_settings"].get("web_search_link", False) results = await self._web_search_default(query, max_results) @@ -246,10 +248,10 @@ async def search_from_search_engine( return ret - async def ensure_baidu_ai_search_mcp(self, umo: str | None = None) -> None: + async def ensure_baidu_ai_search_mcp(self, config_id: str | None = None): if self.baidu_initialized: return - cfg = self.context.get_config(umo=umo) + cfg = self.context.get_config_by_id(config_id) key = cfg.get("provider_settings", {}).get( "websearch_baidu_app_builder_key", "", @@ -310,7 +312,9 @@ async def search_from_tavily( """ logger.info(f"web_searcher - search_from_tavily: {query}") - cfg = self.context.get_config(umo=event.unified_msg_origin) + cfg = self.context.get_config_by_id( + event.chain_config.config_id if event.chain_config else None, + ) # websearch_link = cfg["provider_settings"].get("web_search_link", False) if not cfg.get("provider_settings", {}).get("websearch_tavily_key", []): raise ValueError("Error: Tavily API key is not configured in AstrBot.") @@ -372,7 +376,9 @@ async def tavily_extract_web_page( extract_depth(string): Optional. The depth of the extraction, must be one of 'basic', 'advanced'. Default is "basic". """ - cfg = self.context.get_config(umo=event.unified_msg_origin) + cfg = self.context.get_config_by_id( + event.chain_config.config_id if event.chain_config else None, + ) if not cfg.get("provider_settings", {}).get("websearch_tavily_key", []): raise ValueError("Error: Tavily API key is not configured in AstrBot.") @@ -500,7 +506,9 @@ async def search_from_bocha( specified count. """ logger.info(f"web_searcher - search_from_bocha: {query}") - cfg = self.context.get_config(umo=event.unified_msg_origin) + cfg = self.context.get_config_by_id( + event.chain_config.config_id if event.chain_config else None, + ) # websearch_link = cfg["provider_settings"].get("web_search_link", False) if not cfg.get("provider_settings", {}).get("websearch_bocha_key", []): raise ValueError("Error: BoCha API key is not configured in AstrBot.") @@ -555,7 +563,8 @@ async def edit_web_search_tools( req: ProviderRequest, ) -> None: """Get the session conversation for the given event.""" - cfg = self.context.get_config(umo=event.unified_msg_origin) + chain_config_id = event.chain_config.config_id if event.chain_config else None + cfg = self.context.get_config_by_id(chain_config_id) prov_settings = cfg.get("provider_settings", {}) websearch_enable = prov_settings.get("web_search", False) provider = prov_settings.get("websearch_provider", "default") @@ -599,7 +608,7 @@ async def edit_web_search_tools( tool_set.remove_tool("web_search_bocha") elif provider == "baidu_ai_search": try: - await self.ensure_baidu_ai_search_mcp(event.unified_msg_origin) + await self.ensure_baidu_ai_search_mcp(chain_config_id) aisearch_tool = func_tool_mgr.get_func("AIsearch") if not aisearch_tool: raise ValueError("Cannot get Baidu AI Search MCP tool.") diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 230faaf1c..f1a61736f 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -115,7 +115,7 @@ async def _execute_handoff( # to the current/default provider resolution. prov_id = getattr( tool, "provider_id", None - ) or await ctx.get_current_chat_provider_id(umo) + ) or await ctx.get_current_chat_provider_id(umo, event=event) # prepare begin dialogs contexts = None diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 4e70f3d59..d7152d9fd 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -10,7 +10,6 @@ from collections.abc import Coroutine from dataclasses import dataclass, field -from astrbot.api import sp from astrbot.core import logger from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.mcp_client import MCPTool @@ -35,10 +34,10 @@ SEND_MESSAGE_TO_USER_TOOL, TOOL_CALL_PROMPT, TOOL_CALL_PROMPT_SKILLS_LIKE_MODE, - retrieve_knowledge_base, ) from astrbot.core.conversation_mgr import Conversation from astrbot.core.message.components import File, Image, Reply +from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider import Provider from astrbot.core.provider.entities import ProviderRequest @@ -50,7 +49,6 @@ DELETE_CRON_JOB_TOOL, LIST_CRON_JOBS_TOOL, ) -from astrbot.core.utils.file_extract import extract_file_moonshotai from astrbot.core.utils.llm_metadata import LLM_METADATAS @@ -77,12 +75,6 @@ class MainAgentBuildConfig: kb_agentic_mode: bool = False """Whether to use agentic mode for knowledge base retrieval. This will inject the knowledge base query tool into the main agent's toolset to allow dynamic querying.""" - file_extract_enabled: bool = False - """Whether to enable file content extraction for uploaded files.""" - file_extract_prov: str = "moonshotai" - """The file extraction provider.""" - file_extract_msh_api_key: str = "" - """The API key for Moonshot AI file extraction provider.""" context_limit_reached_strategy: str = "truncate_by_turns" """The strategy to handle context length limit reached.""" llm_compress_instruction: str = "" @@ -122,22 +114,7 @@ def _select_provider( event: AstrMessageEvent, plugin_context: Context ) -> Provider | None: """Select chat provider for the event.""" - sel_provider = event.get_extra("selected_provider") - if sel_provider and isinstance(sel_provider, str): - provider = plugin_context.get_provider_by_id(sel_provider) - if not provider: - logger.error("未找到指定的提供商: %s。", sel_provider) - if not isinstance(provider, Provider): - logger.error( - "选择的提供商类型无效(%s),跳过 LLM 请求处理。", type(provider) - ) - return None - return provider - try: - return plugin_context.get_using_provider(umo=event.unified_msg_origin) - except ValueError as exc: - logger.error("Error occurred while selecting provider: %s", exc) - return None + return plugin_context.get_chat_provider_for_event(event) async def _get_session_conv( @@ -157,82 +134,17 @@ async def _get_session_conv( return conversation -async def _apply_kb( +def _apply_kb( event: AstrMessageEvent, req: ProviderRequest, - plugin_context: Context, config: MainAgentBuildConfig, ) -> None: - if not config.kb_agentic_mode: - if req.prompt is None: - return - try: - kb_result = await retrieve_knowledge_base( - query=req.prompt, - umo=event.unified_msg_origin, - context=plugin_context, - ) - if not kb_result: - return - if req.system_prompt is not None: - req.system_prompt += ( - f"\n\n[Related Knowledge Base Results]:\n{kb_result}" - ) - except Exception as exc: # noqa: BLE001 - logger.error("Error occurred while retrieving knowledge base: %s", exc) - else: + if config.kb_agentic_mode: + # Agentic mode: add knowledge base query tool if req.func_tool is None: req.func_tool = ToolSet() req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL) - - -async def _apply_file_extract( - event: AstrMessageEvent, - req: ProviderRequest, - config: MainAgentBuildConfig, -) -> None: - file_paths = [] - file_names = [] - for comp in event.message_obj.message: - if isinstance(comp, File): - file_paths.append(await comp.get_file()) - file_names.append(comp.name) - elif isinstance(comp, Reply) and comp.chain: - for reply_comp in comp.chain: - if isinstance(reply_comp, File): - file_paths.append(await reply_comp.get_file()) - file_names.append(reply_comp.name) - if not file_paths: - return - if not req.prompt: - req.prompt = "总结一下文件里面讲了什么?" - if config.file_extract_prov == "moonshotai": - if not config.file_extract_msh_api_key: - logger.error("Moonshot AI API key for file extract is not set") - return - file_contents = await asyncio.gather( - *[ - extract_file_moonshotai( - file_path, - config.file_extract_msh_api_key, - ) - for file_path in file_paths - ] - ) - else: - logger.error("Unsupported file extract provider: %s", config.file_extract_prov) - return - - for file_content, file_name in zip(file_contents, file_names): - req.contexts.append( - { - "role": "system", - "content": ( - "File Extract Results of user uploaded files:\n" - f"{file_content}\nFile Name: {file_name or 'Unknown'}" - ), - }, - ) + # Non-agentic mode: KB context is injected via _inject_pipeline_context() def _apply_prompt_prefix(req: ProviderRequest, cfg: dict) -> None: @@ -257,6 +169,7 @@ async def _ensure_persona_and_skills( cfg: dict, plugin_context: Context, event: AstrMessageEvent, + subagent_orchestrator_cfg: dict | None = None, ) -> None: """Ensure persona and skills are applied to the request's system prompt or user prompt.""" if not req.conversation: @@ -264,15 +177,11 @@ async def _ensure_persona_and_skills( # get persona ID - # 1. from session service config - highest priority - persona_id = ( - await sp.get_async( - scope="umo", - scope_id=event.unified_msg_origin, - key="session_service_config", - default={}, - ) - ).get("persona_id") + # 1. from node config - highest priority + node_config = event.node_config or {} + persona_id = "" + if isinstance(node_config, dict): + persona_id = str(node_config.get("persona_id") or "").strip() if not persona_id: # 2. from conversation setting - second priority @@ -345,7 +254,7 @@ async def _ensure_persona_and_skills( req.func_tool.merge(persona_toolset) # sub agents integration - orch_cfg = plugin_context.get_config().get("subagent_orchestrator", {}) + orch_cfg = subagent_orchestrator_cfg or {} so = plugin_context.subagent_orchestrator if orch_cfg.get("main_enable", False) and so: remove_dup = bool(orch_cfg.get("remove_main_duplicate_tools", False)) @@ -403,11 +312,7 @@ async def _ensure_persona_and_skills( continue req.func_tool.remove_tool(tool_name) - router_prompt = ( - plugin_context.get_config() - .get("subagent_orchestrator", {}) - .get("router_system_prompt", "") - ).strip() + router_prompt = str(orch_cfg.get("router_system_prompt", "") or "").strip() if router_prompt: req.system_prompt += f"\n{router_prompt}\n" try: @@ -473,6 +378,7 @@ async def _ensure_img_caption( async def _process_quote_message( event: AstrMessageEvent, req: ProviderRequest, + cfg: dict, img_cap_prov_id: str, plugin_context: Context, ) -> None: @@ -498,23 +404,34 @@ async def _process_quote_message( if image_seg: try: - prov = None + image_path = await image_seg.convert_to_file_path() + caption_text = "" + if img_cap_prov_id: - prov = plugin_context.get_provider_by_id(img_cap_prov_id) - if prov is None: - prov = plugin_context.get_using_provider(event.unified_msg_origin) - - if prov and isinstance(prov, Provider): - llm_resp = await prov.text_chat( - prompt="Please describe the image content.", - image_urls=[await image_seg.convert_to_file_path()], + caption_text = await _request_img_caption( + img_cap_prov_id, + cfg, + [image_path], + plugin_context, ) - if llm_resp.completion_text: - content_parts.append( - f"[Image Caption in quoted message]: {llm_resp.completion_text}" - ) else: - logger.warning("No provider found for image captioning in quote.") + prov = plugin_context.get_chat_provider_for_event(event) + if prov and isinstance(prov, Provider): + llm_resp = await prov.text_chat( + prompt=cfg.get( + "image_caption_prompt", + "Please describe the image.", + ), + image_urls=[image_path], + ) + caption_text = llm_resp.completion_text + else: + logger.warning("No provider found for image captioning in quote.") + + if caption_text: + content_parts.append( + f"[Image Caption in quoted message]: {caption_text}" + ) except BaseException as exc: logger.error("处理引用图片失败: %s", exc) @@ -567,41 +484,107 @@ def _append_system_reminders( req.extra_user_content_parts.append(TextPart(text=system_content)) +def _inject_pipeline_context(event: AstrMessageEvent, req: ProviderRequest) -> None: + """Inject upstream node output into LLM request. + + When Agent nodes are chained (e.g., [Agent A] -> [Agent B]), this ensures + Agent B receives Agent A's output as additional context. + """ + ctx = event.node_context + if ctx is None or ctx.input is None: + return + + input_data = ctx.input.data + + if isinstance(input_data, MessageChain): + # It's message content - extract text content + from astrbot.core.message.components import Plain + + parts = [] + for comp in input_data.chain or []: + if isinstance(comp, Plain): + parts.append(comp.text) + if parts: + upstream_text = "\n".join(parts) + else: + return # No text content to inject + elif isinstance(input_data, str): + upstream_text = input_data + else: + # Try to convert to string + upstream_text = str(input_data) + + if not upstream_text or not upstream_text.strip(): + return + + # Inject as pipeline context + pipeline_context = ( + f"\n" + f"The following is output from a previous node in the processing pipeline:\n" + f"{upstream_text}\n" + f"" + ) + req.extra_user_content_parts.append(TextPart(text=pipeline_context)) + + async def _decorate_llm_request( event: AstrMessageEvent, req: ProviderRequest, plugin_context: Context, config: MainAgentBuildConfig, ) -> None: - cfg = config.provider_settings or plugin_context.get_config( - umo=event.unified_msg_origin - ).get("provider_settings", {}) + chain_config_id = event.chain_config.config_id if event.chain_config else None + runtime_config = plugin_context.get_config_by_id(chain_config_id) + cfg = config.provider_settings or runtime_config.get("provider_settings", {}) + node_config = event.node_config if isinstance(event.node_config, dict) else {} + + img_caption_cfg = dict(cfg) + node_image_caption_prompt = str( + node_config.get("image_caption_prompt") or "" + ).strip() + if node_image_caption_prompt: + img_caption_cfg["image_caption_prompt"] = node_image_caption_prompt + + node_img_cap_prov_id = str( + node_config.get("image_caption_provider_id") or "" + ).strip() + default_img_cap_prov_id = str( + cfg.get("default_image_caption_provider_id") + or runtime_config.get("default_image_caption_provider_id") + or "" + ).strip() + img_cap_prov_id = node_img_cap_prov_id or default_img_cap_prov_id _apply_prompt_prefix(req, cfg) if req.conversation: - await _ensure_persona_and_skills(req, cfg, plugin_context, event) + await _ensure_persona_and_skills( + req, + cfg, + plugin_context, + event, + subagent_orchestrator_cfg=config.subagent_orchestrator, + ) - img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or "" if img_cap_prov_id and req.image_urls: await _ensure_img_caption( req, - cfg, + img_caption_cfg, plugin_context, img_cap_prov_id, ) - img_cap_prov_id = cfg.get("default_image_caption_provider_id") or "" await _process_quote_message( event, req, + img_caption_cfg, img_cap_prov_id, plugin_context, ) tz = config.timezone if tz is None: - tz = plugin_context.get_config().get("timezone") + tz = runtime_config.get("timezone") _append_system_reminders(event, req, cfg, tz) @@ -891,12 +874,6 @@ async def build_main_agent( if isinstance(req.contexts, str): req.contexts = json.loads(req.contexts) - if config.file_extract_enabled: - try: - await _apply_file_extract(event, req, config) - except Exception as exc: # noqa: BLE001 - logger.error("Error occurred while applying file extract: %s", exc) - if not req.prompt and not req.image_urls: if not event.get_group_id() and req.extra_user_content_parts: req.prompt = "" @@ -905,7 +882,9 @@ async def build_main_agent( await _decorate_llm_request(event, req, plugin_context, config) - await _apply_kb(event, req, plugin_context, config) + _inject_pipeline_context(event, req) + + _apply_kb(event, req, config) if not req.session_id: req.session_id = event.unified_msg_origin diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py index 1d5c085ce..ed5be6bfc 100644 --- a/astrbot/core/astr_main_agent_resources.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -6,7 +6,7 @@ from pydantic.dataclasses import dataclass import astrbot.core.message.components as Comp -from astrbot.api import logger, sp +from astrbot.api import logger from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import FunctionTool, ToolExecResult from astrbot.core.astr_agent_context import AstrAgentContext @@ -160,10 +160,13 @@ async def call( query = kwargs.get("query", "") if not query: return "error: Query parameter is empty." + event = context.context.event + chain_config_id = event.chain_config.config_id if event.chain_config else None result = await retrieve_knowledge_base( query=kwargs.get("query", ""), - umo=context.context.event.unified_msg_origin, + umo=event.unified_msg_origin, context=context.context.context, + config_id=chain_config_id, ) if not result: return "No relevant knowledge found." @@ -234,6 +237,9 @@ async def _resolve_path_from_sandbox( sb = await get_booter( context.context.context, context.context.event.unified_msg_origin, + context.context.event.chain_config.config_id + if context.context.event.chain_config + else None, ) # Use shell to check if the file exists in sandbox result = await sb.shell.exec(f"test -f {path} && echo '_&exists_'") @@ -365,6 +371,7 @@ async def retrieve_knowledge_base( query: str, umo: str, context: Context, + config_id: str | None = None, ) -> str | None: """Inject knowledge base context into the provider request @@ -373,51 +380,16 @@ async def retrieve_knowledge_base( p_ctx: Pipeline context """ kb_mgr = context.kb_manager - config = context.get_config(umo=umo) - - # 1. 优先读取会话级配置 - session_config = await sp.session_get(umo, "kb_config", default={}) - - if session_config and "kb_ids" in session_config: - # 会话级配置 - kb_ids = session_config.get("kb_ids", []) - - # 如果配置为空列表,明确表示不使用知识库 - if not kb_ids: - logger.info(f"[知识库] 会话 {umo} 已被配置为不使用知识库") - return - - top_k = session_config.get("top_k", 5) - - # 将 kb_ids 转换为 kb_names - kb_names = [] - invalid_kb_ids = [] - for kb_id in kb_ids: - kb_helper = await kb_mgr.get_kb(kb_id) - if kb_helper: - kb_names.append(kb_helper.kb.kb_name) - else: - logger.warning(f"[知识库] 知识库不存在或未加载: {kb_id}") - invalid_kb_ids.append(kb_id) - - if invalid_kb_ids: - logger.warning( - f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}", - ) - - if not kb_names: - return + config = context.get_config_by_id(config_id) - logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}") - else: - kb_names = config.get("kb_names", []) - top_k = config.get("kb_final_top_k", 5) - logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}") + kb_names = config.get("kb_names", []) + top_k = config.get("kb_final_top_k", 5) + logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}") top_k_fusion = config.get("kb_fusion_top_k", 20) if not kb_names: - return + return None logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}") kb_context = await kb_mgr.retrieve( @@ -428,13 +400,14 @@ async def retrieve_knowledge_base( ) if not kb_context: - return + return None formatted = kb_context.get("context_text", "") if formatted: results = kb_context.get("results", []) logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块") return formatted + return None KNOWLEDGE_BASE_QUERY_TOOL = KnowledgeBaseQueryTool() diff --git a/astrbot/core/astrbot_config_mgr.py b/astrbot/core/astrbot_config_mgr.py index c2bfb1c37..90d21c3d9 100644 --- a/astrbot/core/astrbot_config_mgr.py +++ b/astrbot/core/astrbot_config_mgr.py @@ -6,7 +6,6 @@ from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH from astrbot.core.config.default import DEFAULT_CONFIG from astrbot.core.platform.message_session import MessageSession -from astrbot.core.umop_config_router import UmopConfigRouter from astrbot.core.utils.astrbot_path import get_astrbot_config_path from astrbot.core.utils.shared_preferences import SharedPreferences @@ -34,15 +33,14 @@ class AstrBotConfigManager: def __init__( self, default_config: AstrBotConfig, - ucr: UmopConfigRouter, sp: SharedPreferences, ) -> None: self.sp = sp - self.ucr = ucr self.confs: dict[str, AstrBotConfig] = {} """uuid / "default" -> AstrBotConfig""" self.confs["default"] = default_config self.abconf_data = None + self._runtime_config_mapping: dict[str, str] = {} self._load_all_configs() def _get_abconf_data(self) -> dict: @@ -72,33 +70,27 @@ def _load_all_configs(self) -> None: ) continue - def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo: - """获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default") - - Returns: - ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型 - - """ - # uuid -> { "path": str, "name": str } - abconf_data = self._get_abconf_data() - + @staticmethod + def _normalize_umo(umo: str | MessageSession) -> str | None: if isinstance(umo, MessageSession): - umo = str(umo) - else: - try: - umo = str(MessageSession.from_str(umo)) # validate - except Exception: - return DEFAULT_CONFIG_CONF_INFO - - conf_id = self.ucr.get_conf_id_for_umop(umo) - if conf_id: - meta = abconf_data.get(conf_id) - if meta and isinstance(meta, dict): - # the bind relation between umo and conf is defined in ucr now, so we remove "umop" here - meta.pop("umop", None) - return ConfInfo(**meta, id=conf_id) - - return DEFAULT_CONFIG_CONF_INFO + return str(umo) + try: + return str(MessageSession.from_str(umo)) # validate + except Exception: + return None + + def set_runtime_config_id(self, umo: str | MessageSession, config_id: str) -> None: + """保存运行时路由结果,用于按会话获取配置文件。""" + norm = self._normalize_umo(umo) + if not norm: + return + self._runtime_config_mapping[norm] = config_id + + def _get_runtime_config_id(self, umo: str | MessageSession) -> str | None: + norm = self._normalize_umo(umo) + if not norm: + return None + return self._runtime_config_mapping.get(norm) def _save_conf_mapping( self, @@ -125,14 +117,20 @@ def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig: """获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。""" if not umo: return self.confs["default"] - if isinstance(umo, MessageSession): - umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}" + config_id = self._get_runtime_config_id(umo) + if not config_id: + return self.confs["default"] - uuid_ = self._load_conf_mapping(umo)["id"] + return self.get_conf_by_id(config_id) - conf = self.confs.get(uuid_) - if not conf: - conf = self.confs["default"] # default MUST exists + def get_conf_by_id(self, config_id: str | None) -> AstrBotConfig: + """通过配置文件 ID 获取配置;无效 ID 回退到默认配置。""" + if not config_id: + return self.confs["default"] + + conf = self.confs.get(config_id) + if conf is None: + return self.confs["default"] return conf @@ -141,12 +139,24 @@ def default_conf(self) -> AstrBotConfig: """获取默认配置文件""" return self.confs["default"] - def get_conf_info(self, umo: str | MessageSession) -> ConfInfo: + def get_config_info(self, umo: str | MessageSession) -> ConfInfo: """获取指定 umo 的配置文件元数据""" - if isinstance(umo, MessageSession): - umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}" + config_id = self._get_runtime_config_id(umo) + if not config_id: + return DEFAULT_CONFIG_CONF_INFO + return self.get_config_info_by_id(config_id) - return self._load_conf_mapping(umo) + def get_config_info_by_id(self, config_id: str) -> ConfInfo: + """通过配置文件 ID 获取元数据,不进行路由.""" + if config_id == "default": + return DEFAULT_CONFIG_CONF_INFO + + abconf_data = self._get_abconf_data() + meta = abconf_data.get(config_id) + if meta and isinstance(meta, dict) and config_id in self.confs: + return ConfInfo(**meta, id=config_id) + + return DEFAULT_CONFIG_CONF_INFO def get_conf_list(self) -> list[ConfInfo]: """获取所有配置文件的元数据列表""" @@ -155,7 +165,6 @@ def get_conf_list(self) -> list[ConfInfo]: for uuid_, meta in abconf_mapping.items(): if not isinstance(meta, dict): continue - meta.pop("umop", None) conf_list.append(ConfInfo(**meta, id=uuid_)) conf_list.append(DEFAULT_CONFIG_CONF_INFO) return conf_list @@ -174,11 +183,11 @@ def create_conf( self.confs[conf_uuid] = conf return conf_uuid - def delete_conf(self, conf_id: str) -> bool: + def delete_conf(self, config_id: str) -> bool: """删除指定配置文件 Args: - conf_id: 配置文件的 UUID + config_id: 配置文件的 UUID Returns: bool: 删除是否成功 @@ -187,7 +196,7 @@ def delete_conf(self, conf_id: str) -> bool: ValueError: 如果试图删除默认配置文件 """ - if conf_id == "default": + if config_id == "default": raise ValueError("不能删除默认配置文件") # 从映射中移除 @@ -197,14 +206,14 @@ def delete_conf(self, conf_id: str) -> bool: scope="global", scope_id="global", ) - if conf_id not in abconf_data: - logger.warning(f"配置文件 {conf_id} 不存在于映射中") + if config_id not in abconf_data: + logger.warning(f"配置文件 {config_id} 不存在于映射中") return False # 获取配置文件路径 conf_path = os.path.join( get_astrbot_config_path(), - abconf_data[conf_id]["path"], + abconf_data[config_id]["path"], ) # 删除配置文件 @@ -217,29 +226,29 @@ def delete_conf(self, conf_id: str) -> bool: return False # 从内存中移除 - if conf_id in self.confs: - del self.confs[conf_id] + if config_id in self.confs: + del self.confs[config_id] # 从映射中移除 - del abconf_data[conf_id] + del abconf_data[config_id] self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global") self.abconf_data = abconf_data - logger.info(f"成功删除配置文件 {conf_id}") + logger.info(f"成功删除配置文件 {config_id}") return True - def update_conf_info(self, conf_id: str, name: str | None = None) -> bool: + def update_conf_info(self, config_id: str, name: str | None = None) -> bool: """更新配置文件信息 Args: - conf_id: 配置文件的 UUID + config_id: 配置文件的 UUID name: 新的配置文件名称 (可选) Returns: bool: 更新是否成功 """ - if conf_id == "default": + if config_id == "default": raise ValueError("不能更新默认配置文件的信息") abconf_data = self.sp.get( @@ -248,18 +257,18 @@ def update_conf_info(self, conf_id: str, name: str | None = None) -> bool: scope="global", scope_id="global", ) - if conf_id not in abconf_data: - logger.warning(f"配置文件 {conf_id} 不存在于映射中") + if config_id not in abconf_data: + logger.warning(f"配置文件 {config_id} 不存在于映射中") return False # 更新名称 if name is not None: - abconf_data[conf_id]["name"] = name + abconf_data[config_id]["name"] = name # 保存更新 self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global") self.abconf_data = abconf_data - logger.info(f"成功更新配置文件 {conf_id} 的信息") + logger.info(f"成功更新配置文件 {config_id} 的信息") return True def g( diff --git a/astrbot/core/computer/computer_client.py b/astrbot/core/computer/computer_client.py index 9750e7b64..02f46c6bc 100644 --- a/astrbot/core/computer/computer_client.py +++ b/astrbot/core/computer/computer_client.py @@ -62,8 +62,9 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None: async def get_booter( context: Context, session_id: str, + config_id: str | None = None, ) -> ComputerBooter: - config = context.get_config(umo=session_id) + config = context.get_config_by_id(config_id) sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) booter_type = sandbox_cfg.get("booter", "shipyard") diff --git a/astrbot/core/computer/tools/fs.py b/astrbot/core/computer/tools/fs.py index 9cf590a61..6216b63f8 100644 --- a/astrbot/core/computer/tools/fs.py +++ b/astrbot/core/computer/tools/fs.py @@ -104,6 +104,9 @@ async def call( sb = await get_booter( context.context.context, context.context.event.unified_msg_origin, + context.context.event.chain_config.config_id + if context.context.event.chain_config + else None, ) try: # Check if file exists @@ -163,6 +166,9 @@ async def call( sb = await get_booter( context.context.context, context.context.event.unified_msg_origin, + context.context.event.chain_config.config_id + if context.context.event.chain_config + else None, ) try: name = os.path.basename(remote_path) diff --git a/astrbot/core/computer/tools/python.py b/astrbot/core/computer/tools/python.py index 333f442f9..c61cefc53 100644 --- a/astrbot/core/computer/tools/python.py +++ b/astrbot/core/computer/tools/python.py @@ -65,6 +65,9 @@ async def call( sb = await get_booter( context.context.context, context.context.event.unified_msg_origin, + context.context.event.chain_config.config_id + if context.context.event.chain_config + else None, ) try: result = await sb.python.exec(code, silent=silent) diff --git a/astrbot/core/computer/tools/shell.py b/astrbot/core/computer/tools/shell.py index eeeb3f9d4..deb589a10 100644 --- a/astrbot/core/computer/tools/shell.py +++ b/astrbot/core/computer/tools/shell.py @@ -55,6 +55,9 @@ async def call( sb = await get_booter( context.context.context, context.context.event.unified_msg_origin, + context.context.event.chain_config.config_id + if context.context.event.chain_config + else None, ) try: result = await sb.shell.exec(command, background=background, env=env) diff --git a/astrbot/core/config/__init__.py b/astrbot/core/config/__init__.py index 839aeef3e..213d3d097 100644 --- a/astrbot/core/config/__init__.py +++ b/astrbot/core/config/__init__.py @@ -1,9 +1,11 @@ from .astrbot_config import * from .default import DB_PATH, DEFAULT_CONFIG, VERSION +from .node_config import AstrBotNodeConfig __all__ = [ "DB_PATH", "DEFAULT_CONFIG", "VERSION", "AstrBotConfig", + "AstrBotNodeConfig", ] diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 546768812..0933678ca 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -28,6 +28,7 @@ "strategy": "stall", # stall, discard }, "reply_prefix": "", + "forward_wrapper": False, "forward_threshold": 1500, "enable_id_white_list": True, "id_whitelist": [], @@ -69,6 +70,8 @@ "default_provider_id": "", "default_image_caption_provider_id": "", "image_caption_prompt": "Please describe the image using Chinese.", + "default_stt_provider_id": "", + "default_tts_provider_id": "", "provider_pool": ["*"], # "*" 表示使用所有可用的提供者 "wake_prefix": "", "web_search": False, @@ -99,10 +102,6 @@ "streaming_response": False, "show_tool_use_status": False, "sanitize_context_by_modalities": False, - "agent_runner_type": "local", - "dify_agent_runner_provider_id": "", - "coze_agent_runner_provider_id": "", - "dashscope_agent_runner_provider_id": "", "unsupported_streaming_strategy": "realtime_segmenting", "reachability_check": False, "max_agent_step": 30, @@ -110,11 +109,6 @@ "tool_schema_mode": "full", "llm_safety_mode": True, "safety_mode_strategy": "system_prompt", # TODO: llm judge - "file_extract": { - "enable": False, - "provider": "moonshotai", - "moonshotai_api_key": "", - }, "proactive_capability": { "add_cron_tools": True, }, @@ -142,17 +136,6 @@ ), "agents": [], }, - "provider_stt_settings": { - "enable": False, - "provider_id": "", - }, - "provider_tts_settings": { - "enable": False, - "provider_id": "", - "dual_output": False, - "use_file_service": False, - "trigger_probability": 1.0, - }, "provider_ltm_settings": { "group_icl_enable": False, "group_message_max_cnt": 300, @@ -165,17 +148,8 @@ "whitelist": [], }, }, - "content_safety": { - "also_use_in_response": False, - "internal_keywords": {"enable": True, "extra_keywords": []}, - "baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""}, - }, "admins_id": ["astrbot"], - "t2i": False, - "t2i_word_threshold": 150, - "t2i_strategy": "remote", "t2i_endpoint": "", - "t2i_use_file_service": False, "t2i_active_template": "base", "http_proxy": "", "no_proxy": ["localhost", "127.0.0.1", "::1", "10.*", "192.168.*"], @@ -835,6 +809,10 @@ class ChatProviderTemplate(TypedDict): "type": "string", "hint": "机器人回复消息时带有的前缀。", }, + "forward_wrapper": { + "type": "bool", + "hint": "启用后,超过转发阈值的消息会以合并转发形式发送(仅 QQ 平台适用)。", + }, "forward_threshold": { "type": "int", "hint": "超过一定字数后,机器人会将消息折叠成 QQ 群聊的 “转发消息”,以防止刷屏。目前仅 QQ 平台适配器适用。", @@ -872,42 +850,6 @@ class ChatProviderTemplate(TypedDict): }, }, }, - "content_safety": { - "type": "object", - "items": { - "also_use_in_response": { - "type": "bool", - "hint": "启用后,大模型的响应也会通过内容安全审核。", - }, - "baidu_aip": { - "type": "object", - "items": { - "enable": { - "type": "bool", - "hint": "启用此功能前,您需要手动在设备中安装 baidu-aip 库。一般来说,安装指令如下: `pip3 install baidu-aip`", - }, - "app_id": {"description": "APP ID", "type": "string"}, - "api_key": {"description": "API Key", "type": "string"}, - "secret_key": { - "type": "string", - }, - }, - }, - "internal_keywords": { - "type": "object", - "items": { - "enable": { - "type": "bool", - }, - "extra_keywords": { - "type": "list", - "items": {"type": "string"}, - "hint": "额外的屏蔽关键词列表,支持正则表达式。", - }, - }, - }, - }, - }, }, }, "provider_group": { @@ -2197,6 +2139,12 @@ class ChatProviderTemplate(TypedDict): "default_provider_id": { "type": "string", }, + "default_image_caption_provider_id": { + "type": "string", + }, + "image_caption_prompt": { + "type": "string", + }, "wake_prefix": { "type": "string", }, @@ -2239,18 +2187,6 @@ class ChatProviderTemplate(TypedDict): "unsupported_streaming_strategy": { "type": "string", }, - "agent_runner_type": { - "type": "string", - }, - "dify_agent_runner_provider_id": { - "type": "string", - }, - "coze_agent_runner_provider_id": { - "type": "string", - }, - "dashscope_agent_runner_provider_id": { - "type": "string", - }, "max_agent_step": { "type": "int", }, @@ -2260,16 +2196,13 @@ class ChatProviderTemplate(TypedDict): "tool_schema_mode": { "type": "string", }, - "file_extract": { + "skills": { "type": "object", "items": { "enable": { "type": "bool", }, - "provider": { - "type": "string", - }, - "moonshotai_api_key": { + "runtime": { "type": "string", }, }, @@ -2284,37 +2217,6 @@ class ChatProviderTemplate(TypedDict): }, }, }, - "provider_stt_settings": { - "type": "object", - "items": { - "enable": { - "type": "bool", - }, - "provider_id": { - "type": "string", - }, - }, - }, - "provider_tts_settings": { - "type": "object", - "items": { - "enable": { - "type": "bool", - }, - "provider_id": { - "type": "string", - }, - "dual_output": { - "type": "bool", - }, - "use_file_service": { - "type": "bool", - }, - "trigger_probability": { - "type": "float", - }, - }, - }, "provider_ltm_settings": { "type": "object", "items": { @@ -2362,12 +2264,6 @@ class ChatProviderTemplate(TypedDict): "type": "list", "items": {"type": "string"}, }, - "t2i": { - "type": "bool", - }, - "t2i_word_threshold": { - "type": "int", - }, "admins_id": { "type": "list", "items": {"type": "string"}, @@ -2410,9 +2306,6 @@ class ChatProviderTemplate(TypedDict): "t2i_endpoint": { "type": "string", }, - "t2i_use_file_service": { - "type": "bool", - }, "pip_install_arg": { "type": "string", }, @@ -2443,54 +2336,6 @@ class ChatProviderTemplate(TypedDict): "ai_group": { "name": "AI 配置", "metadata": { - "agent_runner": { - "description": "Agent 执行方式", - "hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify 或 Coze 等第三方 Agent 执行器,不需要修改此节。", - "type": "object", - "items": { - "provider_settings.enable": { - "description": "启用", - "type": "bool", - "hint": "AI 对话总开关", - }, - "provider_settings.agent_runner_type": { - "description": "执行器", - "type": "string", - "options": ["local", "dify", "coze", "dashscope"], - "labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"], - "condition": { - "provider_settings.enable": True, - }, - }, - "provider_settings.coze_agent_runner_provider_id": { - "description": "Coze Agent 执行器提供商 ID", - "type": "string", - "_special": "select_agent_runner_provider:coze", - "condition": { - "provider_settings.agent_runner_type": "coze", - "provider_settings.enable": True, - }, - }, - "provider_settings.dify_agent_runner_provider_id": { - "description": "Dify Agent 执行器提供商 ID", - "type": "string", - "_special": "select_agent_runner_provider:dify", - "condition": { - "provider_settings.agent_runner_type": "dify", - "provider_settings.enable": True, - }, - }, - "provider_settings.dashscope_agent_runner_provider_id": { - "description": "阿里云百炼应用 Agent 执行器提供商 ID", - "type": "string", - "_special": "select_agent_runner_provider:dashscope", - "condition": { - "provider_settings.agent_runner_type": "dashscope", - "provider_settings.enable": True, - }, - }, - }, - }, "ai": { "description": "模型", "hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。", @@ -2508,49 +2353,12 @@ class ChatProviderTemplate(TypedDict): "_special": "select_provider", "hint": "留空代表不使用,可用于非多模态模型", }, - "provider_stt_settings.enable": { - "description": "启用语音转文本", - "type": "bool", - "hint": "STT 总开关", - }, - "provider_stt_settings.provider_id": { - "description": "默认语音转文本模型", - "type": "string", - "hint": "用户也可使用 /provider 指令单独选择会话的 STT 模型。", - "_special": "select_provider_stt", - "condition": { - "provider_stt_settings.enable": True, - }, - }, - "provider_tts_settings.enable": { - "description": "启用文本转语音", - "type": "bool", - "hint": "TTS 总开关", - }, - "provider_tts_settings.provider_id": { - "description": "默认文本转语音模型", - "type": "string", - "_special": "select_provider_tts", - "condition": { - "provider_tts_settings.enable": True, - }, - }, - "provider_tts_settings.trigger_probability": { - "description": "TTS 触发概率", - "type": "float", - "slider": {"min": 0, "max": 1, "step": 0.05}, - "condition": { - "provider_tts_settings.enable": True, - }, - }, "provider_settings.image_caption_prompt": { "description": "图片转述提示词", "type": "text", }, }, - "condition": { - "provider_settings.enable": True, - }, + "condition": {}, }, "persona": { "description": "人格", @@ -2563,10 +2371,7 @@ class ChatProviderTemplate(TypedDict): "_special": "select_persona", }, }, - "condition": { - "provider_settings.agent_runner_type": "local", - "provider_settings.enable": True, - }, + "condition": {}, }, "knowledgebase": { "description": "知识库", @@ -2596,10 +2401,7 @@ class ChatProviderTemplate(TypedDict): "hint": "启用后,知识库检索将作为 LLM Tool,由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。", }, }, - "condition": { - "provider_settings.agent_runner_type": "local", - "provider_settings.enable": True, - }, + "condition": {}, }, "websearch": { "description": "网页搜索", @@ -2654,10 +2456,7 @@ class ChatProviderTemplate(TypedDict): }, }, }, - "condition": { - "provider_settings.agent_runner_type": "local", - "provider_settings.enable": True, - }, + "condition": {}, }, "agent_computer_use": { "description": "Agent Computer Use", @@ -2718,41 +2517,8 @@ class ChatProviderTemplate(TypedDict): }, }, }, - "condition": { - "provider_settings.agent_runner_type": "local", - "provider_settings.enable": True, - }, + "condition": {}, }, - # "file_extract": { - # "description": "文档解析能力 [beta]", - # "type": "object", - # "items": { - # "provider_settings.file_extract.enable": { - # "description": "启用文档解析能力", - # "type": "bool", - # }, - # "provider_settings.file_extract.provider": { - # "description": "文档解析提供商", - # "type": "string", - # "options": ["moonshotai"], - # "condition": { - # "provider_settings.file_extract.enable": True, - # }, - # }, - # "provider_settings.file_extract.moonshotai_api_key": { - # "description": "Moonshot AI API Key", - # "type": "string", - # "condition": { - # "provider_settings.file_extract.provider": "moonshotai", - # "provider_settings.file_extract.enable": True, - # }, - # }, - # }, - # "condition": { - # "provider_settings.agent_runner_type": "local", - # "provider_settings.enable": True, - # }, - # }, "proactive_capability": { "description": "主动型 Agent", "hint": "https://docs.astrbot.app/use/proactive-agent.html", @@ -2764,10 +2530,7 @@ class ChatProviderTemplate(TypedDict): "hint": "启用后,将会传递给 Agent 相关工具来实现主动型 Agent。你可以告诉 AstrBot 未来某个时间要做的事情,它将被定时触发然后执行任务。", }, }, - "condition": { - "provider_settings.agent_runner_type": "local", - "provider_settings.enable": True, - }, + "condition": {}, }, "truncate_and_compress": { "hint": "", @@ -2778,26 +2541,20 @@ class ChatProviderTemplate(TypedDict): "description": "最多携带对话轮数", "type": "int", "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制", - "condition": { - "provider_settings.agent_runner_type": "local", - }, + "condition": {}, }, "provider_settings.dequeue_context_length": { "description": "丢弃对话轮数", "type": "int", "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数", - "condition": { - "provider_settings.agent_runner_type": "local", - }, + "condition": {}, }, "provider_settings.context_limit_reached_strategy": { "description": "超出模型上下文窗口时的处理方式", "type": "string", "options": ["truncate_by_turns", "llm_compress"], "labels": ["按对话轮数截断", "由 LLM 压缩上下文"], - "condition": { - "provider_settings.agent_runner_type": "local", - }, + "condition": {}, "hint": "", }, "provider_settings.llm_compress_instruction": { @@ -2806,7 +2563,6 @@ class ChatProviderTemplate(TypedDict): "hint": "如果为空则使用默认提示词。", "condition": { "provider_settings.context_limit_reached_strategy": "llm_compress", - "provider_settings.agent_runner_type": "local", }, }, "provider_settings.llm_compress_keep_recent": { @@ -2815,7 +2571,6 @@ class ChatProviderTemplate(TypedDict): "hint": "始终保留的最近 N 轮对话。", "condition": { "provider_settings.context_limit_reached_strategy": "llm_compress", - "provider_settings.agent_runner_type": "local", }, }, "provider_settings.llm_compress_provider_id": { @@ -2825,14 +2580,10 @@ class ChatProviderTemplate(TypedDict): "hint": "留空时将降级为“按对话轮数截断”的策略。", "condition": { "provider_settings.context_limit_reached_strategy": "llm_compress", - "provider_settings.agent_runner_type": "local", }, }, }, - "condition": { - "provider_settings.agent_runner_type": "local", - "provider_settings.enable": True, - }, + "condition": {}, }, "others": { "description": "其他配置", @@ -2841,9 +2592,7 @@ class ChatProviderTemplate(TypedDict): "provider_settings.display_reasoning_text": { "description": "显示思考内容", "type": "bool", - "condition": { - "provider_settings.agent_runner_type": "local", - }, + "condition": {}, }, "provider_settings.streaming_response": { "description": "流式输出", @@ -2887,38 +2636,28 @@ class ChatProviderTemplate(TypedDict): "description": "现实世界时间感知", "type": "bool", "hint": "启用后,会在系统提示词中附带当前时间信息。", - "condition": { - "provider_settings.agent_runner_type": "local", - }, + "condition": {}, }, "provider_settings.show_tool_use_status": { "description": "输出函数调用状态", "type": "bool", - "condition": { - "provider_settings.agent_runner_type": "local", - }, + "condition": {}, }, "provider_settings.sanitize_context_by_modalities": { "description": "按模型能力清理历史上下文", "type": "bool", "hint": "开启后,在每次请求 LLM 前会按当前模型提供商中所选择的模型能力删除对话中不支持的图片/工具调用结构(会改变模型看到的历史)", - "condition": { - "provider_settings.agent_runner_type": "local", - }, + "condition": {}, }, "provider_settings.max_agent_step": { "description": "工具调用轮数上限", "type": "int", - "condition": { - "provider_settings.agent_runner_type": "local", - }, + "condition": {}, }, "provider_settings.tool_call_timeout": { "description": "工具调用超时时间(秒)", "type": "int", - "condition": { - "provider_settings.agent_runner_type": "local", - }, + "condition": {}, }, "provider_settings.tool_schema_mode": { "description": "工具调用模式", @@ -2926,9 +2665,7 @@ class ChatProviderTemplate(TypedDict): "options": ["skills_like", "full"], "labels": ["Skills-like(两阶段)", "Full(完整参数)"], "hint": "skills-like 先下发工具名称与描述,再下发参数;full 一次性下发完整参数。", - "condition": { - "provider_settings.agent_runner_type": "local", - }, + "condition": {}, }, "provider_settings.wake_prefix": { "description": "LLM 聊天额外唤醒前缀 ", @@ -2940,19 +2677,13 @@ class ChatProviderTemplate(TypedDict): "type": "string", "hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。", }, - "provider_tts_settings.dual_output": { - "description": "开启 TTS 时同时输出语音和文字内容", - "type": "bool", - }, "provider_settings.reachability_check": { "description": "提供商可达性检测", "type": "bool", "hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。", }, }, - "condition": { - "provider_settings.enable": True, - }, + "condition": {}, }, }, }, @@ -2994,6 +2725,10 @@ class ChatProviderTemplate(TypedDict): "description": "回复时引用发送人消息", "type": "bool", }, + "platform_settings.forward_wrapper": { + "description": "启用合并转发", + "type": "bool", + }, "platform_settings.forward_threshold": { "description": "转发消息的字数阈值", "type": "int", @@ -3058,66 +2793,6 @@ class ChatProviderTemplate(TypedDict): }, }, }, - "content_safety": { - "description": "内容安全", - "type": "object", - "items": { - "content_safety.also_use_in_response": { - "description": "同时检查模型的响应内容", - "type": "bool", - }, - "content_safety.baidu_aip.enable": { - "description": "使用百度内容安全审核", - "type": "bool", - "hint": "您需要手动安装 baidu-aip 库。", - }, - "content_safety.baidu_aip.app_id": { - "description": "App ID", - "type": "string", - "condition": { - "content_safety.baidu_aip.enable": True, - }, - }, - "content_safety.baidu_aip.api_key": { - "description": "API Key", - "type": "string", - "condition": { - "content_safety.baidu_aip.enable": True, - }, - }, - "content_safety.baidu_aip.secret_key": { - "description": "Secret Key", - "type": "string", - "condition": { - "content_safety.baidu_aip.enable": True, - }, - }, - "content_safety.internal_keywords.enable": { - "description": "关键词检查", - "type": "bool", - }, - "content_safety.internal_keywords.extra_keywords": { - "description": "额外关键词", - "type": "list", - "items": {"type": "string"}, - "hint": "额外的屏蔽关键词列表,支持正则表达式。", - }, - }, - }, - "t2i": { - "description": "文本转图像", - "type": "object", - "items": { - "t2i": { - "description": "文本转图像输出", - "type": "bool", - }, - "t2i_word_threshold": { - "description": "文本转图像字数阈值", - "type": "int", - }, - }, - }, "others": { "description": "其他配置", "type": "object", @@ -3322,27 +2997,15 @@ class ChatProviderTemplate(TypedDict): "description": "系统配置", "type": "object", "items": { - "t2i_strategy": { - "description": "文本转图像策略", - "type": "string", - "hint": "文本转图像策略。`remote` 为使用远程基于 HTML 的渲染服务,`local` 为使用 PIL 本地渲染。当使用 local 时,将 ttf 字体命名为 'font.ttf' 放在 data/ 目录下可自定义字体。", - "options": ["remote", "local"], - }, "t2i_endpoint": { "description": "文本转图像服务 API 地址", "type": "string", "hint": "为空时使用 AstrBot API 服务", - "condition": { - "t2i_strategy": "remote", - }, }, "t2i_template": { "description": "文本转图像自定义模版", "type": "bool", "hint": "启用后可自定义 HTML 模板用于文转图渲染。", - "condition": { - "t2i_strategy": "remote", - }, "_special": "t2i_template", }, "t2i_active_template": { diff --git a/astrbot/core/config/node_config.py b/astrbot/core/config/node_config.py new file mode 100644 index 000000000..c6fd39911 --- /dev/null +++ b/astrbot/core/config/node_config.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import os +import re + +from astrbot.core.utils.astrbot_path import get_astrbot_config_path + +from .astrbot_config import AstrBotConfig + +_SAFE_NAME_RE = re.compile(r"[^A-Za-z0-9._-]+") + + +def _sanitize_name(value: str) -> str: + value = (value or "").strip().lower() + value = value.replace("/", "_").replace("\\", "_") + value = _SAFE_NAME_RE.sub("_", value) + return value or "unknown" + + +def _build_node_config_path( + node_name: str, + chain_id: str, + node_uuid: str, +) -> str: + plugin_key = _sanitize_name(node_name) + chain_key = _sanitize_name(chain_id) + uuid_key = _sanitize_name(node_uuid) + filename = f"node_{plugin_key}_{chain_key}_{uuid_key}.json" + os.makedirs(get_astrbot_config_path(), exist_ok=True) + return os.path.join(get_astrbot_config_path(), filename) + + +class AstrBotNodeConfig(AstrBotConfig): + """Node config - extends AstrBotConfig with chain-specific path. + + Node config is chain-scoped and shared across config_id. + This class reuses AstrBotConfig's schema parsing, integrity checking, + and persistence logic, only overriding the config path. + """ + + node_name: str + chain_id: str + node_uuid: str + + _cache: dict[tuple[str, str, str], AstrBotNodeConfig] = {} + + def __init__( + self, + node_name: str, + chain_id: str, + node_uuid: str, + schema: dict | None = None, + ): + # Store node identifiers before parent init + object.__setattr__(self, "node_name", node_name) + object.__setattr__(self, "chain_id", chain_id) + object.__setattr__(self, "node_uuid", node_uuid) + + # Keep behavior aligned with Star plugin config: + # if schema is not declared, do not create a persisted config file. + if schema is not None: + config_path = _build_node_config_path(node_name, chain_id, node_uuid) + else: + config_path = "" + + # Initialize with empty default_config, schema will generate defaults + super().__init__( + config_path=config_path, + default_config={}, + schema=schema, + ) + + def check_exist(self) -> bool: + """Override to handle empty config_path case.""" + if not self.config_path: + return True # Skip file operations if no path + return super().check_exist() + + def save_config(self, replace_config: dict | None = None): + """Override to handle empty config_path case.""" + if not self.config_path: + return + if replace_config: + self.update(replace_config) + super().save_config() + + @classmethod + def get_cached( + cls, + node_name: str, + chain_id: str, + node_uuid: str, + schema: dict | None = None, + ) -> AstrBotNodeConfig: + cache_key = (node_name, chain_id, node_uuid) + cached = cls._cache.get(cache_key) + if cached is None: + cached = cls( + node_name=node_name, + chain_id=chain_id, + node_uuid=node_uuid, + schema=schema, + ) + cls._cache[cache_key] = cached + return cached + + if schema is not None: + if not cached.config_path: + cached = cls( + node_name=node_name, + chain_id=chain_id, + node_uuid=node_uuid, + schema=schema, + ) + cls._cache[cache_key] = cached + return cached + cached._update_schema(schema) + return cached + + def _update_schema(self, schema: dict) -> None: + if self.schema == schema: + return + object.__setattr__(self, "schema", schema) + refer_conf = self._config_schema_to_default_config(schema) + object.__setattr__(self, "default_config", refer_conf) + conf = dict(self) + has_new = self.check_config_integrity(refer_conf, conf) + self.clear() + self.update(conf) + if has_new: + self.save_config() diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 6b36cca0d..8ab8c54ee 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -1,6 +1,6 @@ """Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作. -该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。 +该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineExecutor、EventBus等。 该类还负责加载和执行插件, 以及处理事件总线的分发。 工作流程: @@ -25,7 +25,9 @@ from astrbot.core.db import BaseDatabase from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.persona_mgr import PersonaManager -from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler +from astrbot.core.pipeline.context import PipelineContext +from astrbot.core.pipeline.engine.executor import PipelineExecutor +from astrbot.core.pipeline.engine.router import ChainRouter from astrbot.core.platform.manager import PlatformManager from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager from astrbot.core.provider.manager import ProviderManager @@ -33,7 +35,6 @@ from astrbot.core.star.context import Context from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map from astrbot.core.subagent_orchestrator import SubAgentOrchestrator -from astrbot.core.umop_config_router import UmopConfigRouter from astrbot.core.updator import AstrBotUpdator from astrbot.core.utils.llm_metadata import update_llm_metadata from astrbot.core.utils.migra_helper import migra @@ -45,7 +46,7 @@ class AstrBotCoreLifecycle: """AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作. - 该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、 + 该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineExecutor、 EventBus 等。 该类还负责加载和执行插件, 以及处理事件总线的分发。 """ @@ -98,7 +99,7 @@ async def _init_or_reload_subagent_orchestrator(self) -> None: async def initialize(self) -> None: """初始化 AstrBot 核心生命周期管理类. - 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。 + 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineExecutor、EventBus、AstrBotUpdator等。 """ # 初始化日志代理 logger.info("AstrBot v" + VERSION) @@ -115,14 +116,12 @@ async def initialize(self) -> None: await html_renderer.initialize() - # 初始化 UMOP 配置路由器 - self.umop_config_router = UmopConfigRouter(sp=sp) - await self.umop_config_router.initialize() + # 初始化 Chain 配置路由器(用于配置文件选择) + self.chain_config_router = ChainRouter() # 初始化 AstrBot 配置管理器 self.astrbot_config_mgr = AstrBotConfigManager( default_config=self.astrbot_config, - ucr=self.umop_config_router, sp=sp, ) @@ -131,13 +130,15 @@ async def initialize(self) -> None: await migra( self.db, self.astrbot_config_mgr, - self.umop_config_router, + sp, self.astrbot_config_mgr, ) except Exception as e: logger.error(f"AstrBot migration failed: {e!s}") logger.error(traceback.format_exc()) + await self.chain_config_router.load_configs(self.db) + # 初始化事件队列 self.event_queue = Queue() @@ -197,8 +198,8 @@ async def initialize(self) -> None: await self.kb_manager.initialize() - # 初始化消息事件流水线调度器 - self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler() + # 初始化消息事件流水线执行器 + self.pipeline_executor_mapping = await self.load_pipeline_executors() # 初始化更新器 self.astrbot_updator = AstrBotUpdator() @@ -206,8 +207,9 @@ async def initialize(self) -> None: # 初始化事件总线 self.event_bus = EventBus( self.event_queue, - self.pipeline_scheduler_mapping, + self.pipeline_executor_mapping, self.astrbot_config_mgr, + self.chain_config_router, ) # 记录启动时间 @@ -353,34 +355,37 @@ def load_platform(self) -> list[asyncio.Task]: ) return tasks - async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]: - """加载消息事件流水线调度器. + async def load_pipeline_executors(self) -> dict[str, PipelineExecutor]: + """加载消息事件流水线执行器. Returns: - dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射 + dict[str, PipelineExecutor]: 配置 ID 到流水线执行器的映射 """ mapping = {} - for conf_id, ab_config in self.astrbot_config_mgr.confs.items(): - scheduler = PipelineScheduler( - PipelineContext(ab_config, self.plugin_manager, conf_id), + for config_id, ab_config in self.astrbot_config_mgr.confs.items(): + executor = PipelineExecutor( + self.star_context, + PipelineContext( + ab_config, + self.plugin_manager, + ), ) - await scheduler.initialize() - mapping[conf_id] = scheduler + await executor.initialize() + mapping[config_id] = executor return mapping - async def reload_pipeline_scheduler(self, conf_id: str) -> None: - """重新加载消息事件流水线调度器. - - Returns: - dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射 - - """ - ab_config = self.astrbot_config_mgr.confs.get(conf_id) + async def reload_pipeline_executor(self, config_id: str) -> None: + """重新加载消息事件流水线执行器.""" + ab_config = self.astrbot_config_mgr.confs.get(config_id) if not ab_config: - raise ValueError(f"配置文件 {conf_id} 不存在") - scheduler = PipelineScheduler( - PipelineContext(ab_config, self.plugin_manager, conf_id), + raise ValueError(f"配置文件 {config_id} 不存在") + executor = PipelineExecutor( + self.star_context, + PipelineContext( + ab_config, + self.plugin_manager, + ), ) - await scheduler.initialize() - self.pipeline_scheduler_mapping[conf_id] = scheduler + await executor.initialize() + self.pipeline_executor_mapping[config_id] = executor diff --git a/astrbot/core/cron/manager.py b/astrbot/core/cron/manager.py index d12878be3..b419eee91 100644 --- a/astrbot/core/cron/manager.py +++ b/astrbot/core/cron/manager.py @@ -297,9 +297,9 @@ async def _woke_main_agent( ) # judge user's role - umo = cron_event.unified_msg_origin - cfg = self.ctx.get_config(umo=umo) cron_payload = extras.get("cron_payload", {}) if extras else {} + chain_config_id = cron_payload.get("config_id") + cfg = self.ctx.get_config_by_id(chain_config_id) sender_id = cron_payload.get("sender_id") admin_ids = cfg.get("admins_id", []) if admin_ids: diff --git a/astrbot/core/db/migration/migra_45_to_46.py b/astrbot/core/db/migration/migra_45_to_46.py index 58736ab51..c9b3de7be 100644 --- a/astrbot/core/db/migration/migra_45_to_46.py +++ b/astrbot/core/db/migration/migra_45_to_46.py @@ -15,7 +15,7 @@ async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> # 如果任何一项带有 umop,则说明需要迁移 need_migration = False - for conf_id, conf_info in abconf_data.items(): + for config_id, conf_info in abconf_data.items(): if isinstance(conf_info, dict) and "umop" in conf_info: need_migration = True break @@ -25,20 +25,20 @@ async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> logger.info("Starting migration from version 4.5 to 4.6") - # extract umo->conf_id mapping - umo_to_conf_id = {} - for conf_id, conf_info in abconf_data.items(): + # extract umo->config_id mapping + umo_to_config_id = {} + for config_id, conf_info in abconf_data.items(): if isinstance(conf_info, dict) and "umop" in conf_info: umop_ls = conf_info.pop("umop") if not isinstance(umop_ls, list): continue for umo in umop_ls: - if isinstance(umo, str) and umo not in umo_to_conf_id: - umo_to_conf_id[umo] = conf_id + if isinstance(umo, str) and umo not in umo_to_config_id: + umo_to_config_id[umo] = config_id # update the abconf data await sp.global_put("abconf_mapping", abconf_data) # update the umop config router - await ucr.update_routing_data(umo_to_conf_id) + await ucr.update_routing_data(umo_to_config_id) logger.info("Migration from version 45 to 46 completed successfully") diff --git a/astrbot/core/db/migration/migra_4_to_5.py b/astrbot/core/db/migration/migra_4_to_5.py new file mode 100644 index 000000000..52cd21f44 --- /dev/null +++ b/astrbot/core/db/migration/migra_4_to_5.py @@ -0,0 +1,688 @@ +from __future__ import annotations + +import json +import uuid +from typing import Any + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import col, select + +from astrbot.api import logger, sp +from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.config.node_config import AstrBotNodeConfig +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import Preference +from astrbot.core.pipeline.agent.runner_config import ( + AGENT_RUNNER_PROVIDER_KEY, + normalize_agent_runner_type, +) +from astrbot.core.pipeline.engine.chain_config import ( + ChainConfigModel, + normalize_chain_nodes, + serialize_chain_nodes, +) +from astrbot.core.pipeline.engine.chain_runtime_flags import ( + FEATURE_LLM, + FEATURE_STT, + FEATURE_T2I, + FEATURE_TTS, +) +from astrbot.core.umop_config_router import UmopConfigRouter + +_MIGRATION_FLAG = "migration_done_v5" +_MIGRATION_PROVIDER_CLEANUP_FLAG = "migration_done_v5_provider_cleanup" + +_SESSION_RULE_KEYS = { + "session_service_config", + "session_plugin_config", + "provider_perf_chat_completion", + "provider_perf_text_to_speech", + "provider_perf_speech_to_text", +} + + +def _normalize_umop_pattern(pattern: str) -> str | None: + parts = [p.strip() for p in str(pattern).split(":")] + if len(parts) != 3: + return None + normalized = [p if p != "" else "*" for p in parts] + return ":".join(normalized) + + +def _build_umo_rule(pattern: str) -> dict: + return { + "type": "condition", + "condition": { + "type": "umo", + "operator": "include", + "value": pattern, + }, + } + + +def _merge_excludes(rule: dict | None, disabled_umos: list[str]) -> dict | None: + if not disabled_umos: + return rule + excludes = [ + { + "type": "condition", + "condition": { + "type": "umo", + "operator": "exclude", + "value": umo, + }, + } + for umo in disabled_umos + ] + if rule is None: + return {"type": "and", "children": excludes} + return {"type": "and", "children": [rule, *excludes]} + + +def _build_plugin_filter(plugin_cfg: dict | None) -> dict | None: + if not isinstance(plugin_cfg, dict): + return None + enabled = plugin_cfg.get("enabled_plugins") or [] + disabled = plugin_cfg.get("disabled_plugins") or [] + if disabled: + return {"mode": "blacklist", "plugins": list(disabled)} + if enabled: + return {"mode": "whitelist", "plugins": list(enabled)} + return None + + +def _read_legacy_enabled(value: object, default: bool = False) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"true", "1", "yes", "on"}: + return True + if lowered in {"false", "0", "no", "off"}: + return False + return bool(value) + + +def _has_kb_binding(conf: dict) -> bool: + kb_names = conf.get("kb_names") + if isinstance(kb_names, list): + for kb_name in kb_names: + if str(kb_name or "").strip(): + return True + + default_kb_collection = str(conf.get("default_kb_collection", "") or "").strip() + return bool(default_kb_collection) + + +def _build_nodes_for_config(conf: dict) -> list[str]: + provider_settings = conf.get("provider_settings", {}) or {} + file_extract_cfg = provider_settings.get("file_extract", {}) or {} + stt_settings = conf.get("provider_stt_settings", {}) or {} + tts_settings = conf.get("provider_tts_settings", {}) or {} + + nodes: list[str] = [] + + if _read_legacy_enabled(stt_settings.get("enable"), False): + nodes.append("stt") + + if _read_legacy_enabled(file_extract_cfg.get("enable"), False): + nodes.append("file_extract") + + if not _read_legacy_enabled(conf.get("kb_agentic_mode"), False) and _has_kb_binding( + conf + ): + nodes.append("knowledge_base") + + nodes.append("agent") + + if _read_legacy_enabled(tts_settings.get("enable"), False): + nodes.append("tts") + + if _read_legacy_enabled(conf.get("t2i"), False): + nodes.append("t2i") + + return nodes + + +def _build_chain_runtime_flags_for_config(conf: dict) -> dict[str, bool]: + provider_settings = conf.get("provider_settings", {}) or {} + stt_settings = conf.get("provider_stt_settings", {}) or {} + tts_settings = conf.get("provider_tts_settings", {}) or {} + + return { + FEATURE_LLM: _read_legacy_enabled(provider_settings.get("enable"), True), + FEATURE_STT: _read_legacy_enabled(stt_settings.get("enable"), False), + FEATURE_TTS: _read_legacy_enabled(tts_settings.get("enable"), False), + FEATURE_T2I: _read_legacy_enabled(conf.get("t2i"), False), + } + + +async def _apply_chain_runtime_flags( + chain_runtime_flags: list[tuple[str, dict[str, bool]]], +) -> None: + if not chain_runtime_flags: + return + + raw = await sp.global_get("chain_runtime_flags", {}) + all_flags = raw if isinstance(raw, dict) else {} + + for chain_id, flags in chain_runtime_flags: + existing_flags = all_flags.get(chain_id) + merged_flags = existing_flags if isinstance(existing_flags, dict) else {} + merged_flags.update( + { + FEATURE_LLM: bool(flags.get(FEATURE_LLM, True)), + FEATURE_STT: bool(flags.get(FEATURE_STT, False)), + FEATURE_TTS: bool(flags.get(FEATURE_TTS, False)), + FEATURE_T2I: bool(flags.get(FEATURE_T2I, False)), + } + ) + all_flags[chain_id] = merged_flags + + await sp.global_put("chain_runtime_flags", all_flags) + + +def _infer_legacy_agent_runner_from_default_provider( + conf: dict, +) -> tuple[str, str] | None: + provider_settings = conf.get("provider_settings", {}) or {} + default_provider_id = str( + provider_settings.get("default_provider_id", "") or "" + ).strip() + if not default_provider_id: + return None + + for provider in conf.get("provider", []) or []: + if provider.get("id") != default_provider_id: + continue + provider_type = str(provider.get("type") or "").strip().lower() + if provider_type in AGENT_RUNNER_PROVIDER_KEY: + return provider_type, default_provider_id + return None + + return None + + +def _resolve_legacy_agent_node_config(conf: dict) -> dict: + provider_settings = conf.get("provider_settings", {}) or {} + + runner_type = normalize_agent_runner_type( + provider_settings.get("agent_runner_type") + ) + provider_key = AGENT_RUNNER_PROVIDER_KEY.get(runner_type, "") + provider_id = str(provider_settings.get(provider_key, "") or "").strip() + + if runner_type in AGENT_RUNNER_PROVIDER_KEY and not provider_id: + inferred = _infer_legacy_agent_runner_from_default_provider(conf) + if inferred is not None: + runner_type, provider_id = inferred + + node_conf: dict[str, object] = { + "agent_runner_type": runner_type, + "provider_id": str( + provider_settings.get("default_provider_id", "") or "" + ).strip(), + } + for key in AGENT_RUNNER_PROVIDER_KEY.values(): + node_conf[key] = "" + if runner_type in AGENT_RUNNER_PROVIDER_KEY: + node_conf[AGENT_RUNNER_PROVIDER_KEY[runner_type]] = provider_id + + return node_conf + + +def _apply_node_defaults(chain_id: str, nodes: list, conf: dict) -> None: + node_map = {node.name: node.uuid for node in nodes if node.name} + + if "t2i" in node_map: + t2i_conf = { + "word_threshold": conf.get("t2i_word_threshold", 150), + "strategy": conf.get("t2i_strategy", "remote"), + "active_template": conf.get("t2i_active_template", ""), + "use_file_service": conf.get("t2i_use_file_service", False), + } + AstrBotNodeConfig.get_cached( + node_name="t2i", + chain_id=chain_id, + node_uuid=node_map["t2i"], + schema={}, + ).save_config(t2i_conf) + + if "stt" in node_map: + stt_settings = conf.get("provider_stt_settings", {}) or {} + stt_conf = { + "provider_id": str(stt_settings.get("provider_id", "") or "").strip(), + } + AstrBotNodeConfig.get_cached( + node_name="stt", + chain_id=chain_id, + node_uuid=node_map["stt"], + schema={}, + ).save_config(stt_conf) + + if "tts" in node_map: + tts_settings = conf.get("provider_tts_settings", {}) or {} + tts_conf = { + "trigger_probability": tts_settings.get("trigger_probability", 1.0), + "use_file_service": tts_settings.get("use_file_service", False), + "dual_output": tts_settings.get("dual_output", False), + "provider_id": str(tts_settings.get("provider_id", "") or "").strip(), + } + AstrBotNodeConfig.get_cached( + node_name="tts", + chain_id=chain_id, + node_uuid=node_map["tts"], + schema={}, + ).save_config(tts_conf) + + if "file_extract" in node_map: + file_extract_cfg = ( + conf.get("provider_settings", {}).get("file_extract", {}) or {} + ) + file_extract_conf = { + "provider": file_extract_cfg.get("provider", "moonshotai"), + "moonshotai_api_key": file_extract_cfg.get("moonshotai_api_key", ""), + } + AstrBotNodeConfig.get_cached( + node_name="file_extract", + chain_id=chain_id, + node_uuid=node_map["file_extract"], + schema={}, + ).save_config(file_extract_conf) + + if "agent" in node_map: + agent_conf = _resolve_legacy_agent_node_config(conf) + AstrBotNodeConfig.get_cached( + node_name="agent", + chain_id=chain_id, + node_uuid=node_map["agent"], + schema={}, + ).save_config(agent_conf) + + +async def migrate_4_to_5( + db_helper: BaseDatabase, + acm: AstrBotConfigManager, + ucr: UmopConfigRouter, +) -> None: + """Migrate UMOP/session-manager rules to chain configs.""" + if await sp.global_get(_MIGRATION_FLAG, False): + await _run_provider_cleanup_for_v5(db_helper, acm) + return + + try: + await ucr.initialize() + except Exception as e: + logger.warning(f"Failed to initialize UmopConfigRouter: {e!s}") + + # Skip if chain configs already exist. + async with db_helper.get_db() as session: + session: AsyncSession + result = await session.execute(select(ChainConfigModel)) + existing = list(result.scalars().all()) + if existing: + logger.info( + "Chain configs already exist, skip 4->5 migration to avoid conflicts." + ) + await sp.global_put(_MIGRATION_FLAG, True) + await _run_provider_cleanup_for_v5(db_helper, acm) + return + + # Load session-level rules from preferences. + async with db_helper.get_db() as session: + session: AsyncSession + result = await session.execute( + select(Preference).where( + col(Preference.scope) == "umo", + col(Preference.key).in_(_SESSION_RULE_KEYS), + ) + ) + prefs = list(result.scalars().all()) + + rules_map: dict[str, dict[str, Any]] = {} + for pref in prefs: + umo = pref.scope_id + rules_map.setdefault(umo, {}) + if isinstance(pref.value, dict): + value = pref.value.get("val") + else: + value = pref.value + if pref.key == "session_plugin_config": + if isinstance(value, dict): + if isinstance(value.get(umo), dict): + rules_map[umo][pref.key] = value.get(umo) + elif "enabled_plugins" in value or "disabled_plugins" in value: + rules_map[umo][pref.key] = value + else: + rules_map[umo][pref.key] = value + + disabled_umos: list[str] = [] + session_chains: list[ChainConfigModel] = [] + umop_chains: list[ChainConfigModel] = [] + node_defaults: list[tuple[str, list, dict]] = [] + runtime_flag_defaults: list[tuple[str, dict[str, bool]]] = [] + + def get_config(config_id: str | None) -> dict: + if config_id and config_id in acm.confs: + return acm.confs[config_id] + return acm.confs["default"] + + # Build chains for session-specific rules. + for umo, rules in rules_map.items(): + service_cfg = rules.get("session_service_config") or {} + if ( + isinstance(service_cfg, dict) + and service_cfg.get("session_enabled") is False + ): + disabled_umos.append(umo) + continue + + plugin_filter = _build_plugin_filter(rules.get("session_plugin_config")) + needs_chain = bool(plugin_filter) + + if not needs_chain: + continue + + config_id = None + try: + config_id = ucr.get_config_id_for_umop(umo) + except Exception: + config_id = None + if config_id not in acm.confs: + config_id = "default" + + conf = get_config(config_id) + chain_id = str(uuid.uuid4()) + nodes_list = _build_nodes_for_config(conf) + normalized_nodes = normalize_chain_nodes(nodes_list, chain_id) + nodes_payload = serialize_chain_nodes(normalized_nodes) + + chain = ChainConfigModel( + chain_id=chain_id, + match_rule=_build_umo_rule(umo), + sort_order=0, + enabled=True, + nodes=nodes_payload, + plugin_filter=plugin_filter, + config_id=config_id, + ) + session_chains.append(chain) + node_defaults.append((chain_id, normalized_nodes, conf)) + runtime_flag_defaults.append( + (chain_id, _build_chain_runtime_flags_for_config(conf)) + ) + + # Build chains for UMOP routing. + for pattern, config_id in (ucr.umop_to_config_id or {}).items(): + norm = _normalize_umop_pattern(pattern) + if not norm: + continue + + if config_id not in acm.confs: + config_id = "default" + + conf = get_config(config_id) + chain_id = str(uuid.uuid4()) + nodes_list = _build_nodes_for_config(conf) + normalized_nodes = normalize_chain_nodes(nodes_list, chain_id) + nodes_payload = serialize_chain_nodes(normalized_nodes) + + chain = ChainConfigModel( + chain_id=chain_id, + match_rule=_build_umo_rule(norm), + sort_order=0, + enabled=True, + nodes=nodes_payload, + plugin_filter=None, + config_id=config_id, + ) + umop_chains.append(chain) + node_defaults.append((chain_id, normalized_nodes, conf)) + runtime_flag_defaults.append( + (chain_id, _build_chain_runtime_flags_for_config(conf)) + ) + + # Always create a default chain for legacy behavior. + default_conf = get_config("default") + default_nodes_list = _build_nodes_for_config(default_conf) + default_nodes = normalize_chain_nodes(default_nodes_list, "default") + default_nodes_payload = serialize_chain_nodes(default_nodes) + default_rule = _merge_excludes(None, disabled_umos) + default_chain = ChainConfigModel( + chain_id="default", + match_rule=default_rule, + sort_order=-1, + enabled=True, + nodes=default_nodes_payload, + plugin_filter=None, + config_id="default", + ) + node_defaults.append(("default", default_nodes, default_conf)) + runtime_flag_defaults.append( + ("default", _build_chain_runtime_flags_for_config(default_conf)) + ) + + # Apply disabled session exclusions. + if disabled_umos: + for chain in session_chains + umop_chains: + chain.match_rule = _merge_excludes(chain.match_rule, disabled_umos) + + # Assign sort_order (higher value -> higher priority) + ordered_chains = [*session_chains, *umop_chains] + total = len(ordered_chains) + for idx, chain in enumerate(ordered_chains): + chain.sort_order = total - 1 - idx + + async with db_helper.get_db() as session: + session: AsyncSession + session.add_all([*ordered_chains, default_chain]) + await session.commit() + + # Apply node config defaults from legacy config. + for chain_id, nodes, conf in node_defaults: + try: + _apply_node_defaults(chain_id, nodes, conf) + except Exception as e: + logger.warning(f"Failed to apply node defaults for chain {chain_id}: {e!s}") + + try: + await _apply_chain_runtime_flags(runtime_flag_defaults) + except Exception as e: + logger.warning( + f"Failed to apply chain runtime flags during 4->5 migration: {e!s}" + ) + + await sp.global_put(_MIGRATION_FLAG, True) + logger.info("Migration from v4 to v5 completed successfully.") + await _run_provider_cleanup_for_v5(db_helper, acm) + + +def _read_provider_id(value: object) -> str: + if not isinstance(value, str): + return "" + return value.strip() + + +async def _migrate_chain_provider_columns_to_node_config( + db_helper: BaseDatabase, +) -> None: + table = ChainConfigModel.__tablename__ + + async with db_helper.get_db() as session: + cols = await session.execute(text(f"PRAGMA table_info({table})")) + names = {row[1] for row in cols.fetchall()} + has_tts = "tts_provider_id" in names + has_stt = "stt_provider_id" in names + if not has_tts and not has_stt: + return + + select_cols = ["chain_id", "nodes"] + if has_tts: + select_cols.append("tts_provider_id") + if has_stt: + select_cols.append("stt_provider_id") + rows = ( + await session.execute(text(f"SELECT {', '.join(select_cols)} FROM {table}")) + ).fetchall() + + for row in rows: + chain_id = row[0] + raw_nodes = row[1] + offset = 2 + tts_provider_id = _read_provider_id(row[offset]) if has_tts else "" + if has_tts: + offset += 1 + stt_provider_id = _read_provider_id(row[offset]) if has_stt else "" + + if isinstance(raw_nodes, str): + try: + raw_nodes = json.loads(raw_nodes) + except Exception: + raw_nodes = [] + + nodes = normalize_chain_nodes(raw_nodes, chain_id) + for node in nodes: + if node.name == "tts" and tts_provider_id: + cfg = AstrBotNodeConfig.get_cached( + node_name="tts", + chain_id=chain_id, + node_uuid=node.uuid, + schema={}, + ) + if not cfg.get("provider_id"): + cfg.save_config({"provider_id": tts_provider_id}) + if node.name == "stt" and stt_provider_id: + cfg = AstrBotNodeConfig.get_cached( + node_name="stt", + chain_id=chain_id, + node_uuid=node.uuid, + schema={}, + ) + if not cfg.get("provider_id"): + cfg.save_config({"provider_id": stt_provider_id}) + + +async def _drop_chain_provider_columns(db_helper: BaseDatabase) -> None: + table = ChainConfigModel.__tablename__ + + async with db_helper.get_db() as session: + cols = await session.execute(text(f"PRAGMA table_info({table})")) + names = {row[1] for row in cols.fetchall()} + has_tts = "tts_provider_id" in names + has_stt = "stt_provider_id" in names + if not has_tts and not has_stt: + return + + try: + if has_tts: + await session.execute( + text(f"ALTER TABLE {table} DROP COLUMN tts_provider_id") + ) + if has_stt: + await session.execute( + text(f"ALTER TABLE {table} DROP COLUMN stt_provider_id") + ) + await session.commit() + return + except Exception: + await session.rollback() + + old_table = f"{table}_old_v5_cleanup" + await session.execute(text(f"ALTER TABLE {table} RENAME TO {old_table}")) + await session.execute( + text( + f""" + CREATE TABLE {table} ( + id INTEGER NOT NULL PRIMARY KEY, + chain_id VARCHAR(36) NOT NULL UNIQUE, + match_rule JSON, + sort_order INTEGER NOT NULL DEFAULT 0, + enabled BOOLEAN NOT NULL DEFAULT 1, + nodes JSON, + plugin_filter JSON, + config_id VARCHAR(36), + created_at DATETIME, + updated_at DATETIME + ) + """ + ) + ) + await session.execute( + text( + f""" + INSERT INTO {table} ( + id, + chain_id, + match_rule, + sort_order, + enabled, + nodes, + plugin_filter, + config_id, + created_at, + updated_at + ) + SELECT + id, + chain_id, + match_rule, + sort_order, + enabled, + nodes, + plugin_filter, + config_id, + created_at, + updated_at + FROM {old_table} + """ + ) + ) + await session.execute(text(f"DROP TABLE {old_table}")) + await session.commit() + + +def _cleanup_legacy_provider_config_keys(acm: AstrBotConfigManager) -> None: + for conf in acm.confs.values(): + changed = False + if "provider_tts_settings" in conf: + conf.pop("provider_tts_settings", None) + changed = True + if "provider_stt_settings" in conf: + conf.pop("provider_stt_settings", None) + changed = True + + provider_settings = conf.get("provider_settings") + if isinstance(provider_settings, dict): + for legacy_key in ( + "enable", + "agent_runner_type", + "dify_agent_runner_provider_id", + "coze_agent_runner_provider_id", + "dashscope_agent_runner_provider_id", + ): + if legacy_key in provider_settings: + provider_settings.pop(legacy_key, None) + changed = True + + if changed: + conf.save_config() + + +async def _run_provider_cleanup_for_v5( + db_helper: BaseDatabase, + acm: AstrBotConfigManager, +) -> None: + if await sp.global_get(_MIGRATION_PROVIDER_CLEANUP_FLAG, False): + return + + logger.info("Starting v5 provider cleanup migration") + await _migrate_chain_provider_columns_to_node_config(db_helper) + _cleanup_legacy_provider_config_keys(acm) + await _drop_chain_provider_columns(db_helper) + await sp.global_put(_MIGRATION_PROVIDER_CLEANUP_FLAG, True) + logger.info("v5 provider cleanup migration completed") diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 44cdccb83..7c0e5c130 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -1,66 +1,139 @@ -"""事件总线, 用于处理事件的分发和处理 -事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理 -其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑 +"""事件总线 - 消息队列消费 + Pipeline 分发""" -class: - EventBus: 事件总线, 用于处理事件的分发和处理 - -工作流程: -1. 维护一个异步队列, 来接受各种消息事件 -2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑 -""" +from __future__ import annotations import asyncio from asyncio import Queue +from typing import TYPE_CHECKING from astrbot.core import logger -from astrbot.core.astrbot_config_mgr import AstrBotConfigManager -from astrbot.core.pipeline.scheduler import PipelineScheduler +from astrbot.core.pipeline.engine.wait_registry import build_wait_key, wait_registry +from astrbot.core.star.modality import extract_modalities -from .platform import AstrMessageEvent +if TYPE_CHECKING: + from astrbot.core.astrbot_config_mgr import AstrBotConfigManager + from astrbot.core.pipeline.engine.executor import PipelineExecutor + from astrbot.core.pipeline.engine.router import ChainRouter + from astrbot.core.platform.astr_message_event import AstrMessageEvent class EventBus: - """用于处理事件的分发和处理""" - def __init__( self, event_queue: Queue, - pipeline_scheduler_mapping: dict[str, PipelineScheduler], + pipeline_executor_mapping: dict[str, PipelineExecutor], astrbot_config_mgr: AstrBotConfigManager, + chain_router: ChainRouter, ) -> None: - self.event_queue = event_queue # 事件队列 - # abconf uuid -> scheduler - self.pipeline_scheduler_mapping = pipeline_scheduler_mapping + self.event_queue = event_queue + self.pipeline_executor_mapping = pipeline_executor_mapping self.astrbot_config_mgr = astrbot_config_mgr + self.chain_router = chain_router async def dispatch(self) -> None: + """消息队列消费循环""" while True: event: AstrMessageEvent = await self.event_queue.get() - conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) - self._print_event(event, conf_info["name"]) - scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"]) - if not scheduler: - logger.error( - f"PipelineScheduler not found for id: {conf_info['id']}, event ignored." + + wait_state = await wait_registry.pop(build_wait_key(event)) + if wait_state is not None: + event.message_str = event.message_str.strip() + + current_chain_config = self.chain_router.get_by_chain_id( + wait_state.chain_config.chain_id + ) + + if not wait_state.is_valid(current_chain_config): + logger.debug( + f"WaitState invalidated for {event.unified_msg_origin}, " + "falling back to normal routing." + ) + modality = extract_modalities(event.get_messages()) + routed_chain_config = self.chain_router.route( + event.unified_msg_origin, + modality, + event.message_str, + ) + if routed_chain_config is None: + logger.debug( + f"No chain matched for {event.unified_msg_origin}, " + "event ignored." + ) + continue + + if not self._dispatch_with_chain_config( + event, + routed_chain_config, + ): + continue + continue + + event.chain_config = wait_state.chain_config + event.set_extra("_resume_node_uuid", wait_state.node_uuid) + + if not self._dispatch_with_chain_config(event, wait_state.chain_config): + continue + continue + + event.message_str = event.message_str.strip() + modality = extract_modalities(event.get_messages()) + chain_config = self.chain_router.route( + event.unified_msg_origin, + modality, + event.message_str, + ) + if chain_config is None: + logger.debug( + f"No chain matched for {event.unified_msg_origin}, event ignored." ) continue - asyncio.create_task(scheduler.execute(event)) - def _print_event(self, event: AstrMessageEvent, conf_name: str) -> None: - """用于记录事件信息 + event.chain_config = chain_config + if not self._dispatch_with_chain_config(event, chain_config): + continue - Args: - event (AstrMessageEvent): 事件对象 + @staticmethod + def _print_event(event: AstrMessageEvent, conf_name: str) -> None: + """记录事件信息""" + sender = event.get_sender_name() + sender_id = event.get_sender_id() + platform_id = event.get_platform_id() + platform_name = event.get_platform_name() + outline = event.get_message_outline() - """ - # 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要 - if event.get_sender_name(): + if sender: logger.info( - f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}", + f"[{conf_name}] [{platform_id}({platform_name})] " + f"{sender}/{sender_id}: {outline}" ) - # 没有发送者名称: [平台名] 发送者ID: 消息概要 else: logger.info( - f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}", + f"[{conf_name}] [{platform_id}({platform_name})] {sender_id}: {outline}" ) + + def _dispatch_with_chain_config( + self, + event: AstrMessageEvent, + chain_config, + ) -> bool: + event.chain_config = chain_config + config_id = chain_config.config_id or "default" + self.astrbot_config_mgr.set_runtime_config_id( + event.unified_msg_origin, + config_id, + ) + config_info = self.astrbot_config_mgr.get_config_info_by_id(config_id) + self._print_event(event, config_info["name"]) + + executor = self.pipeline_executor_mapping.get(config_id) + if executor is None: + executor = self.pipeline_executor_mapping.get("default") + if executor is None: + logger.error( + f"PipelineExecutor not found for config_id: {config_id}, event ignored." + ) + return False + + # 分发到 Pipeline(fire-and-forget) + asyncio.create_task(executor.execute(event)) + return True diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index eba6a4fd6..7972d17c3 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -249,3 +249,34 @@ def is_llm_result(self) -> bool: # 为了兼容旧版代码,保留 CommandResult 的别名 CommandResult = MessageEventResult + + +async def collect_streaming_result( + result: MessageEventResult, + *, + warn: bool = False, + logger=None, +) -> MessageEventResult: + """Collect streaming content into a non-streaming MessageEventResult.""" + if result.result_content_type != ResultContentType.STREAMING_RESULT: + return result + if result.async_stream is None: + return result + + if warn and logger is not None: + logger.warning( + "Streaming result detected during node merge; collecting before merge." + ) + + parts: list[str] = [] + async for chunk in result.async_stream: + if hasattr(chunk, "chain") and chunk.chain: + for comp in chunk.chain: + if isinstance(comp, Plain): + parts.append(comp.text) + + collected_text = "".join(parts) + result.chain = [Plain(collected_text)] if collected_text else [] + result.result_content_type = ResultContentType.LLM_RESULT + result.async_stream = None + return result diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py index f3633f20f..984c496ce 100644 --- a/astrbot/core/persona_mgr.py +++ b/astrbot/core/persona_mgr.py @@ -44,9 +44,13 @@ async def get_persona(self, persona_id: str): async def get_default_persona_v3( self, umo: str | MessageSession | None = None, + config_id: str | None = None, ) -> Personality: """获取默认 persona""" - cfg = self.acm.get_conf(umo) + if config_id: + cfg = self.acm.get_conf_by_id(config_id) + else: + cfg = self.acm.get_conf(umo) default_persona_id = cfg.get("provider_settings", {}).get( "default_personality", "default", diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index 75fef84d3..e69de29bb 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -1,41 +0,0 @@ -from astrbot.core.message.message_event_result import ( - EventResultType, - MessageEventResult, -) - -from .content_safety_check.stage import ContentSafetyCheckStage -from .preprocess_stage.stage import PreProcessStage -from .process_stage.stage import ProcessStage -from .rate_limit_check.stage import RateLimitStage -from .respond.stage import RespondStage -from .result_decorate.stage import ResultDecorateStage -from .session_status_check.stage import SessionStatusCheckStage -from .waking_check.stage import WakingCheckStage -from .whitelist_check.stage import WhitelistCheckStage - -# 管道阶段顺序 -STAGES_ORDER = [ - "WakingCheckStage", # 检查是否需要唤醒 - "WhitelistCheckStage", # 检查是否在群聊/私聊白名单 - "SessionStatusCheckStage", # 检查会话是否整体启用 - "RateLimitStage", # 检查会话是否超过频率限制 - "ContentSafetyCheckStage", # 检查内容安全 - "PreProcessStage", # 预处理 - "ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用 - "ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等 - "RespondStage", # 发送消息 -] - -__all__ = [ - "ContentSafetyCheckStage", - "EventResultType", - "MessageEventResult", - "PreProcessStage", - "ProcessStage", - "RateLimitStage", - "RespondStage", - "ResultDecorateStage", - "SessionStatusCheckStage", - "WakingCheckStage", - "WhitelistCheckStage", -] diff --git a/astrbot/core/pipeline/agent/__init__.py b/astrbot/core/pipeline/agent/__init__.py new file mode 100644 index 000000000..c24449a49 --- /dev/null +++ b/astrbot/core/pipeline/agent/__init__.py @@ -0,0 +1,3 @@ +from .executor import AgentExecutor + +__all__ = ["AgentExecutor"] diff --git a/astrbot/core/pipeline/agent/executor.py b/astrbot/core/pipeline/agent/executor.py new file mode 100644 index 000000000..d82e82632 --- /dev/null +++ b/astrbot/core/pipeline/agent/executor.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from astrbot.core import logger +from astrbot.core.pipeline.agent.internal import InternalAgentExecutor +from astrbot.core.pipeline.agent.runner_config import resolve_agent_runner_config +from astrbot.core.pipeline.agent.third_party import ThirdPartyAgentExecutor +from astrbot.core.pipeline.agent.types import AgentRunOutcome +from astrbot.core.pipeline.context import PipelineContext +from astrbot.core.pipeline.engine.chain_runtime_flags import ( + FEATURE_LLM, + is_chain_runtime_feature_enabled, +) +from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +class AgentExecutor: + """Native agent executor for the new pipeline.""" + + async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx + self.config = ctx.astrbot_config + + self.bot_wake_prefixs: list[str] = self.config["wake_prefix"] + self.prov_wake_prefix: str = self.config["provider_settings"]["wake_prefix"] + for bwp in self.bot_wake_prefixs: + if self.prov_wake_prefix.startswith(bwp): + logger.info( + f"识别 LLM 聊天额外唤醒前缀 {self.prov_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。", + ) + self.prov_wake_prefix = self.prov_wake_prefix[len(bwp) :] + + self.internal_executor = InternalAgentExecutor() + self.third_party_executor = ThirdPartyAgentExecutor() + await self.internal_executor.initialize(ctx) + await self.third_party_executor.initialize(ctx) + + async def run(self, event: AstrMessageEvent) -> AgentRunOutcome: + outcome = AgentRunOutcome() + chain_id = event.chain_config.chain_id if event.chain_config else None + if not await is_chain_runtime_feature_enabled(chain_id, FEATURE_LLM): + logger.debug( + "Current chain runtime LLM switch is disabled, skip processing." + ) + return outcome + + runner_type, provider_id = resolve_agent_runner_config( + event.node_config if isinstance(event.node_config, dict) else None, + ) + + if runner_type == "local": + return await self.internal_executor.run(event, self.prov_wake_prefix) + + return await self.third_party_executor.run( + event, + self.prov_wake_prefix, + runner_type=runner_type, + provider_id=provider_id, + ) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/agent/internal.py similarity index 67% rename from astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py rename to astrbot/core/pipeline/agent/internal.py index d26f67add..1401efe28 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/agent/internal.py @@ -1,13 +1,13 @@ -"""本地 Agent 模式的 LLM 调用 Stage""" +"""本地 Agent 模式的 LLM 执行器""" import asyncio import base64 -from collections.abc import AsyncGenerator from dataclasses import replace from astrbot.core import logger from astrbot.core.agent.message import Message from astrbot.core.agent.response import AgentStats +from astrbot.core.astr_agent_run_util import run_agent, run_live_agent from astrbot.core.astr_main_agent import ( MainAgentBuildConfig, MainAgentBuildResult, @@ -19,21 +19,16 @@ MessageEventResult, ResultContentType, ) +from astrbot.core.pipeline.agent.types import AgentRunOutcome +from astrbot.core.pipeline.context import PipelineContext, call_event_hook from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.provider.entities import ( - LLMResponse, - ProviderRequest, -) +from astrbot.core.provider.entities import LLMResponse, ProviderRequest from astrbot.core.star.star_handler import EventType from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager -from .....astr_agent_run_util import run_agent, run_live_agent -from ....context import PipelineContext, call_event_hook -from ...stage import Stage - -class InternalAgentSubStage(Stage): +class InternalAgentExecutor: async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx conf = ctx.astrbot_config @@ -61,14 +56,6 @@ async def initialize(self, ctx: PipelineContext) -> None: ) self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False) - file_extract_conf: dict = settings.get("file_extract", {}) - self.file_extract_enabled: bool = file_extract_conf.get("enable", False) - self.file_extract_prov: str = file_extract_conf.get("provider", "moonshotai") - self.file_extract_msh_api_key: str = file_extract_conf.get( - "moonshotai_api_key", "" - ) - - # 上下文管理相关 self.context_limit_reached_strategy: str = settings.get( "context_limit_reached_strategy", "truncate_by_turns" ) @@ -79,7 +66,7 @@ async def initialize(self, ctx: PipelineContext) -> None: self.llm_compress_provider_id: str = settings.get( "llm_compress_provider_id", "" ) - self.max_context_length = settings["max_context_length"] # int + self.max_context_length = settings["max_context_length"] self.dequeue_context_length: int = min( max(1, settings["dequeue_context_length"]), self.max_context_length - 1, @@ -92,10 +79,9 @@ async def initialize(self, ctx: PipelineContext) -> None: "safety_mode_strategy", "system_prompt" ) - self.computer_use_runtime = settings.get("computer_use_runtime") + self.computer_use_runtime = settings.get("computer_use_runtime", "local") self.sandbox_cfg = settings.get("sandbox", {}) - # Proactive capability configuration proactive_cfg = settings.get("proactive_capability", {}) self.add_cron_tools = proactive_cfg.get("add_cron_tools", True) @@ -106,9 +92,6 @@ async def initialize(self, ctx: PipelineContext) -> None: tool_schema_mode=self.tool_schema_mode, sanitize_context_by_modalities=self.sanitize_context_by_modalities, kb_agentic_mode=self.kb_agentic_mode, - file_extract_enabled=self.file_extract_enabled, - file_extract_prov=self.file_extract_prov, - file_extract_msh_api_key=self.file_extract_msh_api_key, context_limit_reached_strategy=self.context_limit_reached_strategy, llm_compress_instruction=self.llm_compress_instruction, llm_compress_keep_recent=self.llm_compress_keep_recent, @@ -122,12 +105,13 @@ async def initialize(self, ctx: PipelineContext) -> None: add_cron_tools=self.add_cron_tools, provider_settings=settings, subagent_orchestrator=conf.get("subagent_orchestrator", {}), - timezone=self.ctx.plugin_manager.context.get_config().get("timezone"), + timezone=conf.get("timezone"), ) - async def process( + async def run( self, event: AstrMessageEvent, provider_wake_prefix: str - ) -> AsyncGenerator[None, None]: + ) -> AgentRunOutcome: + outcome = AgentRunOutcome() try: streaming_response = self.streaming_response if (enable_streaming := event.get_extra("enable_streaming")) is not None: @@ -145,7 +129,7 @@ async def process( and not has_media_content ): logger.debug("skip llm request: empty message and no provider_request") - return + return outcome logger.debug("ready to request llm provider") @@ -168,21 +152,23 @@ async def process( ) if build_result is None: - return + return outcome agent_runner = build_result.agent_runner req = build_result.provider_request provider = build_result.provider reset_coro = build_result.reset_coro + outcome.handled = True api_base = provider.provider_config.get("api_base", "") for host in decoded_blocked: if host in api_base: logger.error( - "Provider API base %s is blocked due to security reasons. Please use another ai provider.", + "Provider API base %s is blocked due to security reasons. " + "Please use another ai provider.", api_base, ) - return + return outcome stream_to_general = ( self.unsupported_streaming_strategy == "turn_off" @@ -190,9 +176,8 @@ async def process( ) if await call_event_hook(event, EventType.OnLLMRequestEvent, req): - return + return outcome - # apply reset if reset_coro: await reset_coro @@ -209,91 +194,123 @@ async def process( }, ) - # 检测 Live Mode if action_type == "live": - # Live Mode: 使用 run_live_agent logger.info("[Internal Agent] 检测到 Live Mode,启用 TTS 处理") - # 获取 TTS Provider tts_provider = ( - self.ctx.plugin_manager.context.get_using_tts_provider( - event.unified_msg_origin + self.ctx.plugin_manager.context.get_tts_provider_for_event( + event ) ) if not tts_provider: logger.warning( - "[Live Mode] TTS Provider 未配置,将使用普通流式模式" + "[Live Mode] TTS Provider not configured, fallback to " + "normal streaming." + ) + + async def wrapped_stream(): + async for chunk in run_live_agent( + agent_runner, + tts_provider, + self.max_step, + self.show_tool_use, + show_reasoning=self.show_reasoning, + ): + yield chunk + + final_resp = agent_runner.get_final_llm_resp() + event.trace.record( + "astr_agent_complete", + stats=agent_runner.stats.to_dict(), + resp=final_resp.completion_text if final_resp else None, + ) + + if not event.is_stopped() and agent_runner.done(): + await self._save_to_history( + event, + req, + final_resp, + agent_runner.run_context.messages, + agent_runner.stats, + ) + + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=agent_runner.provider.get_model(), + provider_type=agent_runner.provider.meta().type, + ), ) - # 使用 run_live_agent,总是使用流式响应 event.set_result( MessageEventResult() .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream( - run_live_agent( - agent_runner, - tts_provider, - self.max_step, - self.show_tool_use, - show_reasoning=self.show_reasoning, - ), - ), + .set_async_stream(wrapped_stream()), ) - yield - - # 保存历史记录 - if not event.is_stopped() and agent_runner.done(): - await self._save_to_history( - event, - req, - agent_runner.get_final_llm_resp(), - agent_runner.run_context.messages, - agent_runner.stats, + outcome.streaming = True + outcome.result = event.get_result() + outcome.stopped = event.is_stopped() + return outcome + + if streaming_response and not stream_to_general: + + async def wrapped_stream(): + async for chunk in run_agent( + agent_runner, + self.max_step, + self.show_tool_use, + show_reasoning=self.show_reasoning, + ): + yield chunk + + final_resp = agent_runner.get_final_llm_resp() + event.trace.record( + "astr_agent_complete", + stats=agent_runner.stats.to_dict(), + resp=final_resp.completion_text if final_resp else None, + ) + + if not event.is_stopped() and agent_runner.done(): + await self._save_to_history( + event, + req, + final_resp, + agent_runner.run_context.messages, + agent_runner.stats, + ) + + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=agent_runner.provider.get_model(), + provider_type=agent_runner.provider.meta().type, + ), ) - elif streaming_response and not stream_to_general: - # 流式响应 event.set_result( MessageEventResult() .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream( - run_agent( - agent_runner, - self.max_step, - self.show_tool_use, - show_reasoning=self.show_reasoning, - ), - ), + .set_async_stream(wrapped_stream()), ) - yield - if agent_runner.done(): - if final_llm_resp := agent_runner.get_final_llm_resp(): - if final_llm_resp.completion_text: - chain = ( - MessageChain() - .message(final_llm_resp.completion_text) - .chain - ) - elif final_llm_resp.result_chain: - chain = final_llm_resp.result_chain.chain - else: - chain = MessageChain().chain - event.set_result( - MessageEventResult( - chain=chain, - result_content_type=ResultContentType.STREAMING_FINISH, - ), - ) - else: - async for _ in run_agent( - agent_runner, - self.max_step, - self.show_tool_use, - stream_to_general, - show_reasoning=self.show_reasoning, - ): - yield + outcome.streaming = True + outcome.result = event.get_result() + outcome.stopped = event.is_stopped() + return outcome + + latest_result: MessageEventResult | None = None + async for _ in run_agent( + agent_runner, + self.max_step, + self.show_tool_use, + stream_to_general, + show_reasoning=self.show_reasoning, + ): + result = event.get_result() + if result: + latest_result = result + if latest_result: + event.set_result(latest_result) final_resp = agent_runner.get_final_llm_resp() @@ -303,7 +320,6 @@ async def process( resp=final_resp.completion_text if final_resp else None, ) - # 检查事件是否被停止,如果被停止则不保存历史记录 if not event.is_stopped(): await self._save_to_history( event, @@ -329,6 +345,10 @@ async def process( ) ) + outcome.result = event.get_result() + outcome.stopped = event.is_stopped() + return outcome + async def _save_to_history( self, event: AstrMessageEvent, @@ -346,7 +366,7 @@ async def _save_to_history( return if not llm_response.completion_text and not req.tool_calls_result: - logger.debug("LLM 响应为空,不保存记录。") + logger.debug("LLM response is empty, skipping history save.") return message_to_save = [] @@ -372,7 +392,5 @@ async def _save_to_history( ) -# we prevent astrbot from connecting to known malicious hosts -# these hosts are base64 encoded BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"} decoded_blocked = [base64.b64decode(b).decode("utf-8") for b in BLOCKED] diff --git a/astrbot/core/pipeline/agent/runner_config.py b/astrbot/core/pipeline/agent/runner_config.py new file mode 100644 index 000000000..410514a47 --- /dev/null +++ b/astrbot/core/pipeline/agent/runner_config.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +SUPPORTED_AGENT_RUNNER_TYPES = ("local", "dify", "coze", "dashscope") +AGENT_RUNNER_PROVIDER_KEY = { + "dify": "dify_agent_runner_provider_id", + "coze": "coze_agent_runner_provider_id", + "dashscope": "dashscope_agent_runner_provider_id", +} + + +def normalize_agent_runner_type(value: object) -> str: + runner_type = str(value or "local").strip().lower() + if runner_type in SUPPORTED_AGENT_RUNNER_TYPES: + return runner_type + return "local" + + +def resolve_agent_runner_config( + node_config: Mapping[str, Any] | None, +) -> tuple[str, str]: + """Resolve agent runner config from node config only.""" + node = node_config if isinstance(node_config, Mapping) else {} + + runner_type = normalize_agent_runner_type(node.get("agent_runner_type", "local")) + provider_key = AGENT_RUNNER_PROVIDER_KEY.get(runner_type, "") + provider_id = str(node.get(provider_key, "") or "").strip() if provider_key else "" + + return runner_type, provider_id diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/agent/third_party.py similarity index 68% rename from astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py rename to astrbot/core/pipeline/agent/third_party.py index b590bd77e..82d7b99b4 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/agent/third_party.py @@ -17,6 +17,11 @@ if TYPE_CHECKING: from astrbot.core.agent.runners.base import BaseAgentRunner +from astrbot.core.astr_agent_context import AgentContextWrapper, AstrAgentContext +from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS +from astrbot.core.pipeline.agent.runner_config import AGENT_RUNNER_PROVIDER_KEY +from astrbot.core.pipeline.agent.types import AgentRunOutcome +from astrbot.core.pipeline.context import PipelineContext, call_event_hook from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider.entities import ( ProviderRequest, @@ -24,17 +29,6 @@ from astrbot.core.star.star_handler import EventType from astrbot.core.utils.metrics import Metric -from .....astr_agent_context import AgentContextWrapper, AstrAgentContext -from .....astr_agent_hooks import MAIN_AGENT_HOOKS -from ....context import PipelineContext, call_event_hook -from ...stage import Stage - -AGENT_RUNNER_TYPE_KEY = { - "dify": "dify_agent_runner_provider_id", - "coze": "coze_agent_runner_provider_id", - "dashscope": "dashscope_agent_runner_provider_id", -} - async def run_third_party_agent( runner: "BaseAgentRunner", @@ -62,43 +56,50 @@ async def run_third_party_agent( yield MessageChain().message(err_msg) -class ThirdPartyAgentSubStage(Stage): +class ThirdPartyAgentExecutor: async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.conf = ctx.astrbot_config - self.runner_type = self.conf["provider_settings"]["agent_runner_type"] - self.prov_id = self.conf["provider_settings"].get( - AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""), - "", - ) settings = ctx.astrbot_config["provider_settings"] self.streaming_response: bool = settings["streaming_response"] self.unsupported_streaming_strategy: str = settings[ "unsupported_streaming_strategy" ] - async def process( - self, event: AstrMessageEvent, provider_wake_prefix: str - ) -> AsyncGenerator[None, None]: + async def run( + self, + event: AstrMessageEvent, + provider_wake_prefix: str, + *, + runner_type: str, + provider_id: str, + ) -> AgentRunOutcome: + outcome = AgentRunOutcome() req: ProviderRequest | None = None if provider_wake_prefix and not event.message_str.startswith( provider_wake_prefix ): - return + return outcome + + if runner_type not in AGENT_RUNNER_PROVIDER_KEY: + logger.error("Unsupported third party agent runner type: %s", runner_type) + return outcome + + if not provider_id: + logger.error("没有填写 Agent Runner 提供商 ID,请在 Agent 节点配置中设置。") + return outcome - self.prov_cfg: dict = next( - (p for p in astrbot_config["provider"] if p["id"] == self.prov_id), + prov_cfg: dict = next( + (p for p in astrbot_config["provider"] if p["id"] == provider_id), {}, ) - if not self.prov_id: - logger.error("没有填写 Agent Runner 提供商 ID,请前往配置页面配置。") - return - if not self.prov_cfg: + if not prov_cfg: logger.error( - f"Agent Runner 提供商 {self.prov_id} 配置不存在,请前往配置页面修改配置。" + "Agent Runner 提供商 %s 配置不存在,请检查 Agent 节点配置。", + provider_id, ) - return + return outcome # make provider request req = ProviderRequest() @@ -110,21 +111,21 @@ async def process( req.image_urls.append(image_path) if not req.prompt and not req.image_urls: - return + return outcome # call event hook if await call_event_hook(event, EventType.OnLLMRequestEvent, req): - return + return outcome - if self.runner_type == "dify": + if runner_type == "dify": runner = DifyAgentRunner[AstrAgentContext]() - elif self.runner_type == "coze": + elif runner_type == "coze": runner = CozeAgentRunner[AstrAgentContext]() - elif self.runner_type == "dashscope": + elif runner_type == "dashscope": runner = DashscopeAgentRunner[AstrAgentContext]() else: raise ValueError( - f"Unsupported third party agent runner type: {self.runner_type}", + f"Unsupported third party agent runner type: {runner_type}", ) astr_agent_ctx = AstrAgentContext( @@ -148,45 +149,50 @@ async def process( tool_call_timeout=60, ), agent_hooks=MAIN_AGENT_HOOKS, - provider_config=self.prov_cfg, + provider_config=prov_cfg, streaming=streaming_response, ) + outcome.handled = True if streaming_response and not stream_to_general: # 流式响应 + async def wrapped_stream(): + async for chunk in run_third_party_agent( + runner, + stream_to_general=False, + ): + yield chunk + + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=runner_type, + provider_type=runner_type, + ), + ) + event.set_result( MessageEventResult() .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream( - run_third_party_agent( - runner, - stream_to_general=False, - ), - ), + .set_async_stream(wrapped_stream()), ) - yield - if runner.done(): - final_resp = runner.get_final_llm_resp() - if final_resp and final_resp.result_chain: - event.set_result( - MessageEventResult( - chain=final_resp.result_chain.chain or [], - result_content_type=ResultContentType.STREAMING_FINISH, - ), - ) + outcome.streaming = True + outcome.result = event.get_result() + outcome.stopped = event.is_stopped() + return outcome else: # 非流式响应或转换为普通响应 async for _ in run_third_party_agent( runner, stream_to_general=stream_to_general, ): - yield + pass final_resp = runner.get_final_llm_resp() if not final_resp or not final_resp.result_chain: logger.warning("Agent Runner 未返回最终结果。") - return + return outcome event.set_result( MessageEventResult( @@ -194,12 +200,15 @@ async def process( result_content_type=ResultContentType.LLM_RESULT, ), ) - yield asyncio.create_task( Metric.upload( llm_tick=1, - model_name=self.runner_type, - provider_type=self.runner_type, + model_name=runner_type, + provider_type=runner_type, ), ) + + outcome.result = event.get_result() + outcome.stopped = event.is_stopped() + return outcome diff --git a/astrbot/core/pipeline/agent/types.py b/astrbot/core/pipeline/agent/types.py new file mode 100644 index 000000000..9b1015021 --- /dev/null +++ b/astrbot/core/pipeline/agent/types.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from astrbot.core.message.message_event_result import MessageEventResult + + +@dataclass(slots=True) +class AgentRunOutcome: + handled: bool = False + streaming: bool = False + result: MessageEventResult | None = None + stopped: bool = False diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py deleted file mode 100644 index 19037eb08..000000000 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ /dev/null @@ -1,41 +0,0 @@ -from collections.abc import AsyncGenerator - -from astrbot.core import logger -from astrbot.core.message.message_event_result import MessageEventResult -from astrbot.core.platform.astr_message_event import AstrMessageEvent - -from ..context import PipelineContext -from ..stage import Stage, register_stage -from .strategies.strategy import StrategySelector - - -@register_stage -class ContentSafetyCheckStage(Stage): - """检查内容安全 - - 当前只会检查文本的。 - """ - - async def initialize(self, ctx: PipelineContext) -> None: - config = ctx.astrbot_config["content_safety"] - self.strategy_selector = StrategySelector(config) - - async def process( - self, - event: AstrMessageEvent, - check_text: str | None = None, - ) -> AsyncGenerator[None, None]: - """检查内容安全""" - text = check_text if check_text else event.get_message_str() - ok, info = self.strategy_selector.check(text) - if not ok: - if event.is_at_or_wake_command: - event.set_result( - MessageEventResult().message( - "你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。", - ), - ) - yield - event.stop_event() - logger.info(f"内容安全检查不通过,原因:{info}") - return diff --git a/astrbot/core/pipeline/content_safety_check/strategies/__init__.py b/astrbot/core/pipeline/content_safety_check/strategies/__init__.py deleted file mode 100644 index f0a34e73f..000000000 --- a/astrbot/core/pipeline/content_safety_check/strategies/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -import abc - - -class ContentSafetyStrategy(abc.ABC): - @abc.abstractmethod - def check(self, content: str) -> tuple[bool, str]: - raise NotImplementedError diff --git a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py deleted file mode 100644 index dd8ca629e..000000000 --- a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +++ /dev/null @@ -1,32 +0,0 @@ -"""使用此功能应该先 pip install baidu-aip""" - -from typing import Any, cast - -from aip import AipContentCensor - -from . import ContentSafetyStrategy - - -class BaiduAipStrategy(ContentSafetyStrategy): - def __init__(self, appid: str, ak: str, sk: str) -> None: - self.app_id = appid - self.api_key = ak - self.secret_key = sk - self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key) - - def check(self, content: str) -> tuple[bool, str]: - res = self.client.textCensorUserDefined(content) - if "conclusionType" not in res: - return False, "" - if res["conclusionType"] == 1: - return True, "" - if "data" not in res: - return False, "" - count = len(res["data"]) - parts = [f"百度审核服务发现 {count} 处违规:\n"] - for i in res["data"]: - # 百度 AIP 返回结构是动态 dict;类型检查时 i 可能被推断为序列,转成 dict 后用 get 取字段 - parts.append(f"{cast(dict[str, Any], i).get('msg', '')};\n") - parts.append("\n判断结果:" + res["conclusion"]) - info = "".join(parts) - return False, info diff --git a/astrbot/core/pipeline/content_safety_check/strategies/keywords.py b/astrbot/core/pipeline/content_safety_check/strategies/keywords.py deleted file mode 100644 index 53ad900f7..000000000 --- a/astrbot/core/pipeline/content_safety_check/strategies/keywords.py +++ /dev/null @@ -1,24 +0,0 @@ -import re - -from . import ContentSafetyStrategy - - -class KeywordsStrategy(ContentSafetyStrategy): - def __init__(self, extra_keywords: list) -> None: - self.keywords = [] - if extra_keywords is None: - extra_keywords = [] - self.keywords.extend(extra_keywords) - # keywords_path = os.path.join(os.path.dirname(__file__), "unfit_words") - # internal keywords - # if os.path.exists(keywords_path): - # with open(keywords_path, "r", encoding="utf-8") as f: - # self.keywords.extend( - # json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"] - # ) - - def check(self, content: str) -> tuple[bool, str]: - for keyword in self.keywords: - if re.search(keyword, content): - return False, "内容安全检查不通过,匹配到敏感词。" - return True, "" diff --git a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py deleted file mode 100644 index c971ef26f..000000000 --- a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py +++ /dev/null @@ -1,34 +0,0 @@ -from astrbot import logger - -from . import ContentSafetyStrategy - - -class StrategySelector: - def __init__(self, config: dict) -> None: - self.enabled_strategies: list[ContentSafetyStrategy] = [] - if config["internal_keywords"]["enable"]: - from .keywords import KeywordsStrategy - - self.enabled_strategies.append( - KeywordsStrategy(config["internal_keywords"]["extra_keywords"]), - ) - if config["baidu_aip"]["enable"]: - try: - from .baidu_aip import BaiduAipStrategy - except ImportError: - logger.warning("使用百度内容审核应该先 pip install baidu-aip") - return - self.enabled_strategies.append( - BaiduAipStrategy( - config["baidu_aip"]["app_id"], - config["baidu_aip"]["api_key"], - config["baidu_aip"]["secret_key"], - ), - ) - - def check(self, content: str) -> tuple[bool, str]: - for strategy in self.enabled_strategies: - ok, info = strategy.check(content) - if not ok: - return False, info - return True, "" diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index a6cd567e0..85b00a2c2 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -1,9 +1,15 @@ from dataclasses import dataclass +from typing import TYPE_CHECKING from astrbot.core.config import AstrBotConfig -from astrbot.core.star import PluginManager -from .context_utils import call_event_hook, call_handler +from .context_utils import call_event_hook + +if TYPE_CHECKING: + from astrbot.core.star import PluginManager + + +__all__ = ["PipelineContext", "call_event_hook"] @dataclass @@ -11,7 +17,4 @@ class PipelineContext: """上下文对象,包含管道执行所需的上下文信息""" astrbot_config: AstrBotConfig # AstrBot 配置对象 - plugin_manager: PluginManager # 插件管理器对象 - astrbot_config_id: str - call_handler = call_handler - call_event_hook = call_event_hook + plugin_manager: "PluginManager" # 插件管理器对象 diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 9402ce3e6..0e22dbd07 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -1,77 +1,12 @@ import inspect import traceback -import typing as T from astrbot import logger -from astrbot.core.message.message_event_result import CommandResult, MessageEventResult from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star import star_map from astrbot.core.star.star_handler import EventType, star_handlers_registry -async def call_handler( - event: AstrMessageEvent, - handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]], - *args, - **kwargs, -) -> T.AsyncGenerator[T.Any, None]: - """执行事件处理函数并处理其返回结果 - - 该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数: - 1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层 - 2. 协程: 执行一次并处理返回值 - - Args: - event (AstrMessageEvent): 事件对象 - handler (Awaitable): 事件处理函数 - - Returns: - AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流 - - """ - ready_to_call = None # 一个协程或者异步生成器 - - trace_ = None - - try: - ready_to_call = handler(event, *args, **kwargs) - except TypeError: - logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True) - - if not ready_to_call: - return - - if inspect.isasyncgen(ready_to_call): - _has_yielded = False - try: - async for ret in ready_to_call: - # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码 - # 返回值只能是 MessageEventResult 或者 None(无返回值) - _has_yielded = True - if isinstance(ret, MessageEventResult | CommandResult): - # 如果返回值是 MessageEventResult, 设置结果并继续 - event.set_result(ret) - yield - else: - # 如果返回值是 None, 则不设置结果并继续 - # 继续执行后续阶段 - yield ret - if not _has_yielded: - # 如果这个异步生成器没有执行到 yield 分支 - yield - except Exception as e: - logger.error(f"Previous Error: {trace_}") - raise e - elif inspect.iscoroutine(ready_to_call): - # 如果只是一个协程, 直接执行 - ret = await ready_to_call - if isinstance(ret, MessageEventResult | CommandResult): - event.set_result(ret) - yield - else: - yield ret - - async def call_event_hook( event: AstrMessageEvent, hook_type: EventType, diff --git a/astrbot/core/pipeline/engine/__init__.py b/astrbot/core/pipeline/engine/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/astrbot/core/pipeline/engine/chain_config.py b/astrbot/core/pipeline/engine/chain_config.py new file mode 100644 index 000000000..ef4227af2 --- /dev/null +++ b/astrbot/core/pipeline/engine/chain_config.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from typing import Any + +from sqlmodel import JSON, Field, SQLModel + +from astrbot.core.db.po import TimestampMixin +from astrbot.core.star.modality import Modality + + +class ChainConfigModel(TimestampMixin, SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + chain_id: str = Field( + max_length=36, + unique=True, + default_factory=lambda: str(uuid.uuid4()), + ) + match_rule: dict | None = Field(default=None, sa_type=JSON) + sort_order: int = Field(default=0) + enabled: bool = Field(default=True) + + nodes: list[dict | str] | None = Field(default=None, sa_type=JSON) + + plugin_filter: dict | None = Field(default=None, sa_type=JSON) + + config_id: str | None = Field(default=None, max_length=36) + + +@dataclass +class PluginFilterConfig: + mode: str = "blacklist" + plugins: list[str] = field(default_factory=list) + + +_NODE_UUID_NAMESPACE = uuid.uuid5(uuid.NAMESPACE_URL, "astrbot.chain.node") + + +@dataclass +class ChainNodeConfig: + name: str + uuid: str + + +def _stable_node_uuid(chain_id: str, name: str, occurrence: int) -> str: + seed = f"{chain_id}:{name}:{occurrence}" + return str(uuid.uuid5(_NODE_UUID_NAMESPACE, seed)) + + +def normalize_chain_nodes( + raw_nodes: list[Any] | None, chain_id: str +) -> list[ChainNodeConfig]: + if not raw_nodes: + return [] + + normalized: list[ChainNodeConfig] = [] + seen_uuids: set[str] = set() + name_occurrence: dict[str, int] = {} + + for entry in raw_nodes: + name: str | None = None + node_uuid: str | None = None + + if isinstance(entry, dict): + name = entry.get("name") or entry.get("node_name") or entry.get("node") + node_uuid = entry.get("uuid") or entry.get("id") + elif isinstance(entry, str): + name = entry + + if not name: + continue + + occurrence = name_occurrence.get(name, 0) + 1 + name_occurrence[name] = occurrence + + if node_uuid is not None: + node_uuid = str(node_uuid).strip() + if not node_uuid or node_uuid in seen_uuids: + node_uuid = None + + if not node_uuid: + node_uuid = _stable_node_uuid(chain_id, name, occurrence) + if node_uuid in seen_uuids: + node_uuid = str(uuid.uuid4()) + + seen_uuids.add(node_uuid) + normalized.append(ChainNodeConfig(name=name, uuid=node_uuid)) + + return normalized + + +def serialize_chain_nodes(nodes: list[ChainNodeConfig]) -> list[dict]: + return [{"name": node.name, "uuid": node.uuid} for node in nodes] + + +def clone_chain_nodes(nodes: list[ChainNodeConfig]) -> list[ChainNodeConfig]: + return [ChainNodeConfig(name=node.name, uuid=node.uuid) for node in nodes] + + +@dataclass +class ChainConfig: + chain_id: str + match_rule: dict | None = None + sort_order: int = 0 + enabled: bool = True + nodes: list[ChainNodeConfig] = field(default_factory=list) + plugin_filter: PluginFilterConfig | None = None + config_id: str | None = None + + def matches( + self, + umo: str, + modality: set[Modality] | None = None, + message_text: str = "", + ) -> bool: + if not self.enabled: + return False + + from astrbot.core.pipeline.engine.rule_matcher import rule_matcher + + return rule_matcher.matches(self.match_rule, umo, modality, message_text) + + @staticmethod + def from_model(model: ChainConfigModel) -> ChainConfig: + if model.nodes is None: + nodes = clone_chain_nodes(DEFAULT_CHAIN_CONFIG.nodes) + else: + nodes = normalize_chain_nodes(model.nodes, model.chain_id) + + plugin_filter = None + if model.plugin_filter: + mode = model.plugin_filter.get("mode", "blacklist") + plugins = model.plugin_filter.get("plugins", []) or [] + plugin_filter = PluginFilterConfig(mode=mode, plugins=list(plugins)) + + return ChainConfig( + chain_id=model.chain_id, + match_rule=model.match_rule, + sort_order=model.sort_order, + enabled=model.enabled, + nodes=nodes, + plugin_filter=plugin_filter, + config_id=model.config_id, + ) + + +_DEFAULT_NODES = normalize_chain_nodes(["agent"], "default") + +DEFAULT_CHAIN_CONFIG = ChainConfig( + chain_id="default", + match_rule=None, # None = match all + sort_order=-1, # Always last + nodes=_DEFAULT_NODES, + config_id="default", +) diff --git a/astrbot/core/pipeline/engine/chain_executor.py b/astrbot/core/pipeline/engine/chain_executor.py new file mode 100644 index 000000000..b8a8399cc --- /dev/null +++ b/astrbot/core/pipeline/engine/chain_executor.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import traceback +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from astrbot.core import logger +from astrbot.core.config import AstrBotNodeConfig +from astrbot.core.message.message_event_result import MessageEventResult +from astrbot.core.pipeline.engine.node_context import NodePacket +from astrbot.core.star import Star +from astrbot.core.star.node_star import NodeResult, NodeStar, is_node_star_metadata +from astrbot.core.star.star import StarMetadata, star_registry + +from .node_context import NodeContext, NodeExecutionStatus +from .wait_registry import WaitState, build_wait_key, wait_registry + +if TYPE_CHECKING: + from astrbot.core.pipeline.agent import AgentExecutor + from astrbot.core.pipeline.engine.chain_config import ChainConfig + from astrbot.core.pipeline.engine.send_service import SendService + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +@dataclass +class ChainExecutionResult: + """Chain 执行结果""" + + # TODO Extend Fields + + should_send: bool = True + + +class ChainExecutor: + """Chain 执行器""" + + @staticmethod + async def execute( + event: AstrMessageEvent, + chain_config: ChainConfig, + send_service: SendService, + agent_executor: AgentExecutor, + start_node_uuid: str | None = None, + ) -> ChainExecutionResult: + """执行 Chain + + Args: + event: 消息事件 + chain_config: Chain 配置 + send_service: 发送服务 + agent_executor: Agent 执行器 + start_node_uuid: 指定节点UUID + Returns: + ChainExecutionResult + """ + result = ChainExecutionResult() + + event.send_service = send_service + event.agent_executor = agent_executor + + nodes = chain_config.nodes + start_chain_index = 0 + + if start_node_uuid: + try: + start_chain_index = next( + idx + for idx, node in enumerate(nodes) + if node.uuid == start_node_uuid + ) + nodes = nodes[start_chain_index:] + except StopIteration: + logger.warning( + f"Start node '{start_node_uuid}' not found in chain, " + "fallback to full chain.", + ) + context_stack = event.context_stack + + for offset, node_entry in enumerate(nodes): + chain_index = start_chain_index + offset + node_name = node_entry.name + + # Create node context + node_ctx = NodeContext( + node_name=node_name, + node_uuid=node_entry.uuid, + chain_index=chain_index, + ) + + # Push first, then set input from last EXECUTED node's output + # This ordering doesn't matter because we use last_executed_output() + # which searches all EXECUTED nodes (current is still PENDING) + context_stack.push(node_ctx) + upstream_output = context_stack.last_executed_output() + if upstream_output is not None: + node_ctx.input = upstream_output + + node: NodeStar | None = None + metadata: StarMetadata | None = None + for m in star_registry: + if m.name != node_name: + continue + metadata = m + if ( + m.activated + and is_node_star_metadata(m) + and isinstance(m.star_cls, NodeStar) + ): + node = m.star_cls + break + + if not node: + logger.error(f"Node unavailable: {node_name}") + node_ctx.status = NodeExecutionStatus.FAILED + return result + + # 加载节点配置 + schema = metadata.node_schema if metadata else None + node_config = AstrBotNodeConfig.get_cached( + node_name=node_name, + chain_id=chain_config.chain_id, + node_uuid=node_entry.uuid, + schema=schema, + ) + event.node_config = node_config + + chain_id = chain_config.chain_id + init_key = (chain_id, node_entry.uuid) + if init_key not in node.initialized_node_keys: + try: + await node.node_initialize() + node.initialized_node_keys.add(init_key) + except Exception as e: + logger.error(f"Node {node_name} initialize error: {e}") + logger.error(traceback.format_exc()) + node_ctx.status = NodeExecutionStatus.FAILED + return result + + try: + node_result = await node.process(event) + + # Unified status mapping + match node_result: + case NodeResult.WAIT: + node_ctx.status = NodeExecutionStatus.WAITING + case NodeResult.SKIP: + node_ctx.status = NodeExecutionStatus.SKIPPED + case _: # CONTINUE / STOP + node_ctx.status = NodeExecutionStatus.EXECUTED + + if node_ctx.status == NodeExecutionStatus.EXECUTED: + ChainExecutor._sync_node_output(event, node_ctx) + + except Exception as e: + node_ctx.status = NodeExecutionStatus.FAILED + logger.error(f"Node {node_name} error: {e}") + logger.error(traceback.format_exc()) + return result + + if event.is_stopped(): + event.set_extra("_node_stop_event", True) + break + if node_result == NodeResult.WAIT: + wait_key = build_wait_key(event) + await wait_registry.set( + wait_key, + WaitState( + chain_config=chain_config, + node_uuid=node_entry.uuid, + ), + ) + result.should_send = False + break + elif node_result == NodeResult.STOP: + break + # CONTINUE / SKIP + + # Fallback to last_output if event.result not set + if result.should_send: + if not event.get_result(): + last_output = context_stack.last_executed_output() + if last_output is not None: + event.set_result(last_output.data) + result.should_send = event.get_result() is not None + + return result + + @staticmethod + def _sync_node_output(event: AstrMessageEvent, node_ctx: NodeContext) -> None: + """Align event result and node output for executed nodes.""" + + evt_result = event.get_result() + if node_ctx.output is None and evt_result is not None: + node_ctx.output = NodePacket.create(evt_result) + return + + if node_ctx.output is not None and evt_result is None: + output = node_ctx.output.data + if isinstance(output, MessageEventResult): + event.set_result(output) + + @property + def nodes(self) -> dict[str | None, Star | None]: + """Get all active nodes.""" + return { + m.name: m.star_cls + for m in star_registry + if m.activated + and m.name + and is_node_star_metadata(m) + and m.star_cls is not None + } diff --git a/astrbot/core/pipeline/engine/chain_runtime_flags.py b/astrbot/core/pipeline/engine/chain_runtime_flags.py new file mode 100644 index 000000000..96ddfc667 --- /dev/null +++ b/astrbot/core/pipeline/engine/chain_runtime_flags.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from astrbot.core import sp + +FEATURE_LLM = "llm" +FEATURE_STT = "stt" +FEATURE_TTS = "tts" +FEATURE_T2I = "t2i" + +DEFAULT_CHAIN_RUNTIME_FLAGS: dict[str, bool] = { + FEATURE_LLM: True, + FEATURE_STT: True, + FEATURE_TTS: True, + FEATURE_T2I: True, +} + + +def _normalize_flags(raw: dict | None) -> dict[str, bool]: + flags = dict(DEFAULT_CHAIN_RUNTIME_FLAGS) + if not isinstance(raw, dict): + return flags + for key in DEFAULT_CHAIN_RUNTIME_FLAGS: + if key in raw: + flags[key] = bool(raw.get(key)) + return flags + + +async def get_chain_runtime_flags(chain_id: str | None) -> dict[str, bool]: + if not chain_id: + return dict(DEFAULT_CHAIN_RUNTIME_FLAGS) + all_flags = await sp.global_get("chain_runtime_flags", {}) + if not isinstance(all_flags, dict): + return dict(DEFAULT_CHAIN_RUNTIME_FLAGS) + return _normalize_flags(all_flags.get(chain_id)) + + +async def set_chain_runtime_flag( + chain_id: str | None, + feature: str, + enabled: bool, +) -> dict[str, bool]: + if not chain_id: + return dict(DEFAULT_CHAIN_RUNTIME_FLAGS) + if feature not in DEFAULT_CHAIN_RUNTIME_FLAGS: + raise ValueError(f"Unsupported chain runtime feature: {feature}") + + all_flags = await sp.global_get("chain_runtime_flags", {}) + if not isinstance(all_flags, dict): + all_flags = {} + + chain_flags = _normalize_flags(all_flags.get(chain_id)) + chain_flags[feature] = bool(enabled) + all_flags[chain_id] = chain_flags + await sp.global_put("chain_runtime_flags", all_flags) + return chain_flags + + +async def toggle_chain_runtime_flag(chain_id: str | None, feature: str) -> bool: + flags = await get_chain_runtime_flags(chain_id) + next_value = not bool(flags.get(feature, True)) + await set_chain_runtime_flag(chain_id, feature, next_value) + return next_value + + +async def is_chain_runtime_feature_enabled(chain_id: str | None, feature: str) -> bool: + flags = await get_chain_runtime_flags(chain_id) + return bool(flags.get(feature, True)) diff --git a/astrbot/core/pipeline/engine/executor.py b/astrbot/core/pipeline/engine/executor.py new file mode 100644 index 000000000..c0b01447d --- /dev/null +++ b/astrbot/core/pipeline/engine/executor.py @@ -0,0 +1,323 @@ +from __future__ import annotations + +import random +import traceback +from typing import TYPE_CHECKING + +from astrbot.core import logger +from astrbot.core.message.components import At, AtAll, Reply +from astrbot.core.pipeline.agent import AgentExecutor +from astrbot.core.pipeline.engine.chain_executor import ChainExecutor +from astrbot.core.pipeline.system.access_control import AccessController +from astrbot.core.pipeline.system.command_dispatcher import CommandDispatcher +from astrbot.core.pipeline.system.event_preprocessor import EventPreprocessor +from astrbot.core.pipeline.system.rate_limit import RateLimiter +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent +from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import ( + WecomAIBotMessageEvent, +) +from astrbot.core.star.node_star import NodeResult + +from .send_service import SendService + +if TYPE_CHECKING: + from astrbot.core.pipeline.context import PipelineContext + from astrbot.core.star.context import Context + + +class PipelineExecutor: + """Pipeline 执行器 + + 协调各组件完成消息处理,不直接处理业务逻辑。 + """ + + def __init__( + self, + context: Context, + pipeline_ctx: PipelineContext, + ) -> None: + self.agent_executor = AgentExecutor() + self.context = context # Star Context + self.pipeline_ctx = pipeline_ctx # Pipeline 配置上下文 + self._initialized = False + + # 基础服务 + self.preprocessor = EventPreprocessor(pipeline_ctx) + self.send_service = SendService(pipeline_ctx) + + # 命令分发器(Star 插件) + self.command_dispatcher = CommandDispatcher( + self.pipeline_ctx.astrbot_config, + self.send_service, + self.agent_executor, + ) + + # Chain 执行器(NodeStar 插件) + self.chain_executor = ChainExecutor() + + self.rate_limiter = RateLimiter(pipeline_ctx) + self.access_controller = AccessController(pipeline_ctx) + + async def initialize(self) -> None: + """初始化所有组件""" + if self._initialized: + return + + # AgentExecutor + await self.agent_executor.initialize(self.pipeline_ctx) + + await self.rate_limiter.initialize() + await self.access_controller.initialize() + + # 加载 Chain 配置 + self._initialized = True + logger.info( + f"PipelineExecutor initialized with {len(self.chain_executor.nodes)} nodes" + ) + + async def execute(self, event: AstrMessageEvent) -> None: + """执行 Pipeline""" + try: + # 预处理 + should_continue = await self.preprocessor.preprocess(event) + if not should_continue: + return + + # 获取 Chain + chain_config = event.chain_config + + if not chain_config: + raise RuntimeError("Missing chain_config on event.") + + resume_node_uuid = event.get_extra("_resume_node_uuid") + + if resume_node_uuid: + if await self._run_system_mechanisms(event) == NodeResult.STOP: + if event.get_result(): + await self.send_service.send(event) + return + + chain_result = await self.chain_executor.execute( + event, + chain_config, + self.send_service, + self.agent_executor, + start_node_uuid=resume_node_uuid, + ) + + if ( + chain_result.should_send + and event.get_result() + and ( + not event.is_stopped() + or event.get_extra("_node_stop_event", False) + ) + ): + await self.send_service.send(event) + return + + # 唤醒检测 + self._detect_wake(event) + + # 命令匹配,匹配成功会设置 is_wake,但不执行命令 + event.plugins_name = self._resolve_plugins_name(chain_config) + + logger.debug(f"enabled_plugins_name: {event.plugins_name}") + + matched_handlers = await self.command_dispatcher.match( + event, + event.plugins_name, + ) + + # 如果匹配过程中权限检查失败导致 stop + if event.is_stopped(): + if event.get_result(): + await self.send_service.send(event) + return + + # 如果没有命令匹配且未唤醒,直接返回 + if not matched_handlers and not event.is_wake: + return + + # 系统机制检查(限流、权限) + if await self._run_system_mechanisms(event) == NodeResult.STOP: + if event.get_result(): + await self.send_service.send(event) + return + + # 命令执行 + if matched_handlers: + command_executed = await self.command_dispatcher.execute( + event, + matched_handlers, + ) + + if event.is_stopped(): + if event.get_result(): + await self.send_service.send(event) + return + + # 如果命令已完全处理且无需继续,返回 + if command_executed and event.get_result(): + await self.send_service.send(event) + return + + # Chain 执行(NodeStar 插件) + chain_result = await self.chain_executor.execute( + event, + chain_config, + self.send_service, + self.agent_executor, + ) + + # 发送结果 + if ( + chain_result.should_send + and event.get_result() + and ( + not event.is_stopped() or event.get_extra("_node_stop_event", False) + ) + ): + await self.send_service.send(event) + + except Exception as e: + logger.error(f"Pipeline execution error: {e}") + logger.error(traceback.format_exc()) + finally: + await self._handle_special_platforms(event) + logger.debug("Pipeline 执行完毕。") + + def _resolve_plugins_name(self, chain_config) -> list[str] | None: + if chain_config and chain_config.plugin_filter: + mode = (chain_config.plugin_filter.mode or "blacklist").lower() + plugins = chain_config.plugin_filter.plugins or [] + + if mode == "inherit": + return self._resolve_global_plugins_name() + if mode == "unrestricted": + return None + if mode == "none": + return [] + if mode == "whitelist": + return plugins + if mode == "blacklist": + if not plugins: + return None + all_names = [p.name for p in self.context.get_all_stars() if p.name] + return [name for name in all_names if name not in set(plugins)] + + logger.warning("未知插件过滤模式: %s,使用全局限制。", mode) + return self._resolve_global_plugins_name() + + return self._resolve_global_plugins_name() + + def _resolve_global_plugins_name(self) -> list[str] | None: + plugins_name = self.pipeline_ctx.astrbot_config.get("plugin_set", ["*"]) + if plugins_name == ["*"]: + return None + return plugins_name + + def _detect_wake(self, event: AstrMessageEvent) -> None: + """唤醒检测""" + config = self.pipeline_ctx.astrbot_config + friend_message_needs_wake_prefix = config["platform_settings"].get( + "friend_message_needs_wake_prefix", False + ) + ignore_at_all = config["platform_settings"].get("ignore_at_all", False) + + wake_prefixes = config["wake_prefix"] + messages = event.get_messages() + is_wake = False + + # 检查唤醒前缀 + for wake_prefix in wake_prefixes: + if event.message_str.startswith(wake_prefix): + # 排除 @ 其他人的情况 + if ( + not event.is_private_chat() + and messages + and isinstance(messages[0], At) + and str(messages[0].qq) != str(event.get_self_id()) + and str(messages[0].qq) != "all" + ): + break + is_wake = True + event.is_at_or_wake_command = True + event.is_wake = True + event.message_str = event.message_str[len(wake_prefix) :].strip() + break + + # 检查 @ 和 Reply + if not is_wake: + for message in messages: + if ( + ( + isinstance(message, At) + and str(message.qq) == str(event.get_self_id()) + ) + or (isinstance(message, AtAll) and not ignore_at_all) + or ( + isinstance(message, Reply) + and str(message.sender_id) == str(event.get_self_id()) + ) + ): + event.is_wake = True + event.is_at_or_wake_command = True + break + + # 私聊默认唤醒 + if event.is_private_chat() and not friend_message_needs_wake_prefix: + event.is_wake = True + event.is_at_or_wake_command = True + + async def _run_system_mechanisms( + self, + event: AstrMessageEvent, + ) -> NodeResult: + # 限流检查 + result = await self.rate_limiter.apply(event) + if result == NodeResult.STOP: + return result + + # 访问控制 + result = await self.access_controller.apply(event) + if result == NodeResult.STOP: + return result + + # 预回应表情 + await self._pre_ack_emoji(event) + + return NodeResult.CONTINUE + + async def _pre_ack_emoji(self, event: AstrMessageEvent) -> None: + """预回应表情""" + if event.get_extra("_pre_ack_sent", False): + return + + supported = {"telegram", "lark"} + platform = event.get_platform_name() + cfg = ( + self.pipeline_ctx.astrbot_config.get("platform_specific", {}) + .get(platform, {}) + .get("pre_ack_emoji", {}) + ) or {} + emojis = cfg.get("emojis") or [] + + if ( + cfg.get("enable", False) + and platform in supported + and emojis + and event.is_at_or_wake_command + ): + try: + await event.react(random.choice(emojis)) + event.set_extra("_pre_ack_sent", True) + except Exception as e: + logger.warning(f"{platform} 预回应表情发送失败: {e}") + + @staticmethod + async def _handle_special_platforms(event: AstrMessageEvent) -> None: + """处理特殊平台""" + if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent): + await event.send(None) diff --git a/astrbot/core/pipeline/engine/node_context.py b/astrbot/core/pipeline/engine/node_context.py new file mode 100644 index 000000000..72247696d --- /dev/null +++ b/astrbot/core/pipeline/engine/node_context.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +NODE_PACKET_VERSION = 1 + + +class NodePacketKind(Enum): + """Supported node packet payload kinds.""" + + MESSAGE = "message" + TEXT = "text" + OBJECT = "object" + + +def _infer_node_output_kind(output: Any) -> NodePacketKind: + from astrbot.core.message.message_event_result import ( + MessageChain, + MessageEventResult, + ) + + if isinstance(output, MessageEventResult | MessageChain): + return NodePacketKind.MESSAGE + if isinstance(output, str): + return NodePacketKind.TEXT + return NodePacketKind.OBJECT + + +@dataclass +class NodePacket: + """Standard packet transmitted between pipeline nodes.""" + + version: int + kind: NodePacketKind + data: Any + + def __post_init__(self) -> None: + if isinstance(self.kind, str): + self.kind = NodePacketKind(self.kind) + + @classmethod + def create(cls, output: Any) -> NodePacket: + if isinstance(output, cls): + return output + return cls( + version=NODE_PACKET_VERSION, + kind=_infer_node_output_kind(output), + data=output, + ) + + +class NodeExecutionStatus(Enum): + """Node execution status for tracking in NodeContext.""" + + PENDING = "pending" + EXECUTED = "executed" + SKIPPED = "skipped" + FAILED = "failed" + WAITING = "waiting" + + +@dataclass +class NodeContext: + """Single node execution context.""" + + node_name: str + node_uuid: str + chain_index: int # Position in chain_config.nodes (fixed) + + status: NodeExecutionStatus = NodeExecutionStatus.PENDING + input: NodePacket | None = None # From upstream EXECUTED node's output + output: NodePacket | None = None # Standard node-to-node payload + + +@dataclass +class NodeContextStack: + """Manages node execution contexts for a chain run.""" + + _contexts: list[NodeContext] = field(default_factory=list) + + def push(self, ctx: NodeContext) -> None: + self._contexts.append(ctx) + + def current(self) -> NodeContext | None: + return self._contexts[-1] if self._contexts else None + + def get_contexts( + self, + *, + names: set[str] | None = None, + status: NodeExecutionStatus | None = None, + ) -> list[NodeContext]: + """Get contexts filtered by name/status, preserving chain order.""" + contexts: list[NodeContext] = [] + for ctx in self._contexts: + if status and ctx.status != status: + continue + if names and ctx.node_name not in names: + continue + contexts.append(ctx) + return contexts + + def get_outputs( + self, + *, + names: set[str] | None = None, + status: NodeExecutionStatus | None = NodeExecutionStatus.EXECUTED, + include_none: bool = False, + ) -> list[NodePacket]: + """Get node outputs filtered by name/status, preserving chain order.""" + outputs: list[NodePacket] = [] + for ctx in self.get_contexts(names=names, status=status): + output_packet = ctx.output + if output_packet is None and not include_none: + continue + if output_packet is not None: + outputs.append(output_packet) + return outputs + + def last_executed_output(self) -> NodePacket | None: + """Get output from the most recent EXECUTED node. + + Current PENDING node is naturally excluded since it has no output yet. + """ + outputs = self.get_outputs( + status=NodeExecutionStatus.EXECUTED, + include_none=False, + ) + return outputs[-1] if outputs else None diff --git a/astrbot/core/pipeline/engine/router.py b/astrbot/core/pipeline/engine/router.py new file mode 100644 index 000000000..a93aab37e --- /dev/null +++ b/astrbot/core/pipeline/engine/router.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select + +from astrbot.core import logger +from astrbot.core.db import BaseDatabase +from astrbot.core.pipeline.engine.chain_config import ( + DEFAULT_CHAIN_CONFIG, + ChainConfig, + ChainConfigModel, +) +from astrbot.core.star.modality import Modality + + +class ChainRouter: + def __init__(self) -> None: + self._configs: list[ChainConfig] = [] + self._configs_map: dict[str, ChainConfig] = {} + + async def load_configs(self, db_helper: BaseDatabase) -> None: + db_chains = await self._load_chain_configs_from_db(db_helper) + default_chain = None + normal_chains: list[ChainConfig] = [] + for chain in db_chains: + if chain.chain_id == "default": + default_chain = chain + else: + normal_chains.append(chain) + normal_chains.sort(key=lambda c: c.sort_order, reverse=True) + self._configs = normal_chains + [default_chain or DEFAULT_CHAIN_CONFIG] + self._configs_map = {config.chain_id: config for config in self._configs} + logger.info(f"Loaded {len(self._configs)} chain configs") + + def route( + self, umo: str, modality: set[Modality] | None = None, message_text: str = "" + ) -> ChainConfig | None: + for config in self._configs: + if config.matches(umo, modality, message_text): + logger.debug(f"Routed {umo} to chain: {config.chain_id}") + return config + return None + + def get_by_chain_id(self, chain_id: str) -> ChainConfig | None: + return self._configs_map.get(chain_id) + + async def reload(self, db_helper: BaseDatabase) -> None: + await self.load_configs(db_helper) + + @staticmethod + async def _load_chain_configs_from_db(db_helper: BaseDatabase) -> list[ChainConfig]: + async with db_helper.get_db() as session: + session: AsyncSession + result = await session.execute(select(ChainConfigModel)) + records = result.scalars().all() + + return [ChainConfig.from_model(record) for record in records] diff --git a/astrbot/core/pipeline/engine/rule_matcher.py b/astrbot/core/pipeline/engine/rule_matcher.py new file mode 100644 index 000000000..2b49c084f --- /dev/null +++ b/astrbot/core/pipeline/engine/rule_matcher.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import fnmatch +import re + +from astrbot.core.star.modality import Modality + + +class RuleMatcher: + """规则匹配引擎""" + + def matches( + self, + rule: dict | None, + umo: str, + modality: set[Modality] | None, + message_text: str, + ) -> bool: + """评估规则是否匹配""" + if rule is None: + return True # 无规则 = 匹配所有 + + return self._evaluate(rule, umo, modality, message_text) + + def _evaluate( + self, + node: dict, + umo: str, + modality: set[Modality] | None, + text: str, + ) -> bool: + node_type = node.get("type") + + if node_type == "and": + children = node.get("children", []) + return all(self._evaluate(c, umo, modality, text) for c in children) + + elif node_type == "or": + children = node.get("children", []) + return any(self._evaluate(c, umo, modality, text) for c in children) + + elif node_type == "not": + children = node.get("children", []) + if children: + return not self._evaluate(children[0], umo, modality, text) + return True + + elif node_type == "condition": + return self._evaluate_condition( + node.get("condition", {}), umo, modality, text + ) + + return False + + def _evaluate_condition( + self, + condition: dict, + umo: str, + modality: set[Modality] | None, + text: str, + ) -> bool: + cond_type = condition.get("type") + value = condition.get("value", "") + operator = condition.get("operator", "include") + + # 计算基础匹配结果 + result = self._evaluate_condition_value(cond_type, value, umo, modality, text) + + # 根据 operator 决定是否取反 + if operator == "exclude": + return not result + return result + + @staticmethod + def _evaluate_condition_value( + cond_type: str | None, + value: str, + umo: str, + modality: set[Modality] | None, + text: str, + ) -> bool: + if cond_type == "umo": + return fnmatch.fnmatch(umo, value) + + elif cond_type == "modality": + if modality is None: + return False + try: + target = Modality(value) + return target in modality + except ValueError: + return False + + elif cond_type == "text_regex": + try: + return bool(re.search(value, text, re.IGNORECASE)) + except re.error: + return False + + return False + + +# 单例 +rule_matcher = RuleMatcher() diff --git a/astrbot/core/pipeline/engine/send_service.py b/astrbot/core/pipeline/engine/send_service.py new file mode 100644 index 000000000..27858bb98 --- /dev/null +++ b/astrbot/core/pipeline/engine/send_service.py @@ -0,0 +1,505 @@ +from __future__ import annotations + +import asyncio +import math +import random +import re +from typing import Any + +import astrbot.core.message.components as Comp +from astrbot.core import logger +from astrbot.core.message.components import ( + At, + BaseMessageComponent, + ComponentType, + File, + Node, + Plain, + Reply, +) +from astrbot.core.message.message_event_result import MessageChain, ResultContentType +from astrbot.core.pipeline.context import PipelineContext +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.platform.message_type import MessageType +from astrbot.core.star.star_handler import EventType +from astrbot.core.utils.path_util import path_Mapping + +from ..context_utils import call_event_hook + + +class SendService: + """独立发送服务,包含可配置的发送前装饰功能 + + 此服务从 RespondStage + ResultDecorateStage 抽取,包含: + - 空消息过滤 + - 路径映射 + - 可配置装饰功能(reply_prefix, segment_split, forward_wrapper, at_mention, quote_reply) + - 分段发送 + - 流式发送 + """ + + # 组件类型到其非空判断函数的映射 + _component_validators: dict[type, Any] = { + Comp.Plain: lambda comp: bool(comp.text and comp.text.strip()), + Comp.Face: lambda comp: comp.id is not None, + Comp.Record: lambda comp: bool(comp.file), + Comp.Video: lambda comp: bool(comp.file), + Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), + Comp.Image: lambda comp: bool(comp.file), + Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, + Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, + Comp.Node: lambda comp: bool(comp.content), + Comp.Nodes: lambda comp: bool(comp.nodes), + Comp.File: lambda comp: bool(comp.file_ or comp.url), + Comp.WechatEmoji: lambda comp: comp.md5 is not None, + } + + def __init__(self, ctx: PipelineContext): + self.ctx = ctx + config = ctx.astrbot_config + platform_settings = config.get("platform_settings", {}) + provider_cfg = config.get("provider_settings", {}) + + self.reply_prefix: str = platform_settings.get("reply_prefix", "") + self.forward_wrapper: bool = platform_settings.get("forward_wrapper", False) + self.forward_threshold: int = platform_settings.get("forward_threshold", 1500) + self.reply_with_mention: bool = platform_settings.get( + "reply_with_mention", False + ) + self.reply_with_quote: bool = platform_settings.get("reply_with_quote", False) + + # 分段相关 + segmented_reply_cfg = platform_settings.get("segmented_reply", {}) + self.enable_segment = segmented_reply_cfg.get("enable", False) + self.segment_mode = segmented_reply_cfg.get("split_mode", "regex") + + self.words_count_threshold: int = segmented_reply_cfg.get( + "words_count_threshold", 150 + ) + try: + self.words_count_threshold = max(int(self.words_count_threshold), 1) + except (TypeError, ValueError): + self.words_count_threshold = 150 + self.only_llm_result: bool = segmented_reply_cfg.get("only_llm_result", False) + self.regex_pattern: str = segmented_reply_cfg.get( + "regex", r".*?[。?!~…]+|.+$" + ) + self.split_words: list[str] = segmented_reply_cfg.get( + "split_words", ["。", "?", "!", "~", "…"] + ) + self.content_cleanup_rule: str = segmented_reply_cfg.get( + "content_cleanup_rule", "" + ) + + # 分段回复时间间隔 + self.interval_method: str = segmented_reply_cfg.get("interval_method", "random") + self.log_base: float = float(segmented_reply_cfg.get("log_base", 2.6)) + interval_str: str = segmented_reply_cfg.get("interval", "1.5, 3.5") + try: + self.interval = [float(t) for t in interval_str.replace(" ", "").split(",")] + except Exception: + self.interval = [1.5, 3.5] + + # 构建 split_words pattern + if self.split_words: + escaped_words = sorted( + [re.escape(word) for word in self.split_words], key=len, reverse=True + ) + self.split_words_pattern = re.compile( + f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL + ) + else: + self.split_words_pattern = None + + # 路径映射 + self.path_mapping: list[str] = platform_settings.get("path_mapping", []) + + # reasoning 输出 + self.show_reasoning: bool = provider_cfg.get("display_reasoning_text", False) + + async def send(self, event: AstrMessageEvent) -> None: + """发送消息(由 Agent 的 send tool 或 Pipeline 末端调用)""" + result = event.get_result() + if result is None: + return + + # 已经流式发送完成 + if event.get_extra("_streaming_finished", False): + return + + # 流式发送 + if result.result_content_type == ResultContentType.STREAMING_RESULT: + await self._send_streaming(event, result) + return + + if result.result_content_type == ResultContentType.STREAMING_FINISH: + if result.chain: + if await self._trigger_decorate_hook(event, is_stream=True): + return + event.set_extra("_streaming_finished", True) + return + + if not result.chain: + return + + # 发送消息前事件钩子 + if await self._trigger_decorate_hook(event, is_stream=False): + return + + # 需要再获取一次。 + result = event.get_result() + if result is None or not result.chain: + return + + # reasoning 内容插入(仅在未启用 TTS/T2I 时) + self._maybe_inject_reasoning(event, result) + + # 应用路径映射 + if self.path_mapping: + for idx, comp in enumerate(result.chain): + if isinstance(comp, Comp.File) and comp.file: + comp.file = path_Mapping(self.path_mapping, comp.file) + result.chain[idx] = comp + + # 检查消息链是否为空 + try: + if await self._is_empty_message_chain(result.chain): + logger.info("消息为空,跳过发送阶段") + return + except Exception as e: + logger.warning(f"空内容检查异常: {e}") + + # 将 Plain 为空的消息段移除 + result.chain = [ + comp + for comp in result.chain + if not ( + isinstance(comp, Comp.Plain) + and (not comp.text or not comp.text.strip()) + ) + ] + + if not result.chain: + return + + logger.info( + f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}" + ) + + # 应用装饰 + chain = result.chain + + # 1. 回复前缀 + if self.reply_prefix: + chain = self._add_reply_prefix(chain) + + # 2. 文本分段(仅在 segmented_reply 启用时) + chain = self._split_chain_for_segmented_reply(event, result, chain) + + # 3. 合并转发包装(aiocqhttp) + if self.forward_wrapper and self._should_forward_wrap(event, chain): + chain = self._wrap_forward(event, chain) + + # 4. 是否需要分段发送 + need_segment = self._is_seg_reply_required(event, result) + + # 5. @提及 和 引用回复 + has_plain = any(isinstance(item, Plain) for item in chain) + if has_plain: + if ( + self.reply_with_mention + and event.get_message_type() != MessageType.FRIEND_MESSAGE + ): + chain = self._add_at_mention(chain, event) + if self.reply_with_quote: + chain = self._add_quote_reply(chain, event) + + # 发送 + need_separately = {ComponentType.Record} + if need_segment: + await self._send_segmented(event, chain, need_separately) + else: + await self._send_normal(event, chain, need_separately) + + # 触发 OnAfterMessageSentEvent + await self._trigger_post_send_hook(event) + event.clear_result() + + @staticmethod + async def _trigger_decorate_hook(event: AstrMessageEvent, is_stream: bool) -> bool: + if is_stream: + logger.warning( + "启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作", + ) + return await call_event_hook(event, EventType.OnDecoratingResultEvent) + + def _maybe_inject_reasoning(self, event: AstrMessageEvent, result) -> None: + if not self.show_reasoning: + return + reasoning_content = event.get_extra("_llm_reasoning_content") + if not reasoning_content: + return + if result.chain and isinstance(result.chain[0], Plain): + if result.chain[0].text.startswith("🤔 思考:"): + return + result.chain.insert(0, Plain(f"🤔 思考: {reasoning_content}\n")) + + async def _send_streaming(self, event: AstrMessageEvent, result) -> None: + """流式发送""" + if result.async_stream is None: + logger.warning("async_stream 为空,跳过发送。") + return + + realtime_segmenting = ( + self.ctx.astrbot_config.get("provider_settings", {}).get( + "unsupported_streaming_strategy", "realtime_segmenting" + ) + == "realtime_segmenting" + ) + logger.info(f"应用流式输出({event.get_platform_id()})") + await event.send_streaming(result.async_stream, realtime_segmenting) + + async def _send_normal( + self, + event: AstrMessageEvent, + chain: list[BaseMessageComponent], + need_separately: set[ComponentType], + ) -> None: + """普通发送(非分段)""" + if all(comp.type in {ComponentType.Reply, ComponentType.At} for comp in chain): + logger.warning("消息链全为 Reply 和 At 消息段, 跳过发送阶段。") + return + + # 提取需要单独发送的组件 + sep_comps = self._extract_comp(chain, need_separately, modify_raw_chain=True) + + for comp in sep_comps: + try: + await event.send(MessageChain([comp])) + except Exception as e: + logger.error(f"发送消息链失败: {e}", exc_info=True) + + if chain: + try: + await event.send(MessageChain(chain)) + except Exception as e: + logger.error(f"发送消息链失败: {e}", exc_info=True) + + async def _send_segmented( + self, + event: AstrMessageEvent, + chain: list[BaseMessageComponent], + need_separately: set[ComponentType], + ) -> None: + """分段发送""" + # 提取 header 组件(Reply, At) + header_comps = self._extract_comp( + chain, {ComponentType.Reply, ComponentType.At}, modify_raw_chain=True + ) + + if not chain: + logger.warning("实际消息链为空, 跳过发送阶段。") + return + + for comp in chain: + interval = await self._calc_comp_interval(comp) + await asyncio.sleep(interval) + try: + if comp.type in need_separately: + await event.send(MessageChain([comp])) + else: + await event.send(MessageChain([*header_comps, comp])) + header_comps.clear() + except Exception as e: + logger.error(f"发送消息链失败: {e}", exc_info=True) + + def _split_text_by_words(self, text: str) -> list[str]: + if not self.split_words_pattern: + return [text] + + segments = self.split_words_pattern.findall(text) + result: list[str] = [] + for seg in segments: + if isinstance(seg, tuple): + content = seg[0] + if not isinstance(content, str): + continue + for word in self.split_words: + if content.endswith(word): + content = content[: -len(word)] + break + if content.strip(): + result.append(content) + elif seg and seg.strip(): + result.append(seg) + return result if result else [text] + + def _split_chain_for_segmented_reply( + self, + event: AstrMessageEvent, + result, + chain: list[BaseMessageComponent], + ) -> list[BaseMessageComponent]: + if not self._is_seg_reply_required(event, result): + return chain + + new_chain: list[BaseMessageComponent] = [] + for comp in chain: + if isinstance(comp, Plain): + if len(comp.text) > self.words_count_threshold: + new_chain.append(comp) + continue + + if self.segment_mode == "words": + split_response = self._split_text_by_words(comp.text) + else: + try: + split_response = re.findall( + self.regex_pattern, + comp.text, + re.DOTALL | re.MULTILINE, + ) + except re.error: + logger.error( + "分段回复正则表达式错误,使用默认分段方式。", + exc_info=True, + ) + split_response = re.findall( + r".*?[。?!~…]+|.+$", + comp.text, + re.DOTALL | re.MULTILINE, + ) + + if not split_response: + new_chain.append(comp) + continue + + for seg in split_response: + if self.content_cleanup_rule: + seg = re.sub(self.content_cleanup_rule, "", seg) + if seg.strip(): + new_chain.append(Plain(seg)) + else: + new_chain.append(comp) + + return new_chain + + def _add_reply_prefix( + self, chain: list[BaseMessageComponent] + ) -> list[BaseMessageComponent]: + """添加回复前缀""" + for comp in chain: + if isinstance(comp, Plain): + comp.text = self.reply_prefix + comp.text + break + return chain + + @staticmethod + def _add_at_mention( + chain: list[BaseMessageComponent], event: AstrMessageEvent + ) -> list[BaseMessageComponent]: + """添加 @提及""" + chain.insert(0, At(qq=event.get_sender_id(), name=event.get_sender_name())) + if len(chain) > 1 and isinstance(chain[1], Plain): + chain[1].text = "\n" + chain[1].text + return chain + + @staticmethod + def _add_quote_reply( + chain: list[BaseMessageComponent], event: AstrMessageEvent + ) -> list[BaseMessageComponent]: + """添加引用回复""" + if not any(isinstance(item, File) for item in chain): + chain.insert(0, Reply(id=event.message_obj.message_id)) + return chain + + def _should_forward_wrap( + self, event: AstrMessageEvent, chain: list[BaseMessageComponent] + ) -> bool: + """判断是否需要合并转发""" + if event.get_platform_name() != "aiocqhttp": + return False + word_cnt = sum(len(comp.text) for comp in chain if isinstance(comp, Plain)) + return word_cnt > self.forward_threshold + + @staticmethod + def _wrap_forward( + event: AstrMessageEvent, chain: list[BaseMessageComponent] + ) -> list[BaseMessageComponent]: + """合并转发包装""" + if event.get_platform_name() != "aiocqhttp": + return chain + node = Node( + uin=event.get_self_id(), + name="AstrBot", + content=[*chain], + ) + return [node] + + def _is_seg_reply_required(self, event: AstrMessageEvent, result) -> bool: + """检查是否需要分段回复""" + if not self.enable_segment: + return False + if self.only_llm_result and not result.is_llm_result(): + return False + if event.get_platform_name() in [ + "qq_official", + "weixin_official_account", + "dingtalk", + ]: + return False + return True + + async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool: + """检查消息链是否为空""" + if not chain: + return True + for comp in chain: + comp_type = type(comp) + if comp_type in self._component_validators: + if self._component_validators[comp_type](comp): + return False + return True + + @staticmethod + def _extract_comp( + raw_chain: list[BaseMessageComponent], + extract_types: set[ComponentType], + modify_raw_chain: bool = True, + ) -> list[BaseMessageComponent]: + """提取特定类型的组件""" + extracted = [] + if modify_raw_chain: + remaining = [] + for comp in raw_chain: + if comp.type in extract_types: + extracted.append(comp) + else: + remaining.append(comp) + raw_chain[:] = remaining + else: + extracted = [comp for comp in raw_chain if comp.type in extract_types] + return extracted + + async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: + """计算分段回复间隔时间""" + if self.interval_method == "log": + if isinstance(comp, Comp.Plain): + wc = await self._word_cnt(comp.text) + i = math.log(wc + 1, self.log_base) + return random.uniform(i, i + 0.5) + return random.uniform(1, 1.75) + return random.uniform(self.interval[0], self.interval[1]) + + @staticmethod + async def _word_cnt(text: str) -> int: + """统计字数""" + if all(ord(c) < 128 for c in text): + return len(text.split()) + else: + return len([c for c in text if c.isalnum()]) + + @staticmethod + async def _trigger_post_send_hook(event: AstrMessageEvent) -> None: + """触发发送后钩子""" + await call_event_hook(event, EventType.OnAfterMessageSentEvent) diff --git a/astrbot/core/pipeline/engine/wait_registry.py b/astrbot/core/pipeline/engine/wait_registry.py new file mode 100644 index 000000000..bc068d017 --- /dev/null +++ b/astrbot/core/pipeline/engine/wait_registry.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +if TYPE_CHECKING: + from .chain_config import ChainConfig + + +@dataclass +class WaitState: + chain_config: ChainConfig + node_uuid: str + + def is_valid(self, current_chain_config: ChainConfig | None) -> bool: + """检查 WaitState 是否仍然有效""" + if current_chain_config is None: + return False + + if current_chain_config != self.chain_config: + return False + + return True + + +def build_wait_key(event: AstrMessageEvent) -> str: + """Build a stable wait key that is independent of preprocessing.""" + return ( + f"{event.get_platform_id()}:" + f"{event.get_message_type().value}:" + f"{event.get_sender_id()}:" + f"{event.get_group_id()}" + ) + + +class WaitRegistry: + def __init__(self) -> None: + self._lock = asyncio.Lock() + self._by_key: dict[str, WaitState] = {} + + async def set(self, key: str, state: WaitState) -> None: + async with self._lock: + self._by_key[key] = state + + async def pop(self, key: str) -> WaitState | None: + async with self._lock: + return self._by_key.pop(key, None) + + +wait_registry = WaitRegistry() diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py deleted file mode 100644 index 6544f85c1..000000000 --- a/astrbot/core/pipeline/preprocess_stage/stage.py +++ /dev/null @@ -1,100 +0,0 @@ -import asyncio -import random -import traceback -from collections.abc import AsyncGenerator - -from astrbot.core import logger -from astrbot.core.message.components import Image, Plain, Record -from astrbot.core.platform.astr_message_event import AstrMessageEvent - -from ..context import PipelineContext -from ..stage import Stage, register_stage - - -@register_stage -class PreProcessStage(Stage): - async def initialize(self, ctx: PipelineContext) -> None: - self.ctx = ctx - self.config = ctx.astrbot_config - self.plugin_manager = ctx.plugin_manager - - self.stt_settings: dict = self.config.get("provider_stt_settings", {}) - self.platform_settings: dict = self.config.get("platform_settings", {}) - - async def process( - self, - event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: - """在处理事件之前的预处理""" - # 平台特异配置:platform_specific..pre_ack_emoji - supported = {"telegram", "lark"} - platform = event.get_platform_name() - cfg = ( - self.config.get("platform_specific", {}) - .get(platform, {}) - .get("pre_ack_emoji", {}) - ) or {} - emojis = cfg.get("emojis") or [] - if ( - cfg.get("enable", False) - and platform in supported - and emojis - and event.is_at_or_wake_command - ): - try: - await event.react(random.choice(emojis)) - except Exception as e: - logger.warning(f"{platform} 预回应表情发送失败: {e}") - - # 路径映射 - if mappings := self.platform_settings.get("path_mapping", []): - # 支持 Record,Image 消息段的路径映射。 - message_chain = event.get_messages() - - for idx, component in enumerate(message_chain): - if isinstance(component, Record | Image) and component.url: - for mapping in mappings: - from_, to_ = mapping.split(":") - from_ = from_.removesuffix("/") - to_ = to_.removesuffix("/") - - url = component.url.removeprefix("file://") - if url.startswith(from_): - component.url = url.replace(from_, to_, 1) - logger.debug(f"路径映射: {url} -> {component.url}") - message_chain[idx] = component - - # STT - if self.stt_settings.get("enable", False): - # TODO: 独立 - ctx = self.plugin_manager.context - stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin) - if not stt_provider: - logger.warning( - f"会话 {event.unified_msg_origin} 未配置语音转文本模型。", - ) - return - message_chain = event.get_messages() - for idx, component in enumerate(message_chain): - if isinstance(component, Record) and component.url: - path = component.url.removeprefix("file://") - retry = 5 - for i in range(retry): - try: - result = await stt_provider.get_text(audio_url=path) - if result: - logger.info("语音转文本结果: " + result) - message_chain[idx] = Plain(result) - event.message_str += result - event.message_obj.message_str += result - break - except FileNotFoundError as e: - # napcat workaround - logger.warning(e) - logger.warning(f"重试中: {i + 1}/{retry}") - await asyncio.sleep(0.5) - continue - except BaseException as e: - logger.error(traceback.format_exc()) - logger.error(f"语音转文本失败: {e}") - break diff --git a/astrbot/core/pipeline/process_stage/method/agent_request.py b/astrbot/core/pipeline/process_stage/method/agent_request.py deleted file mode 100644 index 9efe53814..000000000 --- a/astrbot/core/pipeline/process_stage/method/agent_request.py +++ /dev/null @@ -1,48 +0,0 @@ -from collections.abc import AsyncGenerator - -from astrbot.core import logger -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.star.session_llm_manager import SessionServiceManager - -from ...context import PipelineContext -from ..stage import Stage -from .agent_sub_stages.internal import InternalAgentSubStage -from .agent_sub_stages.third_party import ThirdPartyAgentSubStage - - -class AgentRequestSubStage(Stage): - async def initialize(self, ctx: PipelineContext) -> None: - self.ctx = ctx - self.config = ctx.astrbot_config - - self.bot_wake_prefixs: list[str] = self.config["wake_prefix"] - self.prov_wake_prefix: str = self.config["provider_settings"]["wake_prefix"] - for bwp in self.bot_wake_prefixs: - if self.prov_wake_prefix.startswith(bwp): - logger.info( - f"识别 LLM 聊天额外唤醒前缀 {self.prov_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。", - ) - self.prov_wake_prefix = self.prov_wake_prefix[len(bwp) :] - - agent_runner_type = self.config["provider_settings"]["agent_runner_type"] - if agent_runner_type == "local": - self.agent_sub_stage = InternalAgentSubStage() - else: - self.agent_sub_stage = ThirdPartyAgentSubStage() - await self.agent_sub_stage.initialize(ctx) - - async def process(self, event: AstrMessageEvent) -> AsyncGenerator[None, None]: - if not self.ctx.astrbot_config["provider_settings"]["enable"]: - logger.debug( - "This pipeline does not enable AI capability, skip processing." - ) - return - - if not await SessionServiceManager.should_process_llm_request(event): - logger.debug( - f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing." - ) - return - - async for resp in self.agent_sub_stage.process(event, self.prov_wake_prefix): - yield resp diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py deleted file mode 100644 index 8a79b96c9..000000000 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ /dev/null @@ -1,60 +0,0 @@ -"""本地 Agent 模式的 AstrBot 插件调用 Stage""" - -import traceback -from collections.abc import AsyncGenerator -from typing import Any - -from astrbot.core import logger -from astrbot.core.message.message_event_result import MessageEventResult -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.star.star import star_map -from astrbot.core.star.star_handler import StarHandlerMetadata - -from ...context import PipelineContext, call_handler -from ..stage import Stage - - -class StarRequestSubStage(Stage): - async def initialize(self, ctx: PipelineContext) -> None: - self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"] - self.identifier = ctx.astrbot_config["provider_settings"]["identifier"] - self.ctx = ctx - - async def process( - self, - event: AstrMessageEvent, - ) -> AsyncGenerator[Any, None]: - activated_handlers: list[StarHandlerMetadata] = event.get_extra( - "activated_handlers", - ) - handlers_parsed_params: dict[str, dict[str, Any]] = event.get_extra( - "handlers_parsed_params", - ) - if not handlers_parsed_params: - handlers_parsed_params = {} - - for handler in activated_handlers: - params = handlers_parsed_params.get(handler.handler_full_name, {}) - md = star_map.get(handler.handler_module_path) - if not md: - logger.warning( - f"Cannot find plugin for given handler module path: {handler.handler_module_path}", - ) - continue - logger.debug(f"plugin -> {md.name} - {handler.handler_name}") - try: - wrapper = call_handler(event, handler.handler, **params) - async for ret in wrapper: - yield ret - event.clear_result() # 清除上一个 handler 的结果 - except Exception as e: - logger.error(traceback.format_exc()) - logger.error(f"Star {handler.handler_full_name} handle error: {e}") - - if event.is_at_or_wake_command: - ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}" - event.set_result(MessageEventResult().message(ret)) - yield - event.clear_result() - - event.stop_event() diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py deleted file mode 100644 index 076f7f12a..000000000 --- a/astrbot/core/pipeline/process_stage/stage.py +++ /dev/null @@ -1,66 +0,0 @@ -from collections.abc import AsyncGenerator - -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.provider.entities import ProviderRequest -from astrbot.core.star.star_handler import StarHandlerMetadata - -from ..context import PipelineContext -from ..stage import Stage, register_stage -from .method.agent_request import AgentRequestSubStage -from .method.star_request import StarRequestSubStage - - -@register_stage -class ProcessStage(Stage): - async def initialize(self, ctx: PipelineContext) -> None: - self.ctx = ctx - self.config = ctx.astrbot_config - self.plugin_manager = ctx.plugin_manager - - # initialize agent sub stage - self.agent_sub_stage = AgentRequestSubStage() - await self.agent_sub_stage.initialize(ctx) - - # initialize star request sub stage - self.star_request_sub_stage = StarRequestSubStage() - await self.star_request_sub_stage.initialize(ctx) - - async def process( - self, - event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: - """处理事件""" - activated_handlers: list[StarHandlerMetadata] = event.get_extra( - "activated_handlers", - ) - # 有插件 Handler 被激活 - if activated_handlers: - async for resp in self.star_request_sub_stage.process(event): - # 生成器返回值处理 - if isinstance(resp, ProviderRequest): - # Handler 的 LLM 请求 - event.set_extra("provider_request", resp) - _t = False - async for _ in self.agent_sub_stage.process(event): - _t = True - yield - if not _t: - yield - else: - yield - - # 调用 LLM 相关请求 - if not self.ctx.astrbot_config["provider_settings"].get("enable", True): - return - - if ( - not event._has_send_oper - and event.is_at_or_wake_command - and not event.call_llm - ): - # 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀 - if ( - event.get_result() and not event.is_stopped() - ) or not event.get_result(): - async for _ in self.agent_sub_stage.process(event): - yield diff --git a/astrbot/core/pipeline/rate_limit_check/stage.py b/astrbot/core/pipeline/rate_limit_check/stage.py deleted file mode 100644 index 392bceff3..000000000 --- a/astrbot/core/pipeline/rate_limit_check/stage.py +++ /dev/null @@ -1,99 +0,0 @@ -import asyncio -from collections import defaultdict, deque -from collections.abc import AsyncGenerator -from datetime import datetime, timedelta - -from astrbot.core import logger -from astrbot.core.config.astrbot_config import RateLimitStrategy -from astrbot.core.platform.astr_message_event import AstrMessageEvent - -from ..context import PipelineContext -from ..stage import Stage, register_stage - - -@register_stage -class RateLimitStage(Stage): - """检查是否需要限制消息发送的限流器。 - - 使用 Fixed Window 算法。 - 如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。 - """ - - def __init__(self) -> None: - # 存储每个会话的请求时间队列 - self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque) - # 为每个会话设置一个锁,避免并发冲突 - self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock) - # 限流参数 - self.rate_limit_count: int = 0 - self.rate_limit_time: timedelta = timedelta(0) - - async def initialize(self, ctx: PipelineContext) -> None: - """初始化限流器,根据配置设置限流参数。""" - self.rate_limit_count = ctx.astrbot_config["platform_settings"]["rate_limit"][ - "count" - ] - self.rate_limit_time = timedelta( - seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"], - ) - self.rl_strategy = ctx.astrbot_config["platform_settings"]["rate_limit"][ - "strategy" - ] # stall or discard - - async def process( - self, - event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: - """检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。 - - Args: - event (AstrMessageEvent): 当前消息事件。 - ctx (PipelineContext): 流水线上下文。 - - Returns: - MessageEventResult: 继续或停止事件处理的结果。 - - """ - session_id = event.session_id - now = datetime.now() - - async with self.locks[session_id]: # 确保同一会话不会并发修改队列 - # 检查并处理限流,可能需要多次检查直到满足条件 - while True: - timestamps = self.event_timestamps[session_id] - self._remove_expired_timestamps(timestamps, now) - - if len(timestamps) < self.rate_limit_count: - timestamps.append(now) - break - next_window_time = timestamps[0] + self.rate_limit_time - stall_duration = (next_window_time - now).total_seconds() + 0.3 - - match self.rl_strategy: - case RateLimitStrategy.STALL.value: - logger.info( - f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。", - ) - await asyncio.sleep(stall_duration) - now = datetime.now() - case RateLimitStrategy.DISCARD.value: - logger.info( - f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。", - ) - return event.stop_event() - - def _remove_expired_timestamps( - self, - timestamps: deque[datetime], - now: datetime, - ) -> None: - """移除时间窗口外的时间戳。 - - Args: - timestamps (Deque[datetime]): 当前会话的时间戳队列。 - now (datetime): 当前时间,用于计算过期时间。 - - """ - expiry_threshold: datetime = now - self.rate_limit_time - while timestamps and timestamps[0] < expiry_threshold: - timestamps.popleft() diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py deleted file mode 100644 index b57fed29e..000000000 --- a/astrbot/core/pipeline/respond/stage.py +++ /dev/null @@ -1,280 +0,0 @@ -import asyncio -import math -import random -from collections.abc import AsyncGenerator - -import astrbot.core.message.components as Comp -from astrbot.core import logger -from astrbot.core.message.components import BaseMessageComponent, ComponentType -from astrbot.core.message.message_event_result import MessageChain, ResultContentType -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.star.star_handler import EventType -from astrbot.core.utils.path_util import path_Mapping - -from ..context import PipelineContext, call_event_hook -from ..stage import Stage, register_stage - - -@register_stage -class RespondStage(Stage): - # 组件类型到其非空判断函数的映射 - _component_validators = { - Comp.Plain: lambda comp: bool( - comp.text and comp.text.strip(), - ), # 纯文本消息需要strip - Comp.Face: lambda comp: comp.id is not None, # QQ表情 - Comp.Record: lambda comp: bool(comp.file), # 语音 - Comp.Video: lambda comp: bool(comp.file), # 视频 - Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @ - Comp.Image: lambda comp: bool(comp.file), # 图片 - Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复 - Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳 - Comp.Node: lambda comp: bool(comp.content), # 转发节点 - Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点 - Comp.File: lambda comp: bool(comp.file_ or comp.url), - Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情 - } - - async def initialize(self, ctx: PipelineContext) -> None: - self.ctx = ctx - self.config = ctx.astrbot_config - self.platform_settings: dict = self.config.get("platform_settings", {}) - - self.reply_with_mention = ctx.astrbot_config["platform_settings"][ - "reply_with_mention" - ] - self.reply_with_quote = ctx.astrbot_config["platform_settings"][ - "reply_with_quote" - ] - - # 分段回复 - self.enable_seg: bool = ctx.astrbot_config["platform_settings"][ - "segmented_reply" - ]["enable"] - self.only_llm_result = ctx.astrbot_config["platform_settings"][ - "segmented_reply" - ]["only_llm_result"] - - self.interval_method = ctx.astrbot_config["platform_settings"][ - "segmented_reply" - ]["interval_method"] - self.log_base = float( - ctx.astrbot_config["platform_settings"]["segmented_reply"]["log_base"], - ) - interval_str: str = ctx.astrbot_config["platform_settings"]["segmented_reply"][ - "interval" - ] - interval_str_ls = interval_str.replace(" ", "").split(",") - try: - self.interval = [float(t) for t in interval_str_ls] - except BaseException as e: - logger.error(f"解析分段回复的间隔时间失败。{e}") - self.interval = [1.5, 3.5] - logger.info(f"分段回复间隔时间:{self.interval}") - - async def _word_cnt(self, text: str) -> int: - """分段回复 统计字数""" - if all(ord(c) < 128 for c in text): - word_count = len(text.split()) - else: - word_count = len([c for c in text if c.isalnum()]) - return word_count - - async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: - """分段回复 计算间隔时间""" - if self.interval_method == "log": - if isinstance(comp, Comp.Plain): - wc = await self._word_cnt(comp.text) - i = math.log(wc + 1, self.log_base) - return random.uniform(i, i + 0.5) - return random.uniform(1, 1.75) - # random - return random.uniform(self.interval[0], self.interval[1]) - - async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool: - """检查消息链是否为空 - - Args: - chain (list[BaseMessageComponent]): 包含消息对象的列表 - - """ - if not chain: - return True - - for comp in chain: - comp_type = type(comp) - - # 检查组件类型是否在字典中 - if comp_type in self._component_validators: - if self._component_validators[comp_type](comp): - return False - - # 如果所有组件都为空 - return True - - def is_seg_reply_required(self, event: AstrMessageEvent) -> bool: - """检查是否需要分段回复""" - if not self.enable_seg: - return False - - if (result := event.get_result()) is None: - return False - if self.only_llm_result and not result.is_llm_result(): - return False - - if event.get_platform_name() in [ - "qq_official", - "weixin_official_account", - "dingtalk", - ]: - return False - - return True - - def _extract_comp( - self, - raw_chain: list[BaseMessageComponent], - extract_types: set[ComponentType], - modify_raw_chain: bool = True, - ): - extracted = [] - if modify_raw_chain: - remaining = [] - for comp in raw_chain: - if comp.type in extract_types: - extracted.append(comp) - else: - remaining.append(comp) - raw_chain[:] = remaining - else: - extracted = [comp for comp in raw_chain if comp.type in extract_types] - - return extracted - - async def process( - self, - event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: - result = event.get_result() - if result is None: - return - if event.get_extra("_streaming_finished", False): - # prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again - return - if result.result_content_type == ResultContentType.STREAMING_FINISH: - event.set_extra("_streaming_finished", True) - return - - logger.info( - f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}", - ) - - if result.result_content_type == ResultContentType.STREAMING_RESULT: - if result.async_stream is None: - logger.warning("async_stream 为空,跳过发送。") - return - # 流式结果直接交付平台适配器处理 - realtime_segmenting = ( - self.config.get("provider_settings", {}).get( - "unsupported_streaming_strategy", - "realtime_segmenting", - ) - == "realtime_segmenting" - ) - logger.info(f"应用流式输出({event.get_platform_id()})") - await event.send_streaming(result.async_stream, realtime_segmenting) - return - if len(result.chain) > 0: - # 检查路径映射 - if mappings := self.platform_settings.get("path_mapping", []): - for idx, component in enumerate(result.chain): - if isinstance(component, Comp.File) and component.file: - # 支持 File 消息段的路径映射。 - component.file = path_Mapping(mappings, component.file) - result.chain[idx] = component - - # 检查消息链是否为空 - try: - if await self._is_empty_message_chain(result.chain): - logger.info("消息为空,跳过发送阶段") - return - except Exception as e: - logger.warning(f"空内容检查异常: {e}") - - # 将 Plain 为空的消息段移除 - result.chain = [ - comp - for comp in result.chain - if not ( - isinstance(comp, Comp.Plain) - and (not comp.text or not comp.text.strip()) - ) - ] - - # 发送消息链 - # Record 需要强制单独发送 - need_separately = {ComponentType.Record} - if self.is_seg_reply_required(event): - header_comps = self._extract_comp( - result.chain, - {ComponentType.Reply, ComponentType.At}, - modify_raw_chain=True, - ) - if not result.chain or len(result.chain) == 0: - # may fix #2670 - logger.warning( - f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}", - ) - return - for comp in result.chain: - i = await self._calc_comp_interval(comp) - await asyncio.sleep(i) - try: - if comp.type in need_separately: - await event.send(MessageChain([comp])) - else: - await event.send(MessageChain([*header_comps, comp])) - header_comps.clear() - except Exception as e: - logger.error( - f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}", - exc_info=True, - ) - else: - if all( - comp.type in {ComponentType.Reply, ComponentType.At} - for comp in result.chain - ): - # may fix #2670 - logger.warning( - f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}", - ) - return - sep_comps = self._extract_comp( - result.chain, - need_separately, - modify_raw_chain=True, - ) - for comp in sep_comps: - chain = MessageChain([comp]) - try: - await event.send(chain) - except Exception as e: - logger.error( - f"发送消息链失败: chain = {chain}, error = {e}", - exc_info=True, - ) - chain = MessageChain(result.chain) - if result.chain and len(result.chain) > 0: - try: - await event.send(chain) - except Exception as e: - logger.error( - f"发送消息链失败: chain = {chain}, error = {e}", - exc_info=True, - ) - - if await call_event_hook(event, EventType.OnAfterMessageSentEvent): - return - - event.clear_result() diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py deleted file mode 100644 index 823aa0eaa..000000000 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ /dev/null @@ -1,402 +0,0 @@ -import random -import re -import time -import traceback -from collections.abc import AsyncGenerator - -from astrbot.core import file_token_service, html_renderer, logger -from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply -from astrbot.core.message.message_event_result import ResultContentType -from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.platform.message_type import MessageType -from astrbot.core.star.session_llm_manager import SessionServiceManager -from astrbot.core.star.star import star_map -from astrbot.core.star.star_handler import EventType, star_handlers_registry - -from ..context import PipelineContext -from ..stage import Stage, register_stage, registered_stages - - -@register_stage -class ResultDecorateStage(Stage): - async def initialize(self, ctx: PipelineContext) -> None: - self.ctx = ctx - self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"] - self.reply_with_mention = ctx.astrbot_config["platform_settings"][ - "reply_with_mention" - ] - self.reply_with_quote = ctx.astrbot_config["platform_settings"][ - "reply_with_quote" - ] - self.t2i_word_threshold = ctx.astrbot_config["t2i_word_threshold"] - try: - self.t2i_word_threshold = int(self.t2i_word_threshold) - self.t2i_word_threshold = max(self.t2i_word_threshold, 50) - except BaseException: - self.t2i_word_threshold = 150 - self.t2i_strategy = ctx.astrbot_config["t2i_strategy"] - self.t2i_use_network = self.t2i_strategy == "remote" - self.t2i_active_template = ctx.astrbot_config["t2i_active_template"] - - self.forward_threshold = ctx.astrbot_config["platform_settings"][ - "forward_threshold" - ] - - trigger_probability = ctx.astrbot_config["provider_tts_settings"].get( - "trigger_probability", - 1, - ) - try: - self.tts_trigger_probability = max( - 0.0, - min(float(trigger_probability), 1.0), - ) - except (TypeError, ValueError): - self.tts_trigger_probability = 1.0 - - # 分段回复 - self.words_count_threshold = int( - ctx.astrbot_config["platform_settings"]["segmented_reply"][ - "words_count_threshold" - ], - ) - self.enable_segmented_reply = ctx.astrbot_config["platform_settings"][ - "segmented_reply" - ]["enable"] - self.only_llm_result = ctx.astrbot_config["platform_settings"][ - "segmented_reply" - ]["only_llm_result"] - self.split_mode = ctx.astrbot_config["platform_settings"][ - "segmented_reply" - ].get("split_mode", "regex") - self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"] - self.split_words = ctx.astrbot_config["platform_settings"][ - "segmented_reply" - ].get("split_words", ["。", "?", "!", "~", "…"]) - if self.split_words: - escaped_words = sorted( - [re.escape(word) for word in self.split_words], key=len, reverse=True - ) - self.split_words_pattern = re.compile( - f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL - ) - else: - self.split_words_pattern = None - self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][ - "segmented_reply" - ]["content_cleanup_rule"] - - # exception - self.content_safe_check_reply = ctx.astrbot_config["content_safety"][ - "also_use_in_response" - ] - self.content_safe_check_stage = None - if self.content_safe_check_reply: - for stage_cls in registered_stages: - if stage_cls.__name__ == "ContentSafetyCheckStage": - self.content_safe_check_stage = stage_cls() - await self.content_safe_check_stage.initialize(ctx) - - provider_cfg = ctx.astrbot_config.get("provider_settings", {}) - self.show_reasoning = provider_cfg.get("display_reasoning_text", False) - - def _split_text_by_words(self, text: str) -> list[str]: - """使用分段词列表分段文本""" - if not self.split_words_pattern: - return [text] - - segments = self.split_words_pattern.findall(text) - result = [] - for seg in segments: - if isinstance(seg, tuple): - content = seg[0] - if not isinstance(content, str): - continue - for word in self.split_words: - if content.endswith(word): - content = content[: -len(word)] - break - if content.strip(): - result.append(content) - elif seg and seg.strip(): - result.append(seg) - return result if result else [text] - - async def process( - self, - event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: - result = event.get_result() - if result is None or not result.chain: - return - - if result.result_content_type == ResultContentType.STREAMING_RESULT: - return - - is_stream = result.result_content_type == ResultContentType.STREAMING_FINISH - - # 回复时检查内容安全 - if ( - self.content_safe_check_reply - and self.content_safe_check_stage - and result.is_llm_result() - and not is_stream # 流式输出不检查内容安全 - ): - text = "" - for comp in result.chain: - if isinstance(comp, Plain): - text += comp.text - - if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage): - async for _ in self.content_safe_check_stage.process( - event, - check_text=text, - ): - yield - - # 发送消息前事件钩子 - handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnDecoratingResultEvent, - plugins_name=event.plugins_name, - ) - for handler in handlers: - try: - logger.debug( - f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", - ) - if is_stream: - logger.warning( - "启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作", - ) - await handler.handler(event) - - if (result := event.get_result()) is None or not result.chain: - logger.debug( - f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。", - ) - except BaseException: - logger.error(traceback.format_exc()) - - if event.is_stopped(): - logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。", - ) - return - - # 流式输出不执行下面的逻辑 - if is_stream: - logger.info("流式输出已启用,跳过结果装饰阶段") - return - - # 需要再获取一次。插件可能直接对 chain 进行了替换。 - result = event.get_result() - if result is None: - return - - if len(result.chain) > 0: - # 回复前缀 - if self.reply_prefix: - for comp in result.chain: - if isinstance(comp, Plain): - comp.text = self.reply_prefix + comp.text - break - - # 分段回复 - if self.enable_segmented_reply and event.get_platform_name() not in [ - "qq_official", - "weixin_official_account", - "dingtalk", - ]: - if ( - self.only_llm_result and result.is_llm_result() - ) or not self.only_llm_result: - new_chain = [] - for comp in result.chain: - if isinstance(comp, Plain): - if len(comp.text) > self.words_count_threshold: - # 不分段回复 - new_chain.append(comp) - continue - - # 根据 split_mode 选择分段方式 - if self.split_mode == "words": - split_response = self._split_text_by_words(comp.text) - else: # regex 模式 - try: - split_response = re.findall( - self.regex, - comp.text, - re.DOTALL | re.MULTILINE, - ) - except re.error: - logger.error( - f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}", - ) - split_response = re.findall( - r".*?[。?!~…]+|.+$", - comp.text, - re.DOTALL | re.MULTILINE, - ) - - if not split_response: - new_chain.append(comp) - continue - for seg in split_response: - if self.content_cleanup_rule: - seg = re.sub(self.content_cleanup_rule, "", seg) - if seg.strip(): - new_chain.append(Plain(seg)) - else: - # 非 Plain 类型的消息段不分段 - new_chain.append(comp) - result.chain = new_chain - - # TTS - tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider( - event.unified_msg_origin, - ) - - should_tts = ( - bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"]) - and result.is_llm_result() - and await SessionServiceManager.should_process_tts_request(event) - and random.random() <= self.tts_trigger_probability - and tts_provider - ) - if should_tts and not tts_provider: - logger.warning( - f"会话 {event.unified_msg_origin} 未配置文本转语音模型。", - ) - - if ( - not should_tts - and self.show_reasoning - and event.get_extra("_llm_reasoning_content") - ): - # inject reasoning content to chain - reasoning_content = event.get_extra("_llm_reasoning_content") - result.chain.insert(0, Plain(f"🤔 思考: {reasoning_content}\n")) - - if should_tts and tts_provider: - new_chain = [] - for comp in result.chain: - if isinstance(comp, Plain) and len(comp.text) > 1: - try: - logger.info(f"TTS 请求: {comp.text}") - audio_path = await tts_provider.get_audio(comp.text) - logger.info(f"TTS 结果: {audio_path}") - if not audio_path: - logger.error( - f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}", - ) - new_chain.append(comp) - continue - - use_file_service = self.ctx.astrbot_config[ - "provider_tts_settings" - ]["use_file_service"] - callback_api_base = self.ctx.astrbot_config[ - "callback_api_base" - ] - dual_output = self.ctx.astrbot_config[ - "provider_tts_settings" - ]["dual_output"] - - url = None - if use_file_service and callback_api_base: - token = await file_token_service.register_file( - audio_path, - ) - url = f"{callback_api_base}/api/file/{token}" - logger.debug(f"已注册:{url}") - - new_chain.append( - Record( - file=url or audio_path, - url=url or audio_path, - ), - ) - if dual_output: - new_chain.append(comp) - except Exception: - logger.error(traceback.format_exc()) - logger.error("TTS 失败,使用文本发送。") - new_chain.append(comp) - else: - new_chain.append(comp) - result.chain = new_chain - - # 文本转图片 - elif ( - result.use_t2i_ is None and self.ctx.astrbot_config["t2i"] - ) or result.use_t2i_: - parts = [] - for comp in result.chain: - if isinstance(comp, Plain): - parts.append("\n\n" + comp.text) - else: - break - plain_str = "".join(parts) - if plain_str and len(plain_str) > self.t2i_word_threshold: - render_start = time.time() - try: - url = await html_renderer.render_t2i( - plain_str, - return_url=True, - use_network=self.t2i_use_network, - template_name=self.t2i_active_template, - ) - except BaseException: - logger.error("文本转图片失败,使用文本发送。") - return - if time.time() - render_start > 3: - logger.warning( - "文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。", - ) - if url: - if url.startswith("http"): - result.chain = [Image.fromURL(url)] - elif ( - self.ctx.astrbot_config["t2i_use_file_service"] - and self.ctx.astrbot_config["callback_api_base"] - ): - token = await file_token_service.register_file(url) - url = f"{self.ctx.astrbot_config['callback_api_base']}/api/file/{token}" - logger.debug(f"已注册:{url}") - result.chain = [Image.fromURL(url)] - else: - result.chain = [Image.fromFileSystem(url)] - - # 触发转发消息 - if event.get_platform_name() == "aiocqhttp": - word_cnt = 0 - for comp in result.chain: - if isinstance(comp, Plain): - word_cnt += len(comp.text) - if word_cnt > self.forward_threshold: - node = Node( - uin=event.get_self_id(), - name="AstrBot", - content=[*result.chain], - ) - result.chain = [node] - - has_plain = any(isinstance(item, Plain) for item in result.chain) - if has_plain: - # at 回复 - if ( - self.reply_with_mention - and event.get_message_type() != MessageType.FRIEND_MESSAGE - ): - result.chain.insert( - 0, - At(qq=event.get_sender_id(), name=event.get_sender_name()), - ) - if len(result.chain) > 1 and isinstance(result.chain[1], Plain): - result.chain[1].text = "\n" + result.chain[1].text - - # 引用回复 - if self.reply_with_quote: - if not any(isinstance(item, File) for item in result.chain): - result.chain.insert(0, Reply(id=event.message_obj.message_id)) diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py deleted file mode 100644 index 71c98778f..000000000 --- a/astrbot/core/pipeline/scheduler.py +++ /dev/null @@ -1,88 +0,0 @@ -from collections.abc import AsyncGenerator - -from astrbot.core import logger -from astrbot.core.platform import AstrMessageEvent -from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent -from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import ( - WecomAIBotMessageEvent, -) - -from . import STAGES_ORDER -from .context import PipelineContext -from .stage import registered_stages - - -class PipelineScheduler: - """管道调度器,负责调度各个阶段的执行""" - - def __init__(self, context: PipelineContext) -> None: - registered_stages.sort( - key=lambda x: STAGES_ORDER.index(x.__name__), - ) # 按照顺序排序 - self.ctx = context # 上下文对象 - self.stages = [] # 存储阶段实例 - - async def initialize(self) -> None: - """初始化管道调度器时, 初始化所有阶段""" - for stage_cls in registered_stages: - stage_instance = stage_cls() # 创建实例 - await stage_instance.initialize(self.ctx) - self.stages.append(stage_instance) - - async def _process_stages(self, event: AstrMessageEvent, from_stage=0) -> None: - """依次执行各个阶段 - - Args: - event (AstrMessageEvent): 事件对象 - from_stage (int): 从第几个阶段开始执行, 默认从0开始 - - """ - for i in range(from_stage, len(self.stages)): - stage = self.stages[i] # 获取当前要执行的阶段 - # logger.debug(f"执行阶段 {stage.__class__.__name__}") - coroutine = stage.process( - event, - ) # 调用阶段的process方法, 返回协程或者异步生成器 - - if isinstance(coroutine, AsyncGenerator): - # 如果返回的是异步生成器, 实现洋葱模型的核心 - async for _ in coroutine: - # 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段 - if event.is_stopped(): - logger.debug( - f"阶段 {stage.__class__.__name__} 已终止事件传播。", - ) - break - - # 递归调用, 处理所有后续阶段 - await self._process_stages(event, i + 1) - - # 此处是后续所有阶段处理完毕后返回的点, 执行后置处理 - if event.is_stopped(): - logger.debug( - f"阶段 {stage.__class__.__name__} 已终止事件传播。", - ) - break - else: - # 如果返回的是普通协程(不含yield的async函数), 则不进入下一层(基线条件) - # 简单地等待它执行完成, 然后继续执行下一个阶段 - await coroutine - - if event.is_stopped(): - logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") - break - - async def execute(self, event: AstrMessageEvent) -> None: - """执行 pipeline - - Args: - event (AstrMessageEvent): 事件对象 - - """ - await self._process_stages(event) - - # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 - if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent): - await event.send(None) - - logger.debug("pipeline 执行完毕。") diff --git a/astrbot/core/pipeline/session_status_check/stage.py b/astrbot/core/pipeline/session_status_check/stage.py deleted file mode 100644 index 26c3c235a..000000000 --- a/astrbot/core/pipeline/session_status_check/stage.py +++ /dev/null @@ -1,37 +0,0 @@ -from collections.abc import AsyncGenerator - -from astrbot.core import logger -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.star.session_llm_manager import SessionServiceManager - -from ..context import PipelineContext -from ..stage import Stage, register_stage - - -@register_stage -class SessionStatusCheckStage(Stage): - """检查会话是否整体启用""" - - async def initialize(self, ctx: PipelineContext) -> None: - self.ctx = ctx - self.conv_mgr = ctx.plugin_manager.context.conversation_manager - - async def process( - self, - event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: - # 检查会话是否整体启用 - if not await SessionServiceManager.is_session_enabled(event.unified_msg_origin): - logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。") - - # workaround for #2309 - conv_id = await self.conv_mgr.get_curr_conversation_id( - event.unified_msg_origin, - ) - if not conv_id: - await self.conv_mgr.new_conversation( - event.unified_msg_origin, - platform_id=event.get_platform_id(), - ) - - event.stop_event() diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py deleted file mode 100644 index 74aca4ef1..000000000 --- a/astrbot/core/pipeline/stage.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -import abc -from collections.abc import AsyncGenerator - -from astrbot.core.platform.astr_message_event import AstrMessageEvent - -from .context import PipelineContext - -registered_stages: list[type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型 - - -def register_stage(cls): - """一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类""" - registered_stages.append(cls) - return cls - - -class Stage(abc.ABC): - """描述一个 Pipeline 的某个阶段""" - - @abc.abstractmethod - async def initialize(self, ctx: PipelineContext) -> None: - """初始化阶段 - - Args: - ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器 - - """ - raise NotImplementedError - - @abc.abstractmethod - async def process( - self, - event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: - """处理事件 - - Args: - event (AstrMessageEvent): 事件对象,包含事件的相关信息 - Returns: - Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段) - - """ - raise NotImplementedError diff --git a/astrbot/core/pipeline/system/__init__.py b/astrbot/core/pipeline/system/__init__.py new file mode 100644 index 000000000..734abb23d --- /dev/null +++ b/astrbot/core/pipeline/system/__init__.py @@ -0,0 +1,11 @@ +from .access_control import AccessController +from .command_dispatcher import CommandDispatcher +from .event_preprocessor import EventPreprocessor +from .rate_limit import RateLimiter + +__all__ = [ + "AccessController", + "CommandDispatcher", + "EventPreprocessor", + "RateLimiter", +] diff --git a/astrbot/core/pipeline/system/access_control.py b/astrbot/core/pipeline/system/access_control.py new file mode 100644 index 000000000..5d1aae794 --- /dev/null +++ b/astrbot/core/pipeline/system/access_control.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from astrbot.core import logger +from astrbot.core.pipeline.context import PipelineContext +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.platform.message_type import MessageType +from astrbot.core.star.node_star import NodeResult + + +class AccessController: + """Whitelist check""" + + def __init__(self, ctx: PipelineContext): + self._ctx = ctx + self._initialized = False + self.enable_whitelist_check: bool = False + self.whitelist: list[str] = [] + self.wl_ignore_admin_on_group: bool = False + self.wl_ignore_admin_on_friend: bool = False + self.wl_log: bool = False + + async def initialize(self) -> None: + if self._initialized: + return + cfg = self._ctx.astrbot_config["platform_settings"] + self.enable_whitelist_check = cfg["enable_id_white_list"] + self.whitelist = [ + str(i).strip() for i in cfg["id_whitelist"] if str(i).strip() != "" + ] + self.wl_ignore_admin_on_group = cfg["wl_ignore_admin_on_group"] + self.wl_ignore_admin_on_friend = cfg["wl_ignore_admin_on_friend"] + self.wl_log = cfg["id_whitelist_log"] + self._initialized = True + + async def apply(self, event: AstrMessageEvent) -> NodeResult: + if not self.enable_whitelist_check: + return NodeResult.CONTINUE + + if not self.whitelist: + return NodeResult.CONTINUE + + if event.get_platform_name() == "webchat": + return NodeResult.CONTINUE + + if self.wl_ignore_admin_on_group: + if ( + event.role == "admin" + and event.get_message_type() == MessageType.GROUP_MESSAGE + ): + return NodeResult.CONTINUE + + if self.wl_ignore_admin_on_friend: + if ( + event.role == "admin" + and event.get_message_type() == MessageType.FRIEND_MESSAGE + ): + return NodeResult.CONTINUE + + if ( + event.unified_msg_origin not in self.whitelist + and str(event.get_group_id()).strip() not in self.whitelist + ): + if self.wl_log: + logger.info( + f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。" + ) + event.stop_event() + return NodeResult.STOP + + return NodeResult.CONTINUE diff --git a/astrbot/core/pipeline/system/command_dispatcher.py b/astrbot/core/pipeline/system/command_dispatcher.py new file mode 100644 index 000000000..5b9919b84 --- /dev/null +++ b/astrbot/core/pipeline/system/command_dispatcher.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +import traceback +from typing import TYPE_CHECKING, Any + +from astrbot.core import logger +from astrbot.core.message.message_event_result import MessageChain, MessageEventResult +from astrbot.core.pipeline.system.star_yield import StarHandlerAdapter, StarYieldDriver +from astrbot.core.star.star import star_map +from astrbot.core.star.star_handler import ( + EventType, + StarHandlerMetadata, + star_handlers_registry, +) + +if TYPE_CHECKING: + from astrbot.core.config import AstrBotConfig + from astrbot.core.pipeline.agent.executor import AgentExecutor + from astrbot.core.pipeline.engine.send_service import SendService + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +class CommandDispatcher: + def __init__( + self, + config: AstrBotConfig, + send_service: SendService, + agent_executor: AgentExecutor | None = None, + ) -> None: + self._config = config + self._send_service = send_service + self._agent_executor = agent_executor + + # 初始化 yield 驱动器 + self._yield_driver = StarYieldDriver( + self._send_message, + ) + self._handler_adapter = StarHandlerAdapter(self._yield_driver) + + # 配置 + self._no_permission_reply = config.get("platform_settings", {}).get( + "no_permission_reply", True + ) + self._disable_builtin = config.get("disable_builtin_commands", False) + + async def _send_message(self, event: AstrMessageEvent) -> None: + """发送消息回调""" + if event.get_result(): + await self._send_service.send(event) + + async def _handle_provider_request( + self, + event: AstrMessageEvent, + ) -> None: + """收到 ProviderRequest 时立即执行 Agent 并发送结果""" + if not self._agent_executor: + return + await self._agent_executor.run(event) + await self._send_service.send(event) + event.set_extra("_provider_request_consumed", True) + event.set_extra("provider_request", None) + event.set_extra("has_provider_request", False) + + async def match( + self, + event: AstrMessageEvent, + plugins_name: list[str] | None = None, + ) -> list[tuple[StarHandlerMetadata, dict[str, Any]]]: + """仅匹配命令,不执行 + + 用于在系统机制检查之前确定是否有命令匹配,并设置 is_wake 标志。 + 匹配结果应传递给 execute() 方法在系统机制检查之后执行。 + + Args: + event: 消息事件 + plugins_name: 启用的插件列表(None 表示全部) + + Returns: + 匹配的 handler 列表 [(handler, parsed_params), ...] + """ + return await self._match_handlers(event, plugins_name) + + async def execute( + self, + event: AstrMessageEvent, + matched_handlers: list[tuple[StarHandlerMetadata, dict[str, Any]]], + ) -> bool: + """执行已匹配的命令 + + 应在系统机制检查(限流、权限等)之后调用。 + + Args: + event: 消息事件 + matched_handlers: match() 返回的匹配结果 + + Returns: + bool: 是否有命令被执行 + """ + if not matched_handlers: + return False + + for handler, parsed_params in matched_handlers: + plugin_meta = star_map.get(handler.handler_module_path) + if not plugin_meta: + logger.warning( + f"Plugin not found for handler: {handler.handler_module_path}" + ) + continue + + logger.debug(f"Dispatching to {plugin_meta.name}.{handler.handler_name}") + + try: + # 使用适配器调用,完整支持 yield + result = await self._handler_adapter.invoke( + handler.handler, + event, + **parsed_params, + ) + + if result.error: + # handler 执行出错 + if event.is_at_or_wake_command: + error_msg = ( + f":(\n\n调用插件 {plugin_meta.name} 的 " + f"{handler.handler_name} 时出现异常:{result.error}" + ) + event.set_result(MessageEventResult().message(error_msg)) + event.stop_event() + return True + + # 检查是否有 LLM 请求 + if result.llm_requests and not event.get_extra( + "_provider_request_consumed", + False, + ): + event.set_extra("has_provider_request", True) + + if result.stopped or event.is_stopped(): + return True + + except Exception as e: + logger.error(f"Dispatch error: {e}") + logger.error(traceback.format_exc()) + event.stop_event() + return True + + return len(matched_handlers) > 0 + + async def _match_handlers( + self, + event: AstrMessageEvent, + plugins_name: list[str] | None, + ) -> list[tuple[StarHandlerMetadata, dict[str, Any]]]: + """匹配所有适用的 handler + + Returns: + [(handler, parsed_params), ...] + """ + matched: list[tuple[StarHandlerMetadata, dict[str, Any]]] = [] + + for handler in star_handlers_registry.get_handlers_by_event_type( + EventType.AdapterMessageEvent, + plugins_name=plugins_name, + ): + # 跳过内置命令(如配置) + if ( + self._disable_builtin + and handler.handler_module_path + == "astrbot.builtin_stars.builtin_commands.main" + ): + continue + + # 必须有过滤器 + if not handler.event_filters: + continue + + # 应用过滤器 + passed = True + permission_failed = False + permission_raise_error = False + parsed_params: dict[str, Any] = {} + + for f in handler.event_filters: + try: + from astrbot.core.star.filter.permission import PermissionTypeFilter + + if isinstance(f, PermissionTypeFilter): + if not f.filter(event, self._config): + permission_failed = True + permission_raise_error = f.raise_error + elif not f.filter(event, self._config): + passed = False + break + except Exception as e: + # 过滤器执行出错 — 发送错误消息并停止 + plugin_meta = star_map.get(handler.handler_module_path) + plugin_name = plugin_meta.name if plugin_meta else "unknown" + await event.send( + MessageEventResult().message(f"插件 {plugin_name}: {e}") + ) + event.stop_event() + passed = False + break + + # 获取解析的参数 + if "parsed_params" in event.get_extra(default={}): + parsed_params = event.get_extra("parsed_params") + event._extras.pop("parsed_params", None) + + if not passed: + continue + + if permission_failed: + if not permission_raise_error: + continue + if self._no_permission_reply: + await self._handle_permission_denied(event, handler) + event.stop_event() + return [] + + # 跳过 CommandGroup 的空 handler + from astrbot.core.star.filter.command_group import CommandGroupFilter + + is_group_cmd = any( + isinstance(f, CommandGroupFilter) for f in handler.event_filters + ) + if not is_group_cmd: + matched.append((handler, parsed_params)) + event.is_wake = True + + return matched + + @staticmethod + async def _handle_permission_denied( + event: AstrMessageEvent, + handler: StarHandlerMetadata, + ) -> None: + """处理权限不足""" + plugin_meta = star_map.get(handler.handler_module_path) + plugin_name = plugin_meta.name if plugin_meta else "unknown" + + await event.send( + MessageChain().message( + f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。" + f"通过 /sid 获取 ID 并请管理员添加。" + ) + ) + logger.info( + f"触发 {plugin_name} 时, 用户(ID={event.get_sender_id()}) 权限不足。" + ) diff --git a/astrbot/core/pipeline/system/event_preprocessor.py b/astrbot/core/pipeline/system/event_preprocessor.py new file mode 100644 index 000000000..0f3f478b3 --- /dev/null +++ b/astrbot/core/pipeline/system/event_preprocessor.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from astrbot.core.message.components import Image, Record +from astrbot.core.pipeline.context import PipelineContext +from astrbot.core.pipeline.system.session_utils import build_unique_session_id +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.platform.message_type import MessageType +from astrbot.core.utils.path_util import path_Mapping + + +class EventPreprocessor: + """系统级事件预处理""" + + def __init__(self, ctx: PipelineContext): + platform_settings = ctx.astrbot_config.get("platform_settings", {}) + self.ignore_bot_self_message = platform_settings.get( + "ignore_bot_self_message", False + ) + self.unique_session = platform_settings.get("unique_session", False) + self.admins_id: list[str] = ctx.astrbot_config.get("admins_id", []) + self.path_mapping: list[str] = platform_settings.get("path_mapping", []) + + async def preprocess(self, event: AstrMessageEvent) -> bool: + """ + 系统级预处理。 + + Returns: + 是否继续处理 + """ + # 应用唯一会话 ID + if self.unique_session and event.message_obj.type == MessageType.GROUP_MESSAGE: + sid = build_unique_session_id(event) + if sid: + event.session_id = sid + + # 过滤机器人自身消息 + if ( + self.ignore_bot_self_message + and event.get_self_id() == event.get_sender_id() + ): + event.stop_event() + return False + + # 识别管理员身份 + for admin_id in self.admins_id: + if str(event.get_sender_id()) == admin_id: + event.role = "admin" + break + + # 入站 Record/Image 路径映射 + if self.path_mapping: + message_chain = event.get_messages() + for idx, component in enumerate(message_chain): + if isinstance(component, Record | Image): + if component.url: + component.url = path_Mapping(self.path_mapping, component.url) + if component.file: + component.file = path_Mapping(self.path_mapping, component.file) + message_chain[idx] = component + + return True diff --git a/astrbot/core/pipeline/system/rate_limit.py b/astrbot/core/pipeline/system/rate_limit.py new file mode 100644 index 000000000..41d85915a --- /dev/null +++ b/astrbot/core/pipeline/system/rate_limit.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import asyncio +from collections import defaultdict, deque +from datetime import datetime, timedelta + +from astrbot.core import logger +from astrbot.core.config.astrbot_config import RateLimitStrategy +from astrbot.core.pipeline.context import PipelineContext +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.star.node_star import NodeResult + + +class RateLimiter: + """Fixed-window rate limiter""" + + def __init__(self, ctx: PipelineContext): + self._ctx = ctx + self._initialized = False + self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque) + self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + self.rate_limit_count: int = 0 + self.rate_limit_time: timedelta = timedelta(0) + self.rl_strategy: str = "" + + async def initialize(self) -> None: + if self._initialized: + return + cfg = self._ctx.astrbot_config["platform_settings"]["rate_limit"] + self.rate_limit_count = cfg["count"] + self.rate_limit_time = timedelta(seconds=cfg["time"]) + self.rl_strategy = cfg["strategy"] + self._initialized = True + + async def apply(self, event: AstrMessageEvent) -> NodeResult: + session_id = event.session_id + now = datetime.now() + + async with self.locks[session_id]: + while True: + timestamps = self.event_timestamps[session_id] + self._remove_expired_timestamps(timestamps, now) + + if len(timestamps) < self.rate_limit_count: + timestamps.append(now) + break + + next_window_time = timestamps[0] + self.rate_limit_time + stall_duration = (next_window_time - now).total_seconds() + 0.3 + + match self.rl_strategy: + case RateLimitStrategy.STALL.value: + logger.info( + f"会话 {session_id} 被限流。暂停 {stall_duration:.2f} 秒。" + ) + await asyncio.sleep(stall_duration) + now = datetime.now() + case RateLimitStrategy.DISCARD.value: + logger.info(f"会话 {session_id} 被限流。此请求已被丢弃。") + event.stop_event() + return NodeResult.STOP + + return NodeResult.CONTINUE + + def _remove_expired_timestamps( + self, timestamps: deque[datetime], now: datetime + ) -> None: + expiry_threshold = now - self.rate_limit_time + while timestamps and timestamps[0] < expiry_threshold: + timestamps.popleft() diff --git a/astrbot/core/pipeline/system/session_utils.py b/astrbot/core/pipeline/system/session_utils.py new file mode 100644 index 000000000..d63eab78a --- /dev/null +++ b/astrbot/core/pipeline/system/session_utils.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from collections.abc import Callable + +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +UNIQUE_SESSION_ID_BUILDERS: dict[str, Callable[[AstrMessageEvent], str | None]] = { + "aiocqhttp": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}", + "slack": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}", + "dingtalk": lambda e: e.get_sender_id(), + "qq_official": lambda e: e.get_sender_id(), + "qq_official_webhook": lambda e: e.get_sender_id(), + "lark": lambda e: f"{e.get_sender_id()}%{e.get_group_id()}", + "misskey": lambda e: f"{e.get_session_id()}_{e.get_sender_id()}", +} + + +def build_unique_session_id(event: AstrMessageEvent) -> str | None: + platform = event.get_platform_name() + builder = UNIQUE_SESSION_ID_BUILDERS.get(platform) + return builder(event) if builder else None diff --git a/astrbot/core/pipeline/system/star_yield.py b/astrbot/core/pipeline/system/star_yield.py new file mode 100644 index 000000000..7534b16f9 --- /dev/null +++ b/astrbot/core/pipeline/system/star_yield.py @@ -0,0 +1,200 @@ +"""Star 插件 yield 模式兼容层。 + +提供 StarYieldDriver 和 StarHandlerAdapter,用于在新架构中 +支持传统 Star 插件的 AsyncGenerator (yield) 模式。 +""" + +from __future__ import annotations + +import inspect +import traceback +from collections.abc import AsyncGenerator +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from astrbot.core import logger +from astrbot.core.message.message_event_result import CommandResult, MessageEventResult + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from astrbot.core.platform.astr_message_event import AstrMessageEvent + from astrbot.core.provider.entities import ProviderRequest + + +@dataclass +class YieldDriverResult: + """yield 驱动执行结果""" + + messages_sent: int = 0 + llm_requests: list[Any] = field(default_factory=list) + stopped: bool = False + error: Exception | None = None + + +class StarYieldDriver: + """Star 插件 yield 模式驱动器""" + + def __init__( + self, + send_callback: Callable, + provider_request_callback: Callable[ + [AstrMessageEvent, ProviderRequest], Awaitable[None] + ] + | None = None, + ) -> None: + """ + Args: + send_callback: async def (event: AstrMessageEvent) -> None + 发送消息的回调,由外部提供 + """ + self._send = send_callback + self._on_provider_request = provider_request_callback + + async def drive( + self, + generator: AsyncGenerator, + event: AstrMessageEvent, + ) -> YieldDriverResult: + """驱动 AsyncGenerator 执行 + + Args: + generator: 插件 handler 返回的 AsyncGenerator + event: 消息事件 + + Returns: + YieldDriverResult 包含执行统计 + """ + result = YieldDriverResult() + + while True: + try: + yielded = await generator.asend(None) + except StopAsyncIteration: + break + except Exception as e: + result.error = e + logger.error(traceback.format_exc()) + break + + try: + await self._handle_yielded(yielded, event, result) + except Exception as e: + # 将异常传回 generator,让插件有机会 catch + try: + yielded = await generator.athrow(type(e), e, e.__traceback__) + except StopAsyncIteration: + break + except Exception as inner_e: + result.error = inner_e + logger.error(traceback.format_exc()) + break + await self._handle_yielded(yielded, event, result) + + if event.is_stopped(): + result.stopped = True + break + + return result + + async def _handle_yielded( + self, + yielded: Any, + event: AstrMessageEvent, + result: YieldDriverResult, + ) -> None: + """处理 yield 出来的值""" + from astrbot.core.provider.entities import ProviderRequest + + if yielded is None: + # yield 空值 — 检查 event 上是否已有 result + if event.get_result(): + await self._send_and_clear(event, result) + return + + if isinstance(yielded, ProviderRequest): + # LLM 请求 + result.llm_requests.append(yielded) + event.set_extra("has_provider_request", True) + event.set_extra("provider_request", yielded) + if self._on_provider_request: + await self._on_provider_request(event, yielded) + return + + if isinstance(yielded, MessageEventResult | CommandResult): + event.set_result(yielded) + await self._send_and_clear(event, result) + return + + if isinstance(yielded, str): + event.set_result(MessageEventResult().message(yielded)) + await self._send_and_clear(event, result) + return + + # 未知类型 — 检查 event 上是否有 result (插件可能直接 set_result) + if event.get_result(): + await self._send_and_clear(event, result) + + async def _send_and_clear( + self, + event: AstrMessageEvent, + result: YieldDriverResult, + ) -> None: + """发送消息并清理 result""" + if event.get_result(): + await self._send(event) + result.messages_sent += 1 + event.clear_result() + + +class StarHandlerAdapter: + """Star Handler 适配器 + + 统一处理 async def 和 async generator 两种 handler 形式。 + 自动检测 handler 返回类型并选择合适的执行方式。 + """ + + def __init__(self, yield_driver: StarYieldDriver) -> None: + self._driver = yield_driver + + async def invoke( + self, + handler: Callable, + event: AstrMessageEvent, + *args: Any, + **kwargs: Any, + ) -> YieldDriverResult: + """调用 handler + + 自动检测 handler 类型: + - AsyncGenerator: 使用 yield driver 驱动 + - Coroutine: 直接 await + + Returns: + YieldDriverResult + """ + result = YieldDriverResult() + + try: + ready_to_call = handler(event, *args, **kwargs) + except TypeError: + logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True) + result.error = TypeError("handler parameter mismatch") + return result + + if ready_to_call is None: + return result + + if inspect.isasyncgen(ready_to_call): + return await self._driver.drive(ready_to_call, event) + + if inspect.iscoroutine(ready_to_call): + try: + ret = await ready_to_call + if ret is not None: + await self._driver._handle_yielded(ret, event, result) + except Exception as e: + result.error = e + logger.error(traceback.format_exc()) + + return result diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py deleted file mode 100644 index 2dcb840e9..000000000 --- a/astrbot/core/pipeline/waking_check/stage.py +++ /dev/null @@ -1,237 +0,0 @@ -from collections.abc import AsyncGenerator, Callable - -from astrbot import logger -from astrbot.core.message.components import At, AtAll, Reply -from astrbot.core.message.message_event_result import MessageChain, MessageEventResult -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.platform.message_type import MessageType -from astrbot.core.star.filter.command_group import CommandGroupFilter -from astrbot.core.star.filter.permission import PermissionTypeFilter -from astrbot.core.star.session_plugin_manager import SessionPluginManager -from astrbot.core.star.star import star_map -from astrbot.core.star.star_handler import EventType, star_handlers_registry - -from ..context import PipelineContext -from ..stage import Stage, register_stage - -UNIQUE_SESSION_ID_BUILDERS: dict[str, Callable[[AstrMessageEvent], str | None]] = { - "aiocqhttp": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}", - "slack": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}", - "dingtalk": lambda e: e.get_sender_id(), - "qq_official": lambda e: e.get_sender_id(), - "qq_official_webhook": lambda e: e.get_sender_id(), - "lark": lambda e: f"{e.get_sender_id()}%{e.get_group_id()}", - "misskey": lambda e: f"{e.get_session_id()}_{e.get_sender_id()}", -} - - -def build_unique_session_id(event: AstrMessageEvent) -> str | None: - platform = event.get_platform_name() - builder = UNIQUE_SESSION_ID_BUILDERS.get(platform) - return builder(event) if builder else None - - -@register_stage -class WakingCheckStage(Stage): - """检查是否需要唤醒。唤醒机器人有如下几点条件: - - 1. 机器人被 @ 了 - 2. 机器人的消息被提到了 - 3. 以 wake_prefix 前缀开头,并且消息没有以 At 消息段开头 - 4. 插件(Star)的 handler filter 通过 - 5. 私聊情况下,位于 admins_id 列表中的管理员的消息(在白名单阶段中) - """ - - async def initialize(self, ctx: PipelineContext) -> None: - """初始化唤醒检查阶段 - - Args: - ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器 - - """ - self.ctx = ctx - self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get( - "no_permission_reply", - True, - ) - # 私聊是否需要 wake_prefix 才能唤醒机器人 - self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[ - "platform_settings" - ].get("friend_message_needs_wake_prefix", False) - # 是否忽略机器人自己发送的消息 - self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get( - "ignore_bot_self_message", - False, - ) - self.ignore_at_all = self.ctx.astrbot_config["platform_settings"].get( - "ignore_at_all", - False, - ) - self.disable_builtin_commands = self.ctx.astrbot_config.get( - "disable_builtin_commands", False - ) - platform_settings = self.ctx.astrbot_config.get("platform_settings", {}) - self.unique_session = platform_settings.get("unique_session", False) - - async def process( - self, - event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: - # apply unique session - if self.unique_session and event.message_obj.type == MessageType.GROUP_MESSAGE: - sid = build_unique_session_id(event) - if sid: - event.session_id = sid - - # ignore bot self message - if ( - self.ignore_bot_self_message - and event.get_self_id() == event.get_sender_id() - ): - event.stop_event() - return - - # 设置 sender 身份 - event.message_str = event.message_str.strip() - for admin_id in self.ctx.astrbot_config["admins_id"]: - if str(event.get_sender_id()) == admin_id: - event.role = "admin" - break - - # 检查 wake - wake_prefixes = self.ctx.astrbot_config["wake_prefix"] - messages = event.get_messages() - is_wake = False - for wake_prefix in wake_prefixes: - if event.message_str.startswith(wake_prefix): - if ( - not event.is_private_chat() - and isinstance(messages[0], At) - and str(messages[0].qq) != str(event.get_self_id()) - and str(messages[0].qq) != "all" - ): - # 如果是群聊,且第一个消息段是 At 消息,但不是 At 机器人或 At 全体成员,则不唤醒 - break - is_wake = True - event.is_at_or_wake_command = True - event.is_wake = True - event.message_str = event.message_str[len(wake_prefix) :].strip() - break - if not is_wake: - # 检查是否有at消息 / at全体成员消息 / 引用了bot的消息 - for message in messages: - if ( - ( - isinstance(message, At) - and (str(message.qq) == str(event.get_self_id())) - ) - or (isinstance(message, AtAll) and not self.ignore_at_all) - or ( - isinstance(message, Reply) - and str(message.sender_id) == str(event.get_self_id()) - ) - ): - is_wake = True - event.is_wake = True - wake_prefix = "" - event.is_at_or_wake_command = True - break - # 检查是否是私聊 - if event.is_private_chat() and not self.friend_message_needs_wake_prefix: - is_wake = True - event.is_wake = True - event.is_at_or_wake_command = True - wake_prefix = "" - - # 检查插件的 handler filter - activated_handlers = [] - handlers_parsed_params = {} # 注册了指令的 handler - - # 将 plugins_name 设置到 event 中 - enabled_plugins_name = self.ctx.astrbot_config.get("plugin_set", ["*"]) - if enabled_plugins_name == ["*"]: - # 如果是 *,则表示所有插件都启用 - event.plugins_name = None - else: - event.plugins_name = enabled_plugins_name - logger.debug(f"enabled_plugins_name: {enabled_plugins_name}") - - for handler in star_handlers_registry.get_handlers_by_event_type( - EventType.AdapterMessageEvent, - plugins_name=event.plugins_name, - ): - if ( - self.disable_builtin_commands - and handler.handler_module_path - == "astrbot.builtin_stars.builtin_commands.main" - ): - continue - - # filter 需满足 AND 逻辑关系 - passed = True - permission_not_pass = False - permission_filter_raise_error = False - if len(handler.event_filters) == 0: - continue - - for filter in handler.event_filters: - try: - if isinstance(filter, PermissionTypeFilter): - if not filter.filter(event, self.ctx.astrbot_config): - permission_not_pass = True - permission_filter_raise_error = filter.raise_error - elif not filter.filter(event, self.ctx.astrbot_config): - passed = False - break - except Exception as e: - await event.send( - MessageEventResult().message( - f"插件 {star_map[handler.handler_module_path].name}: {e}", - ), - ) - event.stop_event() - passed = False - break - if passed: - if permission_not_pass: - if not permission_filter_raise_error: - # 跳过 - continue - if self.no_permission_reply: - await event.send( - MessageChain().message( - f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。", - ), - ) - logger.info( - f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。", - ) - event.stop_event() - return - - is_wake = True - event.is_wake = True - - is_group_cmd_handler = any( - isinstance(f, CommandGroupFilter) for f in handler.event_filters - ) - if not is_group_cmd_handler: - activated_handlers.append(handler) - if "parsed_params" in event.get_extra(default={}): - handlers_parsed_params[handler.handler_full_name] = ( - event.get_extra("parsed_params") - ) - - event._extras.pop("parsed_params", None) - - # 根据会话配置过滤插件处理器 - activated_handlers = await SessionPluginManager.filter_handlers_by_session( - event, - activated_handlers, - ) - - event.set_extra("activated_handlers", activated_handlers) - event.set_extra("handlers_parsed_params", handlers_parsed_params) - - if not is_wake: - event.stop_event() diff --git a/astrbot/core/pipeline/whitelist_check/stage.py b/astrbot/core/pipeline/whitelist_check/stage.py deleted file mode 100644 index ea9c55228..000000000 --- a/astrbot/core/pipeline/whitelist_check/stage.py +++ /dev/null @@ -1,68 +0,0 @@ -from collections.abc import AsyncGenerator - -from astrbot.core import logger -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.platform.message_type import MessageType - -from ..context import PipelineContext -from ..stage import Stage, register_stage - - -@register_stage -class WhitelistCheckStage(Stage): - """检查是否在群聊/私聊白名单""" - - async def initialize(self, ctx: PipelineContext) -> None: - self.enable_whitelist_check = ctx.astrbot_config["platform_settings"][ - "enable_id_white_list" - ] - self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"] - self.whitelist = [ - str(i).strip() for i in self.whitelist if str(i).strip() != "" - ] - self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][ - "wl_ignore_admin_on_group" - ] - self.wl_ignore_admin_on_friend = ctx.astrbot_config["platform_settings"][ - "wl_ignore_admin_on_friend" - ] - self.wl_log = ctx.astrbot_config["platform_settings"]["id_whitelist_log"] - - async def process( - self, - event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: - if not self.enable_whitelist_check: - # 白名单检查未启用 - return - - if len(self.whitelist) == 0: - # 白名单为空,不检查 - return - - if event.get_platform_name() == "webchat": - # WebChat 豁免 - return - - # 检查是否在白名单 - if self.wl_ignore_admin_on_group: - if ( - event.role == "admin" - and event.get_message_type() == MessageType.GROUP_MESSAGE - ): - return - if self.wl_ignore_admin_on_friend: - if ( - event.role == "admin" - and event.get_message_type() == MessageType.FRIEND_MESSAGE - ): - return - if ( - event.unified_msg_origin not in self.whitelist - and str(event.get_group_id()).strip() not in self.whitelist - ): - if self.wl_log: - logger.info( - f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。", - ) - event.stop_event() diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 83b9813e0..22e02251b 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc import asyncio import hashlib @@ -5,7 +7,7 @@ import uuid from collections.abc import AsyncGenerator from time import time -from typing import Any +from typing import TYPE_CHECKING, Any from astrbot import logger from astrbot.core.agent.tool import ToolSet @@ -20,7 +22,13 @@ Plain, Reply, ) -from astrbot.core.message.message_event_result import MessageChain, MessageEventResult +from astrbot.core.message.message_event_result import ( + MessageChain, + MessageEventResult, + ResultContentType, + collect_streaming_result, +) +from astrbot.core.pipeline.engine.node_context import NodePacket, NodePacketKind from astrbot.core.platform.message_type import MessageType from astrbot.core.provider.entities import ProviderRequest from astrbot.core.utils.metrics import Metric @@ -30,6 +38,14 @@ from .message_session import MessageSesion, MessageSession # noqa from .platform_metadata import PlatformMetadata +if TYPE_CHECKING: + from astrbot.core.pipeline.agent import AgentExecutor + from astrbot.core.pipeline.engine.chain_config import ChainConfig + from astrbot.core.pipeline.engine.node_context import NodeContext, NodeContextStack + from astrbot.core.pipeline.engine.send_service import SendService + +ASTR_MESSAGE_EVENT_VERSION = 5 + class AstrMessageEvent(abc.ABC): def __init__( @@ -62,6 +78,9 @@ def __init__( self._result: MessageEventResult | None = None """消息事件的结果""" + self.version: int = ASTR_MESSAGE_EVENT_VERSION + """Event payload schema version.""" + self.created_at = time() """事件创建时间(Unix timestamp)""" self.trace = TraceSpan( @@ -85,6 +104,31 @@ def __init__( # back_compability self.platform = platform_meta + # Chain-level configs (set by ChainExecutor) + self.chain_config: ChainConfig | None = None + self.node_config: Any | None = None + + # Pipeline services (set by ChainExecutor) + self.send_service: SendService | None = None + self.agent_executor: AgentExecutor | None = None + + # NodeContext stack (lazily initialized) + self._node_context_stack: NodeContextStack | None = None + + @property + def context_stack(self) -> NodeContextStack: + """Get the NodeContextStack for this event, lazily initialized.""" + if self._node_context_stack is None: + from astrbot.core.pipeline.engine.node_context import NodeContextStack + + self._node_context_stack = NodeContextStack() + return self._node_context_stack + + @property + def node_context(self) -> NodeContext | None: + """Get the current NodeContext from the stack.""" + return self.context_stack.current() + @property def unified_msg_origin(self) -> str: """统一的消息来源字符串。格式为 platform_name:message_type:session_id""" @@ -124,6 +168,28 @@ def get_message_str(self) -> str: """获取消息字符串。""" return self.message_str + def set_message_str(self, text: str) -> None: + """Update input message_str and keep message_obj.message_str in sync.""" + self.message_str = text + self.message_obj.message_str = text + + def append_message_str(self, text: str) -> None: + """Append text to input message_str and keep message_obj.message_str in sync.""" + if not text: + return + self.message_str += text + self.message_obj.message_str += text + + def rebuild_message_str_from_plain(self) -> str: + """Rebuild input message_str from top-level Plain components.""" + parts: list[str] = [] + for comp in self.message_obj.message: + if isinstance(comp, Plain): + parts.append(comp.text) + merged = "".join(parts) + self.set_message_str(merged) + return merged + def _outline_chain(self, chain: list[BaseMessageComponent] | None) -> str: if not chain: return "" @@ -244,11 +310,11 @@ async def send_streaming( ) self._has_send_oper = True - async def _pre_send(self) -> None: - """调度器会在执行 send() 前调用该方法 deprecated in v3.5.18""" + async def _pre_send(self, message: MessageChain | None = None, **_): + """发送前钩子(平台可覆写)""" - async def _post_send(self) -> None: - """调度器会在执行 send() 后调用该方法 deprecated in v3.5.18""" + async def _post_send(self, message: MessageChain | None = None, **_): + """发送后钩子(平台可覆写)""" def set_result(self, result: MessageEventResult | str) -> None: """设置消息事件的结果。 @@ -279,7 +345,126 @@ async def check_count(self, event: AstrMessageEvent): result.chain = [] self._result = result - def stop_event(self) -> None: + def set_node_output(self, output: Any) -> None: + """Set node output for downstream nodes and sync MessageEventResult.""" + ctx = self.node_context + if not ctx: + raise RuntimeError("Node context is not available for this event.") + + packet = NodePacket.create(output) + ctx.output = packet + + if isinstance(packet.data, MessageEventResult): + self.set_result(packet.data) + + @staticmethod + def _normalize_node_names(names: str | list[str] | None) -> set[str] | None: + if names is None: + return None + if isinstance(names, str): + names_list = [names] + else: + names_list = list(names) + cleaned = {str(name).strip() for name in names_list if str(name).strip()} + return cleaned or None + + @staticmethod + async def _collect_streaming_output( + output: MessageEventResult, + ) -> MessageEventResult: + if output.result_content_type != ResultContentType.STREAMING_RESULT: + return output + if output.async_stream is None: + return output + return await collect_streaming_result(output, warn=True, logger=logger) + + @staticmethod + def _output_to_text(output: Any) -> str: + if isinstance(output, MessageChain): + return output.get_plain_text() + if isinstance(output, str): + return output + return str(output) + + async def get_node_input( + self, + strategy: str = "last", + names: str | list[str] | None = None, + ) -> Any: + """Get upstream node output with optional merge strategy. + + strategy: + - "last" (default): last output in chain order + - "first": first output in chain order + - "list": list of outputs + - "text_concat": merge as concatenated text + - "chain_concat": merge as MessageEventResult (chain) + names: + Limit outputs to specific node names (str or list[str]). + """ + ctx = self.node_context + if not ctx: + raise RuntimeError("Node context is not available for this event.") + + from astrbot.core.pipeline.engine.node_context import NodeExecutionStatus + + name_set = self._normalize_node_names(names) + packets = self.context_stack.get_outputs( + names=name_set, + status=NodeExecutionStatus.EXECUTED, + include_none=False, + ) + if not packets: + return None + + strategy = (strategy or "last").lower() + + if strategy == "last": + return packets[-1].data + if strategy == "first": + return packets[0].data + if strategy == "list": + return [packet.data for packet in packets] + + if strategy == "text_concat": + texts: list[str] = [] + for packet in packets: + output = packet.data + if packet.kind == NodePacketKind.MESSAGE: + if isinstance(output, MessageEventResult): + output = await self._collect_streaming_output(output) + text = self._output_to_text(output) + elif packet.kind == NodePacketKind.TEXT: + text = output if isinstance(output, str) else str(output) + else: + text = self._output_to_text(output) + if text and text.strip(): + texts.append(text) + return "\n".join(texts) + + if strategy == "chain_concat": + chain = [] + for packet in packets: + output = packet.data + if packet.kind == NodePacketKind.MESSAGE: + if isinstance(output, MessageEventResult): + output = await self._collect_streaming_output(output) + chain.extend(output.chain or []) + elif isinstance(output, MessageChain): + chain.extend(output.chain or []) + else: + chain.append(Plain(str(output))) + elif packet.kind == NodePacketKind.TEXT: + chain.append( + Plain(output if isinstance(output, str) else str(output)) + ) + else: + chain.append(Plain(str(output))) + return MessageEventResult(chain=chain) + + raise ValueError(f"Unsupported node input strategy: {strategy}") + + def stop_event(self): """终止事件传播。""" if self._result is None: self.set_result(MessageEventResult().stop_event()) diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 5eb62e6b3..fe0e5ecc6 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -201,9 +201,9 @@ async def handle_msg(self, message: AstrBotMessage) -> None: _, _, payload = message.raw_message # type: ignore message_event.set_extra("selected_provider", payload.get("selected_provider")) message_event.set_extra("selected_model", payload.get("selected_model")) - message_event.set_extra( - "enable_streaming", payload.get("enable_streaming", True) - ) + # 只有当 payload 明确提供 enable_streaming 时才设置,否则使用全局配置 + if "enable_streaming" in payload: + message_event.set_extra("enable_streaming", payload.get("enable_streaming")) message_event.set_extra("action_type", payload.get("action_type")) self.commit_event(message_event) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index ff0bb303d..5e9df9b0b 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -4,6 +4,8 @@ import traceback from typing import Protocol, runtime_checkable +from deprecated import deprecated + from astrbot.core import astrbot_config, logger, sp from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.db import BaseDatabase @@ -41,8 +43,6 @@ def __init__( self.providers_config: list = config["provider"] self.provider_sources_config: list = config.get("provider_sources", []) self.provider_settings: dict = config["provider_settings"] - self.provider_stt_settings: dict = config.get("provider_stt_settings", {}) - self.provider_tts_settings: dict = config.get("provider_tts_settings", {}) # 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager self.default_persona_name = persona_mgr.default_persona @@ -87,6 +87,13 @@ def selected_default_persona(self): """动态获取最新的默认选中 persona。已弃用,请使用 context.persona_mgr.get_default_persona_v3()""" return self.persona_mgr.selected_default_persona_v3 + @deprecated( + version="5.0", + reason=( + "Legacy session/global provider switching is not Node-aware and may not " + "work as expected in multi-node chains." + ), + ) async def set_provider( self, provider_id: str, @@ -103,6 +110,11 @@ async def set_provider( Version 4.0.0: 这个版本下已经默认隔离提供商 """ + logger.warning( + "ProviderManager.set_provider is deprecated and may not work as expected " + "in multi-node architecture. Use node-level provider binding instead." + ) + if provider_id not in self.inst_map: raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") if umo: @@ -153,6 +165,13 @@ async def get_provider_by_id(self, provider_id: str) -> Providers | None: """根据提供商 ID 获取提供商实例""" return self.inst_map.get(provider_id) + @deprecated( + version="5.0", + reason=( + "Legacy session/global provider resolution is not Node-aware and may not " + "work as expected in multi-node chains." + ), + ) def get_using_provider( self, provider_type: ProviderType, umo=None ) -> Providers | None: @@ -166,6 +185,11 @@ def get_using_provider( Provider: 正在使用的提供商实例。 """ + logger.warning( + "ProviderManager.get_using_provider is deprecated and may not work as " + "expected in multi-node architecture. Use node-aware resolver instead." + ) + provider = None provider_id = None if umo: @@ -186,19 +210,29 @@ def get_using_provider( if not provider: provider = self.provider_insts[0] if self.provider_insts else None elif provider_type == ProviderType.SPEECH_TO_TEXT: - provider_id = config["provider_stt_settings"].get("provider_id") - if not provider_id: - return None - provider = self.inst_map.get(provider_id) + global_stt_provider_id = sp.get( + "curr_provider_stt", + None, + scope="global", + scope_id="global", + ) + if isinstance(global_stt_provider_id, str) and global_stt_provider_id: + provider = self.inst_map.get(global_stt_provider_id) + provider_id = global_stt_provider_id if not provider: provider = ( self.stt_provider_insts[0] if self.stt_provider_insts else None ) elif provider_type == ProviderType.TEXT_TO_SPEECH: - provider_id = config["provider_tts_settings"].get("provider_id") - if not provider_id: - return None - provider = self.inst_map.get(provider_id) + global_tts_provider_id = sp.get( + "curr_provider_tts", + None, + scope="global", + scope_id="global", + ) + if isinstance(global_tts_provider_id, str) and global_tts_provider_id: + provider = self.inst_map.get(global_tts_provider_id) + provider_id = global_tts_provider_id if not provider: provider = ( self.tts_provider_insts[0] if self.tts_provider_insts else None @@ -230,13 +264,13 @@ async def initialize(self) -> None: ) selected_stt_provider_id = await sp.get_async( key="curr_provider_stt", - default=self.provider_stt_settings.get("provider_id"), + default=None, scope="global", scope_id="global", ) selected_tts_provider_id = await sp.get_async( key="curr_provider_tts", - default=self.provider_tts_settings.get("provider_id"), + default=None, scope="global", scope_id="global", ) @@ -497,14 +531,6 @@ async def load_provider(self, provider_config: dict) -> None: await inst.initialize() self.stt_provider_insts.append(inst) - if ( - self.provider_stt_settings.get("provider_id") - == provider_config["id"] - ): - self.curr_stt_provider_inst = inst - logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。", - ) if not self.curr_stt_provider_inst: self.curr_stt_provider_inst = inst @@ -520,14 +546,6 @@ async def load_provider(self, provider_config: dict) -> None: await inst.initialize() self.tts_provider_insts.append(inst) - if ( - self.provider_settings.get("provider_id") - == provider_config["id"] - ): - self.curr_tts_provider_inst = inst - logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。", - ) if not self.curr_tts_provider_inst: self.curr_tts_provider_inst = inst diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index 2bf86872e..d3b636122 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -5,64 +5,26 @@ from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin from .context import Context +from .modality import Modality, extract_modalities +from .node_star import NodeResult, NodeStar from .star import StarMetadata, star_map, star_registry +from .star_base import Star from .star_manager import PluginManager - -class Star(CommandParserMixin, PluginKVStoreMixin): - """所有插件(Star)的父类,所有插件都应该继承于这个类""" - - author: str - name: str - - def __init__(self, context: Context, config: dict | None = None) -> None: - StarTools.initialize(context) - self.context = context - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - if not star_map.get(cls.__module__): - metadata = StarMetadata( - star_cls_type=cls, - module_path=cls.__module__, - ) - star_map[cls.__module__] = metadata - star_registry.append(metadata) - else: - star_map[cls.__module__].star_cls_type = cls - star_map[cls.__module__].module_path = cls.__module__ - - async def text_to_image(self, text: str, return_url=True) -> str: - """将文本转换为图片""" - return await html_renderer.render_t2i( - text, - return_url=return_url, - template_name=self.context._config.get("t2i_active_template"), - ) - - async def html_render( - self, - tmpl: str, - data: dict, - return_url=True, - options: dict | None = None, - ) -> str: - """渲染 HTML""" - return await html_renderer.render_custom_template( - tmpl, - data, - return_url=return_url, - options=options, - ) - - async def initialize(self) -> None: - """当插件被激活时会调用这个方法""" - - async def terminate(self) -> None: - """当插件被禁用、重载插件时会调用这个方法""" - - def __del__(self) -> None: - """[Deprecated] 当插件被禁用、重载插件时会调用这个方法""" - - -__all__ = ["Context", "PluginManager", "Provider", "Star", "StarMetadata", "StarTools"] +__all__ = [ + "Context", + "CommandParserMixin", + "PluginKVStoreMixin", + "html_renderer", + "NodeResult", + "NodeStar", + "Modality", + "extract_modalities", + "PluginManager", + "Provider", + "Star", + "StarMetadata", + "StarTools", + "star_map", + "star_registry", +] diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 6a74580f6..7104a6e81 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -21,7 +21,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion from astrbot.core.platform.manager import PlatformManager from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager -from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType +from astrbot.core.provider.entities import LLMResponse, ProviderRequest from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager from astrbot.core.provider.manager import ProviderManager from astrbot.core.provider.provider import ( @@ -239,7 +239,11 @@ async def tool_loop_agent( raise Exception("Agent did not produce a final LLM response") return llm_resp - async def get_current_chat_provider_id(self, umo: str) -> str: + async def get_current_chat_provider_id( + self, + umo: str, + event: AstrMessageEvent | None = None, + ) -> str: """获取当前使用的聊天模型 Provider ID。 Args: @@ -251,11 +255,123 @@ async def get_current_chat_provider_id(self, umo: str) -> str: Raises: ProviderNotFoundError: 未找到。 """ - prov = self.get_using_provider(umo) + prov: Provider | None = None + if event is not None: + prov = self.get_chat_provider_for_event(event) + if not prov: + prov = self.get_using_provider(umo) if not prov: raise ProviderNotFoundError("Provider not found") return prov.meta().id + def _resolve_provider_for_event( + self, + event: AstrMessageEvent, + *, + expected_type: type, + default_provider_key: str, + kind_name: str, + providers_getter: Callable[[], list], + ): + selected_provider = event.get_extra("selected_provider") + if isinstance(selected_provider, str) and selected_provider: + provider = self.get_provider_by_id(selected_provider) + if isinstance(provider, expected_type): + return provider + if provider is None: + logger.error( + "Configured selected_provider is not found: %s", + selected_provider, + ) + else: + logger.warning( + "selected_provider is not a %s provider: %s", + kind_name, + selected_provider, + ) + + node_config = event.node_config or {} + if isinstance(node_config, dict): + node_provider_id = str(node_config.get("provider_id") or "").strip() + if node_provider_id: + provider = self.get_provider_by_id(node_provider_id) + if isinstance(provider, expected_type): + return provider + if provider is None: + logger.error( + "Configured %s provider is not found: %s", + kind_name, + node_provider_id, + ) + else: + logger.warning( + "Configured provider_id is not a %s provider: %s", + kind_name, + node_provider_id, + ) + return None + + chain_config_id = event.chain_config.config_id if event.chain_config else None + runtime_config = self.get_config_by_id(chain_config_id) + provider_settings = runtime_config.get("provider_settings", {}) + default_provider_id = str( + provider_settings.get(default_provider_key) or "" + ).strip() + + if default_provider_id: + provider = self.get_provider_by_id(default_provider_id) + if isinstance(provider, expected_type): + return provider + if provider is None: + logger.error( + "Configured %s is not found: %s", + default_provider_key, + default_provider_id, + ) + else: + logger.warning( + "Configured %s is not a %s provider: %s", + default_provider_key, + kind_name, + default_provider_id, + ) + + providers = providers_getter() + return providers[0] if providers else None + + def get_chat_provider_for_event(self, event: AstrMessageEvent) -> Provider | None: + """Resolve chat provider for a specific event with node-aware rules.""" + provider = self._resolve_provider_for_event( + event, + expected_type=Provider, + default_provider_key="default_provider_id", + kind_name="chat", + providers_getter=self.get_all_providers, + ) + return provider if isinstance(provider, Provider) else None + + def get_tts_provider_for_event(self, event: AstrMessageEvent) -> TTSProvider | None: + """Resolve TTS provider for a specific event with node-aware rules.""" + provider = self._resolve_provider_for_event( + event, + expected_type=TTSProvider, + default_provider_key="default_tts_provider_id", + kind_name="tts", + providers_getter=self.get_all_tts_providers, + ) + return provider if isinstance(provider, TTSProvider) else None + + def get_stt_provider_for_event(self, event: AstrMessageEvent) -> STTProvider | None: + """Resolve STT provider for a specific event with node-aware rules.""" + provider = self._resolve_provider_for_event( + event, + expected_type=STTProvider, + default_provider_key="default_stt_provider_id", + kind_name="stt", + providers_getter=self.get_all_stt_providers, + ) + return provider if isinstance(provider, STTProvider) else None + def get_registered_star(self, star_name: str) -> StarMetadata | None: """根据插件名获取插件的 Metadata""" for star in star_registry: @@ -335,6 +451,13 @@ def get_all_embedding_providers(self) -> list[EmbeddingProvider]: """获取所有用于 Embedding 任务的 Provider。""" return self.provider_manager.embedding_provider_insts + @deprecated( + version="5.0", + reason=( + "Use get_chat_provider_for_event(event) for node-aware resolution. " + "This fallback only uses session config defaults." + ), + ) def get_using_provider(self, umo: str | None = None) -> Provider | None: """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。 @@ -348,18 +471,35 @@ def get_using_provider(self, umo: str | None = None) -> Provider | None: Raises: ValueError: 该会话来源配置的的对话模型(提供商)的类型不正确。 """ - prov = self.provider_manager.get_using_provider( - provider_type=ProviderType.CHAT_COMPLETION, - umo=umo, - ) - if prov is None: - return None - if not isinstance(prov, Provider): - raise ValueError( - f"该会话来源的对话模型(提供商)的类型不正确: {type(prov)}" - ) - return prov - + config = self.get_config(umo) + provider_id = str( + config.get("provider_settings", {}).get("default_provider_id") or "" + ).strip() + if provider_id: + prov = self.get_provider_by_id(provider_id) + if isinstance(prov, Provider): + return prov + if prov is None: + logger.warning( + "Configured default_provider_id is not found: %s", + provider_id, + ) + else: + logger.warning( + "Configured default_provider_id is not a chat provider: %s", + provider_id, + ) + + providers = self.get_all_providers() + return providers[0] if providers else None + + @deprecated( + version="5.0", + reason=( + "Use node-level provider binding or node-aware resolver. " + "This fallback only uses session config defaults." + ), + ) def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None: """获取当前使用的用于 TTS 任务的 Provider。 @@ -372,14 +512,35 @@ def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None: Raises: ValueError: 返回的提供者不是 TTSProvider 类型。 """ - prov = self.provider_manager.get_using_provider( - provider_type=ProviderType.TEXT_TO_SPEECH, - umo=umo, - ) - if prov and not isinstance(prov, TTSProvider): - raise ValueError("返回的 Provider 不是 TTSProvider 类型") - return prov - + config = self.get_config(umo) + provider_id = str( + config.get("provider_settings", {}).get("default_tts_provider_id") or "" + ).strip() + if provider_id: + prov = self.get_provider_by_id(provider_id) + if isinstance(prov, TTSProvider): + return prov + if prov is None: + logger.warning( + "Configured default_tts_provider_id is not found: %s", + provider_id, + ) + else: + logger.warning( + "Configured default_tts_provider_id is not a TTS provider: %s", + provider_id, + ) + + providers = self.get_all_tts_providers() + return providers[0] if providers else None + + @deprecated( + version="5.0", + reason=( + "Use node-level provider binding or node-aware resolver. " + "This fallback only uses session config defaults." + ), + ) def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None: """获取当前使用的用于 STT 任务的 Provider。 @@ -392,13 +553,27 @@ def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None: Raises: ValueError: 返回的提供者不是 STTProvider 类型。 """ - prov = self.provider_manager.get_using_provider( - provider_type=ProviderType.SPEECH_TO_TEXT, - umo=umo, - ) - if prov and not isinstance(prov, STTProvider): - raise ValueError("返回的 Provider 不是 STTProvider 类型") - return prov + config = self.get_config(umo) + provider_id = str( + config.get("provider_settings", {}).get("default_stt_provider_id") or "" + ).strip() + if provider_id: + prov = self.get_provider_by_id(provider_id) + if isinstance(prov, STTProvider): + return prov + if prov is None: + logger.warning( + "Configured default_stt_provider_id is not found: %s", + provider_id, + ) + else: + logger.warning( + "Configured default_stt_provider_id is not a STT provider: %s", + provider_id, + ) + + providers = self.get_all_stt_providers() + return providers[0] if providers else None def get_config(self, umo: str | None = None) -> AstrBotConfig: """获取 AstrBot 的配置。 @@ -417,6 +592,10 @@ def get_config(self, umo: str | None = None) -> AstrBotConfig: return self._config return self.astrbot_config_mgr.get_conf(umo) + def get_config_by_id(self, config_id: str | None) -> AstrBotConfig: + """通过配置文件 ID 获取配置,不依赖 umo 映射。""" + return self.astrbot_config_mgr.get_conf_by_id(config_id) + async def send_message( self, session: str | MessageSesion, diff --git a/astrbot/core/star/modality.py b/astrbot/core/star/modality.py new file mode 100644 index 000000000..d7c96d0b2 --- /dev/null +++ b/astrbot/core/star/modality.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import enum + +from astrbot.core.message.components import BaseMessageComponent, ComponentType + + +class Modality(enum.Enum): + """单一模态类型""" + + TEXT = "text" + IMAGE = "image" + AUDIO = "audio" + VIDEO = "video" + FILE = "file" + + +# 组件类型到模态的映射 +_COMPONENT_TYPE_TO_MODALITY: dict[ComponentType, Modality] = { + ComponentType.Plain: Modality.TEXT, + ComponentType.Image: Modality.IMAGE, + ComponentType.Record: Modality.AUDIO, + ComponentType.Video: Modality.VIDEO, + ComponentType.File: Modality.FILE, +} + + +def extract_modalities(components: list[BaseMessageComponent]) -> set[Modality]: + """从消息组件列表提取实际模态集合""" + result: set[Modality] = set() + for comp in components: + modality = _COMPONENT_TYPE_TO_MODALITY.get(comp.type) + if modality is not None: + result.add(modality) + return result or {Modality.TEXT} # 默认 TEXT diff --git a/astrbot/core/star/node_star.py b/astrbot/core/star/node_star.py new file mode 100644 index 000000000..53aa26da0 --- /dev/null +++ b/astrbot/core/star/node_star.py @@ -0,0 +1,76 @@ +"""NodeStar base class for pipeline nodes.""" + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING + +from .star_base import Star + +if TYPE_CHECKING: + from astrbot.core.message.message_event_result import MessageEventResult + from astrbot.core.platform.astr_message_event import AstrMessageEvent + from astrbot.core.star.star import StarMetadata + + +class NodeResult(Enum): + CONTINUE = "continue" + STOP = "stop" + WAIT = "wait" + SKIP = "skip" + + +def is_node_star_metadata(metadata: StarMetadata) -> bool: + """Return whether metadata represents a NodeStar plugin.""" + if metadata.star_cls_type: + try: + return issubclass(metadata.star_cls_type, NodeStar) + except TypeError: + return False + return isinstance(metadata.star_cls, NodeStar) + + +class NodeStar(Star): + """Star subclass that can be mounted into pipeline chains.""" + + def __init__(self, context, config: dict | None = None): + super().__init__(context, config) + self.initialized_node_keys: set[tuple[str, str]] = set() + + async def node_initialize(self) -> None: + pass + + async def process( + self, + event: AstrMessageEvent, + ) -> NodeResult: + raise NotImplementedError + + @staticmethod + async def collect_stream( + event: AstrMessageEvent, + result: MessageEventResult | None = None, + ) -> str | None: + from astrbot.core.message.components import Plain + from astrbot.core.message.message_event_result import ( + ResultContentType, + collect_streaming_result, + ) + + if result is None: + result = event.get_result() + if not result: + return None + + if result.result_content_type != ResultContentType.STREAMING_RESULT: + return None + + if result.async_stream is None: + return None + + await collect_streaming_result(result) + + parts: list[str] = [ + comp.text for comp in result.chain if isinstance(comp, Plain) + ] + return "".join(parts) diff --git a/astrbot/core/star/register/star.py b/astrbot/core/star/register/star.py index 617cd5ff7..c1a0ce10c 100644 --- a/astrbot/core/star/register/star.py +++ b/astrbot/core/star/register/star.py @@ -1,6 +1,6 @@ import warnings -from astrbot.core.star import StarMetadata, star_map +from astrbot.core.star.star import StarMetadata, star_map _warned_register_star = False diff --git a/astrbot/core/star/session_llm_manager.py b/astrbot/core/star/session_llm_manager.py deleted file mode 100644 index ad4a473b4..000000000 --- a/astrbot/core/star/session_llm_manager.py +++ /dev/null @@ -1,185 +0,0 @@ -"""会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态""" - -from astrbot.core import logger, sp -from astrbot.core.platform.astr_message_event import AstrMessageEvent - - -class SessionServiceManager: - """管理会话级别的服务启停状态,包括LLM和TTS""" - - # ============================================================================= - # LLM 相关方法 - # ============================================================================= - - @staticmethod - async def is_llm_enabled_for_session(session_id: str) -> bool: - """检查LLM是否在指定会话中启用 - - Args: - session_id: 会话ID (unified_msg_origin) - - Returns: - bool: True表示启用,False表示禁用 - - """ - # 获取会话服务配置 - session_services = await sp.get_async( - scope="umo", - scope_id=session_id, - key="session_service_config", - default={}, - ) - - # 如果配置了该会话的LLM状态,返回该状态 - llm_enabled = session_services.get("llm_enabled") - if llm_enabled is not None: - return llm_enabled - - # 如果没有配置,默认为启用(兼容性考虑) - return True - - @staticmethod - async def set_llm_status_for_session(session_id: str, enabled: bool) -> None: - """设置LLM在指定会话中的启停状态 - - Args: - session_id: 会话ID (unified_msg_origin) - enabled: True表示启用,False表示禁用 - - """ - session_config = ( - await sp.get_async( - scope="umo", - scope_id=session_id, - key="session_service_config", - default={}, - ) - or {} - ) - session_config["llm_enabled"] = enabled - await sp.put_async( - scope="umo", - scope_id=session_id, - key="session_service_config", - value=session_config, - ) - - @staticmethod - async def should_process_llm_request(event: AstrMessageEvent) -> bool: - """检查是否应该处理LLM请求 - - Args: - event: 消息事件 - - Returns: - bool: True表示应该处理,False表示跳过 - - """ - session_id = event.unified_msg_origin - return await SessionServiceManager.is_llm_enabled_for_session(session_id) - - # ============================================================================= - # TTS 相关方法 - # ============================================================================= - - @staticmethod - async def is_tts_enabled_for_session(session_id: str) -> bool: - """检查TTS是否在指定会话中启用 - - Args: - session_id: 会话ID (unified_msg_origin) - - Returns: - bool: True表示启用,False表示禁用 - - """ - # 获取会话服务配置 - session_services = await sp.get_async( - scope="umo", - scope_id=session_id, - key="session_service_config", - default={}, - ) - - # 如果配置了该会话的TTS状态,返回该状态 - tts_enabled = session_services.get("tts_enabled") - if tts_enabled is not None: - return tts_enabled - - # 如果没有配置,默认为启用(兼容性考虑) - return True - - @staticmethod - async def set_tts_status_for_session(session_id: str, enabled: bool) -> None: - """设置TTS在指定会话中的启停状态 - - Args: - session_id: 会话ID (unified_msg_origin) - enabled: True表示启用,False表示禁用 - - """ - session_config = ( - await sp.get_async( - scope="umo", - scope_id=session_id, - key="session_service_config", - default={}, - ) - or {} - ) - session_config["tts_enabled"] = enabled - await sp.put_async( - scope="umo", - scope_id=session_id, - key="session_service_config", - value=session_config, - ) - - logger.info( - f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}", - ) - - @staticmethod - async def should_process_tts_request(event: AstrMessageEvent) -> bool: - """检查是否应该处理TTS请求 - - Args: - event: 消息事件 - - Returns: - bool: True表示应该处理,False表示跳过 - - """ - session_id = event.unified_msg_origin - return await SessionServiceManager.is_tts_enabled_for_session(session_id) - - # ============================================================================= - # 会话整体启停相关方法 - # ============================================================================= - - @staticmethod - async def is_session_enabled(session_id: str) -> bool: - """检查会话是否整体启用 - - Args: - session_id: 会话ID (unified_msg_origin) - - Returns: - bool: True表示启用,False表示禁用 - - """ - # 获取会话服务配置 - session_services = await sp.get_async( - scope="umo", - scope_id=session_id, - key="session_service_config", - default={}, - ) - - # 如果配置了该会话的整体状态,返回该状态 - session_enabled = session_services.get("session_enabled") - if session_enabled is not None: - return session_enabled - - # 如果没有配置,默认为启用(兼容性考虑) - return True diff --git a/astrbot/core/star/session_plugin_manager.py b/astrbot/core/star/session_plugin_manager.py deleted file mode 100644 index a81113415..000000000 --- a/astrbot/core/star/session_plugin_manager.py +++ /dev/null @@ -1,101 +0,0 @@ -"""会话插件管理器 - 负责管理每个会话的插件启停状态""" - -from astrbot.core import logger, sp -from astrbot.core.platform.astr_message_event import AstrMessageEvent - - -class SessionPluginManager: - """管理会话级别的插件启停状态""" - - @staticmethod - async def is_plugin_enabled_for_session( - session_id: str, - plugin_name: str, - ) -> bool: - """检查插件是否在指定会话中启用 - - Args: - session_id: 会话ID (unified_msg_origin) - plugin_name: 插件名称 - - Returns: - bool: True表示启用,False表示禁用 - - """ - # 获取会话插件配置 - session_plugin_config = await sp.get_async( - scope="umo", - scope_id=session_id, - key="session_plugin_config", - default={}, - ) - session_config = session_plugin_config.get(session_id, {}) - - enabled_plugins = session_config.get("enabled_plugins", []) - disabled_plugins = session_config.get("disabled_plugins", []) - - # 如果插件在禁用列表中,返回False - if plugin_name in disabled_plugins: - return False - - # 如果插件在启用列表中,返回True - if plugin_name in enabled_plugins: - return True - - # 如果都没有配置,默认为启用(兼容性考虑) - return True - - @staticmethod - async def filter_handlers_by_session( - event: AstrMessageEvent, - handlers: list, - ) -> list: - """根据会话配置过滤处理器列表 - - Args: - event: 消息事件 - handlers: 原始处理器列表 - - Returns: - List: 过滤后的处理器列表 - - """ - from astrbot.core.star.star import star_map - - session_id = event.unified_msg_origin - filtered_handlers = [] - - session_plugin_config = await sp.get_async( - scope="umo", - scope_id=session_id, - key="session_plugin_config", - default={}, - ) - session_config = session_plugin_config.get(session_id, {}) - disabled_plugins = session_config.get("disabled_plugins", []) - - for handler in handlers: - # 获取处理器对应的插件 - plugin = star_map.get(handler.handler_module_path) - if not plugin: - # 如果找不到插件元数据,允许执行(可能是系统插件) - filtered_handlers.append(handler) - continue - - # 跳过保留插件(系统插件) - if plugin.reserved: - filtered_handlers.append(handler) - continue - - if plugin.name is None: - continue - - # 检查插件是否在当前会话中启用 - if plugin.name in disabled_plugins: - logger.debug( - f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}", - ) - else: - filtered_handlers.append(handler) - - return filtered_handlers diff --git a/astrbot/core/star/star.py b/astrbot/core/star/star.py index c5b7b1243..362f2068b 100644 --- a/astrbot/core/star/star.py +++ b/astrbot/core/star/star.py @@ -11,7 +11,7 @@ """key 是模块路径,__module__""" if TYPE_CHECKING: - from . import Star + from .star_base import Star @dataclass @@ -61,6 +61,9 @@ class StarMetadata: logo_path: str | None = None """插件 Logo 的路径""" + node_schema: dict | None = None + """Node 参数 Schema,仅对 node 类型插件有效""" + def __str__(self) -> str: return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}" diff --git a/astrbot/core/star/star_base.py b/astrbot/core/star/star_base.py new file mode 100644 index 000000000..4abfb3b50 --- /dev/null +++ b/astrbot/core/star/star_base.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from astrbot.core import html_renderer +from astrbot.core.star.star_tools import StarTools +from astrbot.core.utils.command_parser import CommandParserMixin +from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin + +from .context import Context +from .star import StarMetadata, star_map, star_registry + + +class Star(CommandParserMixin, PluginKVStoreMixin): + """所有插件(Star)的父类,所有插件都应该继承于这个类""" + + author: str + name: str + + def __init__(self, context: Context, config: dict | None = None): + StarTools.initialize(context) + self.context = context + self.config = config + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if not star_map.get(cls.__module__): + metadata = StarMetadata( + star_cls_type=cls, + module_path=cls.__module__, + ) + star_map[cls.__module__] = metadata + star_registry.append(metadata) + else: + star_map[cls.__module__].star_cls_type = cls + star_map[cls.__module__].module_path = cls.__module__ + + async def text_to_image(self, text: str, return_url=True) -> str: + """将文本转换为图片""" + return await html_renderer.render_t2i( + text, + return_url=return_url, + template_name=self.context._config.get("t2i_active_template"), + ) + + async def html_render( + self, + tmpl: str, + data: dict, + return_url=True, + options: dict | None = None, + ) -> str: + """渲染 HTML""" + return await html_renderer.render_custom_template( + tmpl, + data, + return_url=return_url, + options=options, + ) + + async def initialize(self): + """当插件被激活时会调用这个方法""" + + async def terminate(self): + """当插件被禁用、重载插件时会调用这个方法""" + + def __del__(self): + """[Deprecated] 当插件被禁用、重载插件时会调用这个方法""" diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 2c8c940f2..ff4e5be16 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -15,6 +15,7 @@ from astrbot.core import logger, pip_installer, sp from astrbot.core.agent.handoff import FunctionTool, HandoffTool from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.config.default import DEFAULT_VALUE_MAP from astrbot.core.platform.register import unregister_platform_adapters_by_module from astrbot.core.provider.register import llm_tools from astrbot.core.utils.astrbot_path import ( @@ -29,6 +30,7 @@ from .command_management import sync_command_configs from .context import Context from .filter.permission import PermissionType, PermissionTypeFilter +from .node_star import is_node_star_metadata from .star import star_map, star_registry from .star_handler import star_handlers_registry from .updator import PluginUpdator @@ -57,6 +59,7 @@ def __init__(self, context: Context, config: AstrBotConfig) -> None: ) """保留插件的路径。在 astrbot/builtin_stars 目录下""" self.conf_schema_fname = "_conf_schema.json" + self.node_conf_schema_fname = "_node_config_schema.json" self.logo_fname = "logo.png" """插件配置 Schema 文件名""" self._pm_lock = asyncio.Lock() @@ -66,7 +69,55 @@ def __init__(self, context: Context, config: AstrBotConfig) -> None: if os.getenv("ASTRBOT_RELOAD", "0") == "1": asyncio.create_task(self._watch_plugins_changes()) - async def _watch_plugins_changes(self) -> None: + @staticmethod + def _schema_to_default_config(schema: dict) -> dict: + """Convert a schema to default config dict, matching AstrBotConfig behavior.""" + conf: dict = {} + + def _parse_schema(schema: dict, conf: dict): + for k, v in schema.items(): + if v["type"] not in DEFAULT_VALUE_MAP: + raise TypeError( + f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}", + ) + if "default" in v: + default = v["default"] + else: + default = DEFAULT_VALUE_MAP[v["type"]] + + if v["type"] == "object": + conf[k] = {} + _parse_schema(v["items"], conf[k]) + elif v["type"] == "template_list": + conf[k] = default + else: + conf[k] = default + + _parse_schema(schema, conf) + return conf + + def _load_node_schema(self, metadata: StarMetadata, plugin_dir_path: str) -> None: + """Load node schema for NodeStar plugins when available.""" + if not is_node_star_metadata(metadata): + metadata.node_schema = None + return + node_schema_path = os.path.join( + plugin_dir_path, + self.node_conf_schema_fname, + ) + if not os.path.exists(node_schema_path): + metadata.node_schema = None + return + try: + with open(node_schema_path, encoding="utf-8") as f: + metadata.node_schema = json.loads(f.read()) + except Exception as e: + logger.warning( + f"插件 {plugin_dir_path} 读取节点配置 Schema 失败: {e!s}", + ) + metadata.node_schema = None + + async def _watch_plugins_changes(self): """监视插件文件变化""" try: async for changes in awatch( @@ -421,15 +472,15 @@ async def load(self, specified_module_path=None, specified_dir_name=None): self.conf_schema_fname, ) if os.path.exists(plugin_schema_path): - # 加载插件配置 with open(plugin_schema_path, encoding="utf-8") as f: - plugin_config = AstrBotConfig( - config_path=os.path.join( - self.plugin_config_path, - f"{root_dir_name}_config.json", - ), - schema=json.loads(f.read()), - ) + schema_payload = json.loads(f.read()) + plugin_config = AstrBotConfig( + config_path=os.path.join( + self.plugin_config_path, + f"{root_dir_name}_config.json", + ), + schema=schema_payload, + ) logo_path = os.path.join(plugin_dir_path, self.logo_fname) if path in star_map: @@ -452,6 +503,7 @@ async def load(self, specified_module_path=None, specified_dir_name=None): logger.warning( f"插件 {root_dir_name} 元数据载入失败: {e!s}。使用默认元数据。", ) + self._load_node_schema(metadata, plugin_dir_path) logger.info(metadata) metadata.config = plugin_config if path not in inactivated_plugins: @@ -568,6 +620,7 @@ async def load(self, specified_module_path=None, specified_dir_name=None): metadata.module_path = path star_map[path] = metadata star_registry.append(metadata) + self._load_node_schema(metadata, plugin_dir_path) # 禁用/启用插件 if metadata.module_path in inactivated_plugins: diff --git a/astrbot/core/tools/cron_tools.py b/astrbot/core/tools/cron_tools.py index ee22b943d..e2c4b5a65 100644 --- a/astrbot/core/tools/cron_tools.py +++ b/astrbot/core/tools/cron_tools.py @@ -77,6 +77,11 @@ async def call( "sender_id": context.context.event.get_sender_id(), "note": note, "origin": "tool", + "config_id": ( + context.context.event.chain_config.config_id + if context.context.event.chain_config + else None + ), } job = await cron_mgr.add_active_job( diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py index d8b010d50..e0e703e6c 100644 --- a/astrbot/core/umop_config_router.py +++ b/astrbot/core/umop_config_router.py @@ -6,8 +6,8 @@ class UmopConfigRouter: """UMOP 配置路由器""" - def __init__(self, sp: SharedPreferences) -> None: - self.umop_to_conf_id: dict[str, str] = {} + def __init__(self, sp: SharedPreferences): + self.umop_to_config_id: dict[str, str] = {} """UMOP 到配置文件 ID 的映射""" self.sp = sp @@ -16,14 +16,14 @@ async def initialize(self) -> None: async def _load_routing_table(self) -> None: """加载路由表""" - # 从 SharedPreferences 中加载 umop_to_conf_id 映射 + # 从 SharedPreferences 中加载 umop_to_config_id 映射 sp_data = await self.sp.get_async( key="umop_config_routing", default={}, scope="global", scope_id="global", ) - self.umop_to_conf_id = sp_data + self.umop_to_config_id = sp_data def _is_umo_match(self, p1: str, p2: str) -> bool: """判断 p2 umo 是否逻辑包含于 p1 umo""" @@ -35,7 +35,7 @@ def _is_umo_match(self, p1: str, p2: str) -> bool: return all(p == "" or fnmatch.fnmatchcase(t, p) for p, t in zip(p1_ls, p2_ls)) - def get_conf_id_for_umop(self, umo: str) -> str | None: + def get_config_id_for_umop(self, umo: str) -> str | None: """根据 UMO 获取对应的配置文件 ID Args: @@ -45,9 +45,9 @@ def get_conf_id_for_umop(self, umo: str) -> str | None: str | None: 配置文件 ID,如果没有找到则返回 None """ - for pattern, conf_id in self.umop_to_conf_id.items(): + for pattern, config_id in self.umop_to_config_id.items(): if self._is_umo_match(pattern, umo): - return conf_id + return config_id return None async def update_routing_data(self, new_routing: dict[str, str]) -> None: @@ -67,15 +67,15 @@ async def update_routing_data(self, new_routing: dict[str, str]) -> None: "umop keys must be strings in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", ) - self.umop_to_conf_id = new_routing - await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) + self.umop_to_config_id = new_routing + await self.sp.global_put("umop_config_routing", self.umop_to_config_id) - async def update_route(self, umo: str, conf_id: str) -> None: + async def update_route(self, umo: str, config_id: str): """更新一条路由 Args: umo (str): UMO 字符串 - conf_id (str): 配置文件 ID + config_id (str): 配置文件 ID Raises: ValueError: 如果 umo 格式不正确 @@ -86,8 +86,8 @@ async def update_route(self, umo: str, conf_id: str) -> None: "umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", ) - self.umop_to_conf_id[umo] = conf_id - await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) + self.umop_to_config_id[umo] = config_id + await self.sp.global_put("umop_config_routing", self.umop_to_config_id) async def delete_route(self, umo: str) -> None: """删除一条路由 @@ -104,6 +104,6 @@ async def delete_route(self, umo: str) -> None: "umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", ) - if umo in self.umop_to_conf_id: - del self.umop_to_conf_id[umo] - await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) + if umo in self.umop_to_config_id: + del self.umop_to_config_id[umo] + await self.sp.global_put("umop_config_routing", self.umop_to_config_id) diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py index 6a300302d..283b05c12 100644 --- a/astrbot/core/utils/migra_helper.py +++ b/astrbot/core/utils/migra_helper.py @@ -2,35 +2,11 @@ from astrbot.core import astrbot_config, logger from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager +from astrbot.core.db.migration.migra_4_to_5 import migrate_4_to_5 from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 from astrbot.core.db.migration.migra_token_usage import migrate_token_usage from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session - - -def _migra_agent_runner_configs(conf: AstrBotConfig, ids_map: dict) -> None: - """ - Migra agent runner configs from provider configs. - """ - try: - default_prov_id = conf["provider_settings"]["default_provider_id"] - if default_prov_id in ids_map: - conf["provider_settings"]["default_provider_id"] = "" - p = ids_map[default_prov_id] - if p["type"] == "dify": - conf["provider_settings"]["dify_agent_runner_provider_id"] = p["id"] - conf["provider_settings"]["agent_runner_type"] = "dify" - elif p["type"] == "coze": - conf["provider_settings"]["coze_agent_runner_provider_id"] = p["id"] - conf["provider_settings"]["agent_runner_type"] = "coze" - elif p["type"] == "dashscope": - conf["provider_settings"]["dashscope_agent_runner_provider_id"] = p[ - "id" - ] - conf["provider_settings"]["agent_runner_type"] = "dashscope" - conf.save_config() - except Exception as e: - logger.error(f"Migration for third party agent runner configs failed: {e!s}") - logger.error(traceback.format_exc()) +from astrbot.core.umop_config_router import UmopConfigRouter def _migra_provider_to_source_structure(conf: AstrBotConfig) -> None: @@ -120,12 +96,17 @@ def _migra_provider_to_source_structure(conf: AstrBotConfig) -> None: async def migra( - db, astrbot_config_mgr, umop_config_router, acm: AstrBotConfigManager + db, + astrbot_config_mgr, + sp, + acm: AstrBotConfigManager, ) -> None: """ Stores the migration logic here. btw, i really don't like migration :( """ + umop_config_router = UmopConfigRouter(sp) + # 4.5 to 4.6 migration for umop_config_router try: await migrate_45_to_46(astrbot_config_mgr, umop_config_router) @@ -140,6 +121,13 @@ async def migra( logger.error(f"Migration for webchat session failed: {e!s}") logger.error(traceback.format_exc()) + # migration for chain configs (v4 to v5) + try: + await migrate_4_to_5(db, astrbot_config_mgr, umop_config_router) + except Exception as e: + logger.error(f"Migration from version 4 to 5 failed: {e!s}") + logger.error(traceback.format_exc()) + # migration for token_usage column try: await migrate_token_usage(db) @@ -147,25 +135,18 @@ async def migra( logger.error(f"Migration for token_usage column failed: {e!s}") logger.error(traceback.format_exc()) - # migra third party agent runner configs - _c = False + # normalize legacy third-party provider type for later migration steps + changed = False providers = astrbot_config["provider"] - ids_map = {} for prov in providers: type_ = prov.get("type") if type_ in ["dify", "coze", "dashscope"]: - prov["provider_type"] = "agent_runner" - ids_map[prov["id"]] = { - "type": type_, - "id": prov["id"], - } - _c = True - if _c: + if prov.get("provider_type") != "agent_runner": + prov["provider_type"] = "agent_runner" + changed = True + if changed: astrbot_config.save_config() - for conf in acm.confs.values(): - _migra_agent_runner_configs(conf, ids_map) - # Migrate providers to new structure: extract source fields to provider_sources try: _migra_provider_to_source_structure(astrbot_config) diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index b327a6184..423d0f25e 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -188,6 +188,10 @@ async def wrapper( *args, **kwargs, ): + if event.chain_config is not None: + raise RuntimeError( + "SessionWaiter is not allowed in NodeStar context.", + ) if not session_filter: session_filter = DefaultSessionFilter() if not isinstance(session_filter, SessionFilter): diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index 481be2f89..2d97abb9b 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -1,5 +1,6 @@ from .auth import AuthRoute from .backup import BackupRoute +from .chain_management import ChainManagementRoute from .chat import ChatRoute from .chatui_project import ChatUIProjectRoute from .command import CommandRoute @@ -12,7 +13,6 @@ from .persona import PersonaRoute from .platform import PlatformRoute from .plugin import PluginRoute -from .session_management import SessionManagementRoute from .skills import SkillsRoute from .stat import StatRoute from .static_file import StaticFileRoute @@ -24,6 +24,7 @@ "AuthRoute", "BackupRoute", "ChatRoute", + "ChainManagementRoute", "ChatUIProjectRoute", "CommandRoute", "ConfigRoute", @@ -35,7 +36,6 @@ "PersonaRoute", "PlatformRoute", "PluginRoute", - "SessionManagementRoute", "StatRoute", "StaticFileRoute", "SubAgentRoute", diff --git a/astrbot/dashboard/routes/chain_management.py b/astrbot/dashboard/routes/chain_management.py new file mode 100644 index 000000000..6fd9520a6 --- /dev/null +++ b/astrbot/dashboard/routes/chain_management.py @@ -0,0 +1,531 @@ +import traceback +import uuid + +from quart import request +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select + +from astrbot.core import logger +from astrbot.core.config.node_config import AstrBotNodeConfig +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db import BaseDatabase +from astrbot.core.pipeline.engine.chain_config import ( + DEFAULT_CHAIN_CONFIG, + ChainConfigModel, + normalize_chain_nodes, + serialize_chain_nodes, +) +from astrbot.core.star.modality import Modality +from astrbot.core.star.node_star import is_node_star_metadata +from astrbot.core.star.star import StarMetadata + +from .config import validate_config +from .route import Response, Route, RouteContext + + +class ChainManagementRoute(Route): + def __init__( + self, + context: RouteContext, + db_helper: BaseDatabase, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: + super().__init__(context) + self.db_helper = db_helper + self.core_lifecycle = core_lifecycle + self.routes = { + "/chain/list": ("GET", self.list_chains), + "/chain/get": ("GET", self.get_chain), + "/chain/create": ("POST", self.create_chain), + "/chain/update": ("POST", self.update_chain), + "/chain/delete": ("POST", self.delete_chain), + "/chain/reorder": ("POST", self.reorder_chains), + "/chain/available-options": ("GET", self.get_available_options), + "/chain/node-config": ("GET", self.get_node_config), + "/chain/node-config/update": ("POST", self.update_node_config), + } + self.register_routes() + + def _default_nodes(self) -> list[dict]: + return serialize_chain_nodes(DEFAULT_CHAIN_CONFIG.nodes) + + async def _reload_chain_configs(self) -> None: + await self.core_lifecycle.chain_config_router.reload( + self.db_helper, + ) + + def _serialize_chain(self, chain: ChainConfigModel) -> dict: + is_default = chain.chain_id == "default" + nodes_payload = None + if chain.nodes is not None: + normalized = normalize_chain_nodes(chain.nodes, chain.chain_id) + nodes_payload = serialize_chain_nodes(normalized) + return { + "chain_id": chain.chain_id, + "match_rule": chain.match_rule, + "sort_order": chain.sort_order, + "enabled": chain.enabled, + "nodes": nodes_payload, + "plugin_filter": chain.plugin_filter, + "config_id": chain.config_id, + "created_at": chain.created_at.isoformat() if chain.created_at else None, + "updated_at": chain.updated_at.isoformat() if chain.updated_at else None, + "is_default": is_default, + } + + def _serialize_default_chain_virtual(self) -> dict: + return { + "chain_id": "default", + "match_rule": None, + "sort_order": -1, + "enabled": True, + "nodes": None, + "plugin_filter": None, + "config_id": "default", + "created_at": None, + "updated_at": None, + "is_default": True, + } + + def _get_node_plugin_map(self) -> dict[str, StarMetadata]: + return { + p.name: p + for p in self.core_lifecycle.plugin_manager.context.get_all_stars() + if p.name and is_node_star_metadata(p) + } + + def _get_node_schema(self, node_name: str) -> dict | None: + node = self._get_node_plugin_map().get(node_name) + return node.node_schema if node else None + + async def list_chains(self): + try: + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 10, type=int) + search = request.args.get("search", "", type=str).strip() + + if page < 1: + page = 1 + if page_size < 1: + page_size = 10 + if page_size > 100: + page_size = 100 + + async with self.db_helper.get_db() as session: + session: AsyncSession + result = await session.execute(select(ChainConfigModel)) + chains = list(result.scalars().all()) + + default_chain = None + normal_chains: list[ChainConfigModel] = [] + for chain in chains: + if chain.chain_id == "default": + default_chain = chain + else: + normal_chains.append(chain) + + if search: + search_lower = search.lower() + normal_chains = [ + chain + for chain in normal_chains + if chain.match_rule + and search_lower in str(chain.match_rule).lower() + ] + + chains = sorted(normal_chains, key=lambda c: c.sort_order, reverse=True) + + if default_chain: + include_default = True + if search: + search_lower = search.lower() + include_default = search_lower in "default" or ( + default_chain.match_rule + and search_lower in str(default_chain.match_rule).lower() + ) + if include_default: + chains.append(default_chain) + elif not search or "default" in search.lower(): + chains.append(self._serialize_default_chain_virtual()) + + total = len(chains) + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + paginated = chains[start_idx:end_idx] + + return ( + Response() + .ok( + { + "chains": [ + self._serialize_chain(chain) + if isinstance(chain, ChainConfigModel) + else chain + for chain in paginated + ], + "total": total, + "page": page, + "page_size": page_size, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"获取 Chain 列表失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"获取 Chain 列表失败: {e!s}").__dict__ + + async def get_chain(self): + try: + chain_id = request.args.get("chain_id", "") + if not chain_id: + return Response().error("缺少必要参数: chain_id").__dict__ + + async with self.db_helper.get_db() as session: + session: AsyncSession + result = await session.execute( + select(ChainConfigModel).where( + ChainConfigModel.chain_id == chain_id + ) + ) + chain = result.scalar_one_or_none() + + if not chain: + if chain_id == "default": + return ( + Response().ok(self._serialize_default_chain_virtual()).__dict__ + ) + return Response().error("Chain 不存在").__dict__ + + return Response().ok(self._serialize_chain(chain)).__dict__ + except Exception as e: + logger.error(f"获取 Chain 失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"获取 Chain 失败: {e!s}").__dict__ + + async def create_chain(self): + try: + data = await request.get_json() + + chain_id = data.get("chain_id") or str(uuid.uuid4()) + if chain_id == "default": + return ( + Response().error("默认 Chain 不允许创建,请使用编辑功能。").__dict__ + ) + nodes = data.get("nodes") + nodes_payload = ( + serialize_chain_nodes(normalize_chain_nodes(nodes, chain_id)) + if nodes is not None + else None + ) + + # 获取当前最大 sort_order,新建的放到最后 + async with self.db_helper.get_db() as session: + session: AsyncSession + result = await session.execute(select(ChainConfigModel)) + existing_chains = list(result.scalars().all()) + max_sort_order = max( + (c.sort_order for c in existing_chains), default=-1 + ) + + chain = ChainConfigModel( + chain_id=chain_id, + match_rule=data.get("match_rule"), + sort_order=max_sort_order + 1, + enabled=data.get("enabled", True), + nodes=nodes_payload, + plugin_filter=data.get("plugin_filter"), + config_id=data.get("config_id"), + ) + + async with self.db_helper.get_db() as session: + session: AsyncSession + session.add(chain) + await session.commit() + await session.refresh(chain) + + await self._reload_chain_configs() + + return Response().ok(self._serialize_chain(chain)).__dict__ + except Exception as e: + logger.error(f"创建 Chain 失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"创建 Chain 失败: {e!s}").__dict__ + + async def update_chain(self): + try: + data = await request.get_json() + chain_id = data.get("chain_id", "") + if not chain_id: + return Response().error("缺少必要参数: chain_id").__dict__ + + async with self.db_helper.get_db() as session: + session: AsyncSession + result = await session.execute( + select(ChainConfigModel).where( + ChainConfigModel.chain_id == chain_id + ) + ) + chain = result.scalar_one_or_none() + if not chain and chain_id != "default": + return Response().error("Chain 不存在").__dict__ + + if chain_id == "default" and not chain: + nodes = data.get("nodes") + nodes_payload = ( + serialize_chain_nodes(normalize_chain_nodes(nodes, chain_id)) + if nodes is not None + else None + ) + chain = ChainConfigModel( + chain_id="default", + match_rule=None, + sort_order=-1, + enabled=data.get("enabled", True), + nodes=nodes_payload, + plugin_filter=data.get("plugin_filter"), + config_id="default", + ) + session.add(chain) + await session.commit() + await session.refresh(chain) + await self._reload_chain_configs() + return Response().ok(self._serialize_chain(chain)).__dict__ + + for field in [ + "match_rule", + "enabled", + "nodes", + "plugin_filter", + "config_id", + ]: + if field in data: + value = data.get(field) + if field == "nodes": + value = ( + serialize_chain_nodes( + normalize_chain_nodes(value, chain_id) + ) + if value is not None + else None + ) + setattr(chain, field, value) + + if chain.chain_id == "default": + chain.match_rule = None + chain.sort_order = -1 + chain.config_id = "default" + + session.add(chain) + await session.commit() + await session.refresh(chain) + + await self._reload_chain_configs() + + return Response().ok(self._serialize_chain(chain)).__dict__ + except Exception as e: + logger.error(f"更新 Chain 失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"更新 Chain 失败: {e!s}").__dict__ + + async def delete_chain(self): + try: + data = await request.get_json() + chain_id = data.get("chain_id", "") + if not chain_id: + return Response().error("缺少必要参数: chain_id").__dict__ + if chain_id == "default": + return Response().error("默认 Chain 不允许删除。").__dict__ + + async with self.db_helper.get_db() as session: + session: AsyncSession + result = await session.execute( + select(ChainConfigModel).where( + ChainConfigModel.chain_id == chain_id + ) + ) + chain = result.scalar_one_or_none() + if not chain: + return Response().error("Chain 不存在").__dict__ + + await session.delete(chain) + await session.commit() + + await self._reload_chain_configs() + + return Response().ok({"message": "Chain 已删除"}).__dict__ + except Exception as e: + logger.error(f"删除 Chain 失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"删除 Chain 失败: {e!s}").__dict__ + + async def reorder_chains(self): + """接收有序的 chain_id 列表,按顺序分配 sort_order(列表顺序即匹配顺序)""" + try: + data = await request.get_json() + chain_ids = data.get("chain_ids", []) + if not chain_ids: + return Response().error("chain_ids 不能为空").__dict__ + chain_ids = [cid for cid in chain_ids if cid != "default"] + + async with self.db_helper.get_db() as session: + session: AsyncSession + total = len(chain_ids) + for index, chain_id in enumerate(chain_ids): + result = await session.execute( + select(ChainConfigModel).where( + ChainConfigModel.chain_id == chain_id + ) + ) + chain = result.scalar_one_or_none() + if chain: + # 列表第一个元素 sort_order 最大,最先匹配 + chain.sort_order = total - 1 - index + session.add(chain) + await session.commit() + + await self._reload_chain_configs() + + return Response().ok({"message": "排序已更新"}).__dict__ + except Exception as e: + logger.error(f"更新排序失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"更新排序失败: {e!s}").__dict__ + + async def get_available_options(self): + try: + provider_manager = self.core_lifecycle.provider_manager + plugin_manager = self.core_lifecycle.plugin_manager + + available_stt_providers = [ + { + "id": p.meta().id, + "name": p.meta().id, + "model": p.meta().model, + } + for p in provider_manager.stt_provider_insts + ] + + available_tts_providers = [ + { + "id": p.meta().id, + "name": p.meta().id, + "model": p.meta().model, + } + for p in provider_manager.tts_provider_insts + ] + + available_plugins = [ + { + "name": p.name, + "display_name": p.display_name or p.name, + "desc": p.desc, + } + for p in plugin_manager.context.get_all_stars() + if not p.reserved and p.name and not is_node_star_metadata(p) + ] + + node_plugins = [ + p + for p in plugin_manager.context.get_all_stars() + if p.name and is_node_star_metadata(p) + ] + available_nodes = [ + { + "name": p.name, + "display_name": p.display_name or p.name, + "schema": p.node_schema or {}, + } + for p in node_plugins + ] + default_nodes = self._default_nodes() + available_nodes = {node["name"]: node for node in available_nodes} + available_nodes = list(available_nodes.values()) + + available_configs = self.core_lifecycle.astrbot_config_mgr.get_conf_list() + + available_modalities = [ + {"label": m.value, "value": m.value} for m in Modality + ] + + return ( + Response() + .ok( + { + "available_stt_providers": available_stt_providers, + "available_tts_providers": available_tts_providers, + "available_plugins": available_plugins, + "available_nodes": available_nodes, + "default_nodes": default_nodes, + "available_configs": available_configs, + "available_modalities": available_modalities, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"获取可用选项失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"获取可用选项失败: {e!s}").__dict__ + + async def get_node_config(self): + try: + chain_id = request.args.get("chain_id", "").strip() + node_name = request.args.get("node_name", "").strip() + node_uuid = request.args.get("node_uuid", "").strip() + + if not chain_id or not node_name: + return Response().error("缺少必要参数: chain_id 或 node_name").__dict__ + + raw_schema = self._get_node_schema(node_name) + schema = raw_schema or {} + node_config = AstrBotNodeConfig.get_cached( + node_name=node_name, + chain_id=chain_id, + node_uuid=node_uuid, + schema=raw_schema, + ) + + return ( + Response() + .ok( + { + "config": dict(node_config), + "schema": schema, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"获取节点配置失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"获取节点配置失败: {e!s}").__dict__ + + async def update_node_config(self): + try: + data = await request.get_json() + chain_id = (data.get("chain_id") or "").strip() + node_name = (data.get("node_name") or "").strip() + node_uuid = (data.get("node_uuid") or "").strip() + config = data.get("config") + + if not chain_id or not node_name: + return Response().error("缺少必要参数: chain_id 或 node_name").__dict__ + if not isinstance(config, dict): + return Response().error("配置内容必须是对象").__dict__ + + raw_schema = self._get_node_schema(node_name) + if raw_schema is None: + return ( + Response().error("该节点未声明配置 Schema,无法保存配置").__dict__ + ) + + if raw_schema: + errors, config = validate_config(config, raw_schema, is_core=False) + if errors: + return Response().error(f"配置校验失败: {errors}").__dict__ + + node_config = AstrBotNodeConfig.get_cached( + node_name=node_name, + chain_id=chain_id, + node_uuid=node_uuid, + schema=raw_schema, + ) + node_config.save_config(config) + + return Response().ok({"message": "节点配置已保存"}).__dict__ + except Exception as e: + logger.error(f"保存节点配置失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"保存节点配置失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index a7c0e3a57..79b2d47ad 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -59,7 +59,6 @@ def __init__( self.conv_mgr = core_lifecycle.conversation_manager self.platform_history_mgr = core_lifecycle.platform_message_history_manager self.db = db - self.umop_config_router = core_lifecycle.umop_config_router self.running_convs: dict[str, bool] = {} @@ -619,16 +618,6 @@ async def delete_webchat_session(self): offset_sec=99999999, ) - # 删除与会话关联的配置路由 - try: - await self.umop_config_router.delete_route(unified_msg_origin) - except ValueError as exc: - logger.warning( - "Failed to delete UMO route %s during session cleanup: %s", - unified_msg_origin, - exc, - ) - # 清理队列(仅对 webchat) if session.platform_id == "webchat": webchat_queue_mgr.remove_queues(session_id) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index efea4c7cf..da065738f 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -244,17 +244,12 @@ def __init__( self.config: AstrBotConfig = core_lifecycle.astrbot_config self._logo_token_cache = {} # 缓存logo token,避免重复注册 self.acm = core_lifecycle.astrbot_config_mgr - self.ucr = core_lifecycle.umop_config_router self.routes = { "/config/abconf/new": ("POST", self.create_abconf), "/config/abconf": ("GET", self.get_abconf), "/config/abconfs": ("GET", self.get_abconf_list), "/config/abconf/delete": ("POST", self.delete_abconf), "/config/abconf/update": ("POST", self.update_abconf), - "/config/umo_abconf_routes": ("GET", self.get_uc_table), - "/config/umo_abconf_route/update_all": ("POST", self.update_ucr_all), - "/config/umo_abconf_route/update": ("POST", self.update_ucr), - "/config/umo_abconf_route/delete": ("POST", self.delete_ucr), "/config/get": ("GET", self.get_configs), "/config/default": ("GET", self.get_default_config), "/config/astrbot/update": ("POST", self.post_astrbot_configs), @@ -431,67 +426,6 @@ async def get_provider_template(self): } return Response().ok(data=data).__dict__ - async def get_uc_table(self): - """获取 UMOP 配置路由表""" - return Response().ok({"routing": self.ucr.umop_to_conf_id}).__dict__ - - async def update_ucr_all(self): - """更新 UMOP 配置路由表的全部内容""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - - new_routing = post_data.get("routing", None) - - if not new_routing or not isinstance(new_routing, dict): - return Response().error("缺少或错误的路由表数据").__dict__ - - try: - await self.ucr.update_routing_data(new_routing) - return Response().ok(message="更新成功").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"更新路由表失败: {e!s}").__dict__ - - async def update_ucr(self): - """更新 UMOP 配置路由表""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - - umo = post_data.get("umo", None) - conf_id = post_data.get("conf_id", None) - - if not umo or not conf_id: - return Response().error("缺少 UMO 或配置文件 ID").__dict__ - - try: - await self.ucr.update_route(umo, conf_id) - return Response().ok(message="更新成功").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"更新路由表失败: {e!s}").__dict__ - - async def delete_ucr(self): - """删除 UMOP 配置路由表中的一项""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - - umo = post_data.get("umo", None) - - if not umo: - return Response().error("缺少 UMO").__dict__ - - try: - if umo in self.ucr.umop_to_conf_id: - del self.ucr.umop_to_conf_id[umo] - await self.ucr.update_routing_data(self.ucr.umop_to_conf_id) - return Response().ok(message="删除成功").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"删除路由表项失败: {e!s}").__dict__ - async def get_default_config(self): """获取默认配置文件""" metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3) @@ -511,8 +445,12 @@ async def create_abconf(self): config = post_data.get("config", DEFAULT_CONFIG) try: - conf_id = self.acm.create_conf(name=name, config=config) - return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__ + config_id = self.acm.create_conf(name=name, config=config) + return ( + Response() + .ok(message="创建成功", data={"config_id": config_id}) + .__dict__ + ) except ValueError as e: return Response().error(str(e)).__dict__ @@ -544,12 +482,12 @@ async def delete_abconf(self): if not post_data: return Response().error("缺少配置数据").__dict__ - conf_id = post_data.get("id") - if not conf_id: + config_id = post_data.get("config_id") or post_data.get("id") + if not config_id: return Response().error("缺少配置文件 ID").__dict__ try: - success = self.acm.delete_conf(conf_id) + success = self.acm.delete_conf(config_id) if success: return Response().ok(message="删除成功").__dict__ return Response().error("删除失败").__dict__ @@ -565,14 +503,14 @@ async def update_abconf(self): if not post_data: return Response().error("缺少配置数据").__dict__ - conf_id = post_data.get("id") - if not conf_id: + config_id = post_data.get("config_id") or post_data.get("id") + if not config_id: return Response().error("缺少配置文件 ID").__dict__ name = post_data.get("name") try: - success = self.acm.update_conf_info(conf_id, name=name) + success = self.acm.update_conf_info(config_id, name=name) if success: return Response().ok(message="更新成功").__dict__ return Response().error("更新失败").__dict__ @@ -900,18 +838,18 @@ async def get_platform_list(self): async def post_astrbot_configs(self): data = await request.json config = data.get("config", None) - conf_id = data.get("conf_id", None) + config_id = data.get("config_id") try: # 不更新 provider_sources, provider, platform # 这些配置有单独的接口进行更新 - if conf_id == "default": + if config_id == "default": no_update_keys = ["provider_sources", "provider", "platform"] for key in no_update_keys: config[key] = self.acm.default_conf[key] - await self._save_astrbot_configs(config, conf_id) - await self.core_lifecycle.reload_pipeline_scheduler(conf_id) + await self._save_astrbot_configs(config, config_id) + await self.core_lifecycle.reload_pipeline_executor(config_id) return Response().ok(None, "保存成功~").__dict__ except Exception as e: logger.error(traceback.format_exc()) @@ -1358,12 +1296,12 @@ async def _get_plugin_config(self, plugin_name: str): return ret async def _save_astrbot_configs( - self, post_configs: dict, conf_id: str | None = None - ) -> None: + self, post_configs: dict, config_id: str | None = None + ): try: - if conf_id not in self.acm.confs: - raise ValueError(f"配置文件 {conf_id} 不存在") - astrbot_config = self.acm.confs[conf_id] + if config_id not in self.acm.confs: + raise ValueError(f"配置文件 {config_id} 不存在") + astrbot_config = self.acm.confs[config_id] # 保留服务端的 t2i_active_template 值 if "t2i_active_template" in astrbot_config: diff --git a/astrbot/dashboard/routes/cron.py b/astrbot/dashboard/routes/cron.py index 8861fc5cc..1c6a41396 100644 --- a/astrbot/dashboard/routes/cron.py +++ b/astrbot/dashboard/routes/cron.py @@ -98,6 +98,7 @@ async def create_job(self): Response().error("run_at must be ISO datetime").__dict__ ) + config_id = payload.get("config_id") job_payload = { "session": session, "note": note, @@ -105,6 +106,7 @@ async def create_job(self): "provider_id": provider_id, "run_at": run_at, "origin": "api", + "config_id": config_id, } job = await cron_mgr.add_active_job( diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py deleted file mode 100644 index ffe5372a0..000000000 --- a/astrbot/dashboard/routes/session_management.py +++ /dev/null @@ -1,938 +0,0 @@ -from quart import request -from sqlalchemy.ext.asyncio import AsyncSession -from sqlmodel import col, select - -from astrbot.core import logger, sp -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.db import BaseDatabase -from astrbot.core.db.po import ConversationV2, Preference -from astrbot.core.provider.entities import ProviderType - -from .route import Response, Route, RouteContext - -AVAILABLE_SESSION_RULE_KEYS = [ - "session_service_config", - "session_plugin_config", - "kb_config", - f"provider_perf_{ProviderType.CHAT_COMPLETION.value}", - f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}", - f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}", -] - - -class SessionManagementRoute(Route): - def __init__( - self, - context: RouteContext, - db_helper: BaseDatabase, - core_lifecycle: AstrBotCoreLifecycle, - ) -> None: - super().__init__(context) - self.db_helper = db_helper - self.routes = { - "/session/list-rule": ("GET", self.list_session_rule), - "/session/update-rule": ("POST", self.update_session_rule), - "/session/delete-rule": ("POST", self.delete_session_rule), - "/session/batch-delete-rule": ("POST", self.batch_delete_session_rule), - "/session/active-umos": ("GET", self.list_umos), - "/session/list-all-with-status": ("GET", self.list_all_umos_with_status), - "/session/batch-update-service": ("POST", self.batch_update_service), - "/session/batch-update-provider": ("POST", self.batch_update_provider), - # 分组管理 API - "/session/groups": ("GET", self.list_groups), - "/session/group/create": ("POST", self.create_group), - "/session/group/update": ("POST", self.update_group), - "/session/group/delete": ("POST", self.delete_group), - } - self.conv_mgr = core_lifecycle.conversation_manager - self.core_lifecycle = core_lifecycle - self.register_routes() - - async def _get_umo_rules( - self, page: int = 1, page_size: int = 10, search: str = "" - ) -> tuple[dict, int]: - """获取所有带有自定义规则的 umo 及其规则内容(支持分页和搜索)。 - - 如果某个 umo 在 preference 中有以下字段,则表示有自定义规则: - - 1. session_service_config (包含了 是否启用这个umo, 这个umo是否启用 llm, 这个umo是否启用tts, umo自定义名称。) - 2. session_plugin_config (包含了 这个 umo 的 plugin set) - 3. provider_perf_{ProviderType.value} (包含了这个 umo 所选择使用的 provider 信息) - 4. kb_config (包含了这个 umo 的知识库相关配置) - - Args: - page: 页码,从 1 开始 - page_size: 每页数量 - search: 搜索关键词,匹配 umo 或 custom_name - - Returns: - tuple[dict, int]: (umo_rules, total) - 分页后的 umo 规则和总数 - """ - umo_rules = {} - async with self.db_helper.get_db() as session: - session: AsyncSession - result = await session.execute( - select(Preference).where( - col(Preference.scope) == "umo", - col(Preference.key).in_(AVAILABLE_SESSION_RULE_KEYS), - ) - ) - prefs = result.scalars().all() - for pref in prefs: - umo_id = pref.scope_id - if umo_id not in umo_rules: - umo_rules[umo_id] = {} - if pref.key == "session_plugin_config" and umo_id in pref.value["val"]: - umo_rules[umo_id][pref.key] = pref.value["val"][umo_id] - else: - umo_rules[umo_id][pref.key] = pref.value["val"] - - # 搜索过滤 - if search: - search_lower = search.lower() - filtered_rules = {} - for umo_id, rules in umo_rules.items(): - # 匹配 umo - if search_lower in umo_id.lower(): - filtered_rules[umo_id] = rules - continue - # 匹配 custom_name - svc_config = rules.get("session_service_config", {}) - custom_name = svc_config.get("custom_name", "") if svc_config else "" - if custom_name and search_lower in custom_name.lower(): - filtered_rules[umo_id] = rules - umo_rules = filtered_rules - - # 获取总数 - total = len(umo_rules) - - # 分页处理 - all_umo_ids = list(umo_rules.keys()) - start_idx = (page - 1) * page_size - end_idx = start_idx + page_size - paginated_umo_ids = all_umo_ids[start_idx:end_idx] - - # 只返回分页后的数据 - paginated_rules = {umo_id: umo_rules[umo_id] for umo_id in paginated_umo_ids} - - return paginated_rules, total - - async def list_session_rule(self): - """获取所有自定义的规则(支持分页和搜索) - - 返回已配置规则的 umo 列表及其规则内容,以及可用的 personas 和 providers - - Query 参数: - page: 页码,默认为 1 - page_size: 每页数量,默认为 10 - search: 搜索关键词,匹配 umo 或 custom_name - """ - try: - # 获取分页和搜索参数 - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 10, type=int) - search = request.args.get("search", "", type=str).strip() - - # 参数校验 - if page < 1: - page = 1 - if page_size < 1: - page_size = 10 - if page_size > 100: - page_size = 100 - - umo_rules, total = await self._get_umo_rules( - page=page, page_size=page_size, search=search - ) - - # 构建规则列表 - rules_list = [] - for umo, rules in umo_rules.items(): - rule_info = { - "umo": umo, - "rules": rules, - } - # 解析 umo 格式: 平台:消息类型:会话ID - parts = umo.split(":") - if len(parts) >= 3: - rule_info["platform"] = parts[0] - rule_info["message_type"] = parts[1] - rule_info["session_id"] = parts[2] - rules_list.append(rule_info) - - # 获取可用的 providers 和 personas - provider_manager = self.core_lifecycle.provider_manager - persona_mgr = self.core_lifecycle.persona_mgr - - available_personas = [ - {"name": p["name"], "prompt": p.get("prompt", "")} - for p in persona_mgr.personas_v3 - ] - - available_chat_providers = [ - { - "id": p.meta().id, - "name": p.meta().id, - "model": p.meta().model, - } - for p in provider_manager.provider_insts - ] - - available_stt_providers = [ - { - "id": p.meta().id, - "name": p.meta().id, - "model": p.meta().model, - } - for p in provider_manager.stt_provider_insts - ] - - available_tts_providers = [ - { - "id": p.meta().id, - "name": p.meta().id, - "model": p.meta().model, - } - for p in provider_manager.tts_provider_insts - ] - - # 获取可用的插件列表(排除 reserved 的系统插件) - plugin_manager = self.core_lifecycle.plugin_manager - available_plugins = [ - { - "name": p.name, - "display_name": p.display_name or p.name, - "desc": p.desc, - } - for p in plugin_manager.context.get_all_stars() - if not p.reserved and p.name - ] - - # 获取可用的知识库列表 - available_kbs = [] - kb_manager = self.core_lifecycle.kb_manager - if kb_manager: - try: - kbs = await kb_manager.list_kbs() - available_kbs = [ - { - "kb_id": kb.kb_id, - "kb_name": kb.kb_name, - "emoji": kb.emoji, - } - for kb in kbs - ] - except Exception as e: - logger.warning(f"获取知识库列表失败: {e!s}") - - return ( - Response() - .ok( - { - "rules": rules_list, - "total": total, - "page": page, - "page_size": page_size, - "available_personas": available_personas, - "available_chat_providers": available_chat_providers, - "available_stt_providers": available_stt_providers, - "available_tts_providers": available_tts_providers, - "available_plugins": available_plugins, - "available_kbs": available_kbs, - "available_rule_keys": AVAILABLE_SESSION_RULE_KEYS, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"获取规则列表失败: {e!s}") - return Response().error(f"获取规则列表失败: {e!s}").__dict__ - - async def update_session_rule(self): - """更新某个 umo 的自定义规则 - - 请求体: - { - "umo": "平台:消息类型:会话ID", - "rule_key": "session_service_config" | "session_plugin_config" | "kb_config" | "provider_perf_xxx", - "rule_value": {...} // 规则值,具体结构根据 rule_key 不同而不同 - } - """ - try: - data = await request.get_json() - umo = data.get("umo") - rule_key = data.get("rule_key") - rule_value = data.get("rule_value") - - if not umo: - return Response().error("缺少必要参数: umo").__dict__ - if not rule_key: - return Response().error("缺少必要参数: rule_key").__dict__ - if rule_key not in AVAILABLE_SESSION_RULE_KEYS: - return Response().error(f"不支持的规则键: {rule_key}").__dict__ - - if rule_key == "session_plugin_config": - rule_value = { - umo: rule_value, - } - - # 使用 shared preferences 更新规则 - await sp.session_put(umo, rule_key, rule_value) - - return ( - Response() - .ok({"message": f"规则 {rule_key} 已更新", "umo": umo}) - .__dict__ - ) - except Exception as e: - logger.error(f"更新会话规则失败: {e!s}") - return Response().error(f"更新会话规则失败: {e!s}").__dict__ - - async def delete_session_rule(self): - """删除某个 umo 的自定义规则 - - 请求体: - { - "umo": "平台:消息类型:会话ID", - "rule_key": "session_service_config" | "session_plugin_config" | ... (可选,不传则删除所有规则) - } - """ - try: - data = await request.get_json() - umo = data.get("umo") - rule_key = data.get("rule_key") - - if not umo: - return Response().error("缺少必要参数: umo").__dict__ - - if rule_key: - # 删除单个规则 - if rule_key not in AVAILABLE_SESSION_RULE_KEYS: - return Response().error(f"不支持的规则键: {rule_key}").__dict__ - await sp.session_remove(umo, rule_key) - return ( - Response() - .ok({"message": f"规则 {rule_key} 已删除", "umo": umo}) - .__dict__ - ) - else: - # 删除该 umo 的所有规则 - await sp.clear_async("umo", umo) - return Response().ok({"message": "所有规则已删除", "umo": umo}).__dict__ - except Exception as e: - logger.error(f"删除会话规则失败: {e!s}") - return Response().error(f"删除会话规则失败: {e!s}").__dict__ - - async def batch_delete_session_rule(self): - """批量删除多个 umo 的自定义规则 - - 请求体: - { - "umos": ["平台:消息类型:会话ID", ...] // umo 列表 - } - """ - try: - data = await request.get_json() - umos = data.get("umos", []) - - if not umos: - return Response().error("缺少必要参数: umos").__dict__ - - if not isinstance(umos, list): - return Response().error("参数 umos 必须是数组").__dict__ - - # 批量删除 - deleted_count = 0 - failed_umos = [] - for umo in umos: - try: - await sp.clear_async("umo", umo) - deleted_count += 1 - except Exception as e: - logger.error(f"删除 umo {umo} 的规则失败: {e!s}") - failed_umos.append(umo) - - if failed_umos: - return ( - Response() - .ok( - { - "message": f"已删除 {deleted_count} 条规则,{len(failed_umos)} 条删除失败", - "deleted_count": deleted_count, - "failed_umos": failed_umos, - } - ) - .__dict__ - ) - else: - return ( - Response() - .ok( - { - "message": f"已删除 {deleted_count} 条规则", - "deleted_count": deleted_count, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"批量删除会话规则失败: {e!s}") - return Response().error(f"批量删除会话规则失败: {e!s}").__dict__ - - async def list_umos(self): - """列出所有有对话记录的 umo,从 Conversations 表中找 - - 仅返回 umo 字符串列表,用于用户在创建规则时选择 umo - """ - try: - # 从 Conversation 表获取所有 distinct user_id (即 umo) - async with self.db_helper.get_db() as session: - session: AsyncSession - result = await session.execute( - select(ConversationV2.user_id) - .distinct() - .order_by(ConversationV2.user_id) - ) - umos = [row[0] for row in result.fetchall()] - - return Response().ok({"umos": umos}).__dict__ - except Exception as e: - logger.error(f"获取 UMO 列表失败: {e!s}") - return Response().error(f"获取 UMO 列表失败: {e!s}").__dict__ - - async def list_all_umos_with_status(self): - """获取所有有对话记录的 UMO 及其服务状态(支持分页、搜索、筛选) - - Query 参数: - page: 页码,默认为 1 - page_size: 每页数量,默认为 20 - search: 搜索关键词 - message_type: 筛选消息类型 (group/private/all) - platform: 筛选平台 - """ - try: - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 20, type=int) - search = request.args.get("search", "", type=str).strip() - message_type = request.args.get("message_type", "all", type=str) - platform = request.args.get("platform", "", type=str) - - if page < 1: - page = 1 - if page_size < 1: - page_size = 20 - if page_size > 100: - page_size = 100 - - # 从 Conversation 表获取所有 distinct user_id (即 umo) - async with self.db_helper.get_db() as session: - session: AsyncSession - result = await session.execute( - select(ConversationV2.user_id) - .distinct() - .order_by(ConversationV2.user_id) - ) - all_umos = [row[0] for row in result.fetchall()] - - # 获取所有 umo 的规则配置 - umo_rules, _ = await self._get_umo_rules(page=1, page_size=99999, search="") - - # 构建带状态的 umo 列表 - umos_with_status = [] - for umo in all_umos: - parts = umo.split(":") - umo_platform = parts[0] if len(parts) >= 1 else "unknown" - umo_message_type = parts[1] if len(parts) >= 2 else "unknown" - umo_session_id = parts[2] if len(parts) >= 3 else umo - - # 筛选消息类型 - if message_type != "all": - if message_type == "group" and umo_message_type not in [ - "group", - "GroupMessage", - ]: - continue - if message_type == "private" and umo_message_type not in [ - "private", - "FriendMessage", - "friend", - ]: - continue - - # 筛选平台 - if platform and umo_platform != platform: - continue - - # 获取服务配置 - rules = umo_rules.get(umo, {}) - svc_config = rules.get("session_service_config", {}) - - custom_name = svc_config.get("custom_name", "") if svc_config else "" - session_enabled = ( - svc_config.get("session_enabled", True) if svc_config else True - ) - llm_enabled = ( - svc_config.get("llm_enabled", True) if svc_config else True - ) - tts_enabled = ( - svc_config.get("tts_enabled", True) if svc_config else True - ) - - # 搜索过滤 - if search: - search_lower = search.lower() - if ( - search_lower not in umo.lower() - and search_lower not in custom_name.lower() - ): - continue - - # 获取 provider 配置 - chat_provider_key = ( - f"provider_perf_{ProviderType.CHAT_COMPLETION.value}" - ) - tts_provider_key = f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}" - stt_provider_key = f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}" - - umos_with_status.append( - { - "umo": umo, - "platform": umo_platform, - "message_type": umo_message_type, - "session_id": umo_session_id, - "custom_name": custom_name, - "session_enabled": session_enabled, - "llm_enabled": llm_enabled, - "tts_enabled": tts_enabled, - "has_rules": umo in umo_rules, - "chat_provider": rules.get(chat_provider_key), - "tts_provider": rules.get(tts_provider_key), - "stt_provider": rules.get(stt_provider_key), - } - ) - - # 分页 - total = len(umos_with_status) - start_idx = (page - 1) * page_size - end_idx = start_idx + page_size - paginated = umos_with_status[start_idx:end_idx] - - # 获取可用的平台列表 - platforms = list({u["platform"] for u in umos_with_status}) - - # 获取可用的 providers - provider_manager = self.core_lifecycle.provider_manager - available_chat_providers = [ - {"id": p.meta().id, "name": p.meta().id, "model": p.meta().model} - for p in provider_manager.provider_insts - ] - available_tts_providers = [ - {"id": p.meta().id, "name": p.meta().id, "model": p.meta().model} - for p in provider_manager.tts_provider_insts - ] - available_stt_providers = [ - {"id": p.meta().id, "name": p.meta().id, "model": p.meta().model} - for p in provider_manager.stt_provider_insts - ] - - return ( - Response() - .ok( - { - "sessions": paginated, - "total": total, - "page": page, - "page_size": page_size, - "platforms": platforms, - "available_chat_providers": available_chat_providers, - "available_tts_providers": available_tts_providers, - "available_stt_providers": available_stt_providers, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"获取会话状态列表失败: {e!s}") - return Response().error(f"获取会话状态列表失败: {e!s}").__dict__ - - async def batch_update_service(self): - """批量更新多个 UMO 的服务状态 (LLM/TTS/Session) - - 请求体: - { - "umos": ["平台:消息类型:会话ID", ...], // 可选,如果不传则根据 scope 筛选 - "scope": "all" | "group" | "private" | "custom_group", // 可选,批量范围 - "group_id": "分组ID", // 当 scope 为 custom_group 时必填 - "llm_enabled": true/false/null, // 可选,null表示不修改 - "tts_enabled": true/false/null, // 可选 - "session_enabled": true/false/null // 可选 - } - """ - try: - data = await request.get_json() - umos = data.get("umos", []) - scope = data.get("scope", "") - group_id = data.get("group_id", "") - llm_enabled = data.get("llm_enabled") - tts_enabled = data.get("tts_enabled") - session_enabled = data.get("session_enabled") - - # 如果没有任何修改 - if llm_enabled is None and tts_enabled is None and session_enabled is None: - return Response().error("至少需要指定一个要修改的状态").__dict__ - - # 如果指定了 scope,获取符合条件的所有 umo - if scope and not umos: - # 如果是自定义分组 - if scope == "custom_group": - if not group_id: - return Response().error("请指定分组 ID").__dict__ - groups = self._get_groups() - if group_id not in groups: - return Response().error(f"分组 '{group_id}' 不存在").__dict__ - umos = groups[group_id].get("umos", []) - else: - async with self.db_helper.get_db() as session: - session: AsyncSession - result = await session.execute( - select(ConversationV2.user_id).distinct() - ) - all_umos = [row[0] for row in result.fetchall()] - - if scope == "group": - umos = [ - u - for u in all_umos - if ":group:" in u.lower() or ":groupmessage:" in u.lower() - ] - elif scope == "private": - umos = [ - u - for u in all_umos - if ":private:" in u.lower() or ":friend" in u.lower() - ] - elif scope == "all": - umos = all_umos - - if not umos: - return Response().error("没有找到符合条件的会话").__dict__ - - # 批量更新 - success_count = 0 - failed_umos = [] - - for umo in umos: - try: - # 获取现有配置 - session_config = ( - sp.get("session_service_config", {}, scope="umo", scope_id=umo) - or {} - ) - - # 更新状态 - if llm_enabled is not None: - session_config["llm_enabled"] = llm_enabled - if tts_enabled is not None: - session_config["tts_enabled"] = tts_enabled - if session_enabled is not None: - session_config["session_enabled"] = session_enabled - - # 保存 - sp.put( - "session_service_config", - session_config, - scope="umo", - scope_id=umo, - ) - success_count += 1 - except Exception as e: - logger.error(f"更新 {umo} 服务状态失败: {e!s}") - failed_umos.append(umo) - - status_changes = [] - if llm_enabled is not None: - status_changes.append(f"LLM={'启用' if llm_enabled else '禁用'}") - if tts_enabled is not None: - status_changes.append(f"TTS={'启用' if tts_enabled else '禁用'}") - if session_enabled is not None: - status_changes.append(f"会话={'启用' if session_enabled else '禁用'}") - - return ( - Response() - .ok( - { - "message": f"已更新 {success_count} 个会话 ({', '.join(status_changes)})", - "success_count": success_count, - "failed_count": len(failed_umos), - "failed_umos": failed_umos, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"批量更新服务状态失败: {e!s}") - return Response().error(f"批量更新服务状态失败: {e!s}").__dict__ - - async def batch_update_provider(self): - """批量更新多个 UMO 的 Provider 配置 - - 请求体: - { - "umos": ["平台:消息类型:会话ID", ...], // 可选 - "scope": "all" | "group" | "private", // 可选 - "provider_type": "chat_completion" | "text_to_speech" | "speech_to_text", - "provider_id": "provider_id" - } - """ - try: - data = await request.get_json() - umos = data.get("umos", []) - scope = data.get("scope", "") - provider_type = data.get("provider_type") - provider_id = data.get("provider_id") - - if not provider_type or not provider_id: - return ( - Response() - .error("缺少必要参数: provider_type, provider_id") - .__dict__ - ) - - # 转换 provider_type - provider_type_map = { - "chat_completion": ProviderType.CHAT_COMPLETION, - "text_to_speech": ProviderType.TEXT_TO_SPEECH, - "speech_to_text": ProviderType.SPEECH_TO_TEXT, - } - if provider_type not in provider_type_map: - return ( - Response() - .error(f"不支持的 provider_type: {provider_type}") - .__dict__ - ) - - provider_type_enum = provider_type_map[provider_type] - - # 如果指定了 scope,获取符合条件的所有 umo - group_id = data.get("group_id", "") - if scope and not umos: - # 如果是自定义分组 - if scope == "custom_group": - if not group_id: - return Response().error("请指定分组 ID").__dict__ - groups = self._get_groups() - if group_id not in groups: - return Response().error(f"分组 '{group_id}' 不存在").__dict__ - umos = groups[group_id].get("umos", []) - else: - async with self.db_helper.get_db() as session: - session: AsyncSession - result = await session.execute( - select(ConversationV2.user_id).distinct() - ) - all_umos = [row[0] for row in result.fetchall()] - - if scope == "group": - umos = [ - u - for u in all_umos - if ":group:" in u.lower() or ":groupmessage:" in u.lower() - ] - elif scope == "private": - umos = [ - u - for u in all_umos - if ":private:" in u.lower() or ":friend" in u.lower() - ] - elif scope == "all": - umos = all_umos - - if not umos: - return Response().error("没有找到符合条件的会话").__dict__ - - # 批量更新 - success_count = 0 - failed_umos = [] - provider_manager = self.core_lifecycle.provider_manager - - for umo in umos: - try: - await provider_manager.set_provider( - provider_id=provider_id, - provider_type=provider_type_enum, - umo=umo, - ) - success_count += 1 - except Exception as e: - logger.error(f"更新 {umo} Provider 失败: {e!s}") - failed_umos.append(umo) - - return ( - Response() - .ok( - { - "message": f"已更新 {success_count} 个会话的 {provider_type} 为 {provider_id}", - "success_count": success_count, - "failed_count": len(failed_umos), - "failed_umos": failed_umos, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"批量更新 Provider 失败: {e!s}") - return Response().error(f"批量更新 Provider 失败: {e!s}").__dict__ - - # ==================== 分组管理 API ==================== - - def _get_groups(self) -> dict: - """获取所有分组""" - return sp.get("session_groups", {}) - - def _save_groups(self, groups: dict) -> None: - """保存分组""" - sp.put("session_groups", groups) - - async def list_groups(self): - """获取所有分组列表""" - try: - groups = self._get_groups() - # 转换为列表格式,方便前端使用 - groups_list = [] - for group_id, group_data in groups.items(): - groups_list.append( - { - "id": group_id, - "name": group_data.get("name", ""), - "umos": group_data.get("umos", []), - "umo_count": len(group_data.get("umos", [])), - } - ) - return Response().ok({"groups": groups_list}).__dict__ - except Exception as e: - logger.error(f"获取分组列表失败: {e!s}") - return Response().error(f"获取分组列表失败: {e!s}").__dict__ - - async def create_group(self): - """创建新分组""" - try: - data = await request.json - name = data.get("name", "").strip() - umos = data.get("umos", []) - - if not name: - return Response().error("分组名称不能为空").__dict__ - - groups = self._get_groups() - - # 生成唯一 ID - import uuid - - group_id = str(uuid.uuid4())[:8] - - groups[group_id] = { - "name": name, - "umos": umos, - } - - self._save_groups(groups) - - return ( - Response() - .ok( - { - "message": f"分组 '{name}' 创建成功", - "group": { - "id": group_id, - "name": name, - "umos": umos, - "umo_count": len(umos), - }, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"创建分组失败: {e!s}") - return Response().error(f"创建分组失败: {e!s}").__dict__ - - async def update_group(self): - """更新分组(改名、增删成员)""" - try: - data = await request.json - group_id = data.get("id") - name = data.get("name") - umos = data.get("umos") - add_umos = data.get("add_umos", []) - remove_umos = data.get("remove_umos", []) - - if not group_id: - return Response().error("分组 ID 不能为空").__dict__ - - groups = self._get_groups() - - if group_id not in groups: - return Response().error(f"分组 '{group_id}' 不存在").__dict__ - - group = groups[group_id] - - # 更新名称 - if name is not None: - group["name"] = name.strip() - - # 直接设置 umos 列表 - if umos is not None: - group["umos"] = umos - else: - # 增量更新 - current_umos = set(group.get("umos", [])) - if add_umos: - current_umos.update(add_umos) - if remove_umos: - current_umos.difference_update(remove_umos) - group["umos"] = list(current_umos) - - self._save_groups(groups) - - return ( - Response() - .ok( - { - "message": f"分组 '{group['name']}' 更新成功", - "group": { - "id": group_id, - "name": group["name"], - "umos": group["umos"], - "umo_count": len(group["umos"]), - }, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"更新分组失败: {e!s}") - return Response().error(f"更新分组失败: {e!s}").__dict__ - - async def delete_group(self): - """删除分组""" - try: - data = await request.json - group_id = data.get("id") - - if not group_id: - return Response().error("分组 ID 不能为空").__dict__ - - groups = self._get_groups() - - if group_id not in groups: - return Response().error(f"分组 '{group_id}' 不存在").__dict__ - - group_name = groups[group_id].get("name", group_id) - del groups[group_id] - - self._save_groups(groups) - - return Response().ok({"message": f"分组 '{group_name}' 已删除"}).__dict__ - except Exception as e: - logger.error(f"删除分组失败: {e!s}") - return Response().error(f"删除分组失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/t2i.py b/astrbot/dashboard/routes/t2i.py index 8d06826be..c56fbf118 100644 --- a/astrbot/dashboard/routes/t2i.py +++ b/astrbot/dashboard/routes/t2i.py @@ -1,5 +1,3 @@ -# astrbot/dashboard/routes/t2i.py - from dataclasses import asdict from quart import jsonify, request @@ -38,6 +36,11 @@ def __init__( ] self.register_routes() + async def _reload_all_pipeline_executors(self) -> None: + config_ids = list(self.core_lifecycle.astrbot_config_mgr.confs.keys()) + for config_id in config_ids: + await self.core_lifecycle.reload_pipeline_executor(config_id) + async def list_templates(self): """获取所有T2I模板列表""" try: @@ -133,7 +136,7 @@ async def update_template(self, name: str): # 检查更新的是否为当前激活的模板,如果是,则热重载 active_template = self.config.get("t2i_active_template", "base") if name == active_template: - await self.core_lifecycle.reload_pipeline_scheduler("default") + await self._reload_all_pipeline_executors() message = f"模板 '{name}' 已更新并重新加载。" else: message = f"模板 '{name}' 已更新。" @@ -188,7 +191,7 @@ async def set_active_template(self): config.save_config(config) # 热重载以应用更改 - await self.core_lifecycle.reload_pipeline_scheduler("default") + await self._reload_all_pipeline_executors() return jsonify(asdict(Response().ok(message=f"模板 '{name}' 已成功应用。"))) @@ -215,7 +218,7 @@ async def reset_default_template(self): config.save_config(config) # 热重载以应用更改 - await self.core_lifecycle.reload_pipeline_scheduler("default") + await self._reload_all_pipeline_executors() return jsonify( asdict( diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 604866a87..9e97182f7 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -24,8 +24,6 @@ from .routes.live_chat import LiveChatRoute from .routes.platform import PlatformRoute from .routes.route import Response, RouteContext -from .routes.session_management import SessionManagementRoute -from .routes.subagent import SubAgentRoute from .routes.t2i import T2iRoute @@ -88,7 +86,7 @@ def __init__( self.skills_route = SkillsRoute(self.context, core_lifecycle) self.conversation_route = ConversationRoute(self.context, db, core_lifecycle) self.file_route = FileRoute(self.context) - self.session_management_route = SessionManagementRoute( + self.chain_management_route = ChainManagementRoute( self.context, db, core_lifecycle, diff --git a/dashboard/src/components/RuleEditor.vue b/dashboard/src/components/RuleEditor.vue new file mode 100644 index 000000000..61e94a395 --- /dev/null +++ b/dashboard/src/components/RuleEditor.vue @@ -0,0 +1,99 @@ + + + + + diff --git a/dashboard/src/components/RuleNode.vue b/dashboard/src/components/RuleNode.vue new file mode 100644 index 000000000..211fb01d6 --- /dev/null +++ b/dashboard/src/components/RuleNode.vue @@ -0,0 +1,323 @@ + + + + + diff --git a/dashboard/src/components/chat/ChatInput.vue b/dashboard/src/components/chat/ChatInput.vue index 35ec22cd3..1c4628080 100644 --- a/dashboard/src/components/chat/ChatInput.vue +++ b/dashboard/src/components/chat/ChatInput.vue @@ -53,11 +53,6 @@ - - - diff --git a/dashboard/src/components/shared/ObjectEditor.vue b/dashboard/src/components/shared/ObjectEditor.vue index ee6dc84bf..cb459133f 100644 --- a/dashboard/src/components/shared/ObjectEditor.vue +++ b/dashboard/src/components/shared/ObjectEditor.vue @@ -183,7 +183,7 @@ - +
{ return props.itemMeta?.template_schema || {} }) +const allowCustomKeys = computed(() => { + return !props.itemMeta?.lock_template_keys +}) + const hasTemplateSchema = computed(() => { return Object.keys(templateSchema.value).length > 0 }) @@ -271,6 +275,9 @@ const displayKeys = computed(() => { // 分离模板字段和普通字段 const nonTemplatePairs = computed(() => { + if (!allowCustomKeys.value) { + return [] + } return localKeyValuePairs.value.filter(pair => !templateSchema.value[pair.key]) }) @@ -282,9 +289,28 @@ watch(() => props.modelValue, (newValue) => { function initializeLocalKeyValuePairs() { localKeyValuePairs.value = [] + if (hasTemplateSchema.value && !allowCustomKeys.value) { + for (const [key, template] of Object.entries(templateSchema.value)) { + const rawValue = props.modelValue[key] + let _type = template.type || (typeof rawValue === 'object' ? 'json' : typeof rawValue) + let _value = _type === 'json' ? JSON.stringify(rawValue) : rawValue + if (_value === undefined || _value === null) { + _value = template.default !== undefined ? template.default : getDefaultValueForType(_type) + } + localKeyValuePairs.value.push({ + key, + value: _value, + type: _type, + slider: template?.slider, + template + }) + } + return + } + for (const [key, value] of Object.entries(props.modelValue)) { - let _type = (typeof value) === 'object' ? 'json':(typeof value) - let _value = _type === 'json'?JSON.stringify(value):value + let _type = (typeof value) === 'object' ? 'json' : (typeof value) + let _value = _type === 'json' ? JSON.stringify(value) : value // Check if this key has a template schema const template = templateSchema.value[key] @@ -316,6 +342,9 @@ function openDialog() { } function addKeyValuePair() { + if (!allowCustomKeys.value) { + return + } const key = newKey.value.trim() if (key !== '') { const isKeyExists = localKeyValuePairs.value.some(pair => pair.key === key) @@ -366,6 +395,9 @@ function removeKeyValuePairByKey(key) { } function updateKey(index, newKey) { + if (!allowCustomKeys.value) { + return + } const originalKey = localKeyValuePairs.value[index].key // 如果键名没有改变,则不执行任何操作 if (originalKey === newKey) return @@ -515,4 +547,4 @@ function cancelDialog() { .template-field-inactive { opacity: 0.8; } - \ No newline at end of file + diff --git a/dashboard/src/i18n/loader.ts b/dashboard/src/i18n/loader.ts index e9914554a..29992ed1a 100644 --- a/dashboard/src/i18n/loader.ts +++ b/dashboard/src/i18n/loader.ts @@ -39,7 +39,7 @@ export class I18nLoader { { name: 'features/chat', path: 'features/chat.json' }, { name: 'features/extension', path: 'features/extension.json' }, { name: 'features/conversation', path: 'features/conversation.json' }, - { name: 'features/session-management', path: 'features/session-management.json' }, + { name: 'features/chain-management', path: 'features/chain-management.json' }, { name: 'features/tooluse', path: 'features/tool-use.json' }, { name: 'features/provider', path: 'features/provider.json' }, { name: 'features/platform', path: 'features/platform.json' }, diff --git a/dashboard/src/i18n/locales/en-US/core/navigation.json b/dashboard/src/i18n/locales/en-US/core/navigation.json index 3c70afd80..cdbc34bbc 100644 --- a/dashboard/src/i18n/locales/en-US/core/navigation.json +++ b/dashboard/src/i18n/locales/en-US/core/navigation.json @@ -19,7 +19,7 @@ "components": "Handlers" }, "conversation": "Conversations", - "sessionManagement": "Custom Rules", + "chainManagement": "Chain Routing", "console": "Console", "trace": "Trace", "alkaid": "Alkaid Lab", diff --git a/dashboard/src/i18n/locales/en-US/features/chain-management.json b/dashboard/src/i18n/locales/en-US/features/chain-management.json new file mode 100644 index 000000000..dfd933143 --- /dev/null +++ b/dashboard/src/i18n/locales/en-US/features/chain-management.json @@ -0,0 +1,120 @@ +{ + "title": "Chain Routing", + "chainsCount": "chains", + "defaultTag": "Default", + "defaultHint": "The default chain matches all messages when enabled. You can only edit nodes, overrides, and the enabled state.", + "search": { + "placeholder": "Search by rule" + }, + "buttons": { + "create": "Create Chain", + "refresh": "Refresh", + "save": "Save", + "cancel": "Cancel", + "delete": "Delete", + "add": "Add", + "addNode": "Add Node", + "sort": "Sort" + }, + "table": { + "matchRule": "Match Rule", + "services": "Services", + "nodes": "Nodes", + "providers": "Providers", + "actions": "Actions" + }, + "sections": { + "providers": "Provider Overrides", + "nodes": "Node Configs", + "pluginFilter": "Plugin Filter" + }, + "fields": { + "config": "Config File", + "enabled": "Enabled", + "sttProvider": "STT Provider", + "ttsProvider": "TTS Provider", + "selectNode": "Select Node", + "pluginFilterMode": "Filter Mode", + "pluginList": "Plugin List" + }, + "providers": { + "tts": "TTS", + "stt": "STT", + "followDefault": "Follow default" + }, + "chips": { + "session": "Session", + "llm": "LLM" + }, + "dialogs": { + "createTitle": "Create Chain", + "editTitle": "Edit Chain", + "deleteTitle": "Delete Chain", + "deleteConfirm": "This will delete the chain configuration.", + "addNodeTitle": "Add Node", + "nodeConfigTitle": "Node Configuration", + "sortTitle": "Sort Chains", + "sortHint": "Drag to reorder. Higher items match first.", + "sortDefaultHint": "Default chain stays last" + }, + "messages": { + "loadError": "Failed to load chains", + "optionsError": "Failed to load options", + "saveSuccess": "Saved", + "saveError": "Save failed", + "deleteSuccess": "Deleted", + "deleteError": "Delete failed", + "nodeConfigEmpty": "This node has no configurable options.", + "nodeConfigNeedSave": "Save the chain before editing node config.", + "nodeConfigLoadError": "Failed to load node config", + "nodeConfigSaveError": "Failed to save node config", + "nodeConfigJsonError": "Invalid JSON format", + "nodeConfigRawHint": "No schema provided. Node config is unavailable.", + "nodeConfigNoSchema": "No schema provided. Node config is unavailable.", + "sortSuccess": "Order updated", + "sortError": "Failed to update order", + "sortLoadError": "Failed to load chain order" + }, + "warnings": { + "modalityMismatch": "Modality mismatch", + "outputs": "outputs", + "accepts": "accepts" + }, + "empty": { + "title": "No chain routes", + "subtitle": "Create a chain to start routing messages", + "nodes": "No nodes", + "providers": "No providers" + }, + "ruleEditor": { + "label": "Match Rule", + "addGroup": "Add Group", + "empty": "No rule (matches all messages)", + "addCondition": "Add Condition", + "addSubGroup": "Add Sub-group", + "changeType": "Change Type", + "umoPlaceholder": "e.g. aiocqhttp:group_message:*", + "regexPlaceholder": "e.g. .*hello.*", + "types": { + "umo": "UMO Match", + "modality": "Modality", + "textRegex": "Text Regex" + }, + "operators": { + "include": "Include", + "exclude": "Exclude" + } + }, + "pluginConfig": { + "inherit": "Follow global", + "blacklist": "Blacklist", + "whitelist": "Whitelist", + "noRestriction": "Enable all plugins", + "disableAll": "Disable all plugins", + "inheritHint": "Follow global plugin restrictions", + "blacklistHint": "Selected plugins will not be executed", + "whitelistHint": "Only selected plugins will be executed", + "noRestrictionHint": "This chain allows any plugin to be executed", + "disableAllHint": "No plugins will be executed for this chain" + } +} diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 2166d5391..eb120a340 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -44,28 +44,6 @@ "image_caption_prompt": { "description": "Image Caption Prompt" } - }, - "provider_stt_settings": { - "enable": { - "description": "Enable Speech-to-Text", - "hint": "Master switch for STT" - }, - "provider_id": { - "description": "Default Speech-to-Text Model", - "hint": "Users can also select session-specific STT models using the /provider command." - } - }, - "provider_tts_settings": { - "enable": { - "description": "Enable Text-to-Speech", - "hint": "Master switch for TTS" - }, - "provider_id": { - "description": "Default Text-to-Speech Model" - }, - "trigger_probability": { - "description": "TTS Trigger Probability" - } } }, "persona": { @@ -121,22 +99,6 @@ } } }, - "file_extract": { - "description": "File Extract", - "provider_settings": { - "file_extract": { - "enable": { - "description": "Enable File Extract" - }, - "provider": { - "description": "File Extract Provider" - }, - "moonshotai_api_key": { - "description": "Moonshot AI API Key" - } - } - } - }, "agent_computer_use": { "description": "Agent Computer Use", "hint": "Allows the AstrBot to access and use your computer or an sandbox environment to perform more complex tasks. See [Sandbox Mode](https://docs.astrbot.app/use/astrbot-agent-sandbox.html), [Skills](https://docs.astrbot.app/use/skills.html)", @@ -284,11 +246,6 @@ "description": "Provider Reachability Check", "hint": "When running the /provider command, test provider connectivity in parallel. This actively pings models and may consume extra tokens." } - }, - "provider_tts_settings": { - "dual_output": { - "description": "Output Both Voice and Text When TTS is Enabled" - } } } }, diff --git a/dashboard/src/i18n/locales/en-US/features/session-management.json b/dashboard/src/i18n/locales/en-US/features/session-management.json deleted file mode 100644 index fdd6b4f82..000000000 --- a/dashboard/src/i18n/locales/en-US/features/session-management.json +++ /dev/null @@ -1,148 +0,0 @@ -{ - "title": "Custom Rules", - "subtitle": "Set custom rules for specific sessions, which take priority over global settings", - "buttons": { - "refresh": "Refresh", - "edit": "Edit", - "editRule": "Edit Rules", - "deleteAllRules": "Delete All Rules", - "addRule": "Add Rule", - "save": "Save", - "cancel": "Cancel", - "delete": "Delete", - "clear": "Clear", - "next": "Next", - "editCustomName": "Edit Note", - "batchDelete": "Batch Delete" - }, - "customRules": { - "title": "Custom Rules", - "rulesCount": "rules", - "hasRules": "Configured", - "noRules": "No Custom Rules", - "noRulesDesc": "Click 'Add Rule' to configure custom rules for specific sessions", - "serviceConfig": "Service Config", - "pluginConfig": "Plugin Config", - "kbConfig": "Knowledge Base", - "providerConfig": "Provider Config", - "configured": "Configured", - "noCustomName": "No note set" - }, - "quickEditName": { - "title": "Edit Note" - }, - "search": { - "placeholder": "Search sessions..." - }, - "table": { - "headers": { - "umoInfo": "Unified Message Origin", - "rulesOverview": "Rules Overview", - "actions": "Actions" - } - }, - "persona": { - "none": "Follow Config" - }, - "provider": { - "followConfig": "Follow Config" - }, - "addRule": { - "title": "Add Custom Rule", - "description": "Select a session (UMO) to configure custom rules. Custom rules take priority over global settings.", - "selectUmo": "Select Session", - "noUmos": "No sessions available" - }, - "ruleEditor": { - "title": "Edit Custom Rules", - "description": "Configure custom rules for this session. These rules take priority over global settings.", - "serviceConfig": { - "title": "Service Configuration", - "sessionEnabled": "Enable Session", - "llmEnabled": "Enable LLM", - "ttsEnabled": "Enable TTS", - "customName": "Custom Name" - }, - "providerConfig": { - "title": "Provider Configuration", - "chatProvider": "Chat Provider", - "sttProvider": "STT Provider", - "ttsProvider": "TTS Provider" - }, - "personaConfig": { - "title": "Persona Configuration", - "selectPersona": "Select Persona", - "hint": "Persona settings affect the conversation style and behavior of the LLM" - }, - "pluginConfig": { - "title": "Plugin Configuration", - "disabledPlugins": "Disabled Plugins", - "hint": "Select plugins to disable for this session. Unselected plugins will remain enabled." - }, - "kbConfig": { - "title": "Knowledge Base Configuration", - "selectKbs": "Select Knowledge Bases", - "topK": "Top K Results", - "enableRerank": "Enable Reranking" - } - }, - "deleteConfirm": { - "title": "Confirm Delete", - "message": "Are you sure you want to delete all custom rules for this session? Global settings will be used after deletion." - }, - "batchDeleteConfirm": { - "title": "Confirm Batch Delete", - "message": "Are you sure you want to delete {count} selected rules? Global settings will be used after deletion." - }, - "batchOperations": { - "title": "Batch Operations", - "hint": "Quick batch modify session settings", - "scope": "Apply to", - "scopeSelected": "Selected sessions", - "scopeAll": "All sessions", - "scopeGroup": "All groups", - "scopePrivate": "All private chats", - "llmStatus": "LLM Status", - "ttsStatus": "TTS Status", - "chatProvider": "Chat Model", - "ttsProvider": "TTS Model", - "apply": "Apply Changes" - }, - "status": { - "enabled": "Enabled", - "disabled": "Disabled" - }, - "batchOperations": { - "title": "Batch Operations", - "hint": "Quick batch modify session settings", - "scope": "Apply to", - "scopeSelected": "Selected sessions", - "scopeAll": "All sessions", - "scopeGroup": "All groups", - "scopePrivate": "All private chats", - "llmStatus": "LLM Status", - "ttsStatus": "TTS Status", - "chatProvider": "Chat Model", - "ttsProvider": "TTS Model", - "apply": "Apply Changes" - }, - "status": { - "enabled": "Enabled", - "disabled": "Disabled" - }, - "messages": { - "refreshSuccess": "Data refreshed", - "loadError": "Failed to load data", - "saveSuccess": "Saved successfully", - "saveError": "Failed to save", - "clearSuccess": "Cleared successfully", - "clearError": "Failed to clear", - "deleteSuccess": "Deleted successfully", - "deleteError": "Failed to delete", - "noChanges": "No changes to save", - "batchDeleteSuccess": "Batch delete successful", - "batchDeleteError": "Batch delete failed", - "batchUpdateError": "Batch update failed", - "batchUpdateSuccess": "Batch update success" - } -} diff --git a/dashboard/src/i18n/locales/zh-CN/core/navigation.json b/dashboard/src/i18n/locales/zh-CN/core/navigation.json index 9481cc259..1b60137fb 100644 --- a/dashboard/src/i18n/locales/zh-CN/core/navigation.json +++ b/dashboard/src/i18n/locales/zh-CN/core/navigation.json @@ -19,7 +19,7 @@ "chat": "聊天", "cron": "未来任务", "conversation": "对话数据", - "sessionManagement": "自定义规则", + "chainManagement": "Chain 路由", "console": "平台日志", "trace": "追踪", "alkaid": "Alkaid", diff --git a/dashboard/src/i18n/locales/zh-CN/features/chain-management.json b/dashboard/src/i18n/locales/zh-CN/features/chain-management.json new file mode 100644 index 000000000..baea1cc0e --- /dev/null +++ b/dashboard/src/i18n/locales/zh-CN/features/chain-management.json @@ -0,0 +1,120 @@ +{ + "title": "Chain 路由", + "chainsCount": "条", + "defaultTag": "默认", + "defaultHint": "默认 Chain 在启用时匹配所有消息,仅支持编辑节点、覆盖设置与启用状态。", + "search": { + "placeholder": "按匹配规则搜索" + }, + "buttons": { + "create": "新建 Chain", + "refresh": "刷新", + "save": "保存", + "cancel": "取消", + "delete": "删除", + "add": "添加", + "addNode": "添加节点", + "sort": "排序" + }, + "table": { + "matchRule": "匹配规则", + "services": "服务", + "nodes": "节点", + "providers": "Provider", + "actions": "操作" + }, + "sections": { + "providers": "Provider 覆盖", + "nodes": "节点配置", + "pluginFilter": "插件过滤" + }, + "fields": { + "config": "配置文件", + "enabled": "启用", + "sttProvider": "STT Provider", + "ttsProvider": "TTS Provider", + "selectNode": "选择节点", + "pluginFilterMode": "过滤模式", + "pluginList": "插件列表" + }, + "providers": { + "tts": "TTS", + "stt": "STT", + "followDefault": "跟随默认" + }, + "chips": { + "session": "会话", + "llm": "LLM" + }, + "dialogs": { + "createTitle": "新建 Chain", + "editTitle": "编辑 Chain", + "deleteTitle": "删除 Chain", + "deleteConfirm": "确认删除该 Chain 路由?", + "addNodeTitle": "添加节点", + "nodeConfigTitle": "节点配置", + "sortTitle": "排序 Chain", + "sortHint": "拖拽调整顺序,越靠前匹配优先级越高。", + "sortDefaultHint": "默认 Chain 固定在最后" + }, + "messages": { + "loadError": "加载失败", + "optionsError": "加载选项失败", + "saveSuccess": "保存成功", + "saveError": "保存失败", + "deleteSuccess": "删除成功", + "deleteError": "删除失败", + "nodeConfigEmpty": "该节点没有可配置项。", + "nodeConfigNeedSave": "请先保存 Chain 再编辑节点配置。", + "nodeConfigLoadError": "加载节点配置失败", + "nodeConfigSaveError": "保存节点配置失败", + "nodeConfigJsonError": "JSON 格式错误", + "nodeConfigRawHint": "未提供 Schema,节点配置不可编辑。", + "nodeConfigNoSchema": "未提供 Schema,节点配置不可编辑。", + "sortSuccess": "排序已更新", + "sortError": "排序更新失败", + "sortLoadError": "加载排序失败" + }, + "warnings": { + "modalityMismatch": "模态不匹配", + "outputs": "输出", + "accepts": "接受" + }, + "empty": { + "title": "暂无 Chain 路由", + "subtitle": "创建 Chain 以启用路由", + "nodes": "暂无节点", + "providers": "暂无 Provider" + }, + "ruleEditor": { + "label": "匹配规则", + "addGroup": "添加规则组", + "empty": "无规则(匹配所有消息)", + "addCondition": "添加条件", + "addSubGroup": "添加子组", + "changeType": "切换类型", + "umoPlaceholder": "如 aiocqhttp:group_message:*", + "regexPlaceholder": "如 .*hello.*", + "types": { + "umo": "UMO 匹配", + "modality": "消息模态", + "textRegex": "文本正则" + }, + "operators": { + "include": "包含", + "exclude": "不包含" + } + }, + "pluginConfig": { + "inherit": "跟随全局限制", + "blacklist": "黑名单", + "whitelist": "白名单", + "noRestriction": "启用所有插件", + "disableAll": "禁用所有插件", + "inheritHint": "跟随全局插件限制", + "blacklistHint": "选中的插件将不会执行", + "whitelistHint": "只有选中的插件会执行", + "noRestrictionHint": "该 Chain 允许任何插件执行", + "disableAllHint": "该 Chain 不允许任何插件执行" + } +} diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 2d1c11cda..e1ca7f9d4 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -44,28 +44,6 @@ "image_caption_prompt": { "description": "图片转述提示词" } - }, - "provider_stt_settings": { - "enable": { - "description": "启用语音转文本", - "hint": "STT 总开关" - }, - "provider_id": { - "description": "默认语音转文本模型", - "hint": "用户也可使用 /provider 指令单独选择会话的 STT 模型。" - } - }, - "provider_tts_settings": { - "enable": { - "description": "启用文本转语音", - "hint": "TTS 总开关" - }, - "provider_id": { - "description": "默认文本转语音模型" - }, - "trigger_probability": { - "description": "TTS 触发概率" - } } }, "persona": { @@ -124,22 +102,6 @@ } } }, - "file_extract": { - "description": "文档解析能力", - "provider_settings": { - "file_extract": { - "enable": { - "description": "启用文档解析能力" - }, - "provider": { - "description": "文档解析提供商" - }, - "moonshotai_api_key": { - "description": "Moonshot AI API Key" - } - } - } - }, "agent_computer_use": { "description": "使用电脑能力", "hint": "让 AstrBot 访问和使用你的电脑或者隔离的沙盒环境,以执行更复杂的任务。详见: [沙盒模式](https://docs.astrbot.app/use/astrbot-agent-sandbox.html), [Skills](https://docs.astrbot.app/use/skills.html)。", @@ -287,11 +249,6 @@ "description": "提供商可达性检测", "hint": "/provider 命令列出模型时并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。" } - }, - "provider_tts_settings": { - "dual_output": { - "description": "开启 TTS 时同时输出语音和文字内容" - } } } }, diff --git a/dashboard/src/i18n/locales/zh-CN/features/session-management.json b/dashboard/src/i18n/locales/zh-CN/features/session-management.json deleted file mode 100644 index 33b387cd2..000000000 --- a/dashboard/src/i18n/locales/zh-CN/features/session-management.json +++ /dev/null @@ -1,128 +0,0 @@ -{ - "title": "自定义规则", - "subtitle": "为特定会话设置自定义规则,优先级高于全局配置", - "buttons": { - "refresh": "刷新", - "edit": "编辑", - "editRule": "编辑规则", - "deleteAllRules": "删除所有规则", - "addRule": "添加规则", - "save": "保存", - "cancel": "取消", - "delete": "删除", - "clear": "清除", - "next": "下一步", - "editCustomName": "编辑备注", - "batchDelete": "批量删除" - }, - "customRules": { - "title": "自定义规则", - "rulesCount": "条规则", - "hasRules": "已配置", - "noRules": "暂无自定义规则", - "noRulesDesc": "点击「添加规则」为特定会话配置自定义规则", - "serviceConfig": "服务配置", - "pluginConfig": "插件配置", - "kbConfig": "知识库配置", - "providerConfig": "模型配置", - "configured": "已配置", - "noCustomName": "未设置备注" - }, - "quickEditName": { - "title": "编辑备注名" - }, - "search": { - "placeholder": "搜索会话..." - }, - "table": { - "headers": { - "umoInfo": "消息会话来源", - "rulesOverview": "规则概览", - "actions": "操作" - } - }, - "persona": { - "none": "跟随配置文件" - }, - "provider": { - "followConfig": "跟随配置文件" - }, - "addRule": { - "title": "添加自定义规则", - "description": "选择一个消息会话来源 (UMO) 来配置自定义规则。自定义规则的优先级高于该来源所属的配置文件中的全局规则。可以使用 /sid 指令获取该来源的 UMO 信息。", - "selectUmo": "选择会话", - "noUmos": "暂无可用会话" - }, - "ruleEditor": { - "title": "编辑自定义规则", - "description": "为此会话配置自定义规则,这些规则将优先于全局配置生效。", - "serviceConfig": { - "title": "服务配置", - "sessionEnabled": "启用该消息会话来源的消息处理", - "llmEnabled": "启用 LLM", - "ttsEnabled": "启用 TTS", - "customName": "消息会话来源备注名称" - }, - "providerConfig": { - "title": "模型配置", - "chatProvider": "聊天模型", - "sttProvider": "语音识别模型", - "ttsProvider": "语音合成模型" - }, - "personaConfig": { - "title": "人格配置", - "selectPersona": "选择人格", - "hint": "应用人格配置后,将会强制该来源的所有对话使用该人格。" - }, - "pluginConfig": { - "title": "插件配置", - "disabledPlugins": "禁用的插件", - "hint": "选择要在此会话中禁用的插件。未选择的插件将保持启用状态。" - }, - "kbConfig": { - "title": "知识库配置", - "selectKbs": "选择知识库", - "topK": "返回结果数量 (Top K)", - "enableRerank": "启用重排序" - } - }, - "deleteConfirm": { - "title": "确认删除", - "message": "确定要删除此会话的所有自定义规则吗?删除后将恢复使用全局配置。" - }, - "batchDeleteConfirm": { - "title": "确认批量删除", - "message": "确定要删除选中的 {count} 条规则吗?删除后将恢复使用全局配置。" - }, - "batchOperations": { - "title": "批量操作", - "hint": "快速批量修改会话配置", - "scope": "应用范围", - "scopeSelected": "选中的会话", - "scopeAll": "所有会话", - "scopeGroup": "所有群聊", - "scopePrivate": "所有私聊", - "llmStatus": "LLM 状态", - "ttsStatus": "TTS 状态", - "chatProvider": "聊天模型", - "ttsProvider": "TTS 模型", - "apply": "应用更改" - }, - "status": { - "enabled": "启用", - "disabled": "禁用" - }, - "messages": { - "refreshSuccess": "数据已刷新", - "loadError": "加载数据失败", - "saveSuccess": "保存成功", - "saveError": "保存失败", - "clearSuccess": "已清除", - "clearError": "清除失败", - "deleteSuccess": "删除成功", - "deleteError": "删除失败", - "noChanges": "没有需要保存的更改", - "batchDeleteSuccess": "批量删除成功", - "batchDeleteError": "批量删除失败" - } -} diff --git a/dashboard/src/i18n/translations.ts b/dashboard/src/i18n/translations.ts index d72cc9114..6b7cfcb35 100644 --- a/dashboard/src/i18n/translations.ts +++ b/dashboard/src/i18n/translations.ts @@ -12,7 +12,7 @@ import zhCNShared from './locales/zh-CN/core/shared.json'; import zhCNChat from './locales/zh-CN/features/chat.json'; import zhCNExtension from './locales/zh-CN/features/extension.json'; import zhCNConversation from './locales/zh-CN/features/conversation.json'; -import zhCNSessionManagement from './locales/zh-CN/features/session-management.json'; +import zhCNChainManagement from './locales/zh-CN/features/chain-management.json'; import zhCNToolUse from './locales/zh-CN/features/tool-use.json'; import zhCNProvider from './locales/zh-CN/features/provider.json'; import zhCNPlatform from './locales/zh-CN/features/platform.json'; @@ -53,7 +53,7 @@ import enUSShared from './locales/en-US/core/shared.json'; import enUSChat from './locales/en-US/features/chat.json'; import enUSExtension from './locales/en-US/features/extension.json'; import enUSConversation from './locales/en-US/features/conversation.json'; -import enUSSessionManagement from './locales/en-US/features/session-management.json'; +import enUSChainManagement from './locales/en-US/features/chain-management.json'; import enUSToolUse from './locales/en-US/features/tool-use.json'; import enUSProvider from './locales/en-US/features/provider.json'; import enUSPlatform from './locales/en-US/features/platform.json'; @@ -98,7 +98,7 @@ export const translations = { chat: zhCNChat, extension: zhCNExtension, conversation: zhCNConversation, - 'session-management': zhCNSessionManagement, + 'chain-management': zhCNChainManagement, tooluse: zhCNToolUse, provider: zhCNProvider, platform: zhCNPlatform, @@ -147,7 +147,7 @@ export const translations = { chat: enUSChat, extension: enUSExtension, conversation: enUSConversation, - 'session-management': enUSSessionManagement, + 'chain-management': enUSChainManagement, tooluse: enUSToolUse, provider: enUSProvider, platform: enUSPlatform, diff --git a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts index 69a3791fa..ebb09cef7 100644 --- a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts +++ b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts @@ -90,9 +90,9 @@ const sidebarItem: menu[] = [ to: '/conversation' }, { - title: 'core.navigation.sessionManagement', - icon: 'mdi-pencil-ruler', - to: '/session-management' + title: 'core.navigation.chainManagement', + icon: 'mdi-shuffle-variant', + to: '/chain-management' }, { title: 'core.navigation.cron', diff --git a/dashboard/src/router/MainRoutes.ts b/dashboard/src/router/MainRoutes.ts index ce0706498..c08160c50 100644 --- a/dashboard/src/router/MainRoutes.ts +++ b/dashboard/src/router/MainRoutes.ts @@ -52,9 +52,9 @@ const MainRoutes = { component: () => import('@/views/ConversationPage.vue') }, { - name: 'SessionManagement', - path: '/session-management', - component: () => import('@/views/SessionManagementPage.vue') + name: 'ChainManagement', + path: '/chain-management', + component: () => import('@/views/ChainManagementPage.vue') }, { name: 'Persona', diff --git a/dashboard/src/views/ChainManagementPage.vue b/dashboard/src/views/ChainManagementPage.vue new file mode 100644 index 000000000..e9e267806 --- /dev/null +++ b/dashboard/src/views/ChainManagementPage.vue @@ -0,0 +1,1028 @@ + + + + + diff --git a/dashboard/src/views/SessionManagementPage.vue b/dashboard/src/views/SessionManagementPage.vue deleted file mode 100644 index 5008e1dd3..000000000 --- a/dashboard/src/views/SessionManagementPage.vue +++ /dev/null @@ -1,1586 +0,0 @@ - - - - -