diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index d6aed6dfa..8309e6674 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -91,6 +91,7 @@ async def reset( custom_token_counter: TokenCounter | None = None, custom_compressor: ContextCompressor | None = None, tool_schema_mode: str | None = "full", + fallback_providers: list[Provider] | None = None, **kwargs: T.Any, ) -> None: self.req = request @@ -120,6 +121,17 @@ async def reset( self.context_manager = ContextManager(self.context_config) self.provider = provider + self.fallback_providers: list[Provider] = [] + seen_provider_ids: set[str] = {str(provider.provider_config.get("id", ""))} + for fallback_provider in fallback_providers or []: + fallback_id = str(fallback_provider.provider_config.get("id", "")) + if fallback_provider is provider: + continue + if fallback_id and fallback_id in seen_provider_ids: + continue + self.fallback_providers.append(fallback_provider) + if fallback_id: + seen_provider_ids.add(fallback_id) self.final_llm_resp = None self._state = AgentState.IDLE self.tool_executor = tool_executor @@ -166,16 +178,19 @@ async def reset( self.stats = AgentStats() self.stats.start_time = time.time() - async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: + async def _iter_llm_responses( + self, *, include_model: bool = True + ) -> T.AsyncGenerator[LLMResponse, None]: """Yields chunks *and* a final LLMResponse.""" payload = { "contexts": self.run_context.messages, # list[Message] "func_tool": self.req.func_tool, - "model": self.req.model, # NOTE: in fact, this arg is None in most cases "session_id": self.req.session_id, "extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart] } - + if include_model: + # For primary provider we keep explicit model selection if provided. + payload["model"] = self.req.model if self.streaming: stream = self.provider.text_chat_stream(**payload) async for resp in stream: # type: ignore @@ -183,6 +198,77 @@ async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: else: yield await self.provider.text_chat(**payload) + async def _iter_llm_responses_with_fallback( + self, + ) -> T.AsyncGenerator[LLMResponse, None]: + """Wrap _iter_llm_responses with provider fallback handling.""" + candidates = [self.provider, *self.fallback_providers] + total_candidates = len(candidates) + last_exception: Exception | None = None + last_err_response: LLMResponse | None = None + + for idx, candidate in enumerate(candidates): + candidate_id = candidate.provider_config.get("id", "") + is_last_candidate = idx == total_candidates - 1 + if idx > 0: + logger.warning( + "Switched from %s to fallback chat provider: %s", + self.provider.provider_config.get("id", ""), + candidate_id, + ) + self.provider = candidate + has_stream_output = False + try: + async for resp in self._iter_llm_responses(include_model=idx == 0): + if resp.is_chunk: + has_stream_output = True + yield resp + continue + + if ( + resp.role == "err" + and not has_stream_output + and (not is_last_candidate) + ): + last_err_response = resp + logger.warning( + "Chat Model %s returns error response, trying fallback to next provider.", + candidate_id, + ) + break + + yield resp + return + + if has_stream_output: + return + except Exception as exc: # noqa: BLE001 + last_exception = exc + logger.warning( + "Chat Model %s request error: %s", + candidate_id, + exc, + exc_info=True, + ) + continue + + if last_err_response: + yield last_err_response + return + if last_exception: + yield LLMResponse( + role="err", + completion_text=( + "All chat models failed: " + f"{type(last_exception).__name__}: {last_exception}" + ), + ) + return + yield LLMResponse( + role="err", + completion_text="All available chat models are unavailable.", + ) + def _simple_print_message_role(self, tag: str = ""): roles = [] for message in self.run_context.messages: @@ -215,7 +301,7 @@ async def step(self): ) self._simple_print_message_role("[AftCompact]") - async for llm_response in self._iter_llm_responses(): + async for llm_response in self._iter_llm_responses_with_fallback(): if llm_response.is_chunk: # update ttft if self.stats.time_to_first_token == 0: diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 12c4fde1d..7883dca8f 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -870,6 +870,41 @@ def _get_compress_provider( return provider +def _get_fallback_chat_providers( + provider: Provider, plugin_context: Context, provider_settings: dict +) -> list[Provider]: + fallback_ids = provider_settings.get("fallback_chat_models", []) + if not isinstance(fallback_ids, list): + logger.warning( + "fallback_chat_models setting is not a list, skip fallback providers." + ) + return [] + + provider_id = str(provider.provider_config.get("id", "")) + seen_provider_ids: set[str] = {provider_id} if provider_id else set() + fallbacks: list[Provider] = [] + + for fallback_id in fallback_ids: + if not isinstance(fallback_id, str) or not fallback_id: + continue + if fallback_id in seen_provider_ids: + continue + fallback_provider = plugin_context.get_provider_by_id(fallback_id) + if fallback_provider is None: + logger.warning("Fallback chat provider `%s` not found, skip.", fallback_id) + continue + if not isinstance(fallback_provider, Provider): + logger.warning( + "Fallback chat provider `%s` is invalid type: %s, skip.", + fallback_id, + type(fallback_provider), + ) + continue + fallbacks.append(fallback_provider) + seen_provider_ids.add(fallback_id) + return fallbacks + + async def build_main_agent( *, event: AstrMessageEvent, @@ -1093,6 +1128,9 @@ async def build_main_agent( truncate_turns=config.dequeue_context_length, enforce_max_turns=config.max_context_length, tool_schema_mode=config.tool_schema_mode, + fallback_providers=_get_fallback_chat_providers( + provider, plugin_context, config.provider_settings + ), ) if apply_reset: diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 235915c59..43d9991bd 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -68,6 +68,7 @@ "provider_settings": { "enable": True, "default_provider_id": "", + "fallback_chat_models": [], "default_image_caption_provider_id": "", "image_caption_prompt": "Please describe the image using Chinese.", "provider_pool": ["*"], # "*" 表示使用所有可用的提供者 @@ -2207,6 +2208,10 @@ class ChatProviderTemplate(TypedDict): "default_provider_id": { "type": "string", }, + "fallback_chat_models": { + "type": "list", + "items": {"type": "string"}, + }, "wake_prefix": { "type": "string", }, @@ -2504,15 +2509,22 @@ class ChatProviderTemplate(TypedDict): }, "ai": { "description": "模型", - "hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。", + "hint": "当使用非内置 Agent 执行器时,默认对话模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。", "type": "object", "items": { "provider_settings.default_provider_id": { - "description": "默认聊天模型", + "description": "默认对话模型", "type": "string", "_special": "select_provider", "hint": "留空时使用第一个模型", }, + "provider_settings.fallback_chat_models": { + "description": "回退对话模型列表", + "type": "list", + "items": {"type": "string"}, + "_special": "select_providers", + "hint": "主聊天模型请求失败时,按顺序切换到这些模型。", + }, "provider_settings.default_image_caption_provider_id": { "description": "默认图片转述模型", "type": "string", diff --git a/dashboard/src/components/shared/ConfigItemRenderer.vue b/dashboard/src/components/shared/ConfigItemRenderer.vue index 5f2341ee7..3c3262064 100644 --- a/dashboard/src/components/shared/ConfigItemRenderer.vue +++ b/dashboard/src/components/shared/ConfigItemRenderer.vue @@ -10,6 +10,14 @@ +