Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 90 additions & 4 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ async def reset(
custom_token_counter: TokenCounter | None = None,
custom_compressor: ContextCompressor | None = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
Expand Down Expand Up @@ -120,6 +121,17 @@ async def reset(
self.context_manager = ContextManager(self.context_config)

self.provider = provider
self.fallback_providers: list[Provider] = []
seen_provider_ids: set[str] = {str(provider.provider_config.get("id", ""))}
for fallback_provider in fallback_providers or []:
fallback_id = str(fallback_provider.provider_config.get("id", ""))
if fallback_provider is provider:
continue
if fallback_id and fallback_id in seen_provider_ids:
continue
self.fallback_providers.append(fallback_provider)
if fallback_id:
seen_provider_ids.add(fallback_id)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.tool_executor = tool_executor
Expand Down Expand Up @@ -166,23 +178,97 @@ async def reset(
self.stats = AgentStats()
self.stats.start_time = time.time()

async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
async def _iter_llm_responses(
self, *, include_model: bool = True
) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
payload = {
"contexts": self.run_context.messages, # list[Message]
"func_tool": self.req.func_tool,
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
}

if include_model:
# For primary provider we keep explicit model selection if provided.
payload["model"] = self.req.model
if self.streaming:
stream = self.provider.text_chat_stream(**payload)
async for resp in stream: # type: ignore
yield resp
else:
yield await self.provider.text_chat(**payload)

async def _iter_llm_responses_with_fallback(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): 建议将 fallback 循环拆分为更小的辅助函数,把模型选择逻辑从 _iter_llm_responses 中移出,并集中处理 provider 去重,以简化控制流并减少状态跟踪的复杂度。

  1. _iter_llm_responses_with_fallback 拆分为更小的辅助方法,以避免多 flag 组合的复杂度
    当前迭代器在单个循环中同时维护 has_stream_outputis_last_candidatelast_exceptionlast_err_response,并使用 break/continue/return 控制流程。可以把“单个 provider 的运行逻辑”移动到一个专门的 helper 中,让它返回清晰的状态;然后在外层用一个更简单的 fallback 循环来根据状态做决策。

例如:

class ProviderRunStatus(enum.Enum):
    SUCCESS = "success"
    ERROR_BEFORE_STREAM = "error_before_stream"
    ERROR_DURING_STREAM = "error_during_stream"


async def _iter_single_provider(
    self, *, include_model: bool
) -> tuple[ProviderRunStatus, LLMResponse | None]:
    """Yield responses for current provider; return status + final response if any."""
    has_stream_output = False
    final_resp: LLMResponse | None = None

    async for resp in self._iter_llm_responses(include_model=include_model):
        if resp.is_chunk:
            has_stream_output = True
            yield resp
            continue

        final_resp = resp
        # distinguish error-before-stream vs after-stream
        if resp.role == "err" and not has_stream_output:
            return ProviderRunStatus.ERROR_BEFORE_STREAM, resp
        return ProviderRunStatus.SUCCESS, resp

    # reached end-of-stream without final_resp (defensive)
    if has_stream_output:
        return ProviderRunStatus.SUCCESS, final_resp
    return ProviderRunStatus.ERROR_BEFORE_STREAM, final_resp

然后 fallback 包装器就可以主要负责在 providers 之间循环,并根据状态做出反应,而不需要在一个循环里维护多个内部标志:

async def _iter_llm_responses_with_fallback(self) -> AsyncIterator[LLMResponse]:
    candidates = [self.provider, *self.fallback_providers]
    last_err_response: LLMResponse | None = None
    last_exception: Exception | None = None

    for idx, candidate in enumerate(candidates):
        is_last_candidate = idx == len(candidates) - 1
        self.provider = candidate

        try:
            status, final_resp = await self._iter_single_provider(include_model=(idx == 0))
        except Exception as exc:  # noqa: BLE001
            last_exception = exc
            # log + continue
            continue

        if status is ProviderRunStatus.ERROR_BEFORE_STREAM and not is_last_candidate:
            last_err_response = final_resp
            # log + try next provider
            continue

        # success or last provider error
        if final_resp:
            yield final_resp
        return

    # fallbacks exhausted
    if last_err_response:
        yield last_err_response
        return
    if last_exception:
        yield LLMResponse(
            role="err",
            completion_text=f"All chat models failed: {type(last_exception).__name__}: {last_exception}",
        )
        return
    yield LLMResponse(
        role="err",
        completion_text="All available chat models are unavailable.",
    )

这可以保持当前的行为(在输出任何 chunk 之前发生错误 → 尝试下一个 provider;一旦开始流式输出,就坚持使用该 provider),同时避免在同一个循环里对多个标志进行复杂推理。

  1. 将 “include_model” 与 _iter_llm_responses 解耦
    include_model 标志仅用于区分主 provider 与 fallback provider。可以让 _iter_llm_responses 专注于“为当前 provider 迭代 LLM 响应”,而把模型选择逻辑移到调用方。

例如:

def _build_llm_payload(self, *, model: str | None) -> dict[str, Any]:
    payload: dict[str, Any] = {
        "contexts": self.run_context.messages,
        "func_tool": self.req.func_tool,
        "session_id": self.req.session_id,
        "extra_user_content_parts": self.req.extra_user_content_parts,
    }
    if model:
        payload["model"] = model
    return payload


async def _iter_llm_responses(
    self, payload: dict[str, Any]
) -> AsyncIterator[LLMResponse]:
    if self.streaming:
        async for resp in self.provider.text_chat_stream(**payload):  # type: ignore
            yield resp
    else:
        yield await self.provider.text_chat(**payload)

然后由调用方决定是传入 self.req.model(主 provider),还是传入 None(fallback):

payload = self._build_llm_payload(model=self.req.model if is_primary else None)
async for resp in self._iter_llm_responses(payload):
    ...
  1. 集中处理 provider 去重逻辑
    你已经在 __init__ 里做了基于 id 的去重,而且 reset / _get_fallback_chat_providers 似乎也有类似逻辑。为了避免在两个地方分别维护唯一性,可以把这部分提取到一个小的 helper 中,并在各处复用:
def _dedupe_providers(
    self, primary: Provider, fallbacks: list[Provider] | None
) -> list[Provider]:
    seen_ids: set[str] = {str(primary.provider_config.get("id", ""))}
    deduped: list[Provider] = []

    for p in fallbacks or []:
        pid = str(p.provider_config.get("id", ""))
        if p is primary:
            continue
        if pid and pid in seen_ids:
            continue
        deduped.append(p)
        if pid:
            seen_ids.add(pid)
    return deduped

__init__ 和任何 reset / _get_fallback_chat_providers 逻辑中使用它:

self.provider = provider
self.fallback_providers = self._dedupe_providers(provider, fallback_providers)

这样可以去除重复的 id 检查逻辑,并将唯一性规则集中在一个更易审查的位置。

Original comment in English

issue (complexity): Consider extracting the fallback loop into smaller helpers, moving model-selection out of _iter_llm_responses, and centralizing provider de-duplication to simplify control flow and reduce state tracking complexity.

  1. Split _iter_llm_responses_with_fallback into smaller helpers to avoid flag juggling
    Right now the iterator is tracking has_stream_output, is_last_candidate, last_exception, and last_err_response in a single loop with break/continue/return. You can move the “single provider” logic into a focused helper that reports a clear status, and keep the high‑level fallback loop simple.

For example:

class ProviderRunStatus(enum.Enum):
    SUCCESS = "success"
    ERROR_BEFORE_STREAM = "error_before_stream"
    ERROR_DURING_STREAM = "error_during_stream"


async def _iter_single_provider(
    self, *, include_model: bool
) -> tuple[ProviderRunStatus, LLMResponse | None]:
    """Yield responses for current provider; return status + final response if any."""
    has_stream_output = False
    final_resp: LLMResponse | None = None

    async for resp in self._iter_llm_responses(include_model=include_model):
        if resp.is_chunk:
            has_stream_output = True
            yield resp
            continue

        final_resp = resp
        # distinguish error-before-stream vs after-stream
        if resp.role == "err" and not has_stream_output:
            return ProviderRunStatus.ERROR_BEFORE_STREAM, resp
        return ProviderRunStatus.SUCCESS, resp

    # reached end-of-stream without final_resp (defensive)
    if has_stream_output:
        return ProviderRunStatus.SUCCESS, final_resp
    return ProviderRunStatus.ERROR_BEFORE_STREAM, final_resp

Then the fallback wrapper becomes mostly about looping providers and reacting to the status, without internal flags:

async def _iter_llm_responses_with_fallback(self) -> AsyncIterator[LLMResponse]:
    candidates = [self.provider, *self.fallback_providers]
    last_err_response: LLMResponse | None = None
    last_exception: Exception | None = None

    for idx, candidate in enumerate(candidates):
        is_last_candidate = idx == len(candidates) - 1
        self.provider = candidate

        try:
            status, final_resp = await self._iter_single_provider(include_model=(idx == 0))
        except Exception as exc:  # noqa: BLE001
            last_exception = exc
            # log + continue
            continue

        if status is ProviderRunStatus.ERROR_BEFORE_STREAM and not is_last_candidate:
            last_err_response = final_resp
            # log + try next provider
            continue

        # success or last provider error
        if final_resp:
            yield final_resp
        return

    # fallbacks exhausted
    if last_err_response:
        yield last_err_response
        return
    if last_exception:
        yield LLMResponse(
            role="err",
            completion_text=f"All chat models failed: {type(last_exception).__name__}: {last_exception}",
        )
        return
    yield LLMResponse(
        role="err",
        completion_text="All available chat models are unavailable.",
    )

This preserves the current behavior (error before any chunk → try next provider; once streaming starts, stick with that provider) but removes the need to reason about multiple flags in the same loop.

  1. Decouple “include_model” from _iter_llm_responses
    The include_model flag exists solely to distinguish primary vs fallback providers. You can keep _iter_llm_responses focused on “iterate LLM responses for the current provider” by moving model selection into the caller.

For example:

def _build_llm_payload(self, *, model: str | None) -> dict[str, Any]:
    payload: dict[str, Any] = {
        "contexts": self.run_context.messages,
        "func_tool": self.req.func_tool,
        "session_id": self.req.session_id,
        "extra_user_content_parts": self.req.extra_user_content_parts,
    }
    if model:
        payload["model"] = model
    return payload


async def _iter_llm_responses(
    self, payload: dict[str, Any]
) -> AsyncIterator[LLMResponse]:
    if self.streaming:
        async for resp in self.provider.text_chat_stream(**payload):  # type: ignore
            yield resp
    else:
        yield await self.provider.text_chat(**payload)

Then callers decide whether to pass self.req.model (primary) or None (fallback), instead of _iter_llm_responses knowing about primary vs fallback:

payload = self._build_llm_payload(model=self.req.model if is_primary else None)
async for resp in self._iter_llm_responses(payload):
    ...
  1. Centralize provider de‑duplication
    You’re already doing id-based deduplication in __init__, and it sounds like reset / _get_fallback_chat_providers have similar logic. To avoid two places enforcing uniqueness, factor this into a small helper used everywhere:
def _dedupe_providers(
    self, primary: Provider, fallbacks: list[Provider] | None
) -> list[Provider]:
    seen_ids: set[str] = {str(primary.provider_config.get("id", ""))}
    deduped: list[Provider] = []

    for p in fallbacks or []:
        pid = str(p.provider_config.get("id", ""))
        if p is primary:
            continue
        if pid and pid in seen_ids:
            continue
        deduped.append(p)
        if pid:
            seen_ids.add(pid)
    return deduped

Use it in __init__ and any reset/_get_fallback_chat_providers logic:

self.provider = provider
self.fallback_providers = self._dedupe_providers(provider, fallback_providers)

This removes duplicated id-checking and keeps the uniqueness rules in a single, easy-to-audit place.

self,
) -> T.AsyncGenerator[LLMResponse, None]:
"""Wrap _iter_llm_responses with provider fallback handling."""
candidates = [self.provider, *self.fallback_providers]
total_candidates = len(candidates)
last_exception: Exception | None = None
last_err_response: LLMResponse | None = None

for idx, candidate in enumerate(candidates):
candidate_id = candidate.provider_config.get("id", "<unknown>")
is_last_candidate = idx == total_candidates - 1
if idx > 0:
logger.warning(
"Switched from %s to fallback chat provider: %s",
self.provider.provider_config.get("id", "<unknown>"),
candidate_id,
)
self.provider = candidate
has_stream_output = False
try:
async for resp in self._iter_llm_responses(include_model=idx == 0):
if resp.is_chunk:
has_stream_output = True
yield resp
continue

if (
resp.role == "err"
Comment on lines +220 to +229
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): 当前在开始流式输出后发生的异常仍然会触发 fallback,这可能导致来自不同 provider 的输出被混合。

由于 has_stream_output 只在正常响应路径中被检查,当已经产出了一些 chunk 之后再发生异常时,仍会触发 fallback。这意味着调用方可能会先收到来自 provider A 的部分输出,然后再收到 provider B 的完整响应,对下游使用方来说这很可能是错误或未定义的行为。

except Exception 分支中,建议利用 has_stream_output 来:

  • 一旦已经输出过任何 chunk,就直接终止并暴露错误(或重新抛出异常),或者
  • 在开始流式输出之后,对当前候选 provider 禁用 fallback。

这样可以让异常路径与当前逻辑保持一致:只有在没有产生任何流式输出的情况下才进行 fallback。

Original comment in English

issue (bug_risk): Exceptions after streaming has started currently trigger fallback, which may mix outputs from different providers.

Because has_stream_output is only checked in the normal response path, an exception after some chunks have been yielded will still trigger fallback. That means the caller can see partial output from provider A followed by a full response from provider B, which is likely incorrect/undefined behavior for consumers.

In the except Exception block, consider using has_stream_output to either:

  • stop and surface an error (or re-raise) once any chunks have been emitted, or
  • skip fallback for this candidate after streaming has started.

That would align the exception path with the existing logic that only falls back when no streamed output was produced.

and not has_stream_output
and (not is_last_candidate)
):
last_err_response = resp
logger.warning(
"Chat Model %s returns error response, trying fallback to next provider.",
candidate_id,
)
break

yield resp
return

if has_stream_output:
return
except Exception as exc: # noqa: BLE001
last_exception = exc
logger.warning(
"Chat Model %s request error: %s",
candidate_id,
exc,
exc_info=True,
)
continue

if last_err_response:
yield last_err_response
return
if last_exception:
yield LLMResponse(
role="err",
completion_text=(
"All chat models failed: "
f"{type(last_exception).__name__}: {last_exception}"
),
)
return
yield LLMResponse(
role="err",
completion_text="All available chat models are unavailable.",
)

def _simple_print_message_role(self, tag: str = ""):
roles = []
for message in self.run_context.messages:
Expand Down Expand Up @@ -215,7 +301,7 @@ async def step(self):
)
self._simple_print_message_role("[AftCompact]")

