Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions astrbot/api/star/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star import Context, NodeResult, NodeStar, Star, StarTools
from astrbot.core.star.config import *
from astrbot.core.star.register import (
register_star as register, # 注册插件(Star)
)

__all__ = ["Context", "Star", "StarTools", "register"]
__all__ = ["Context", "NodeResult", "NodeStar", "Star", "StarTools", "register"]
59 changes: 59 additions & 0 deletions astrbot/builtin_stars/agent/_node_config_schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
{
"agent_runner_type": {
"type": "string",
"default": "local",
"description": "此节点使用的 Agent 运行器类型",
"options": ["local", "dify", "coze", "dashscope"],
"labels": ["内置 Agent", "Dify", "Coze", "DashScope"]
},
"coze_agent_runner_provider_id": {
"type": "string",
"default": "",
"description": "Coze Agent 运行器的提供商 ID",
"_special": "select_agent_runner_provider:coze",
"condition": {
"agent_runner_type": "coze"
}
},
"dify_agent_runner_provider_id": {
"type": "string",
"default": "",
"description": "Dify Agent 运行器的提供商 ID",
"_special": "select_agent_runner_provider:dify",
"condition": {
"agent_runner_type": "dify"
}
},
"dashscope_agent_runner_provider_id": {
"type": "string",
"default": "",
"description": "DashScope Agent 运行器的提供商 ID",
"_special": "select_agent_runner_provider:dashscope",
"condition": {
"agent_runner_type": "dashscope"
}
},
"provider_id": {
"type": "string",
"default": "",
"description": "覆盖此节点的对话模型提供商 ID",
"_special": "select_provider"
},
"image_caption_provider_id": {
"type": "string",
"default": "",
"description": "覆盖此节点的图像描述模型提供商 ID",
"_special": "select_provider"
},
"image_caption_prompt": {
"type": "string",
"default": "",
"description": "覆盖此节点的图像描述提示词"
},
"persona_id": {
"type": "string",
"default": "",
"description": "覆盖此节点使用的人格 ID",
"_special": "select_persona"
}
}
56 changes: 56 additions & 0 deletions astrbot/builtin_stars/agent/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from astrbot.core import logger
from astrbot.core.pipeline.engine.node_context import NodePacket
from astrbot.core.star.node_star import NodeResult, NodeStar

if TYPE_CHECKING:
from astrbot.core.platform.astr_message_event import AstrMessageEvent


class AgentNode(NodeStar):
"""Agent execution node (local + third-party)."""

async def process(self, event: AstrMessageEvent) -> NodeResult:
ctx = event.node_context

# 合并上游输出作为 agent 输入
if ctx:
merged_input = await event.get_node_input(strategy="text_concat")
if isinstance(merged_input, str):
if merged_input.strip():
ctx.input = NodePacket.create(merged_input)
elif merged_input is not None:
ctx.input = NodePacket.create(merged_input)

if event.get_extra("_provider_request_consumed", False):
return NodeResult.SKIP

has_provider_request = event.get_extra("has_provider_request", False)
if not has_provider_request:
has_upstream_input = bool(ctx and ctx.input is not None)
should_wake = (
not event._has_send_oper
and event.is_at_or_wake_command
and not event.call_llm
)
if not (has_upstream_input or should_wake):
return NodeResult.SKIP

# 从 event 获取 AgentExecutor
agent_executor = event.agent_executor
if not agent_executor:
logger.warning("AgentExecutor missing in event services.")
return NodeResult.SKIP

outcome = await agent_executor.run(event)

if outcome.result:
event.set_node_output(outcome.result)

if outcome.stopped or event.is_stopped():
return NodeResult.STOP

return NodeResult.CONTINUE
4 changes: 4 additions & 0 deletions astrbot/builtin_stars/agent/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name: agent
desc: Builtin agent pipeline node
author: AstrBot
version: 1.0.0
13 changes: 11 additions & 2 deletions astrbot/builtin_stars/astrbot/long_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None:
"""记录群成员的群聊记录"""

def cfg(self, event: AstrMessageEvent):
cfg = self.context.get_config(umo=event.unified_msg_origin)
cfg = self.context.get_config_by_id(
event.chain_config.config_id if event.chain_config else None,
)
try:
max_cnt = int(cfg["provider_ltm_settings"]["group_message_max_cnt"])
except BaseException as e:
Expand Down Expand Up @@ -68,9 +70,15 @@ async def get_image_caption(
image_url: str,
image_caption_provider_id: str,
image_caption_prompt: str,
event: AstrMessageEvent | None = None,
) -> str:
if not image_caption_provider_id:
provider = self.context.get_using_provider()
provider = (
self.context.get_chat_provider_for_event(event) if event else None
)
if provider is None:
providers = self.context.get_all_providers()
provider = providers[0] if providers else None
else:
provider = self.context.get_provider_by_id(image_caption_provider_id)
if not provider:
Expand Down Expand Up @@ -133,6 +141,7 @@ async def handle_message(self, event: AstrMessageEvent) -> None:
url,
cfg["image_caption_provider_id"],
cfg["image_caption_prompt"],
event,
)
parts.append(f" [Image: {caption}]")
except Exception as e:
Expand Down
34 changes: 27 additions & 7 deletions astrbot/builtin_stars/astrbot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ def __init__(self, context: star.Context) -> None:
logger.error(f"聊天增强 err: {e}")

