Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 95 additions & 4 deletions astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import traceback
import typing as T
import uuid
from collections.abc import Sequence
from collections.abc import Set as AbstractSet

import mcp

Expand All @@ -26,6 +28,7 @@
SEND_MESSAGE_TO_USER_TOOL,
)
from astrbot.core.cron.events import CronMessageEvent
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
CommandResult,
MessageChain,
Expand All @@ -35,9 +38,93 @@
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.register import llm_tools
from astrbot.core.utils.history_saver import persist_agent_history
from astrbot.core.utils.image_ref_utils import (
ALLOWED_IMAGE_EXTENSIONS,
is_supported_image_ref,
)
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings


class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
_ALLOWED_IMAGE_EXTENSIONS = ALLOWED_IMAGE_EXTENSIONS
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated

@classmethod
def _is_supported_image_ref(cls, image_ref: str) -> bool:
return is_supported_image_ref(image_ref)

@classmethod
def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]:
candidates: list[str] = []
if image_urls_raw is None:
pass
elif isinstance(image_urls_raw, str):
candidates.append(image_urls_raw)
elif isinstance(image_urls_raw, (Sequence, AbstractSet)) and not isinstance(
image_urls_raw, (str, bytes, bytearray)
):
non_string_count = 0
for item in image_urls_raw:
if isinstance(item, str):
candidates.append(item)
else:
non_string_count += 1
if non_string_count > 0:
logger.debug(
"Dropped %d non-string image_urls entries in handoff tool args.",
non_string_count,
)
else:
logger.debug(
"Unsupported image_urls type in handoff tool args: %s",
type(image_urls_raw).__name__,
)
return candidates

@classmethod
async def _collect_image_urls_from_message(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
cls, run_context: ContextWrapper[AstrAgentContext]
) -> list[str]:
urls: list[str] = []
event = getattr(run_context.context, "event", None)
message_obj = getattr(event, "message_obj", None)
message = getattr(message_obj, "message", None)
if message:
for idx, component in enumerate(message):
if not isinstance(component, Image):
continue
try:
path = await component.convert_to_file_path()
if path:
urls.append(path)
except Exception as e:
logger.error(
"Failed to convert handoff image component at index %d: %s",
idx,
e,
exc_info=True,
)
return urls

@classmethod
async def _collect_handoff_image_urls(
cls,
run_context: ContextWrapper[AstrAgentContext],
image_urls_raw: T.Any,
) -> list[str]:
candidates: list[str] = []
candidates.extend(cls._collect_image_urls_from_args(image_urls_raw))
candidates.extend(await cls._collect_image_urls_from_message(run_context))

normalized = normalize_and_dedupe_strings(candidates)
sanitized = [item for item in normalized if cls._is_supported_image_ref(item)]
dropped_count = len(normalized) - len(sanitized)
if dropped_count > 0:
logger.debug(
"Dropped %d invalid image_urls entries in handoff image inputs.",
dropped_count,
)
return sanitized

@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Expand All @@ -58,7 +145,7 @@ async def execute(cls, tool, run_context, **tool_args):
):
yield r
return
async for r in cls._execute_handoff(tool, run_context, **tool_args):
async for r in cls._execute_handoff(tool, run_context, tool_args):
yield r
return

Expand Down Expand Up @@ -161,10 +248,14 @@ async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
tool_args: dict[str, T.Any],
):
input_ = tool_args.get("input")
image_urls = tool_args.get("image_urls")
image_urls = await cls._collect_handoff_image_urls(
run_context,
tool_args.get("image_urls"),
)
tool_args["image_urls"] = image_urls

# Build handoff toolset from registered tools plus runtime computer tools.
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)
Expand Down Expand Up @@ -264,7 +355,7 @@ async def _do_handoff_background(
"""Run the subagent handoff and, on completion, wake the main agent."""
result_text = ""
try:
async for r in cls._execute_handoff(tool, run_context, **tool_args):
async for r in cls._execute_handoff(tool, run_context, tool_args):
if isinstance(r, mcp.types.CallToolResult):
for content in r.content:
if isinstance(content, mcp.types.TextContent):
Expand Down
61 changes: 61 additions & 0 deletions astrbot/core/utils/image_ref_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

import os
from urllib.parse import unquote, urlparse

ALLOWED_IMAGE_EXTENSIONS = {
".png",
".jpg",
".jpeg",
".gif",
".webp",
".bmp",
".tif",
".tiff",
".svg",
".heic",
}


def resolve_file_url_path(image_ref: str) -> str:
parsed = urlparse(image_ref)
if parsed.scheme != "file":
return image_ref

path = unquote(parsed.path or "")
netloc = unquote(parsed.netloc or "")

# Keep support for file://<host>/path and file://<path> forms.
if netloc and netloc.lower() != "localhost":
path = f"//{netloc}{path}" if path else netloc
elif not path and netloc:
path = netloc

if os.name == "nt" and len(path) > 2 and path[0] == "/" and path[2] == ":":
path = path[1:]

return path or image_ref


def is_supported_image_ref(
image_ref: str,
*,
allow_extensionless_existing_local_file: bool = True,
) -> bool:
if not image_ref:
return False

lowered = image_ref.lower()
if lowered.startswith(("http://", "https://", "base64://")):
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
return True

file_path = (
resolve_file_url_path(image_ref) if lowered.startswith("file://") else image_ref
)
ext = os.path.splitext(file_path)[1].lower()
if ext in ALLOWED_IMAGE_EXTENSIONS:
return True
if not allow_extensionless_existing_local_file:
return False
# Keep support for extension-less temp files returned by image converters.
return ext == "" and os.path.exists(file_path)
13 changes: 3 additions & 10 deletions astrbot/core/utils/quoted_message/image_refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,9 @@
import os
from urllib.parse import urlsplit

IMAGE_EXTENSIONS = {
".jpg",
".jpeg",
".png",
".webp",
".bmp",
".tif",
".tiff",
".gif",
}
from astrbot.core.utils.image_ref_utils import ALLOWED_IMAGE_EXTENSIONS

IMAGE_EXTENSIONS = ALLOWED_IMAGE_EXTENSIONS


def normalize_file_like_url(path: str | None) -> str | None:
Expand Down
Loading