async for llm_response in self._iter_llm_responses():
async for llm_response in self._iter_llm_responses_with_fallback():
if llm_response.is_chunk:
# update ttft
if self.stats.time_to_first_token == 0:
Expand Down
38 changes: 38 additions & 0 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,41 @@ def _get_compress_provider(
return provider


def _get_fallback_chat_providers(
provider: Provider, plugin_context: Context, provider_settings: dict
) -> list[Provider]:
fallback_ids = provider_settings.get("fallback_chat_models", [])
if not isinstance(fallback_ids, list):
logger.warning(
"fallback_chat_models setting is not a list, skip fallback providers."
)
return []

provider_id = str(provider.provider_config.get("id", ""))
seen_provider_ids: set[str] = {provider_id} if provider_id else set()
fallbacks: list[Provider] = []

for fallback_id in fallback_ids:
if not isinstance(fallback_id, str) or not fallback_id:
continue
if fallback_id in seen_provider_ids:
continue
fallback_provider = plugin_context.get_provider_by_id(fallback_id)
if fallback_provider is None:
logger.warning("Fallback chat provider `%s` not found, skip.", fallback_id)
continue
if not isinstance(fallback_provider, Provider):
logger.warning(
"Fallback chat provider `%s` is invalid type: %s, skip.",
fallback_id,
type(fallback_provider),
)
continue
fallbacks.append(fallback_provider)
seen_provider_ids.add(fallback_id)
return fallbacks


async def build_main_agent(
*,
event: AstrMessageEvent,
Expand Down Expand Up @@ -1093,6 +1128,9 @@ async def build_main_agent(
truncate_turns=config.dequeue_context_length,
enforce_max_turns=config.max_context_length,
tool_schema_mode=config.tool_schema_mode,
fallback_providers=_get_fallback_chat_providers(
provider, plugin_context, config.provider_settings
),
)

if apply_reset:
Expand Down
16 changes: 14 additions & 2 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"provider_settings": {
"enable": True,
"default_provider_id": "",
"fallback_chat_models": [],
"default_image_caption_provider_id": "",
"image_caption_prompt": "Please describe the image using Chinese.",
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
Expand Down Expand Up @@ -2207,6 +2208,10 @@ class ChatProviderTemplate(TypedDict):
"default_provider_id": {
"type": "string",
},
"fallback_chat_models": {
"type": "list",
"items": {"type": "string"},
},
"wake_prefix": {
"type": "string",
},
Expand Down Expand Up @@ -2504,15 +2509,22 @@ class ChatProviderTemplate(TypedDict):
},
"ai": {
"description": "模型",
"hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
"hint": "当使用非内置 Agent 执行器时,默认对话模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
"type": "object",
"items": {
"provider_settings.default_provider_id": {
"description": "默认聊天模型",
"description": "默认对话模型",
"type": "string",
"_special": "select_provider",
"hint": "留空时使用第一个模型",
},
"provider_settings.fallback_chat_models": {
"description": "回退对话模型列表",
"type": "list",
"items": {"type": "string"},
"_special": "select_providers",
"hint": "主聊天模型请求失败时,按顺序切换到这些模型。",
},
"provider_settings.default_image_caption_provider_id": {
"description": "默认图片转述模型",
"type": "string",
Expand Down
8 changes: 8 additions & 0 deletions dashboard/src/components/shared/ConfigItemRenderer.vue
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
<template v-else-if="itemMeta?._special === 'select_provider_tts'">
<ProviderSelector :model-value="modelValue" @update:model-value="emitUpdate" :provider-type="'text_to_speech'" />
</template>
<template v-else-if="itemMeta?._special === 'select_providers'">
<ProviderSelector
:model-value="modelValue"
@update:model-value="emitUpdate"
:provider-type="'chat_completion'"
:multiple="true"
/>
</template>
<template v-else-if="getSpecialName(itemMeta?._special) === 'select_agent_runner_provider'">
<ProviderSelector
:model-value="modelValue"
Expand Down
Loading