From b8a4c1c6df4bf8697eb0e7994bde8173c379ea16 Mon Sep 17 00:00:00 2001 From: xunxi Date: Wed, 11 Feb 2026 23:57:02 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E5=AE=9E=E7=8E=B0respons?= =?UTF-8?q?es=E6=8E=A5=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/agent/context/compressor.py | 56 +- astrbot/core/agent/context/manager.py | 4 +- astrbot/core/astr_main_agent.py | 56 +- astrbot/core/config/default.py | 23 + .../method/agent_sub_stages/internal.py | 4 + astrbot/core/provider/manager.py | 4 + .../sources/openai_responses_source.py | 750 ++++++++++++++++++ astrbot/dashboard/routes/chat.py | 7 +- astrbot/dashboard/routes/config.py | 148 +++- .../en-US/features/config-metadata.json | 4 + .../zh-CN/features/config-metadata.json | 4 + tests/agent/test_context_manager.py | 76 ++ tests/test_openai_responses_source.py | 79 ++ tests/test_provider_config_sanitization.py | 86 ++ 14 files changed, 1268 insertions(+), 33 deletions(-) create mode 100644 astrbot/core/provider/sources/openai_responses_source.py create mode 100644 tests/test_openai_responses_source.py create mode 100644 tests/test_provider_config_sanitization.py diff --git a/astrbot/core/agent/context/compressor.py b/astrbot/core/agent/context/compressor.py index 31a0b0b48..6686b5279 100644 --- a/astrbot/core/agent/context/compressor.py +++ b/astrbot/core/agent/context/compressor.py @@ -193,13 +193,49 @@ def should_compress( usage_rate = current_tokens / max_tokens return usage_rate > self.compression_threshold + def _supports_native_compact(self) -> bool: + support_native_compact = getattr(self.provider, "supports_native_compact", None) + if not callable(support_native_compact): + return False + try: + return bool(support_native_compact()) + except Exception: + return False + + async def _try_native_compact( + self, + system_messages: list[Message], + messages_to_summarize: list[Message], + recent_messages: list[Message], + ) -> list[Message] | None: + compact_context = getattr(self.provider, "compact_context", None) + if not callable(compact_context): + return None + + try: + compacted_messages = await compact_context(messages_to_summarize) + except Exception as e: + logger.warning( + f"Native compact failed, fallback to summary compression: {e}" + ) + return None + + if not compacted_messages: + return None + + result: list[Message] = [] + result.extend(system_messages) + result.extend(compacted_messages) + result.extend(recent_messages) + return result + async def __call__(self, messages: list[Message]) -> list[Message]: """Use LLM to generate a summary of the conversation history. Process: 1. Divide messages: keep the system message and the latest N messages. - 2. Send the old messages + the instruction message to the LLM. - 3. Reconstruct the message list: [system message, summary message, latest messages]. + 2. Prefer native compact when provider supports it. + 3. Fallback to LLM summary and reconstruct message list. """ if len(messages) <= self.keep_recent + 1: return messages @@ -207,15 +243,22 @@ async def __call__(self, messages: list[Message]) -> list[Message]: system_messages, messages_to_summarize, recent_messages = split_history( messages, self.keep_recent ) - if not messages_to_summarize: return messages - # build payload + native_compact_supported = self._supports_native_compact() + + if native_compact_supported: + compacted = await self._try_native_compact( + system_messages, + messages_to_summarize, + recent_messages, + ) + if compacted is not None: + return compacted instruction_message = Message(role="user", content=self.instruction_text) llm_payload = messages_to_summarize + [instruction_message] - # generate summary try: response = await self.provider.text_chat(contexts=llm_payload) summary_content = response.completion_text @@ -223,8 +266,7 @@ async def __call__(self, messages: list[Message]) -> list[Message]: logger.error(f"Failed to generate summary: {e}") return messages - # build result - result = [] + result: list[Message] = [] result.extend(system_messages) result.append( diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 216a3e7e1..6f1c5abdb 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -55,7 +55,7 @@ async def process( try: result = messages - # 1. 基于轮次的截断 (Enforce max turns) + # 1. ???????(Enforce max turns) if self.config.enforce_max_turns != -1: result = self.truncator.truncate_by_turns( result, @@ -63,7 +63,7 @@ async def process( drop_turns=self.config.truncate_turns, ) - # 2. 基于 token 的压缩 + # 2. ?? token ??? if self.config.max_context_tokens > 0: total_tokens = self.token_counter.count_tokens( result, trusted_token_usage diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 4e70f3d59..9b4efb79d 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -91,6 +91,8 @@ class MainAgentBuildConfig: """The number of most recent turns to keep during llm_compress strategy.""" llm_compress_provider_id: str = "" """The provider ID for the LLM used in context compression.""" + llm_compress_use_compact_api: bool = True + """Whether to prefer provider native compact API when available.""" max_context_length: int = -1 """The maximum number of turns to keep in context. -1 means no limit. This enforce max turns before compression""" @@ -742,17 +744,22 @@ async def _handle_webchat( if not user_prompt or not chatui_session_id or not session or session.display_name: return - llm_resp = await prov.text_chat( - system_prompt=( - "You are a conversation title generator. " - "Generate a concise title in the same language as the user’s input, " - "no more than 10 words, capturing only the core topic." - "If the input is a greeting, small talk, or has no clear topic, " - "(e.g., “hi”, “hello”, “haha”), return . " - "Output only the title itself or , with no explanations." - ), - prompt=f"Generate a concise title for the following user query:\n{user_prompt}", - ) + try: + llm_resp = await prov.text_chat( + system_prompt=( + "You are a conversation title generator. " + "Generate a concise title in the same language as the user's input, " + "no more than 10 words, capturing only the core topic." + "If the input is a greeting, small talk, or has no clear topic, " + '(e.g., "hi", "hello", "haha"), return . ' + "Output only the title itself or , with no explanations." + ), + prompt=f"Generate a concise title for the following user query:\n{user_prompt}", + ) + except Exception as e: + logger.warning("Failed to generate chatui title: %s", e) + return + if llm_resp and llm_resp.completion_text: title = llm_resp.completion_text.strip() if not title or "" in title: @@ -807,26 +814,33 @@ def _proactive_cron_job_tools(req: ProviderRequest) -> None: def _get_compress_provider( - config: MainAgentBuildConfig, plugin_context: Context + config: MainAgentBuildConfig, + plugin_context: Context, + active_provider: Provider | None, ) -> Provider | None: - if not config.llm_compress_provider_id: - return None if config.context_limit_reached_strategy != "llm_compress": return None - provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id) - if provider is None: + + _ = active_provider + + if not config.llm_compress_provider_id: + return None + + selected_provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id) + if selected_provider is None: logger.warning( - "未找到指定的上下文压缩模型 %s,将跳过压缩。", + "Configured llm_compress_provider_id not found: %s. Skip compression.", config.llm_compress_provider_id, ) return None - if not isinstance(provider, Provider): + if not isinstance(selected_provider, Provider): logger.warning( - "指定的上下文压缩模型 %s 不是对话模型,将跳过压缩。", + "Configured llm_compress_provider_id is not a Provider: %s. Skip compression.", config.llm_compress_provider_id, ) return None - return provider + + return selected_provider async def build_main_agent( @@ -970,7 +984,7 @@ async def build_main_agent( streaming=config.streaming_response, llm_compress_instruction=config.llm_compress_instruction, llm_compress_keep_recent=config.llm_compress_keep_recent, - llm_compress_provider=_get_compress_provider(config, plugin_context), + llm_compress_provider=_get_compress_provider(config, plugin_context, provider), truncate_turns=config.dequeue_context_length, enforce_max_turns=config.max_context_length, tool_schema_mode=config.tool_schema_mode, diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 546768812..92257910f 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -94,6 +94,7 @@ ), "llm_compress_keep_recent": 6, "llm_compress_provider_id": "", + "llm_compress_use_compact_api": True, "max_context_length": -1, "dequeue_context_length": 1, "streaming_response": False, @@ -929,6 +930,19 @@ class ChatProviderTemplate(TypedDict): "proxy": "", "custom_headers": {}, }, + "OpenAI Responses": { + "id": "openai_responses", + "provider": "openai", + "type": "openai_responses", + "provider_type": "chat_completion", + "enable": True, + "key": [], + "api_base": "https://api.openai.com/v1", + "timeout": 120, + "proxy": "", + "custom_headers": {}, + "custom_extra_body": {}, + }, "Google Gemini": { "id": "google_gemini", "provider": "google", @@ -2828,6 +2842,15 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.llm_compress_use_compact_api": { + "description": "Prefer compact API when available", + "type": "bool", + "hint": "When enabled, local runner first tries provider native compact API and falls back to LLM summary compression.", + "condition": { + "provider_settings.context_limit_reached_strategy": "llm_compress", + "provider_settings.agent_runner_type": "local", + }, + }, }, "condition": { "provider_settings.agent_runner_type": "local", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index d26f67add..12f8437f6 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -79,6 +79,9 @@ async def initialize(self, ctx: PipelineContext) -> None: self.llm_compress_provider_id: str = settings.get( "llm_compress_provider_id", "" ) + self.llm_compress_use_compact_api: bool = settings.get( + "llm_compress_use_compact_api", True + ) self.max_context_length = settings["max_context_length"] # int self.dequeue_context_length: int = min( max(1, settings["dequeue_context_length"]), @@ -113,6 +116,7 @@ async def initialize(self, ctx: PipelineContext) -> None: llm_compress_instruction=self.llm_compress_instruction, llm_compress_keep_recent=self.llm_compress_keep_recent, llm_compress_provider_id=self.llm_compress_provider_id, + llm_compress_use_compact_api=self.llm_compress_use_compact_api, max_context_length=self.max_context_length, dequeue_context_length=self.dequeue_context_length, llm_safety_mode=self.llm_safety_mode, diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index ff0bb303d..4bbba07b2 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -291,6 +291,10 @@ def dynamic_import_provider(self, type: str) -> None: from .sources.openai_source import ( ProviderOpenAIOfficial as ProviderOpenAIOfficial, ) + case "openai_responses": + from .sources.openai_responses_source import ( + ProviderOpenAIResponses as ProviderOpenAIResponses, + ) case "zhipu_chat_completion": from .sources.zhipu_source import ProviderZhipu as ProviderZhipu case "groq_chat_completion": diff --git a/astrbot/core/provider/sources/openai_responses_source.py b/astrbot/core/provider/sources/openai_responses_source.py new file mode 100644 index 000000000..efa2a0138 --- /dev/null +++ b/astrbot/core/provider/sources/openai_responses_source.py @@ -0,0 +1,750 @@ +import inspect +import json +from collections.abc import AsyncGenerator +from typing import Any + +import httpx +from openai.types.responses.response import Response as OpenAIResponse +from openai.types.responses.response_usage import ResponseUsage + +import astrbot.core.message.components as Comp +from astrbot import logger +from astrbot.core.agent.message import ImageURLPart, Message, TextPart +from astrbot.core.agent.tool import ToolSet +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import LLMResponse, TokenUsage +from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure + +from ..register import register_provider_adapter +from .openai_source import ProviderOpenAIOfficial + + +@register_provider_adapter( + "openai_responses", + "OpenAI API Responses Provider Adapter", + default_config_tmpl={ + "id": "openai_responses", + "provider": "openai", + "type": "openai_responses", + "provider_type": "chat_completion", + "enable": True, + "key": [], + "api_base": "https://api.openai.com/v1", + "timeout": 120, + "proxy": "", + "custom_headers": {}, + "custom_extra_body": {}, + }, + provider_display_name="OpenAI Responses", +) +class ProviderOpenAIResponses(ProviderOpenAIOfficial): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.default_params = inspect.signature( + self.client.responses.create, + ).parameters.keys() + self.reasoning_key = "reasoning" + + def supports_native_compact(self) -> bool: + return True + + async def compact_context(self, messages: list[Message]) -> list[Message]: + if not messages: + return messages + + message_dicts = self._ensure_message_to_dicts(messages) + payload = { + "model": self.get_model(), + "input": self._messages_to_response_input(message_dicts), + } + + compact_url = self._build_compact_url() + headers = { + "Authorization": f"Bearer {self.get_current_key()}", + "Content-Type": "application/json", + } + headers.update(self._build_extra_headers()) + + proxy = self.provider_config.get("proxy", "") + client_kwargs: dict[str, Any] = {"timeout": self.timeout} + if proxy: + client_kwargs["proxy"] = proxy + + try: + async with httpx.AsyncClient(**client_kwargs) as client: + response = await client.post(compact_url, json=payload, headers=headers) + response.raise_for_status() + compact_data = response.json() + except Exception as e: + if is_connection_error(e): + log_connection_failure("OpenAI", e, proxy) + raise + + compact_input = self._extract_compact_input(compact_data) + compact_messages = self._response_input_to_messages(compact_input) + if not compact_messages: + raise ValueError("Responses compact returned empty context.") + return compact_messages + + def _build_compact_url(self) -> str: + base_url = str(self.client.base_url).rstrip("/") + if base_url.endswith("/v1"): + return f"{base_url}/responses/compact" + return f"{base_url}/v1/responses/compact" + + def _extract_compact_input(self, compact_data: Any) -> list[dict[str, Any]]: + if not isinstance(compact_data, dict): + raise ValueError("Invalid compact response payload.") + + candidate_keys = ( + "input", + "items", + "input_items", + "compacted_items", + "compacted_input", + ) + for key in candidate_keys: + value = compact_data.get(key) + if isinstance(value, list): + return [item for item in value if isinstance(item, dict)] + + response_obj = compact_data.get("response") + if isinstance(response_obj, dict): + for key in candidate_keys: + value = response_obj.get(key) + if isinstance(value, list): + return [item for item in value if isinstance(item, dict)] + + output = compact_data.get("output") + if isinstance(output, list): + return self._response_output_to_input_items(output) + + raise ValueError("Responses compact payload does not contain compacted items.") + + def _response_output_to_input_items( + self, output: list[Any] + ) -> list[dict[str, Any]]: + converted: list[dict[str, Any]] = [] + for item in output: + if not isinstance(item, dict): + continue + item_type = item.get("type") + if item_type == "message": + role = item.get("role", "assistant") + if role not in {"system", "developer", "user", "assistant"}: + role = "assistant" + content = item.get("content", []) + converted_content = self._output_content_to_input_content(content) + converted.append( + { + "type": "message", + "role": role, + "content": converted_content, + } + ) + elif item_type == "function_call": + converted.append( + { + "type": "function_call", + "call_id": item.get("call_id", item.get("id", "")), + "name": item.get("name", ""), + "arguments": item.get("arguments", "{}"), + } + ) + return converted + + def _output_content_to_input_content(self, content: Any) -> list[dict[str, Any]]: + converted: list[dict[str, Any]] = [] + if not isinstance(content, list): + return converted + for part in content: + if not isinstance(part, dict): + continue + part_type = part.get("type") + if part_type == "output_text": + text = part.get("text") + if text: + converted.append({"type": "input_text", "text": str(text)}) + elif part_type == "input_text": + text = part.get("text") + if text: + converted.append({"type": "input_text", "text": str(text)}) + return converted + + def _response_input_to_messages( + self, input_items: list[dict[str, Any]] + ) -> list[Message]: + messages: list[Message] = [] + for item in input_items: + item_type = item.get("type") + if item_type == "message": + role = item.get("role") + if role not in {"system", "developer", "user", "assistant"}: + continue + content = self._response_content_to_message_content(item.get("content")) + if content is None: + content = "" + messages.append(Message(role=role, content=content)) + elif item_type == "function_call": + call_id = item.get("call_id") or item.get("id") + name = item.get("name") + arguments = item.get("arguments", "{}") + if not call_id or not name: + continue + if not isinstance(arguments, str): + arguments = json.dumps(arguments, ensure_ascii=False) + messages.append( + Message( + role="assistant", + content="", + tool_calls=[ + { + "type": "function", + "id": str(call_id), + "function": { + "name": str(name), + "arguments": arguments, + }, + } + ], + ) + ) + elif item_type == "function_call_output": + call_id = item.get("call_id") + output = item.get("output", "") + if not call_id: + continue + messages.append( + Message( + role="tool", + tool_call_id=str(call_id), + content=str(output), + ) + ) + + return messages + + def _response_content_to_message_content(self, content: Any) -> str | list: + if isinstance(content, str): + return content + if not isinstance(content, list): + return "" + + parts: list[Any] = [] + plain_text: list[str] = [] + has_non_text = False + + for part in content: + if not isinstance(part, dict): + continue + part_type = part.get("type") + if part_type in {"input_text", "output_text", "text"}: + text = part.get("text") + if text is not None: + plain_text.append(str(text)) + parts.append(TextPart(text=str(text))) + elif part_type in {"input_image", "image_url"}: + image_url = None + if part_type == "input_image": + image_url = part.get("image_url") or part.get("file_url") + else: + image_data = part.get("image_url") + if isinstance(image_data, dict): + image_url = image_data.get("url") + elif isinstance(image_data, str): + image_url = image_data + if image_url: + has_non_text = True + parts.append( + ImageURLPart( + image_url=ImageURLPart.ImageURL(url=str(image_url)) + ) + ) + + if has_non_text: + return parts + return "\n".join(plain_text).strip() + + def _messages_to_response_input( + self, + messages: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + response_input: list[dict[str, Any]] = [] + + for message in messages: + role = message.get("role") + content = message.get("content") + tool_calls = message.get("tool_calls") + + if role in {"system", "developer", "user", "assistant"}: + converted_content = self._message_content_to_response_content(content) + message_item: dict[str, Any] = { + "type": "message", + "role": role, + } + if isinstance(converted_content, str): + message_item["content"] = converted_content + else: + message_item["content"] = converted_content + response_input.append(message_item) + + if role == "assistant" and isinstance(tool_calls, list): + for tool_call in tool_calls: + normalized = self._normalize_tool_call(tool_call) + if not normalized: + continue + response_input.append( + { + "type": "function_call", + "call_id": normalized["id"], + "name": normalized["name"], + "arguments": normalized["arguments"], + } + ) + + if role == "tool": + call_id = message.get("tool_call_id") + if call_id: + response_input.append( + { + "type": "function_call_output", + "call_id": str(call_id), + "output": self._extract_text_from_content(content), + } + ) + + return response_input + + def _normalize_tool_call(self, tool_call: Any) -> dict[str, str] | None: + if isinstance(tool_call, str): + try: + tool_call = json.loads(tool_call) + except Exception: + return None + if not isinstance(tool_call, dict): + return None + + tool_type = tool_call.get("type") + if tool_type != "function": + return None + + function_data = tool_call.get("function", {}) + if not isinstance(function_data, dict): + return None + + name = function_data.get("name") + call_id = tool_call.get("id") + arguments = function_data.get("arguments", "{}") + if not name or not call_id: + return None + + if not isinstance(arguments, str): + arguments = json.dumps(arguments, ensure_ascii=False) + + return { + "id": str(call_id), + "name": str(name), + "arguments": arguments, + } + + def _message_content_to_response_content( + self, content: Any + ) -> str | list[dict[str, Any]]: + if isinstance(content, str): + return content + if not isinstance(content, list): + return "" + + converted: list[dict[str, Any]] = [] + for part in content: + if not isinstance(part, dict): + continue + + part_type = part.get("type") + if part_type in {"text", "input_text", "output_text"}: + text = part.get("text") + if text is not None: + converted.append({"type": "input_text", "text": str(text)}) + elif part_type in {"image_url", "input_image"}: + image_part = self._normalize_image_part(part) + if image_part: + converted.append(image_part) + elif part_type == "input_file": + file_id = part.get("file_id") + file_url = part.get("file_url") + file_data = part.get("file_data") + input_file: dict[str, Any] = {"type": "input_file"} + if file_id: + input_file["file_id"] = file_id + elif file_url: + input_file["file_url"] = file_url + elif file_data: + input_file["file_data"] = file_data + filename = part.get("filename") + if filename: + input_file["filename"] = filename + if len(input_file) > 1: + converted.append(input_file) + + if not converted: + return "" + if len(converted) == 1 and converted[0].get("type") == "input_text": + return str(converted[0].get("text", "")) + return converted + + def _normalize_image_part(self, part: dict[str, Any]) -> dict[str, Any] | None: + if part.get("type") == "input_image": + image_url = part.get("image_url") + if image_url: + normalized = {"type": "input_image", "image_url": str(image_url)} + detail = part.get("detail") + if detail: + normalized["detail"] = detail + return normalized + file_id = part.get("file_id") + if file_id: + normalized = {"type": "input_image", "file_id": str(file_id)} + detail = part.get("detail") + if detail: + normalized["detail"] = detail + return normalized + return None + + image_data = part.get("image_url") + image_url = None + if isinstance(image_data, dict): + image_url = image_data.get("url") + elif isinstance(image_data, str): + image_url = image_data + + if not image_url: + return None + + normalized = {"type": "input_image", "image_url": str(image_url)} + detail = part.get("detail") + if detail: + normalized["detail"] = detail + return normalized + + def _extract_text_from_content(self, content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + texts: list[str] = [] + for part in content: + if isinstance(part, dict): + text = part.get("text") + if text is not None: + texts.append(str(text)) + return "\n".join(texts) + return str(content) if content is not None else "" + + def _build_responses_input_and_instructions( + self, + messages: list[dict[str, Any]], + ) -> tuple[list[dict[str, Any]], str | None]: + response_input = self._messages_to_response_input(messages) + filtered_input: list[dict[str, Any]] = [] + instruction_chunks: list[str] = [] + + for item in response_input: + if item.get("type") == "message" and item.get("role") in { + "system", + "developer", + }: + instruction_text = self._extract_text_from_content(item.get("content")) + if instruction_text: + instruction_chunks.append(instruction_text) + continue + filtered_input.append(item) + + instructions = "\n\n".join(instruction_chunks).strip() + return filtered_input, instructions or None + + def _build_extra_headers(self) -> dict[str, str]: + headers: dict[str, str] = {} + if isinstance(self.custom_headers, dict): + for key, value in self.custom_headers.items(): + if str(key).lower() == "authorization": + continue + headers[str(key)] = str(value) + return headers + + def _convert_tools_to_responses( + self, tool_list: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + response_tools: list[dict[str, Any]] = [] + for tool in tool_list: + if not isinstance(tool, dict): + continue + if tool.get("type") != "function": + continue + + function_body = tool.get("function") + if isinstance(function_body, dict): + name = function_body.get("name") + if not name: + continue + response_tools.append( + { + "type": "function", + "name": name, + "description": function_body.get("description", ""), + "parameters": function_body.get("parameters", {}), + "strict": False, + } + ) + continue + + name = tool.get("name") + if not name: + continue + response_tools.append( + { + "type": "function", + "name": name, + "description": tool.get("description", ""), + "parameters": tool.get("parameters", {}), + "strict": False, + } + ) + + return response_tools + + def _extract_response_usage(self, usage: ResponseUsage | None) -> TokenUsage | None: + if usage is None: + return None + cached = 0 + if usage.input_tokens_details: + cached = usage.input_tokens_details.cached_tokens or 0 + input_other = max(0, (usage.input_tokens or 0) - cached) + output = usage.output_tokens or 0 + return TokenUsage(input_other=input_other, input_cached=cached, output=output) + + async def _parse_openai_response( + self, + response: OpenAIResponse, + tools: ToolSet | None, + ) -> LLMResponse: + llm_response = LLMResponse("assistant") + + if response.error: + raise Exception(f"Responses API error: {response.error.message}") + + completion_text = response.output_text.strip() + if completion_text: + llm_response.result_chain = MessageChain().message(completion_text) + + reasoning_segments: list[str] = [] + for output_item in response.output: + output_type = getattr(output_item, "type", "") + if output_type == "reasoning": + summary = getattr(output_item, "summary", []) + for summary_part in summary: + text = getattr(summary_part, "text", "") + if text: + reasoning_segments.append(str(text)) + if output_type == "function_call" and tools is not None: + arguments = getattr(output_item, "arguments", "{}") + function_name = getattr(output_item, "name", "") + call_id = getattr(output_item, "call_id", "") + parsed_arguments: dict[str, Any] + try: + parsed_arguments = json.loads(arguments) if arguments else {} + except Exception: + parsed_arguments = {} + llm_response.tools_call_args.append(parsed_arguments) + llm_response.tools_call_name.append(str(function_name)) + llm_response.tools_call_ids.append(str(call_id)) + + if reasoning_segments: + llm_response.reasoning_content = "\n".join(reasoning_segments) + + if not llm_response.completion_text and not llm_response.tools_call_args: + raise Exception(f"Responses API returned empty output: {response}") + + llm_response.raw_completion = response + llm_response.id = response.id + llm_response.usage = self._extract_response_usage(response.usage) + return llm_response + + async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: + request_payload = dict(payloads) + response_input, instructions = self._build_responses_input_and_instructions( + request_payload.pop("messages", []) + ) + request_payload["input"] = response_input + if instructions and not request_payload.get("instructions"): + request_payload["instructions"] = instructions + + if tools: + model = request_payload.get("model", "").lower() + omit_empty_param_field = "gemini" in model + tool_list = tools.get_func_desc_openai_style( + omit_empty_parameter_field=omit_empty_param_field, + ) + response_tools = self._convert_tools_to_responses(tool_list) + if response_tools: + request_payload["tools"] = response_tools + + extra_body: dict[str, Any] = {} + for key in list(request_payload.keys()): + if key not in self.default_params: + extra_body[key] = request_payload.pop(key) + + custom_extra_body = self.provider_config.get("custom_extra_body", {}) + if isinstance(custom_extra_body, dict): + extra_body.update(custom_extra_body) + + extra_headers = self._build_extra_headers() + completion = await self.client.responses.create( + **request_payload, + stream=False, + extra_body=extra_body, + extra_headers=extra_headers, + ) + + if not isinstance(completion, OpenAIResponse): + raise TypeError(f"Unexpected response object: {type(completion)}") + + logger.debug(f"response: {completion}") + return await self._parse_openai_response(completion, tools) + + async def _query_stream( + self, + payloads: dict, + tools: ToolSet | None, + ) -> AsyncGenerator[LLMResponse, None]: + request_payload = dict(payloads) + response_input, instructions = self._build_responses_input_and_instructions( + request_payload.pop("messages", []) + ) + request_payload["input"] = response_input + if instructions and not request_payload.get("instructions"): + request_payload["instructions"] = instructions + + if tools: + model = request_payload.get("model", "").lower() + omit_empty_param_field = "gemini" in model + tool_list = tools.get_func_desc_openai_style( + omit_empty_parameter_field=omit_empty_param_field, + ) + response_tools = self._convert_tools_to_responses(tool_list) + if response_tools: + request_payload["tools"] = response_tools + + extra_body: dict[str, Any] = {} + for key in list(request_payload.keys()): + if key not in self.default_params: + extra_body[key] = request_payload.pop(key) + + custom_extra_body = self.provider_config.get("custom_extra_body", {}) + if isinstance(custom_extra_body, dict): + extra_body.update(custom_extra_body) + + response_id: str | None = None + extra_headers = self._build_extra_headers() + try: + async with self.client.responses.stream( + **request_payload, + extra_body=extra_body, + extra_headers=extra_headers, + ) as stream: + async for event in stream: + event_type = getattr(event, "type", "") + if event_type == "response.created": + response_id = event.response.id + continue + + if event_type == "response.output_text.delta": + delta = getattr(event, "delta", "") + if delta: + yield LLMResponse( + role="assistant", + result_chain=MessageChain( + chain=[Comp.Plain(str(delta))] + ), + is_chunk=True, + id=response_id, + ) + continue + + if event_type == "response.reasoning_summary_text.delta": + delta = getattr(event, "delta", "") + if delta: + yield LLMResponse( + role="assistant", + reasoning_content=str(delta), + is_chunk=True, + id=response_id, + ) + continue + + if event_type == "error": + raise Exception( + f"Responses stream error: {getattr(event, 'code', 'unknown')} {getattr(event, 'message', '')}" + ) + + if event_type == "response.failed": + error_obj = getattr(event.response, "error", None) + if error_obj is not None: + raise Exception( + f"Responses stream failed: {error_obj.code} {error_obj.message}" + ) + raise Exception("Responses stream failed.") + + final_response = await stream.get_final_response() + except Exception as e: + if self._is_retryable_upstream_error(e) or is_connection_error(e): + logger.warning( + "Responses stream failed, fallback to non-stream create: %s", + e, + ) + yield await self._query(payloads, tools) + return + raise + + yield await self._parse_openai_response(final_response, tools) + + def _is_retryable_upstream_error(self, e: Exception) -> bool: + status_code = getattr(e, "status_code", None) + if status_code in {500, 502, 503, 504}: + return True + + message = str(e).lower() + if "upstream request failed" in message: + return True + + body = getattr(e, "body", None) + if isinstance(body, dict): + error_obj = body.get("error", {}) + if isinstance(error_obj, dict): + if str(error_obj.get("type", "")).lower() == "upstream_error": + return True + + return False + + async def _handle_api_error( + self, + e: Exception, + payloads: dict, + context_query: list, + func_tool: ToolSet | None, + chosen_key: str, + available_api_keys: list[str], + retry_cnt: int, + max_retries: int, + ) -> tuple: + if is_connection_error(e): + proxy = self.provider_config.get("proxy", "") + log_connection_failure("OpenAI", e, proxy) + return await super()._handle_api_error( + e, + payloads, + context_query, + func_tool, + chosen_key, + available_api_keys, + retry_cnt, + max_retries, + ) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index a7c0e3a57..8cc38aded 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -375,15 +375,20 @@ async def stream(): try: async with track_conversation(self.running_convs, webchat_conv_id): while True: + result = None try: result = await asyncio.wait_for(back_queue.get(), timeout=1) except asyncio.TimeoutError: continue except asyncio.CancelledError: - logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") + logger.debug( + f"[WebChat] user {username} disconnected from stream." + ) client_disconnected = True + break except Exception as e: logger.error(f"WebChat stream error: {e}") + continue if not result: continue diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index efea4c7cf..faba4bafd 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -2,6 +2,7 @@ import copy import inspect import os +import re import traceback from pathlib import Path from typing import Any @@ -39,6 +40,140 @@ MAX_FILE_BYTES = 500 * 1024 * 1024 +_MASKED_SECRET_PLACEHOLDER = "***" + +_SENSITIVE_FIELD_NAMES = { + "access_token", + "api_key", + "apikey", + "authorization", + "bearer_token", + "client_secret", + "key", + "password", + "refresh_token", + "secret", + "token", +} + + +def _is_sensitive_field_name(name: str) -> bool: + lower_name = name.strip().lower() + if lower_name in _SENSITIVE_FIELD_NAMES: + return True + + parts = [p for p in re.split(r"[^a-z0-9]+", lower_name) if p] + if not parts: + return False + + if parts[-1] in {"key", "secret", "token", "password"}: + return True + if parts[0] in {"auth", "authorization"}: + return True + return "authorization" in parts + + +def _mask_secret_value(value: Any) -> Any: + if isinstance(value, str): + if not value: + return value + return _MASKED_SECRET_PLACEHOLDER + if isinstance(value, list): + return [_mask_secret_value(item) for item in value] + if isinstance(value, dict): + return {key: _mask_secret_value(item) for key, item in value.items()} + return value + + +def _is_masked_secret_value(value: Any) -> bool: + if isinstance(value, str): + return value == _MASKED_SECRET_PLACEHOLDER + if isinstance(value, list): + return bool(value) and all(_is_masked_secret_value(item) for item in value) + if isinstance(value, dict): + return bool(value) and all( + _is_masked_secret_value(item) for item in value.values() + ) + return False + + +def sanitize_provider_config(provider_config: dict[str, Any]) -> dict[str, Any]: + sanitized = copy.deepcopy(provider_config) + + def _sanitize_mapping(mapping: dict[str, Any]) -> None: + for key, value in list(mapping.items()): + if key == "custom_headers" and isinstance(value, dict): + sanitized_headers = {} + for header_key, header_value in value.items(): + if _is_sensitive_field_name(header_key): + sanitized_headers[header_key] = _mask_secret_value(header_value) + else: + sanitized_headers[header_key] = header_value + mapping[key] = sanitized_headers + continue + + if _is_sensitive_field_name(key): + mapping[key] = _mask_secret_value(value) + continue + + if isinstance(value, dict): + _sanitize_mapping(value) + elif isinstance(value, list): + new_list = [] + for item in value: + if isinstance(item, dict): + item_copy = copy.deepcopy(item) + _sanitize_mapping(item_copy) + new_list.append(item_copy) + else: + new_list.append(item) + mapping[key] = new_list + + _sanitize_mapping(sanitized) + return sanitized + + +def restore_masked_provider_config( + new_config: dict[str, Any], old_config: dict[str, Any] +) -> dict[str, Any]: + restored = copy.deepcopy(new_config) + + def _restore_mapping(mapping: dict[str, Any], old_mapping: dict[str, Any]) -> None: + for key, value in list(mapping.items()): + old_value = old_mapping.get(key) + + if key == "custom_headers" and isinstance(value, dict): + old_headers = old_value if isinstance(old_value, dict) else {} + for header_key, header_value in list(value.items()): + if ( + _is_sensitive_field_name(header_key) + and _is_masked_secret_value(header_value) + and header_key in old_headers + ): + value[header_key] = copy.deepcopy(old_headers[header_key]) + continue + + if _is_sensitive_field_name(key): + if _is_masked_secret_value(value) and old_value is not None: + mapping[key] = copy.deepcopy(old_value) + continue + + if isinstance(value, dict) and isinstance(old_value, dict): + _restore_mapping(value, old_value) + continue + + if isinstance(value, list) and isinstance(old_value, list): + for idx, item in enumerate(value): + if ( + isinstance(item, dict) + and idx < len(old_value) + and isinstance(old_value[idx], dict) + ): + _restore_mapping(item, old_value[idx]) + + _restore_mapping(restored, old_config) + return restored + def try_cast(value: Any, type_: str): if type_ == "int": @@ -694,10 +829,10 @@ async def get_provider_config_list(self): prov = self.core_lifecycle.provider_manager.get_merged_provider_config( provider ) - provider_list.append(prov) + provider_list.append(sanitize_provider_config(prov)) elif not ps_id and provider.get("provider_type", "") in provider_type_ls: # agent runner, embedding, etc - provider_list.append(provider) + provider_list.append(sanitize_provider_config(provider)) return Response().ok(provider_list).__dict__ async def get_provider_model_list(self): @@ -1179,6 +1314,15 @@ async def post_update_provider(self): if not origin_provider_id or not new_config: return Response().error("参数错误").__dict__ + old_config = None + for provider in self.config.get("provider", []): + if provider.get("id") == origin_provider_id: + old_config = provider + break + + if isinstance(old_config, dict): + new_config = restore_masked_provider_config(new_config, old_config) + try: await self.core_lifecycle.provider_manager.update_provider( origin_provider_id, new_config diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 2166d5391..d79149ef2 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -211,6 +211,10 @@ "llm_compress_provider_id": { "description": "Model Provider ID for Context Compression", "hint": "When left empty, will fall back to the 'Truncate by Turns' strategy." + }, + "llm_compress_use_compact_api": { + "description": "Prefer compact API when available", + "hint": "When enabled, local runner first tries provider native compact API and falls back to LLM summary compression." } } }, diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 2d1c11cda..a4b996eb9 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -214,6 +214,10 @@ "llm_compress_provider_id": { "description": "用于上下文压缩的模型提供商 ID", "hint": "留空时将降级为\"按对话轮数截断\"的策略。" + }, + "llm_compress_use_compact_api": { + "description": "\u4f18\u5148\u4f7f\u7528 compact \u63a5\u53e3", + "hint": "\u542f\u7528\u540e\u4f1a\u4f18\u5148\u8c03\u7528\u5f53\u524d\u4f1a\u8bdd\u4f9b\u5e94\u5546\u7684 native compact \u63a5\u53e3\uff1b\u82e5\u4e0d\u652f\u6301\u5219\u56de\u9000\u5230\u5e38\u89c4 LLM \u538b\u7f29\u3002" } } }, diff --git a/tests/agent/test_context_manager.py b/tests/agent/test_context_manager.py index 0b955ff40..600090a2d 100644 --- a/tests/agent/test_context_manager.py +++ b/tests/agent/test_context_manager.py @@ -344,6 +344,82 @@ async def test_sequential_processing_order(self): # Truncator should be called first mock_truncate.assert_called_once() + @pytest.mark.asyncio + async def test_native_compact_is_preferred_when_available(self): + """Native compact should be used before summary generation when available.""" + + class NativeCompactProvider(MockProvider): + def __init__(self): + super().__init__() + self.compact_context = AsyncMock( + return_value=[Message(role="user", content="compacted")] + ) + self.text_chat = AsyncMock() + + def supports_native_compact(self): + return True + + provider = NativeCompactProvider() + config = ContextConfig( + max_context_tokens=10, + llm_compress_provider=provider, # type: ignore[arg-type] + llm_compress_keep_recent=2, + ) + manager = ContextManager(config) + + messages = self.create_messages(8) + + with patch.object( + manager.compressor, "should_compress", side_effect=[True, False] + ): + result = await manager.process(messages) + + provider.compact_context.assert_awaited_once() + provider.text_chat.assert_not_called() + assert any(msg.content == "compacted" for msg in result) + + @pytest.mark.asyncio + async def test_native_compact_fallback_to_summary_on_failure(self): + """Fallback to summary compression when native compact fails.""" + + class FailingCompactProvider(MockProvider): + def __init__(self): + super().__init__() + self.compact_context = AsyncMock( + side_effect=RuntimeError("compact failed") + ) + self.text_chat = AsyncMock( + return_value=LLMResponse( + role="assistant", + completion_text="summary fallback", + ) + ) + + def supports_native_compact(self): + return True + + provider = FailingCompactProvider() + config = ContextConfig( + max_context_tokens=10, + llm_compress_provider=provider, # type: ignore[arg-type] + llm_compress_keep_recent=2, + ) + manager = ContextManager(config) + + messages = self.create_messages(8) + + with patch.object( + manager.compressor, "should_compress", side_effect=[True, False] + ): + result = await manager.process(messages) + + provider.compact_context.assert_awaited_once() + provider.text_chat.assert_awaited_once() + assert any( + msg.content == "Our previous history conversation summary: summary fallback" + for msg in result + ) + # ==================== Error Handling Tests ==================== @pytest.mark.asyncio diff --git a/tests/test_openai_responses_source.py b/tests/test_openai_responses_source.py new file mode 100644 index 000000000..a7a5bab43 --- /dev/null +++ b/tests/test_openai_responses_source.py @@ -0,0 +1,79 @@ +import os +import sys + +# Add project root to sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from astrbot.core.provider.sources.openai_responses_source import ( + ProviderOpenAIResponses, +) + + +class _DummyError(Exception): + def __init__(self, message: str, status_code=None, body=None): + super().__init__(message) + self.status_code = status_code + self.body = body + + +def _provider() -> ProviderOpenAIResponses: + return ProviderOpenAIResponses.__new__(ProviderOpenAIResponses) + + +def test_is_retryable_upstream_error_with_5xx_status(): + err = _DummyError("server error", status_code=502) + + assert _provider()._is_retryable_upstream_error(err) + + +def test_is_retryable_upstream_error_with_upstream_error_type(): + err = _DummyError( + "bad gateway", + status_code=400, + body={"error": {"type": "upstream_error"}}, + ) + + assert _provider()._is_retryable_upstream_error(err) + + +def test_is_retryable_upstream_error_returns_false_for_non_retryable_error(): + err = _DummyError( + "invalid request", + status_code=400, + body={"error": {"type": "invalid_request_error"}}, + ) + + assert not _provider()._is_retryable_upstream_error(err) + + +def test_build_responses_input_and_instructions_moves_system_messages(): + provider = _provider() + provider.custom_headers = {} + + response_input, instructions = provider._build_responses_input_and_instructions( + [ + {"role": "system", "content": "sys text"}, + {"role": "developer", "content": [{"type": "text", "text": "dev text"}]}, + {"role": "user", "content": "hello"}, + ] + ) + + assert instructions == "sys text\n\ndev text" + assert all( + item.get("role") not in {"system", "developer"} for item in response_input + ) + assert any(item.get("role") == "user" for item in response_input) + + +def test_build_extra_headers_keeps_custom_headers_and_ignores_authorization(): + provider = _provider() + provider.custom_headers = { + "X-Test": "value", + "Authorization": "Bearer should-not-pass", + } + + headers = provider._build_extra_headers() + + assert "User-Agent" not in headers + assert headers["X-Test"] == "value" + assert "Authorization" not in headers diff --git a/tests/test_provider_config_sanitization.py b/tests/test_provider_config_sanitization.py new file mode 100644 index 000000000..9725f6cdb --- /dev/null +++ b/tests/test_provider_config_sanitization.py @@ -0,0 +1,86 @@ +import os +import sys + +# Add project root to sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from astrbot.dashboard.routes.config import ( + restore_masked_provider_config, + sanitize_provider_config, +) + + +def test_sanitize_provider_config_masks_sensitive_fields(): + raw = { + "id": "openai", + "type": "openai_chat_completion", + "key": ["sk-123"], + "custom_headers": { + "Authorization": "Bearer sk-abc", + "X-Trace": "ok", + }, + "nested": { + "api_key": "secret-key", + "token": "tkn", + "keep": "value", + }, + } + + sanitized = sanitize_provider_config(raw) + + assert sanitized["key"] == ["***"] + assert sanitized["custom_headers"]["Authorization"] == "***" + assert sanitized["custom_headers"]["X-Trace"] == "ok" + assert sanitized["nested"]["api_key"] == "***" + assert sanitized["nested"]["token"] == "***" + assert sanitized["nested"]["keep"] == "value" + + +def test_sanitize_provider_config_keeps_non_sensitive_fields(): + raw = { + "id": "provider-id", + "model": "gpt-4.1", + "provider_type": "chat_completion", + } + + sanitized = sanitize_provider_config(raw) + + assert sanitized == raw + + +def test_sanitize_provider_config_does_not_mask_non_secret_token_fields(): + raw = { + "max_tokens": 4096, + "token_limit": 8192, + "monkey": "banana", + } + + sanitized = sanitize_provider_config(raw) + + assert sanitized["max_tokens"] == 4096 + assert sanitized["token_limit"] == 8192 + assert sanitized["monkey"] == "banana" + + +def test_restore_masked_provider_config_recovers_existing_secrets(): + old = { + "id": "openai", + "key": ["sk-old"], + "custom_headers": { + "Authorization": "Bearer old", + "X-Trace": "old-trace", + }, + "max_tokens": 1024, + "enable": True, + } + sanitized = sanitize_provider_config(old) + sanitized["enable"] = False + sanitized["max_tokens"] = 2048 + + restored = restore_masked_provider_config(sanitized, old) + + assert restored["key"] == ["sk-old"] + assert restored["custom_headers"]["Authorization"] == "Bearer old" + assert restored["custom_headers"]["X-Trace"] == "old-trace" + assert restored["enable"] is False + assert restored["max_tokens"] == 2048 From 25bd4afb77b2c69f7d11028ed727744b5d71fa57 Mon Sep 17 00:00:00 2001 From: xunxi Date: Thu, 12 Feb 2026 00:11:10 +0800 Subject: [PATCH 2/7] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=97=A0=E5=85=B3?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/agent/context/manager.py | 4 +- astrbot/core/astr_main_agent.py | 6 +- astrbot/dashboard/routes/chat.py | 4 +- astrbot/dashboard/routes/config.py | 146 +-------------------- tests/test_provider_config_sanitization.py | 86 ------------ 5 files changed, 7 insertions(+), 239 deletions(-) delete mode 100644 tests/test_provider_config_sanitization.py diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 6f1c5abdb..216a3e7e1 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -55,7 +55,7 @@ async def process( try: result = messages - # 1. ???????(Enforce max turns) + # 1. 基于轮次的截断 (Enforce max turns) if self.config.enforce_max_turns != -1: result = self.truncator.truncate_by_turns( result, @@ -63,7 +63,7 @@ async def process( drop_turns=self.config.truncate_turns, ) - # 2. ?? token ??? + # 2. 基于 token 的压缩 if self.config.max_context_tokens > 0: total_tokens = self.token_counter.count_tokens( result, trusted_token_usage diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 9b4efb79d..072975a4a 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -821,21 +821,19 @@ def _get_compress_provider( if config.context_limit_reached_strategy != "llm_compress": return None - _ = active_provider - if not config.llm_compress_provider_id: return None selected_provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id) if selected_provider is None: logger.warning( - "Configured llm_compress_provider_id not found: %s. Skip compression.", + "未找到指定的上下文压缩模型 %s,将跳过压缩。", config.llm_compress_provider_id, ) return None if not isinstance(selected_provider, Provider): logger.warning( - "Configured llm_compress_provider_id is not a Provider: %s. Skip compression.", + "指定的上下文压缩模型 %s 不是对话模型,将跳过压缩。", config.llm_compress_provider_id, ) return None diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 8cc38aded..d905a0d3f 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -381,9 +381,7 @@ async def stream(): except asyncio.TimeoutError: continue except asyncio.CancelledError: - logger.debug( - f"[WebChat] user {username} disconnected from stream." - ) + logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") client_disconnected = True break except Exception as e: diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index faba4bafd..125d9c07e 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -40,139 +40,6 @@ MAX_FILE_BYTES = 500 * 1024 * 1024 -_MASKED_SECRET_PLACEHOLDER = "***" - -_SENSITIVE_FIELD_NAMES = { - "access_token", - "api_key", - "apikey", - "authorization", - "bearer_token", - "client_secret", - "key", - "password", - "refresh_token", - "secret", - "token", -} - - -def _is_sensitive_field_name(name: str) -> bool: - lower_name = name.strip().lower() - if lower_name in _SENSITIVE_FIELD_NAMES: - return True - - parts = [p for p in re.split(r"[^a-z0-9]+", lower_name) if p] - if not parts: - return False - - if parts[-1] in {"key", "secret", "token", "password"}: - return True - if parts[0] in {"auth", "authorization"}: - return True - return "authorization" in parts - - -def _mask_secret_value(value: Any) -> Any: - if isinstance(value, str): - if not value: - return value - return _MASKED_SECRET_PLACEHOLDER - if isinstance(value, list): - return [_mask_secret_value(item) for item in value] - if isinstance(value, dict): - return {key: _mask_secret_value(item) for key, item in value.items()} - return value - - -def _is_masked_secret_value(value: Any) -> bool: - if isinstance(value, str): - return value == _MASKED_SECRET_PLACEHOLDER - if isinstance(value, list): - return bool(value) and all(_is_masked_secret_value(item) for item in value) - if isinstance(value, dict): - return bool(value) and all( - _is_masked_secret_value(item) for item in value.values() - ) - return False - - -def sanitize_provider_config(provider_config: dict[str, Any]) -> dict[str, Any]: - sanitized = copy.deepcopy(provider_config) - - def _sanitize_mapping(mapping: dict[str, Any]) -> None: - for key, value in list(mapping.items()): - if key == "custom_headers" and isinstance(value, dict): - sanitized_headers = {} - for header_key, header_value in value.items(): - if _is_sensitive_field_name(header_key): - sanitized_headers[header_key] = _mask_secret_value(header_value) - else: - sanitized_headers[header_key] = header_value - mapping[key] = sanitized_headers - continue - - if _is_sensitive_field_name(key): - mapping[key] = _mask_secret_value(value) - continue - - if isinstance(value, dict): - _sanitize_mapping(value) - elif isinstance(value, list): - new_list = [] - for item in value: - if isinstance(item, dict): - item_copy = copy.deepcopy(item) - _sanitize_mapping(item_copy) - new_list.append(item_copy) - else: - new_list.append(item) - mapping[key] = new_list - - _sanitize_mapping(sanitized) - return sanitized - - -def restore_masked_provider_config( - new_config: dict[str, Any], old_config: dict[str, Any] -) -> dict[str, Any]: - restored = copy.deepcopy(new_config) - - def _restore_mapping(mapping: dict[str, Any], old_mapping: dict[str, Any]) -> None: - for key, value in list(mapping.items()): - old_value = old_mapping.get(key) - - if key == "custom_headers" and isinstance(value, dict): - old_headers = old_value if isinstance(old_value, dict) else {} - for header_key, header_value in list(value.items()): - if ( - _is_sensitive_field_name(header_key) - and _is_masked_secret_value(header_value) - and header_key in old_headers - ): - value[header_key] = copy.deepcopy(old_headers[header_key]) - continue - - if _is_sensitive_field_name(key): - if _is_masked_secret_value(value) and old_value is not None: - mapping[key] = copy.deepcopy(old_value) - continue - - if isinstance(value, dict) and isinstance(old_value, dict): - _restore_mapping(value, old_value) - continue - - if isinstance(value, list) and isinstance(old_value, list): - for idx, item in enumerate(value): - if ( - isinstance(item, dict) - and idx < len(old_value) - and isinstance(old_value[idx], dict) - ): - _restore_mapping(item, old_value[idx]) - - _restore_mapping(restored, old_config) - return restored def try_cast(value: Any, type_: str): @@ -829,10 +696,10 @@ async def get_provider_config_list(self): prov = self.core_lifecycle.provider_manager.get_merged_provider_config( provider ) - provider_list.append(sanitize_provider_config(prov)) + provider_list.append(prov) elif not ps_id and provider.get("provider_type", "") in provider_type_ls: # agent runner, embedding, etc - provider_list.append(sanitize_provider_config(provider)) + provider_list.append(provider) return Response().ok(provider_list).__dict__ async def get_provider_model_list(self): @@ -1314,15 +1181,6 @@ async def post_update_provider(self): if not origin_provider_id or not new_config: return Response().error("参数错误").__dict__ - old_config = None - for provider in self.config.get("provider", []): - if provider.get("id") == origin_provider_id: - old_config = provider - break - - if isinstance(old_config, dict): - new_config = restore_masked_provider_config(new_config, old_config) - try: await self.core_lifecycle.provider_manager.update_provider( origin_provider_id, new_config diff --git a/tests/test_provider_config_sanitization.py b/tests/test_provider_config_sanitization.py deleted file mode 100644 index 9725f6cdb..000000000 --- a/tests/test_provider_config_sanitization.py +++ /dev/null @@ -1,86 +0,0 @@ -import os -import sys - -# Add project root to sys.path -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from astrbot.dashboard.routes.config import ( - restore_masked_provider_config, - sanitize_provider_config, -) - - -def test_sanitize_provider_config_masks_sensitive_fields(): - raw = { - "id": "openai", - "type": "openai_chat_completion", - "key": ["sk-123"], - "custom_headers": { - "Authorization": "Bearer sk-abc", - "X-Trace": "ok", - }, - "nested": { - "api_key": "secret-key", - "token": "tkn", - "keep": "value", - }, - } - - sanitized = sanitize_provider_config(raw) - - assert sanitized["key"] == ["***"] - assert sanitized["custom_headers"]["Authorization"] == "***" - assert sanitized["custom_headers"]["X-Trace"] == "ok" - assert sanitized["nested"]["api_key"] == "***" - assert sanitized["nested"]["token"] == "***" - assert sanitized["nested"]["keep"] == "value" - - -def test_sanitize_provider_config_keeps_non_sensitive_fields(): - raw = { - "id": "provider-id", - "model": "gpt-4.1", - "provider_type": "chat_completion", - } - - sanitized = sanitize_provider_config(raw) - - assert sanitized == raw - - -def test_sanitize_provider_config_does_not_mask_non_secret_token_fields(): - raw = { - "max_tokens": 4096, - "token_limit": 8192, - "monkey": "banana", - } - - sanitized = sanitize_provider_config(raw) - - assert sanitized["max_tokens"] == 4096 - assert sanitized["token_limit"] == 8192 - assert sanitized["monkey"] == "banana" - - -def test_restore_masked_provider_config_recovers_existing_secrets(): - old = { - "id": "openai", - "key": ["sk-old"], - "custom_headers": { - "Authorization": "Bearer old", - "X-Trace": "old-trace", - }, - "max_tokens": 1024, - "enable": True, - } - sanitized = sanitize_provider_config(old) - sanitized["enable"] = False - sanitized["max_tokens"] = 2048 - - restored = restore_masked_provider_config(sanitized, old) - - assert restored["key"] == ["sk-old"] - assert restored["custom_headers"]["Authorization"] == "Bearer old" - assert restored["custom_headers"]["X-Trace"] == "old-trace" - assert restored["enable"] is False - assert restored["max_tokens"] == 2048 From eeaed4da6d7d149c39a11cb257e320299a301009 Mon Sep 17 00:00:00 2001 From: xunxi Date: Thu, 12 Feb 2026 00:39:18 +0800 Subject: [PATCH 3/7] =?UTF-8?q?=E5=88=A0=E9=99=A4=E5=A4=9A=E4=BD=99?= =?UTF-8?q?=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/astr_main_agent.py | 4 +++- astrbot/core/provider/entities.py | 8 +++++++- .../provider/sources/openai_responses_source.py | 13 ++++++++++--- astrbot/dashboard/routes/config.py | 2 -- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 072975a4a..796886a53 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -824,7 +824,9 @@ def _get_compress_provider( if not config.llm_compress_provider_id: return None - selected_provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id) + selected_provider = plugin_context.get_provider_by_id( + config.llm_compress_provider_id + ) if selected_provider is None: logger.warning( "未找到指定的上下文压缩模型 %s,将跳过压缩。", diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 20c5a7947..0ade94246 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -9,6 +9,7 @@ from anthropic.types import Message as AnthropicMessage from google.genai.types import GenerateContentResponse from openai.types.chat.chat_completion import ChatCompletion +from openai.types.responses.response import Response as OpenAIResponse import astrbot.core.message.components as Comp from astrbot import logger @@ -276,7 +277,11 @@ class LLMResponse: """The signature of the reasoning content, if any.""" raw_completion: ( - ChatCompletion | GenerateContentResponse | AnthropicMessage | None + ChatCompletion + | GenerateContentResponse + | AnthropicMessage + | OpenAIResponse + | None ) = None """The raw completion response from the LLM provider.""" @@ -305,6 +310,7 @@ def __init__( raw_completion: ChatCompletion | GenerateContentResponse | AnthropicMessage + | OpenAIResponse | None = None, is_chunk: bool = False, id: str | None = None, diff --git a/astrbot/core/provider/sources/openai_responses_source.py b/astrbot/core/provider/sources/openai_responses_source.py index efa2a0138..27dee024e 100644 --- a/astrbot/core/provider/sources/openai_responses_source.py +++ b/astrbot/core/provider/sources/openai_responses_source.py @@ -653,7 +653,9 @@ async def _query_stream( async for event in stream: event_type = getattr(event, "type", "") if event_type == "response.created": - response_id = event.response.id + response_obj = getattr(event, "response", None) + if response_obj: + response_id = getattr(response_obj, "id", None) continue if event_type == "response.output_text.delta": @@ -686,10 +688,15 @@ async def _query_stream( ) if event_type == "response.failed": - error_obj = getattr(event.response, "error", None) + response_obj = getattr(event, "response", None) + error_obj = ( + getattr(response_obj, "error", None) + if response_obj + else None + ) if error_obj is not None: raise Exception( - f"Responses stream failed: {error_obj.code} {error_obj.message}" + f"Responses stream failed: {getattr(error_obj, 'code', 'unknown')} {getattr(error_obj, 'message', '')}" ) raise Exception("Responses stream failed.") diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 125d9c07e..efea4c7cf 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -2,7 +2,6 @@ import copy import inspect import os -import re import traceback from pathlib import Path from typing import Any @@ -41,7 +40,6 @@ MAX_FILE_BYTES = 500 * 1024 * 1024 - def try_cast(value: Any, type_: str): if type_ == "int": try: From b3d11c057f1d67993cb95e428a28db88e62f0959 Mon Sep 17 00:00:00 2001 From: xunxi Date: Thu, 12 Feb 2026 01:38:04 +0800 Subject: [PATCH 4/7] =?UTF-8?q?=E4=BD=BF=E7=94=A8sdk?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sources/openai_responses_source.py | 96 +++++++----- tests/test_openai_responses_source.py | 141 ++++++++++++++++++ 2 files changed, 196 insertions(+), 41 deletions(-) diff --git a/astrbot/core/provider/sources/openai_responses_source.py b/astrbot/core/provider/sources/openai_responses_source.py index 27dee024e..f928e3bd7 100644 --- a/astrbot/core/provider/sources/openai_responses_source.py +++ b/astrbot/core/provider/sources/openai_responses_source.py @@ -3,7 +3,6 @@ from collections.abc import AsyncGenerator from typing import Any -import httpx from openai.types.responses.response import Response as OpenAIResponse from openai.types.responses.response_usage import ResponseUsage @@ -53,45 +52,42 @@ async def compact_context(self, messages: list[Message]) -> list[Message]: return messages message_dicts = self._ensure_message_to_dicts(messages) - payload = { + request_payload = { "model": self.get_model(), "input": self._messages_to_response_input(message_dicts), } - compact_url = self._build_compact_url() - headers = { - "Authorization": f"Bearer {self.get_current_key()}", - "Content-Type": "application/json", - } - headers.update(self._build_extra_headers()) - - proxy = self.provider_config.get("proxy", "") - client_kwargs: dict[str, Any] = {"timeout": self.timeout} - if proxy: - client_kwargs["proxy"] = proxy + request_options: dict[str, Any] = {} + extra_headers = self._build_extra_headers() + if extra_headers: + request_options["extra_headers"] = extra_headers try: - async with httpx.AsyncClient(**client_kwargs) as client: - response = await client.post(compact_url, json=payload, headers=headers) - response.raise_for_status() - compact_data = response.json() + compact_response = await self.client.responses.compact( + **request_payload, + **request_options, + ) except Exception as e: if is_connection_error(e): + proxy = self.provider_config.get("proxy", "") log_connection_failure("OpenAI", e, proxy) raise + if hasattr(compact_response, "model_dump"): + compact_data = compact_response.model_dump(mode="json") + elif isinstance(compact_response, dict): + compact_data = compact_response + else: + compact_data = { + "output": getattr(compact_response, "output", []), + } + compact_input = self._extract_compact_input(compact_data) compact_messages = self._response_input_to_messages(compact_input) if not compact_messages: raise ValueError("Responses compact returned empty context.") return compact_messages - def _build_compact_url(self) -> str: - base_url = str(self.client.base_url).rstrip("/") - if base_url.endswith("/v1"): - return f"{base_url}/responses/compact" - return f"{base_url}/v1/responses/compact" - def _extract_compact_input(self, compact_data: Any) -> list[dict[str, Any]]: if not isinstance(compact_data, dict): raise ValueError("Invalid compact response payload.") @@ -470,6 +466,22 @@ def _build_extra_headers(self) -> dict[str, str]: headers[str(key)] = str(value) return headers + def _resolve_tool_strict( + self, + tool: dict[str, Any], + function_body: dict[str, Any] | None, + ) -> bool | None: + if isinstance(function_body, dict) and isinstance( + function_body.get("strict"), bool + ): + return function_body["strict"] + if isinstance(tool.get("strict"), bool): + return tool["strict"] + default_strict = self.provider_config.get("responses_tool_strict") + if isinstance(default_strict, bool): + return default_strict + return None + def _convert_tools_to_responses( self, tool_list: list[dict[str, Any]] ) -> list[dict[str, Any]]: @@ -485,29 +497,31 @@ def _convert_tools_to_responses( name = function_body.get("name") if not name: continue - response_tools.append( - { - "type": "function", - "name": name, - "description": function_body.get("description", ""), - "parameters": function_body.get("parameters", {}), - "strict": False, - } - ) + response_tool = { + "type": "function", + "name": name, + "description": function_body.get("description", ""), + "parameters": function_body.get("parameters", {}), + } + strict = self._resolve_tool_strict(tool, function_body) + if strict is not None: + response_tool["strict"] = strict + response_tools.append(response_tool) continue name = tool.get("name") if not name: continue - response_tools.append( - { - "type": "function", - "name": name, - "description": tool.get("description", ""), - "parameters": tool.get("parameters", {}), - "strict": False, - } - ) + response_tool = { + "type": "function", + "name": name, + "description": tool.get("description", ""), + "parameters": tool.get("parameters", {}), + } + strict = self._resolve_tool_strict(tool, None) + if strict is not None: + response_tool["strict"] = strict + response_tools.append(response_tool) return response_tools diff --git a/tests/test_openai_responses_source.py b/tests/test_openai_responses_source.py index a7a5bab43..192425305 100644 --- a/tests/test_openai_responses_source.py +++ b/tests/test_openai_responses_source.py @@ -1,9 +1,14 @@ import os import sys +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest # Add project root to sys.path sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from astrbot.core.agent.message import Message from astrbot.core.provider.sources.openai_responses_source import ( ProviderOpenAIResponses, ) @@ -77,3 +82,139 @@ def test_build_extra_headers_keeps_custom_headers_and_ignores_authorization(): assert "User-Agent" not in headers assert headers["X-Test"] == "value" assert "Authorization" not in headers + + +@pytest.mark.asyncio +async def test_compact_context_uses_sdk_compact_api(): + provider = _provider() + provider.provider_config = {"proxy": ""} + provider.get_model = lambda: "gpt-5.3-codex" + provider._ensure_message_to_dicts = lambda messages: [ + {"role": "user", "content": "hello"} + ] + provider._messages_to_response_input = lambda _: [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "hello"}], + } + ] + provider._build_extra_headers = lambda: {"X-Test": "1"} + + compact_mock = AsyncMock( + return_value=SimpleNamespace( + model_dump=lambda mode="json": { + "output": [ + { + "type": "message", + "role": "user", + "content": [{"type": "output_text", "text": "compacted"}], + } + ] + } + ) + ) + provider.client = SimpleNamespace(responses=SimpleNamespace(compact=compact_mock)) + + result = await provider.compact_context([Message(role="user", content="hello")]) + + compact_mock.assert_awaited_once_with( + model="gpt-5.3-codex", + input=[ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "hello"}], + } + ], + extra_headers={"X-Test": "1"}, + ) + assert result + + +@pytest.mark.asyncio +async def test_compact_context_raises_when_compact_returns_empty_messages(): + provider = _provider() + provider.provider_config = {"proxy": ""} + provider.get_model = lambda: "gpt-5.3-codex" + provider._ensure_message_to_dicts = lambda messages: [ + {"role": "user", "content": "hello"} + ] + provider._messages_to_response_input = lambda _: [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "hello"}], + } + ] + provider._build_extra_headers = lambda: {} + + compact_mock = AsyncMock( + return_value=SimpleNamespace(model_dump=lambda mode="json": {"output": []}) + ) + provider.client = SimpleNamespace(responses=SimpleNamespace(compact=compact_mock)) + + with pytest.raises(ValueError, match="empty context"): + await provider.compact_context([Message(role="user", content="hello")]) + + +def test_convert_tools_to_responses_does_not_force_strict_false(): + provider = _provider() + provider.provider_config = {} + + response_tools = provider._convert_tools_to_responses( + [ + { + "type": "function", + "function": { + "name": "demo", + "description": "desc", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + ) + + assert response_tools + assert "strict" not in response_tools[0] + + +def test_convert_tools_to_responses_keeps_explicit_strict_setting(): + provider = _provider() + provider.provider_config = {} + + response_tools = provider._convert_tools_to_responses( + [ + { + "type": "function", + "function": { + "name": "demo", + "description": "desc", + "parameters": {"type": "object", "properties": {}}, + "strict": True, + }, + } + ] + ) + + assert response_tools[0]["strict"] is True + + +def test_convert_tools_to_responses_supports_provider_default_strict(): + provider = _provider() + provider.provider_config = {"responses_tool_strict": True} + + response_tools = provider._convert_tools_to_responses( + [ + { + "type": "function", + "function": { + "name": "demo", + "description": "desc", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + ) + + assert response_tools[0]["strict"] is True From 1a1b9b6cb48f23fa376ee36c54483b66ac00a0b4 Mon Sep 17 00:00:00 2001 From: xunxi Date: Thu, 12 Feb 2026 02:57:25 +0800 Subject: [PATCH 5/7] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dwebchat=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E6=AD=A3=E7=A1=AE=E6=98=BE=E7=A4=BA=E5=B9=B6=E8=A1=8C=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../agent/runners/tool_loop_agent_runner.py | 34 ++++++------- astrbot/dashboard/routes/chat.py | 9 ++++ dashboard/src/composables/useMessages.ts | 50 ++++++++++++++----- 3 files changed, 64 insertions(+), 29 deletions(-) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 8fb01bfcb..a3dbf44ba 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -659,24 +659,24 @@ async def _handle_function_tools( ), ) - # yield the last tool call result - if tool_call_result_blocks: - last_tcr_content = str(tool_call_result_blocks[-1].content) - yield _HandleFunctionToolsResult.from_message_chain( - MessageChain( - type="tool_call_result", - chain=[ - Json( - data={ - "id": func_tool_id, - "ts": time.time(), - "result": last_tcr_content, - } - ) - ], + # yield the tool call result + if tool_call_result_blocks: + last_tcr_content = str(tool_call_result_blocks[-1].content) + yield _HandleFunctionToolsResult.from_message_chain( + MessageChain( + type="tool_call_result", + chain=[ + Json( + data={ + "id": func_tool_id, + "ts": time.time(), + "result": last_tcr_content, + } + ) + ], + ) ) - ) - logger.info(f"Tool `{func_tool_name}` Result: {last_tcr_content}") + logger.info(f"Tool `{func_tool_name}` Result: {last_tcr_content}") # 处理函数调用响应 if tool_call_result_blocks: diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index d905a0d3f..3ea9772b1 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -403,6 +403,15 @@ async def stream(): streaming = result.get("streaming", False) chain_type = result.get("chain_type") + if ( + enable_streaming + and msg_type == "plain" + and chain_type in {"tool_call", "tool_call_result"} + and not streaming + ): + result["streaming"] = True + streaming = True + if chain_type == "agent_stats": stats_info = { "type": "agent_stats", diff --git a/dashboard/src/composables/useMessages.ts b/dashboard/src/composables/useMessages.ts index 620dd6a15..c413cb5cf 100644 --- a/dashboard/src/composables/useMessages.ts +++ b/dashboard/src/composables/useMessages.ts @@ -479,19 +479,38 @@ export function useMessages( } } } else if (chain_type === 'tool_call_result') { - // 解析工具调用结果数据 + // Parse tool call result payload const resultData = JSON.parse(chunk_json.data); - if (message_obj) { - // 遍历所有 tool_call parts 找到对应的 tool_call - for (const part of message_obj.message) { - if (part.type === 'tool_call' && part.tool_calls) { - const toolCall = part.tool_calls.find((tc: ToolCall) => tc.id === resultData.id); - if (toolCall) { - toolCall.result = resultData.result; - toolCall.finished_ts = resultData.ts; - break; - } + const updateToolCallInContent = (content: MessageContent | null | undefined): boolean => { + if (!content || !Array.isArray(content.message)) { + return false; + } + for (const part of content.message) { + if (part.type !== 'tool_call' || !part.tool_calls) { + continue; + } + const toolCall = part.tool_calls.find((tc: ToolCall) => tc.id === resultData.id); + if (!toolCall) { + continue; + } + toolCall.result = resultData.result; + toolCall.finished_ts = resultData.ts; + return true; + } + return false; + }; + + let updated = updateToolCallInContent(message_obj); + if (!updated) { + for (let i = messages.value.length - 1; i >= 0; i--) { + const message = messages.value[i]?.content; + if (message?.type !== 'bot') { + continue; + } + if (updateToolCallInContent(message)) { + updated = true; + break; } } } @@ -548,7 +567,14 @@ export function useMessages( } } - if ((chunk_json.type === 'break' && chunk_json.streaming) || !chunk_json.streaming) { + const isToolCallEvent = + chunk_json.type === 'plain' && + (chunk_json.chain_type === 'tool_call' || chunk_json.chain_type === 'tool_call_result'); + + if ( + ((chunk_json.type === 'break' && chunk_json.streaming) || !chunk_json.streaming) && + !isToolCallEvent + ) { in_streaming = false; if (!chunk_json.streaming) { isStreaming.value = false; From 3da4cc3202f330dde66d033676431f8e8d4415ec Mon Sep 17 00:00:00 2001 From: xunxi Date: Thu, 12 Feb 2026 03:06:53 +0800 Subject: [PATCH 6/7] =?UTF-8?q?=E6=9B=B4=E6=94=B9=E4=B8=80=E4=BA=9B?= =?UTF-8?q?=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/agent/context/manager.py | 4 ++-- astrbot/core/astr_main_agent.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 216a3e7e1..5d879c611 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -55,7 +55,7 @@ async def process( try: result = messages - # 1. 基于轮次的截断 (Enforce max turns) + # 1. Enforce max turns if self.config.enforce_max_turns != -1: result = self.truncator.truncate_by_turns( result, @@ -63,7 +63,7 @@ async def process( drop_turns=self.config.truncate_turns, ) - # 2. 基于 token 的压缩 + # 2. Token-based compression if self.config.max_context_tokens > 0: total_tokens = self.token_counter.count_tokens( result, trusted_token_usage diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 796886a53..b8a0f9456 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -829,13 +829,13 @@ def _get_compress_provider( ) if selected_provider is None: logger.warning( - "未找到指定的上下文压缩模型 %s,将跳过压缩。", + "Configured llm_compress_provider_id not found: %s. Skip compression.", config.llm_compress_provider_id, ) return None if not isinstance(selected_provider, Provider): logger.warning( - "指定的上下文压缩模型 %s 不是对话模型,将跳过压缩。", + "Configured llm_compress_provider_id is not a Provider: %s. Skip compression.", config.llm_compress_provider_id, ) return None From e89e5e6f4c38ed0b8506283cdff9440d3fbdaefd Mon Sep 17 00:00:00 2001 From: xunxi Date: Fri, 13 Feb 2026 13:59:02 +0800 Subject: [PATCH 7/7] =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/agent/context/compressor.py | 17 +++++++++++++---- astrbot/core/agent/context/config.py | 2 ++ astrbot/core/agent/context/manager.py | 1 + .../agent/runners/tool_loop_agent_runner.py | 3 +++ astrbot/core/astr_main_agent.py | 1 + 5 files changed, 20 insertions(+), 4 deletions(-) diff --git a/astrbot/core/agent/context/compressor.py b/astrbot/core/agent/context/compressor.py index 6686b5279..596fc82da 100644 --- a/astrbot/core/agent/context/compressor.py +++ b/astrbot/core/agent/context/compressor.py @@ -1,3 +1,5 @@ +import typing as T +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Protocol, runtime_checkable from ..message import Message @@ -154,6 +156,7 @@ def __init__( keep_recent: int = 4, instruction_text: str | None = None, compression_threshold: float = 0.82, + use_compact_api: bool = True, ) -> None: """Initialize the LLM summary compressor. @@ -162,10 +165,12 @@ def __init__( keep_recent: The number of latest messages to keep (default: 4). instruction_text: Custom instruction for summary generation. compression_threshold: The compression trigger threshold (default: 0.82). + use_compact_api: Whether to prefer provider native compact API when available. """ self.provider = provider self.keep_recent = keep_recent self.compression_threshold = compression_threshold + self.use_compact_api = use_compact_api self.instruction_text = instruction_text or ( "Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n" @@ -212,8 +217,13 @@ async def _try_native_compact( if not callable(compact_context): return None + compact_context_callable = T.cast( + "Callable[[list[Message]], Awaitable[list[Message]]]", + compact_context, + ) + try: - compacted_messages = await compact_context(messages_to_summarize) + compacted_messages = await compact_context_callable(messages_to_summarize) except Exception as e: logger.warning( f"Native compact failed, fallback to summary compression: {e}" @@ -246,9 +256,8 @@ async def __call__(self, messages: list[Message]) -> list[Message]: if not messages_to_summarize: return messages - native_compact_supported = self._supports_native_compact() - - if native_compact_supported: + # Only try native compact if user allows it and provider supports it + if self.use_compact_api and self._supports_native_compact(): compacted = await self._try_native_compact( system_messages, messages_to_summarize, diff --git a/astrbot/core/agent/context/config.py b/astrbot/core/agent/context/config.py index b8fd8eb96..4d861a43a 100644 --- a/astrbot/core/agent/context/config.py +++ b/astrbot/core/agent/context/config.py @@ -29,6 +29,8 @@ class ContextConfig: """Number of recent messages to keep during LLM-based compression.""" llm_compress_provider: "Provider | None" = None """LLM provider used for compression tasks. If None, truncation strategy is used.""" + llm_compress_use_compact_api: bool = True + """Whether to prefer provider native compact API when available.""" custom_token_counter: TokenCounter | None = None """Custom token counting method. If None, the default method is used.""" custom_compressor: ContextCompressor | None = None diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 5d879c611..2deecff82 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -35,6 +35,7 @@ def __init__( provider=config.llm_compress_provider, keep_recent=config.llm_compress_keep_recent, instruction_text=config.llm_compress_instruction, + use_compact_api=config.llm_compress_use_compact_api, ) else: self.compressor = TruncateByTurnsCompressor( diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index a3dbf44ba..b974da75e 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -85,6 +85,7 @@ async def reset( llm_compress_instruction: str | None = None, llm_compress_keep_recent: int = 0, llm_compress_provider: Provider | None = None, + llm_compress_use_compact_api: bool = True, # truncate by turns compressor truncate_turns: int = 1, # customize @@ -99,6 +100,7 @@ async def reset( self.llm_compress_instruction = llm_compress_instruction self.llm_compress_keep_recent = llm_compress_keep_recent self.llm_compress_provider = llm_compress_provider + self.llm_compress_use_compact_api = llm_compress_use_compact_api self.truncate_turns = truncate_turns self.custom_token_counter = custom_token_counter self.custom_compressor = custom_compressor @@ -114,6 +116,7 @@ async def reset( llm_compress_instruction=self.llm_compress_instruction, llm_compress_keep_recent=self.llm_compress_keep_recent, llm_compress_provider=self.llm_compress_provider, + llm_compress_use_compact_api=self.llm_compress_use_compact_api, custom_token_counter=self.custom_token_counter, custom_compressor=self.custom_compressor, ) diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index b8a0f9456..3018ae63c 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -985,6 +985,7 @@ async def build_main_agent( llm_compress_instruction=config.llm_compress_instruction, llm_compress_keep_recent=config.llm_compress_keep_recent, llm_compress_provider=_get_compress_provider(config, plugin_context, provider), + llm_compress_use_compact_api=config.llm_compress_use_compact_api, truncate_turns=config.dequeue_context_length, enforce_max_turns=config.max_context_length, tool_schema_mode=config.tool_schema_mode,