diff --git a/astrbot/core/agent/context/compressor.py b/astrbot/core/agent/context/compressor.py index 31a0b0b48..596fc82da 100644 --- a/astrbot/core/agent/context/compressor.py +++ b/astrbot/core/agent/context/compressor.py @@ -1,3 +1,5 @@ +import typing as T +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Protocol, runtime_checkable from ..message import Message @@ -154,6 +156,7 @@ def __init__( keep_recent: int = 4, instruction_text: str | None = None, compression_threshold: float = 0.82, + use_compact_api: bool = True, ) -> None: """Initialize the LLM summary compressor. @@ -162,10 +165,12 @@ def __init__( keep_recent: The number of latest messages to keep (default: 4). instruction_text: Custom instruction for summary generation. compression_threshold: The compression trigger threshold (default: 0.82). + use_compact_api: Whether to prefer provider native compact API when available. """ self.provider = provider self.keep_recent = keep_recent self.compression_threshold = compression_threshold + self.use_compact_api = use_compact_api self.instruction_text = instruction_text or ( "Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n" @@ -193,13 +198,54 @@ def should_compress( usage_rate = current_tokens / max_tokens return usage_rate > self.compression_threshold + def _supports_native_compact(self) -> bool: + support_native_compact = getattr(self.provider, "supports_native_compact", None) + if not callable(support_native_compact): + return False + try: + return bool(support_native_compact()) + except Exception: + return False + + async def _try_native_compact( + self, + system_messages: list[Message], + messages_to_summarize: list[Message], + recent_messages: list[Message], + ) -> list[Message] | None: + compact_context = getattr(self.provider, "compact_context", None) + if not callable(compact_context): + return None + + compact_context_callable = T.cast( + "Callable[[list[Message]], Awaitable[list[Message]]]", + compact_context, + ) + + try: + compacted_messages = await compact_context_callable(messages_to_summarize) + except Exception as e: + logger.warning( + f"Native compact failed, fallback to summary compression: {e}" + ) + return None + + if not compacted_messages: + return None + + result: list[Message] = [] + result.extend(system_messages) + result.extend(compacted_messages) + result.extend(recent_messages) + return result + async def __call__(self, messages: list[Message]) -> list[Message]: """Use LLM to generate a summary of the conversation history. Process: 1. Divide messages: keep the system message and the latest N messages. - 2. Send the old messages + the instruction message to the LLM. - 3. Reconstruct the message list: [system message, summary message, latest messages]. + 2. Prefer native compact when provider supports it. + 3. Fallback to LLM summary and reconstruct message list. """ if len(messages) <= self.keep_recent + 1: return messages @@ -207,15 +253,21 @@ async def __call__(self, messages: list[Message]) -> list[Message]: system_messages, messages_to_summarize, recent_messages = split_history( messages, self.keep_recent ) - if not messages_to_summarize: return messages - # build payload + # Only try native compact if user allows it and provider supports it + if self.use_compact_api and self._supports_native_compact(): + compacted = await self._try_native_compact( + system_messages, + messages_to_summarize, + recent_messages, + ) + if compacted is not None: + return compacted instruction_message = Message(role="user", content=self.instruction_text) llm_payload = messages_to_summarize + [instruction_message] - # generate summary try: response = await self.provider.text_chat(contexts=llm_payload) summary_content = response.completion_text @@ -223,8 +275,7 @@ async def __call__(self, messages: list[Message]) -> list[Message]: logger.error(f"Failed to generate summary: {e}") return messages - # build result - result = [] + result: list[Message] = [] result.extend(system_messages) result.append( diff --git a/astrbot/core/agent/context/config.py b/astrbot/core/agent/context/config.py index b8fd8eb96..4d861a43a 100644 --- a/astrbot/core/agent/context/config.py +++ b/astrbot/core/agent/context/config.py @@ -29,6 +29,8 @@ class ContextConfig: """Number of recent messages to keep during LLM-based compression.""" llm_compress_provider: "Provider | None" = None """LLM provider used for compression tasks. If None, truncation strategy is used.""" + llm_compress_use_compact_api: bool = True + """Whether to prefer provider native compact API when available.""" custom_token_counter: TokenCounter | None = None """Custom token counting method. If None, the default method is used.""" custom_compressor: ContextCompressor | None = None diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 216a3e7e1..2deecff82 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -35,6 +35,7 @@ def __init__( provider=config.llm_compress_provider, keep_recent=config.llm_compress_keep_recent, instruction_text=config.llm_compress_instruction, + use_compact_api=config.llm_compress_use_compact_api, ) else: self.compressor = TruncateByTurnsCompressor( @@ -55,7 +56,7 @@ async def process( try: result = messages - # 1. 基于轮次的截断 (Enforce max turns) + # 1. Enforce max turns if self.config.enforce_max_turns != -1: result = self.truncator.truncate_by_turns( result, @@ -63,7 +64,7 @@ async def process( drop_turns=self.config.truncate_turns, ) - # 2. 基于 token 的压缩 + # 2. Token-based compression if self.config.max_context_tokens > 0: total_tokens = self.token_counter.count_tokens( result, trusted_token_usage diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 8fb01bfcb..b974da75e 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -85,6 +85,7 @@ async def reset( llm_compress_instruction: str | None = None, llm_compress_keep_recent: int = 0, llm_compress_provider: Provider | None = None, + llm_compress_use_compact_api: bool = True, # truncate by turns compressor truncate_turns: int = 1, # customize @@ -99,6 +100,7 @@ async def reset( self.llm_compress_instruction = llm_compress_instruction self.llm_compress_keep_recent = llm_compress_keep_recent self.llm_compress_provider = llm_compress_provider + self.llm_compress_use_compact_api = llm_compress_use_compact_api self.truncate_turns = truncate_turns self.custom_token_counter = custom_token_counter self.custom_compressor = custom_compressor @@ -114,6 +116,7 @@ async def reset( llm_compress_instruction=self.llm_compress_instruction, llm_compress_keep_recent=self.llm_compress_keep_recent, llm_compress_provider=self.llm_compress_provider, + llm_compress_use_compact_api=self.llm_compress_use_compact_api, custom_token_counter=self.custom_token_counter, custom_compressor=self.custom_compressor, ) @@ -659,24 +662,24 @@ async def _handle_function_tools( ), ) - # yield the last tool call result - if tool_call_result_blocks: - last_tcr_content = str(tool_call_result_blocks[-1].content) - yield _HandleFunctionToolsResult.from_message_chain( - MessageChain( - type="tool_call_result", - chain=[ - Json( - data={ - "id": func_tool_id, - "ts": time.time(), - "result": last_tcr_content, - } - ) - ], + # yield the tool call result + if tool_call_result_blocks: + last_tcr_content = str(tool_call_result_blocks[-1].content) + yield _HandleFunctionToolsResult.from_message_chain( + MessageChain( + type="tool_call_result", + chain=[ + Json( + data={ + "id": func_tool_id, + "ts": time.time(), + "result": last_tcr_content, + } + ) + ], + ) ) - ) - logger.info(f"Tool `{func_tool_name}` Result: {last_tcr_content}") + logger.info(f"Tool `{func_tool_name}` Result: {last_tcr_content}") # 处理函数调用响应 if tool_call_result_blocks: diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 4e70f3d59..3018ae63c 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -91,6 +91,8 @@ class MainAgentBuildConfig: """The number of most recent turns to keep during llm_compress strategy.""" llm_compress_provider_id: str = "" """The provider ID for the LLM used in context compression.""" + llm_compress_use_compact_api: bool = True + """Whether to prefer provider native compact API when available.""" max_context_length: int = -1 """The maximum number of turns to keep in context. -1 means no limit. This enforce max turns before compression""" @@ -742,17 +744,22 @@ async def _handle_webchat( if not user_prompt or not chatui_session_id or not session or session.display_name: return - llm_resp = await prov.text_chat( - system_prompt=( - "You are a conversation title generator. " - "Generate a concise title in the same language as the user’s input, " - "no more than 10 words, capturing only the core topic." - "If the input is a greeting, small talk, or has no clear topic, " - "(e.g., “hi”, “hello”, “haha”), return . " - "Output only the title itself or , with no explanations." - ), - prompt=f"Generate a concise title for the following user query:\n{user_prompt}", - ) + try: + llm_resp = await prov.text_chat( + system_prompt=( + "You are a conversation title generator. " + "Generate a concise title in the same language as the user's input, " + "no more than 10 words, capturing only the core topic." + "If the input is a greeting, small talk, or has no clear topic, " + '(e.g., "hi", "hello", "haha"), return . ' + "Output only the title itself or , with no explanations." + ), + prompt=f"Generate a concise title for the following user query:\n{user_prompt}", + ) + except Exception as e: + logger.warning("Failed to generate chatui title: %s", e) + return + if llm_resp and llm_resp.completion_text: title = llm_resp.completion_text.strip() if not title or "" in title: @@ -807,26 +814,33 @@ def _proactive_cron_job_tools(req: ProviderRequest) -> None: def _get_compress_provider( - config: MainAgentBuildConfig, plugin_context: Context + config: MainAgentBuildConfig, + plugin_context: Context, + active_provider: Provider | None, ) -> Provider | None: - if not config.llm_compress_provider_id: - return None if config.context_limit_reached_strategy != "llm_compress": return None - provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id) - if provider is None: + + if not config.llm_compress_provider_id: + return None + + selected_provider = plugin_context.get_provider_by_id( + config.llm_compress_provider_id + ) + if selected_provider is None: logger.warning( - "未找到指定的上下文压缩模型 %s,将跳过压缩。", + "Configured llm_compress_provider_id not found: %s. Skip compression.", config.llm_compress_provider_id, ) return None - if not isinstance(provider, Provider): + if not isinstance(selected_provider, Provider): logger.warning( - "指定的上下文压缩模型 %s 不是对话模型,将跳过压缩。", + "Configured llm_compress_provider_id is not a Provider: %s. Skip compression.", config.llm_compress_provider_id, ) return None - return provider + + return selected_provider async def build_main_agent( @@ -970,7 +984,8 @@ async def build_main_agent( streaming=config.streaming_response, llm_compress_instruction=config.llm_compress_instruction, llm_compress_keep_recent=config.llm_compress_keep_recent, - llm_compress_provider=_get_compress_provider(config, plugin_context), + llm_compress_provider=_get_compress_provider(config, plugin_context, provider), + llm_compress_use_compact_api=config.llm_compress_use_compact_api, truncate_turns=config.dequeue_context_length, enforce_max_turns=config.max_context_length, tool_schema_mode=config.tool_schema_mode, diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 546768812..92257910f 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -94,6 +94,7 @@ ), "llm_compress_keep_recent": 6, "llm_compress_provider_id": "", + "llm_compress_use_compact_api": True, "max_context_length": -1, "dequeue_context_length": 1, "streaming_response": False, @@ -929,6 +930,19 @@ class ChatProviderTemplate(TypedDict): "proxy": "", "custom_headers": {}, }, + "OpenAI Responses": { + "id": "openai_responses", + "provider": "openai", + "type": "openai_responses", + "provider_type": "chat_completion", + "enable": True, + "key": [], + "api_base": "https://api.openai.com/v1", + "timeout": 120, + "proxy": "", + "custom_headers": {}, + "custom_extra_body": {}, + }, "Google Gemini": { "id": "google_gemini", "provider": "google", @@ -2828,6 +2842,15 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.llm_compress_use_compact_api": { + "description": "Prefer compact API when available", + "type": "bool", + "hint": "When enabled, local runner first tries provider native compact API and falls back to LLM summary compression.", + "condition": { + "provider_settings.context_limit_reached_strategy": "llm_compress", + "provider_settings.agent_runner_type": "local", + }, + }, }, "condition": { "provider_settings.agent_runner_type": "local", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index d26f67add..12f8437f6 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -79,6 +79,9 @@ async def initialize(self, ctx: PipelineContext) -> None: self.llm_compress_provider_id: str = settings.get( "llm_compress_provider_id", "" ) + self.llm_compress_use_compact_api: bool = settings.get( + "llm_compress_use_compact_api", True + ) self.max_context_length = settings["max_context_length"] # int self.dequeue_context_length: int = min( max(1, settings["dequeue_context_length"]), @@ -113,6 +116,7 @@ async def initialize(self, ctx: PipelineContext) -> None: llm_compress_instruction=self.llm_compress_instruction, llm_compress_keep_recent=self.llm_compress_keep_recent, llm_compress_provider_id=self.llm_compress_provider_id, + llm_compress_use_compact_api=self.llm_compress_use_compact_api, max_context_length=self.max_context_length, dequeue_context_length=self.dequeue_context_length, llm_safety_mode=self.llm_safety_mode, diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 20c5a7947..0ade94246 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -9,6 +9,7 @@ from anthropic.types import Message as AnthropicMessage from google.genai.types import GenerateContentResponse from openai.types.chat.chat_completion import ChatCompletion +from openai.types.responses.response import Response as OpenAIResponse import astrbot.core.message.components as Comp from astrbot import logger @@ -276,7 +277,11 @@ class LLMResponse: """The signature of the reasoning content, if any.""" raw_completion: ( - ChatCompletion | GenerateContentResponse | AnthropicMessage | None + ChatCompletion + | GenerateContentResponse + | AnthropicMessage + | OpenAIResponse + | None ) = None """The raw completion response from the LLM provider.""" @@ -305,6 +310,7 @@ def __init__( raw_completion: ChatCompletion | GenerateContentResponse | AnthropicMessage + | OpenAIResponse | None = None, is_chunk: bool = False, id: str | None = None, diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index ff0bb303d..4bbba07b2 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -291,6 +291,10 @@ def dynamic_import_provider(self, type: str) -> None: from .sources.openai_source import ( ProviderOpenAIOfficial as ProviderOpenAIOfficial, ) + case "openai_responses": + from .sources.openai_responses_source import ( + ProviderOpenAIResponses as ProviderOpenAIResponses, + ) case "zhipu_chat_completion": from .sources.zhipu_source import ProviderZhipu as ProviderZhipu case "groq_chat_completion": diff --git a/astrbot/core/provider/sources/openai_responses_source.py b/astrbot/core/provider/sources/openai_responses_source.py new file mode 100644 index 000000000..f928e3bd7 --- /dev/null +++ b/astrbot/core/provider/sources/openai_responses_source.py @@ -0,0 +1,771 @@ +import inspect +import json +from collections.abc import AsyncGenerator +from typing import Any + +from openai.types.responses.response import Response as OpenAIResponse +from openai.types.responses.response_usage import ResponseUsage + +import astrbot.core.message.components as Comp +from astrbot import logger +from astrbot.core.agent.message import ImageURLPart, Message, TextPart +from astrbot.core.agent.tool import ToolSet +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import LLMResponse, TokenUsage +from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure + +from ..register import register_provider_adapter +from .openai_source import ProviderOpenAIOfficial + + +@register_provider_adapter( + "openai_responses", + "OpenAI API Responses Provider Adapter", + default_config_tmpl={ + "id": "openai_responses", + "provider": "openai", + "type": "openai_responses", + "provider_type": "chat_completion", + "enable": True, + "key": [], + "api_base": "https://api.openai.com/v1", + "timeout": 120, + "proxy": "", + "custom_headers": {}, + "custom_extra_body": {}, + }, + provider_display_name="OpenAI Responses", +) +class ProviderOpenAIResponses(ProviderOpenAIOfficial): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.default_params = inspect.signature( + self.client.responses.create, + ).parameters.keys() + self.reasoning_key = "reasoning" + + def supports_native_compact(self) -> bool: + return True + + async def compact_context(self, messages: list[Message]) -> list[Message]: + if not messages: + return messages + + message_dicts = self._ensure_message_to_dicts(messages) + request_payload = { + "model": self.get_model(), + "input": self._messages_to_response_input(message_dicts), + } + + request_options: dict[str, Any] = {} + extra_headers = self._build_extra_headers() + if extra_headers: + request_options["extra_headers"] = extra_headers + + try: + compact_response = await self.client.responses.compact( + **request_payload, + **request_options, + ) + except Exception as e: + if is_connection_error(e): + proxy = self.provider_config.get("proxy", "") + log_connection_failure("OpenAI", e, proxy) + raise + + if hasattr(compact_response, "model_dump"): + compact_data = compact_response.model_dump(mode="json") + elif isinstance(compact_response, dict): + compact_data = compact_response + else: + compact_data = { + "output": getattr(compact_response, "output", []), + } + + compact_input = self._extract_compact_input(compact_data) + compact_messages = self._response_input_to_messages(compact_input) + if not compact_messages: + raise ValueError("Responses compact returned empty context.") + return compact_messages + + def _extract_compact_input(self, compact_data: Any) -> list[dict[str, Any]]: + if not isinstance(compact_data, dict): + raise ValueError("Invalid compact response payload.") + + candidate_keys = ( + "input", + "items", + "input_items", + "compacted_items", + "compacted_input", + ) + for key in candidate_keys: + value = compact_data.get(key) + if isinstance(value, list): + return [item for item in value if isinstance(item, dict)] + + response_obj = compact_data.get("response") + if isinstance(response_obj, dict): + for key in candidate_keys: + value = response_obj.get(key) + if isinstance(value, list): + return [item for item in value if isinstance(item, dict)] + + output = compact_data.get("output") + if isinstance(output, list): + return self._response_output_to_input_items(output) + + raise ValueError("Responses compact payload does not contain compacted items.") + + def _response_output_to_input_items( + self, output: list[Any] + ) -> list[dict[str, Any]]: + converted: list[dict[str, Any]] = [] + for item in output: + if not isinstance(item, dict): + continue + item_type = item.get("type") + if item_type == "message": + role = item.get("role", "assistant") + if role not in {"system", "developer", "user", "assistant"}: + role = "assistant" + content = item.get("content", []) + converted_content = self._output_content_to_input_content(content) + converted.append( + { + "type": "message", + "role": role, + "content": converted_content, + } + ) + elif item_type == "function_call": + converted.append( + { + "type": "function_call", + "call_id": item.get("call_id", item.get("id", "")), + "name": item.get("name", ""), + "arguments": item.get("arguments", "{}"), + } + ) + return converted + + def _output_content_to_input_content(self, content: Any) -> list[dict[str, Any]]: + converted: list[dict[str, Any]] = [] + if not isinstance(content, list): + return converted + for part in content: + if not isinstance(part, dict): + continue + part_type = part.get("type") + if part_type == "output_text": + text = part.get("text") + if text: + converted.append({"type": "input_text", "text": str(text)}) + elif part_type == "input_text": + text = part.get("text") + if text: + converted.append({"type": "input_text", "text": str(text)}) + return converted + + def _response_input_to_messages( + self, input_items: list[dict[str, Any]] + ) -> list[Message]: + messages: list[Message] = [] + for item in input_items: + item_type = item.get("type") + if item_type == "message": + role = item.get("role") + if role not in {"system", "developer", "user", "assistant"}: + continue + content = self._response_content_to_message_content(item.get("content")) + if content is None: + content = "" + messages.append(Message(role=role, content=content)) + elif item_type == "function_call": + call_id = item.get("call_id") or item.get("id") + name = item.get("name") + arguments = item.get("arguments", "{}") + if not call_id or not name: + continue + if not isinstance(arguments, str): + arguments = json.dumps(arguments, ensure_ascii=False) + messages.append( + Message( + role="assistant", + content="", + tool_calls=[ + { + "type": "function", + "id": str(call_id), + "function": { + "name": str(name), + "arguments": arguments, + }, + } + ], + ) + ) + elif item_type == "function_call_output": + call_id = item.get("call_id") + output = item.get("output", "") + if not call_id: + continue + messages.append( + Message( + role="tool", + tool_call_id=str(call_id), + content=str(output), + ) + ) + + return messages + + def _response_content_to_message_content(self, content: Any) -> str | list: + if isinstance(content, str): + return content + if not isinstance(content, list): + return "" + + parts: list[Any] = [] + plain_text: list[str] = [] + has_non_text = False + + for part in content: + if not isinstance(part, dict): + continue + part_type = part.get("type") + if part_type in {"input_text", "output_text", "text"}: + text = part.get("text") + if text is not None: + plain_text.append(str(text)) + parts.append(TextPart(text=str(text))) + elif part_type in {"input_image", "image_url"}: + image_url = None + if part_type == "input_image": + image_url = part.get("image_url") or part.get("file_url") + else: + image_data = part.get("image_url") + if isinstance(image_data, dict): + image_url = image_data.get("url") + elif isinstance(image_data, str): + image_url = image_data + if image_url: + has_non_text = True + parts.append( + ImageURLPart( + image_url=ImageURLPart.ImageURL(url=str(image_url)) + ) + ) + + if has_non_text: + return parts + return "\n".join(plain_text).strip() + + def _messages_to_response_input( + self, + messages: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + response_input: list[dict[str, Any]] = [] + + for message in messages: + role = message.get("role") + content = message.get("content") + tool_calls = message.get("tool_calls") + + if role in {"system", "developer", "user", "assistant"}: + converted_content = self._message_content_to_response_content(content) + message_item: dict[str, Any] = { + "type": "message", + "role": role, + } + if isinstance(converted_content, str): + message_item["content"] = converted_content + else: + message_item["content"] = converted_content + response_input.append(message_item) + + if role == "assistant" and isinstance(tool_calls, list): + for tool_call in tool_calls: + normalized = self._normalize_tool_call(tool_call) + if not normalized: + continue + response_input.append( + { + "type": "function_call", + "call_id": normalized["id"], + "name": normalized["name"], + "arguments": normalized["arguments"], + } + ) + + if role == "tool": + call_id = message.get("tool_call_id") + if call_id: + response_input.append( + { + "type": "function_call_output", + "call_id": str(call_id), + "output": self._extract_text_from_content(content), + } + ) + + return response_input + + def _normalize_tool_call(self, tool_call: Any) -> dict[str, str] | None: + if isinstance(tool_call, str): + try: + tool_call = json.loads(tool_call) + except Exception: + return None + if not isinstance(tool_call, dict): + return None + + tool_type = tool_call.get("type") + if tool_type != "function": + return None + + function_data = tool_call.get("function", {}) + if not isinstance(function_data, dict): + return None + + name = function_data.get("name") + call_id = tool_call.get("id") + arguments = function_data.get("arguments", "{}") + if not name or not call_id: + return None + + if not isinstance(arguments, str): + arguments = json.dumps(arguments, ensure_ascii=False) + + return { + "id": str(call_id), + "name": str(name), + "arguments": arguments, + } + + def _message_content_to_response_content( + self, content: Any + ) -> str | list[dict[str, Any]]: + if isinstance(content, str): + return content + if not isinstance(content, list): + return "" + + converted: list[dict[str, Any]] = [] + for part in content: + if not isinstance(part, dict): + continue + + part_type = part.get("type") + if part_type in {"text", "input_text", "output_text"}: + text = part.get("text") + if text is not None: + converted.append({"type": "input_text", "text": str(text)}) + elif part_type in {"image_url", "input_image"}: + image_part = self._normalize_image_part(part) + if image_part: + converted.append(image_part) + elif part_type == "input_file": + file_id = part.get("file_id") + file_url = part.get("file_url") + file_data = part.get("file_data") + input_file: dict[str, Any] = {"type": "input_file"} + if file_id: + input_file["file_id"] = file_id + elif file_url: + input_file["file_url"] = file_url + elif file_data: + input_file["file_data"] = file_data + filename = part.get("filename") + if filename: + input_file["filename"] = filename + if len(input_file) > 1: + converted.append(input_file) + + if not converted: + return "" + if len(converted) == 1 and converted[0].get("type") == "input_text": + return str(converted[0].get("text", "")) + return converted + + def _normalize_image_part(self, part: dict[str, Any]) -> dict[str, Any] | None: + if part.get("type") == "input_image": + image_url = part.get("image_url") + if image_url: + normalized = {"type": "input_image", "image_url": str(image_url)} + detail = part.get("detail") + if detail: + normalized["detail"] = detail + return normalized + file_id = part.get("file_id") + if file_id: + normalized = {"type": "input_image", "file_id": str(file_id)} + detail = part.get("detail") + if detail: + normalized["detail"] = detail + return normalized + return None + + image_data = part.get("image_url") + image_url = None + if isinstance(image_data, dict): + image_url = image_data.get("url") + elif isinstance(image_data, str): + image_url = image_data + + if not image_url: + return None + + normalized = {"type": "input_image", "image_url": str(image_url)} + detail = part.get("detail") + if detail: + normalized["detail"] = detail + return normalized + + def _extract_text_from_content(self, content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + texts: list[str] = [] + for part in content: + if isinstance(part, dict): + text = part.get("text") + if text is not None: + texts.append(str(text)) + return "\n".join(texts) + return str(content) if content is not None else "" + + def _build_responses_input_and_instructions( + self, + messages: list[dict[str, Any]], + ) -> tuple[list[dict[str, Any]], str | None]: + response_input = self._messages_to_response_input(messages) + filtered_input: list[dict[str, Any]] = [] + instruction_chunks: list[str] = [] + + for item in response_input: + if item.get("type") == "message" and item.get("role") in { + "system", + "developer", + }: + instruction_text = self._extract_text_from_content(item.get("content")) + if instruction_text: + instruction_chunks.append(instruction_text) + continue + filtered_input.append(item) + + instructions = "\n\n".join(instruction_chunks).strip() + return filtered_input, instructions or None + + def _build_extra_headers(self) -> dict[str, str]: + headers: dict[str, str] = {} + if isinstance(self.custom_headers, dict): + for key, value in self.custom_headers.items(): + if str(key).lower() == "authorization": + continue + headers[str(key)] = str(value) + return headers + + def _resolve_tool_strict( + self, + tool: dict[str, Any], + function_body: dict[str, Any] | None, + ) -> bool | None: + if isinstance(function_body, dict) and isinstance( + function_body.get("strict"), bool + ): + return function_body["strict"] + if isinstance(tool.get("strict"), bool): + return tool["strict"] + default_strict = self.provider_config.get("responses_tool_strict") + if isinstance(default_strict, bool): + return default_strict + return None + + def _convert_tools_to_responses( + self, tool_list: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + response_tools: list[dict[str, Any]] = [] + for tool in tool_list: + if not isinstance(tool, dict): + continue + if tool.get("type") != "function": + continue + + function_body = tool.get("function") + if isinstance(function_body, dict): + name = function_body.get("name") + if not name: + continue + response_tool = { + "type": "function", + "name": name, + "description": function_body.get("description", ""), + "parameters": function_body.get("parameters", {}), + } + strict = self._resolve_tool_strict(tool, function_body) + if strict is not None: + response_tool["strict"] = strict + response_tools.append(response_tool) + continue + + name = tool.get("name") + if not name: + continue + response_tool = { + "type": "function", + "name": name, + "description": tool.get("description", ""), + "parameters": tool.get("parameters", {}), + } + strict = self._resolve_tool_strict(tool, None) + if strict is not None: + response_tool["strict"] = strict + response_tools.append(response_tool) + + return response_tools + + def _extract_response_usage(self, usage: ResponseUsage | None) -> TokenUsage | None: + if usage is None: + return None + cached = 0 + if usage.input_tokens_details: + cached = usage.input_tokens_details.cached_tokens or 0 + input_other = max(0, (usage.input_tokens or 0) - cached) + output = usage.output_tokens or 0 + return TokenUsage(input_other=input_other, input_cached=cached, output=output) + + async def _parse_openai_response( + self, + response: OpenAIResponse, + tools: ToolSet | None, + ) -> LLMResponse: + llm_response = LLMResponse("assistant") + + if response.error: + raise Exception(f"Responses API error: {response.error.message}") + + completion_text = response.output_text.strip() + if completion_text: + llm_response.result_chain = MessageChain().message(completion_text) + + reasoning_segments: list[str] = [] + for output_item in response.output: + output_type = getattr(output_item, "type", "") + if output_type == "reasoning": + summary = getattr(output_item, "summary", []) + for summary_part in summary: + text = getattr(summary_part, "text", "") + if text: + reasoning_segments.append(str(text)) + if output_type == "function_call" and tools is not None: + arguments = getattr(output_item, "arguments", "{}") + function_name = getattr(output_item, "name", "") + call_id = getattr(output_item, "call_id", "") + parsed_arguments: dict[str, Any] + try: + parsed_arguments = json.loads(arguments) if arguments else {} + except Exception: + parsed_arguments = {} + llm_response.tools_call_args.append(parsed_arguments) + llm_response.tools_call_name.append(str(function_name)) + llm_response.tools_call_ids.append(str(call_id)) + + if reasoning_segments: + llm_response.reasoning_content = "\n".join(reasoning_segments) + + if not llm_response.completion_text and not llm_response.tools_call_args: + raise Exception(f"Responses API returned empty output: {response}") + + llm_response.raw_completion = response + llm_response.id = response.id + llm_response.usage = self._extract_response_usage(response.usage) + return llm_response + + async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: + request_payload = dict(payloads) + response_input, instructions = self._build_responses_input_and_instructions( + request_payload.pop("messages", []) + ) + request_payload["input"] = response_input + if instructions and not request_payload.get("instructions"): + request_payload["instructions"] = instructions + + if tools: + model = request_payload.get("model", "").lower() + omit_empty_param_field = "gemini" in model + tool_list = tools.get_func_desc_openai_style( + omit_empty_parameter_field=omit_empty_param_field, + ) + response_tools = self._convert_tools_to_responses(tool_list) + if response_tools: + request_payload["tools"] = response_tools + + extra_body: dict[str, Any] = {} + for key in list(request_payload.keys()): + if key not in self.default_params: + extra_body[key] = request_payload.pop(key) + + custom_extra_body = self.provider_config.get("custom_extra_body", {}) + if isinstance(custom_extra_body, dict): + extra_body.update(custom_extra_body) + + extra_headers = self._build_extra_headers() + completion = await self.client.responses.create( + **request_payload, + stream=False, + extra_body=extra_body, + extra_headers=extra_headers, + ) + + if not isinstance(completion, OpenAIResponse): + raise TypeError(f"Unexpected response object: {type(completion)}") + + logger.debug(f"response: {completion}") + return await self._parse_openai_response(completion, tools) + + async def _query_stream( + self, + payloads: dict, + tools: ToolSet | None, + ) -> AsyncGenerator[LLMResponse, None]: + request_payload = dict(payloads) + response_input, instructions = self._build_responses_input_and_instructions( + request_payload.pop("messages", []) + ) + request_payload["input"] = response_input + if instructions and not request_payload.get("instructions"): + request_payload["instructions"] = instructions + + if tools: + model = request_payload.get("model", "").lower() + omit_empty_param_field = "gemini" in model + tool_list = tools.get_func_desc_openai_style( + omit_empty_parameter_field=omit_empty_param_field, + ) + response_tools = self._convert_tools_to_responses(tool_list) + if response_tools: + request_payload["tools"] = response_tools + + extra_body: dict[str, Any] = {} + for key in list(request_payload.keys()): + if key not in self.default_params: + extra_body[key] = request_payload.pop(key) + + custom_extra_body = self.provider_config.get("custom_extra_body", {}) + if isinstance(custom_extra_body, dict): + extra_body.update(custom_extra_body) + + response_id: str | None = None + extra_headers = self._build_extra_headers() + try: + async with self.client.responses.stream( + **request_payload, + extra_body=extra_body, + extra_headers=extra_headers, + ) as stream: + async for event in stream: + event_type = getattr(event, "type", "") + if event_type == "response.created": + response_obj = getattr(event, "response", None) + if response_obj: + response_id = getattr(response_obj, "id", None) + continue + + if event_type == "response.output_text.delta": + delta = getattr(event, "delta", "") + if delta: + yield LLMResponse( + role="assistant", + result_chain=MessageChain( + chain=[Comp.Plain(str(delta))] + ), + is_chunk=True, + id=response_id, + ) + continue + + if event_type == "response.reasoning_summary_text.delta": + delta = getattr(event, "delta", "") + if delta: + yield LLMResponse( + role="assistant", + reasoning_content=str(delta), + is_chunk=True, + id=response_id, + ) + continue + + if event_type == "error": + raise Exception( + f"Responses stream error: {getattr(event, 'code', 'unknown')} {getattr(event, 'message', '')}" + ) + + if event_type == "response.failed": + response_obj = getattr(event, "response", None) + error_obj = ( + getattr(response_obj, "error", None) + if response_obj + else None + ) + if error_obj is not None: + raise Exception( + f"Responses stream failed: {getattr(error_obj, 'code', 'unknown')} {getattr(error_obj, 'message', '')}" + ) + raise Exception("Responses stream failed.") + + final_response = await stream.get_final_response() + except Exception as e: + if self._is_retryable_upstream_error(e) or is_connection_error(e): + logger.warning( + "Responses stream failed, fallback to non-stream create: %s", + e, + ) + yield await self._query(payloads, tools) + return + raise + + yield await self._parse_openai_response(final_response, tools) + + def _is_retryable_upstream_error(self, e: Exception) -> bool: + status_code = getattr(e, "status_code", None) + if status_code in {500, 502, 503, 504}: + return True + + message = str(e).lower() + if "upstream request failed" in message: + return True + + body = getattr(e, "body", None) + if isinstance(body, dict): + error_obj = body.get("error", {}) + if isinstance(error_obj, dict): + if str(error_obj.get("type", "")).lower() == "upstream_error": + return True + + return False + + async def _handle_api_error( + self, + e: Exception, + payloads: dict, + context_query: list, + func_tool: ToolSet | None, + chosen_key: str, + available_api_keys: list[str], + retry_cnt: int, + max_retries: int, + ) -> tuple: + if is_connection_error(e): + proxy = self.provider_config.get("proxy", "") + log_connection_failure("OpenAI", e, proxy) + return await super()._handle_api_error( + e, + payloads, + context_query, + func_tool, + chosen_key, + available_api_keys, + retry_cnt, + max_retries, + ) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index a7c0e3a57..3ea9772b1 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -375,6 +375,7 @@ async def stream(): try: async with track_conversation(self.running_convs, webchat_conv_id): while True: + result = None try: result = await asyncio.wait_for(back_queue.get(), timeout=1) except asyncio.TimeoutError: @@ -382,8 +383,10 @@ async def stream(): except asyncio.CancelledError: logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") client_disconnected = True + break except Exception as e: logger.error(f"WebChat stream error: {e}") + continue if not result: continue @@ -400,6 +403,15 @@ async def stream(): streaming = result.get("streaming", False) chain_type = result.get("chain_type") + if ( + enable_streaming + and msg_type == "plain" + and chain_type in {"tool_call", "tool_call_result"} + and not streaming + ): + result["streaming"] = True + streaming = True + if chain_type == "agent_stats": stats_info = { "type": "agent_stats", diff --git a/dashboard/src/composables/useMessages.ts b/dashboard/src/composables/useMessages.ts index 620dd6a15..c413cb5cf 100644 --- a/dashboard/src/composables/useMessages.ts +++ b/dashboard/src/composables/useMessages.ts @@ -479,19 +479,38 @@ export function useMessages( } } } else if (chain_type === 'tool_call_result') { - // 解析工具调用结果数据 + // Parse tool call result payload const resultData = JSON.parse(chunk_json.data); - if (message_obj) { - // 遍历所有 tool_call parts 找到对应的 tool_call - for (const part of message_obj.message) { - if (part.type === 'tool_call' && part.tool_calls) { - const toolCall = part.tool_calls.find((tc: ToolCall) => tc.id === resultData.id); - if (toolCall) { - toolCall.result = resultData.result; - toolCall.finished_ts = resultData.ts; - break; - } + const updateToolCallInContent = (content: MessageContent | null | undefined): boolean => { + if (!content || !Array.isArray(content.message)) { + return false; + } + for (const part of content.message) { + if (part.type !== 'tool_call' || !part.tool_calls) { + continue; + } + const toolCall = part.tool_calls.find((tc: ToolCall) => tc.id === resultData.id); + if (!toolCall) { + continue; + } + toolCall.result = resultData.result; + toolCall.finished_ts = resultData.ts; + return true; + } + return false; + }; + + let updated = updateToolCallInContent(message_obj); + if (!updated) { + for (let i = messages.value.length - 1; i >= 0; i--) { + const message = messages.value[i]?.content; + if (message?.type !== 'bot') { + continue; + } + if (updateToolCallInContent(message)) { + updated = true; + break; } } } @@ -548,7 +567,14 @@ export function useMessages( } } - if ((chunk_json.type === 'break' && chunk_json.streaming) || !chunk_json.streaming) { + const isToolCallEvent = + chunk_json.type === 'plain' && + (chunk_json.chain_type === 'tool_call' || chunk_json.chain_type === 'tool_call_result'); + + if ( + ((chunk_json.type === 'break' && chunk_json.streaming) || !chunk_json.streaming) && + !isToolCallEvent + ) { in_streaming = false; if (!chunk_json.streaming) { isStreaming.value = false; diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 2166d5391..d79149ef2 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -211,6 +211,10 @@ "llm_compress_provider_id": { "description": "Model Provider ID for Context Compression", "hint": "When left empty, will fall back to the 'Truncate by Turns' strategy." + }, + "llm_compress_use_compact_api": { + "description": "Prefer compact API when available", + "hint": "When enabled, local runner first tries provider native compact API and falls back to LLM summary compression." } } }, diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 2d1c11cda..a4b996eb9 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -214,6 +214,10 @@ "llm_compress_provider_id": { "description": "用于上下文压缩的模型提供商 ID", "hint": "留空时将降级为\"按对话轮数截断\"的策略。" + }, + "llm_compress_use_compact_api": { + "description": "\u4f18\u5148\u4f7f\u7528 compact \u63a5\u53e3", + "hint": "\u542f\u7528\u540e\u4f1a\u4f18\u5148\u8c03\u7528\u5f53\u524d\u4f1a\u8bdd\u4f9b\u5e94\u5546\u7684 native compact \u63a5\u53e3\uff1b\u82e5\u4e0d\u652f\u6301\u5219\u56de\u9000\u5230\u5e38\u89c4 LLM \u538b\u7f29\u3002" } } }, diff --git a/tests/agent/test_context_manager.py b/tests/agent/test_context_manager.py index 0b955ff40..600090a2d 100644 --- a/tests/agent/test_context_manager.py +++ b/tests/agent/test_context_manager.py @@ -344,6 +344,82 @@ async def test_sequential_processing_order(self): # Truncator should be called first mock_truncate.assert_called_once() + @pytest.mark.asyncio + async def test_native_compact_is_preferred_when_available(self): + """Native compact should be used before summary generation when available.""" + + class NativeCompactProvider(MockProvider): + def __init__(self): + super().__init__() + self.compact_context = AsyncMock( + return_value=[Message(role="user", content="compacted")] + ) + self.text_chat = AsyncMock() + + def supports_native_compact(self): + return True + + provider = NativeCompactProvider() + config = ContextConfig( + max_context_tokens=10, + llm_compress_provider=provider, # type: ignore[arg-type] + llm_compress_keep_recent=2, + ) + manager = ContextManager(config) + + messages = self.create_messages(8) + + with patch.object( + manager.compressor, "should_compress", side_effect=[True, False] + ): + result = await manager.process(messages) + + provider.compact_context.assert_awaited_once() + provider.text_chat.assert_not_called() + assert any(msg.content == "compacted" for msg in result) + + @pytest.mark.asyncio + async def test_native_compact_fallback_to_summary_on_failure(self): + """Fallback to summary compression when native compact fails.""" + + class FailingCompactProvider(MockProvider): + def __init__(self): + super().__init__() + self.compact_context = AsyncMock( + side_effect=RuntimeError("compact failed") + ) + self.text_chat = AsyncMock( + return_value=LLMResponse( + role="assistant", + completion_text="summary fallback", + ) + ) + + def supports_native_compact(self): + return True + + provider = FailingCompactProvider() + config = ContextConfig( + max_context_tokens=10, + llm_compress_provider=provider, # type: ignore[arg-type] + llm_compress_keep_recent=2, + ) + manager = ContextManager(config) + + messages = self.create_messages(8) + + with patch.object( + manager.compressor, "should_compress", side_effect=[True, False] + ): + result = await manager.process(messages) + + provider.compact_context.assert_awaited_once() + provider.text_chat.assert_awaited_once() + assert any( + msg.content == "Our previous history conversation summary: summary fallback" + for msg in result + ) + # ==================== Error Handling Tests ==================== @pytest.mark.asyncio diff --git a/tests/test_openai_responses_source.py b/tests/test_openai_responses_source.py new file mode 100644 index 000000000..192425305 --- /dev/null +++ b/tests/test_openai_responses_source.py @@ -0,0 +1,220 @@ +import os +import sys +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +# Add project root to sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from astrbot.core.agent.message import Message +from astrbot.core.provider.sources.openai_responses_source import ( + ProviderOpenAIResponses, +) + + +class _DummyError(Exception): + def __init__(self, message: str, status_code=None, body=None): + super().__init__(message) + self.status_code = status_code + self.body = body + + +def _provider() -> ProviderOpenAIResponses: + return ProviderOpenAIResponses.__new__(ProviderOpenAIResponses) + + +def test_is_retryable_upstream_error_with_5xx_status(): + err = _DummyError("server error", status_code=502) + + assert _provider()._is_retryable_upstream_error(err) + + +def test_is_retryable_upstream_error_with_upstream_error_type(): + err = _DummyError( + "bad gateway", + status_code=400, + body={"error": {"type": "upstream_error"}}, + ) + + assert _provider()._is_retryable_upstream_error(err) + + +def test_is_retryable_upstream_error_returns_false_for_non_retryable_error(): + err = _DummyError( + "invalid request", + status_code=400, + body={"error": {"type": "invalid_request_error"}}, + ) + + assert not _provider()._is_retryable_upstream_error(err) + + +def test_build_responses_input_and_instructions_moves_system_messages(): + provider = _provider() + provider.custom_headers = {} + + response_input, instructions = provider._build_responses_input_and_instructions( + [ + {"role": "system", "content": "sys text"}, + {"role": "developer", "content": [{"type": "text", "text": "dev text"}]}, + {"role": "user", "content": "hello"}, + ] + ) + + assert instructions == "sys text\n\ndev text" + assert all( + item.get("role") not in {"system", "developer"} for item in response_input + ) + assert any(item.get("role") == "user" for item in response_input) + + +def test_build_extra_headers_keeps_custom_headers_and_ignores_authorization(): + provider = _provider() + provider.custom_headers = { + "X-Test": "value", + "Authorization": "Bearer should-not-pass", + } + + headers = provider._build_extra_headers() + + assert "User-Agent" not in headers + assert headers["X-Test"] == "value" + assert "Authorization" not in headers + + +@pytest.mark.asyncio +async def test_compact_context_uses_sdk_compact_api(): + provider = _provider() + provider.provider_config = {"proxy": ""} + provider.get_model = lambda: "gpt-5.3-codex" + provider._ensure_message_to_dicts = lambda messages: [ + {"role": "user", "content": "hello"} + ] + provider._messages_to_response_input = lambda _: [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "hello"}], + } + ] + provider._build_extra_headers = lambda: {"X-Test": "1"} + + compact_mock = AsyncMock( + return_value=SimpleNamespace( + model_dump=lambda mode="json": { + "output": [ + { + "type": "message", + "role": "user", + "content": [{"type": "output_text", "text": "compacted"}], + } + ] + } + ) + ) + provider.client = SimpleNamespace(responses=SimpleNamespace(compact=compact_mock)) + + result = await provider.compact_context([Message(role="user", content="hello")]) + + compact_mock.assert_awaited_once_with( + model="gpt-5.3-codex", + input=[ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "hello"}], + } + ], + extra_headers={"X-Test": "1"}, + ) + assert result + + +@pytest.mark.asyncio +async def test_compact_context_raises_when_compact_returns_empty_messages(): + provider = _provider() + provider.provider_config = {"proxy": ""} + provider.get_model = lambda: "gpt-5.3-codex" + provider._ensure_message_to_dicts = lambda messages: [ + {"role": "user", "content": "hello"} + ] + provider._messages_to_response_input = lambda _: [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "hello"}], + } + ] + provider._build_extra_headers = lambda: {} + + compact_mock = AsyncMock( + return_value=SimpleNamespace(model_dump=lambda mode="json": {"output": []}) + ) + provider.client = SimpleNamespace(responses=SimpleNamespace(compact=compact_mock)) + + with pytest.raises(ValueError, match="empty context"): + await provider.compact_context([Message(role="user", content="hello")]) + + +def test_convert_tools_to_responses_does_not_force_strict_false(): + provider = _provider() + provider.provider_config = {} + + response_tools = provider._convert_tools_to_responses( + [ + { + "type": "function", + "function": { + "name": "demo", + "description": "desc", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + ) + + assert response_tools + assert "strict" not in response_tools[0] + + +def test_convert_tools_to_responses_keeps_explicit_strict_setting(): + provider = _provider() + provider.provider_config = {} + + response_tools = provider._convert_tools_to_responses( + [ + { + "type": "function", + "function": { + "name": "demo", + "description": "desc", + "parameters": {"type": "object", "properties": {}}, + "strict": True, + }, + } + ] + ) + + assert response_tools[0]["strict"] is True + + +def test_convert_tools_to_responses_supports_provider_default_strict(): + provider = _provider() + provider.provider_config = {"responses_tool_strict": True} + + response_tools = provider._convert_tools_to_responses( + [ + { + "type": "function", + "function": { + "name": "demo", + "description": "desc", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + ) + + assert response_tools[0]["strict"] is True