def ltm_enabled(self, event: AstrMessageEvent):
ltmse = self.context.get_config(umo=event.unified_msg_origin)[
"provider_ltm_settings"
]
chain_config_id = event.chain_config.config_id if event.chain_config else None
ltmse = self.context.get_config_by_id(chain_config_id)["provider_ltm_settings"]
return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"]

@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
Expand All @@ -36,9 +35,12 @@ async def on_message(self, event: AstrMessageEvent):
if self.ltm_enabled(event) and self.ltm and has_image_or_plain:
need_active = await self.ltm.need_active_reply(event)

group_icl_enable = self.context.get_config()["provider_ltm_settings"][
"group_icl_enable"
]
chain_config_id = (
event.chain_config.config_id if event.chain_config else None
)
group_icl_enable = self.context.get_config_by_id(chain_config_id)[
"provider_ltm_settings"
]["group_icl_enable"]
if group_icl_enable:
"""记录对话"""
try:
Expand All @@ -48,7 +50,25 @@ async def on_message(self, event: AstrMessageEvent):

if need_active:
"""主动回复"""
provider = self.context.get_using_provider(event.unified_msg_origin)
chain_config_id = (
event.chain_config.config_id if event.chain_config else None
)
runtime_cfg = self.context.get_config_by_id(chain_config_id)
default_provider_id = str(
runtime_cfg.get("provider_settings", {}).get(
"default_provider_id",
"",
)
or ""
).strip()
provider = (
self.context.get_provider_by_id(default_provider_id)
if default_provider_id
else None
)
if not provider:
all_providers = self.context.get_all_providers()
provider = all_providers[0] if all_providers else None
if not provider:
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
return
Expand Down
2 changes: 2 additions & 0 deletions astrbot/builtin_stars/builtin_commands/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .provider import ProviderCommands
from .setunset import SetUnsetCommands
from .sid import SIDCommand
from .stt import STTCommand
from .t2i import T2ICommand
from .tts import TTSCommand

Expand All @@ -24,6 +25,7 @@
"ProviderCommands",
"SIDCommand",
"SetUnsetCommands",
"STTCommand",
"T2ICommand",
"TTSCommand",
]
100 changes: 100 additions & 0 deletions astrbot/builtin_stars/builtin_commands/commands/_node_binding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

from dataclasses import dataclass

from astrbot.api.event import AstrMessageEvent
from astrbot.core.config.node_config import AstrBotNodeConfig
from astrbot.core.pipeline.engine.chain_config import ChainNodeConfig
from astrbot.core.star.context import Context


@dataclass
class NodeTarget:
node: ChainNodeConfig
config: AstrBotNodeConfig


def _get_node_schema(context: Context, node_name: str) -> dict | None:
meta = context.get_registered_star(node_name)
if meta:
return meta.node_schema
return None


def get_chain_nodes(event: AstrMessageEvent, node_name: str) -> list[ChainNodeConfig]:
chain_config = event.chain_config
if not chain_config:
return []
return [node for node in chain_config.nodes if node.name == node_name]


def resolve_node_selector(
nodes: list[ChainNodeConfig], selector: str
) -> ChainNodeConfig | None:
selector = (selector or "").strip()
if not selector:
return None
if selector.isdigit():
idx = int(selector)
if idx < 1 or idx > len(nodes):
return None
return nodes[idx - 1]
for node in nodes:
if node.uuid == selector:
return node
return None


def get_node_target(
context: Context,
event: AstrMessageEvent,
node_name: str,
selector: str | None = None,
) -> NodeTarget | None:
chain_config = event.chain_config
if not chain_config:
return None

nodes = get_chain_nodes(event, node_name)
if not nodes:
return None

target: ChainNodeConfig | None = None
if selector:
target = resolve_node_selector(nodes, selector)
elif len(nodes) == 1:
target = nodes[0]

if target is None:
return None

schema = _get_node_schema(context, node_name)
cfg = AstrBotNodeConfig.get_cached(
node_name=node_name,
chain_id=chain_config.chain_id,
node_uuid=target.uuid,
schema=schema,
)
return NodeTarget(node=target, config=cfg)


def list_nodes_with_config(
context: Context,
event: AstrMessageEvent,
node_name: str,
) -> list[NodeTarget]:
chain_config = event.chain_config
if not chain_config:
return []

schema = _get_node_schema(context, node_name)
ret: list[NodeTarget] = []
for node in get_chain_nodes(event, node_name):
cfg = AstrBotNodeConfig.get_cached(
node_name=node_name,
chain_id=chain_config.chain_id,
node_uuid=node.uuid,
schema=schema,
)
ret.append(NodeTarget(node=node, config=cfg))
return ret
8 changes: 6 additions & 2 deletions astrbot/builtin_stars/builtin_commands/commands/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
),
)
return
cfg = self.context.get_config(umo=event.unified_msg_origin)
cfg = self.context.get_config_by_id(
event.chain_config.config_id if event.chain_config else None,
)
cfg["platform_settings"]["id_whitelist"].append(str(sid))
cfg.save_config()
event.set_result(MessageEventResult().message("添加白名单成功。"))
Expand All @@ -63,7 +65,9 @@ async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None:
)
return
try:
cfg = self.context.get_config(umo=event.unified_msg_origin)
cfg = self.context.get_config_by_id(
event.chain_config.config_id if event.chain_config else None,
)
cfg["platform_settings"]["id_whitelist"].remove(str(sid))
cfg.save_config()
event.set_result(MessageEventResult().message("删除白名单成功。"))
Expand Down
Loading
Loading