-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: add fallback chat model chain in tool loop runner #5109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -91,6 +91,7 @@ async def reset( | |
| custom_token_counter: TokenCounter | None = None, | ||
| custom_compressor: ContextCompressor | None = None, | ||
| tool_schema_mode: str | None = "full", | ||
| fallback_providers: list[Provider] | None = None, | ||
| **kwargs: T.Any, | ||
| ) -> None: | ||
| self.req = request | ||
|
|
@@ -120,6 +121,17 @@ async def reset( | |
| self.context_manager = ContextManager(self.context_config) | ||
|
|
||
| self.provider = provider | ||
| self.fallback_providers: list[Provider] = [] | ||
| seen_provider_ids: set[str] = {str(provider.provider_config.get("id", ""))} | ||
| for fallback_provider in fallback_providers or []: | ||
| fallback_id = str(fallback_provider.provider_config.get("id", "")) | ||
| if fallback_provider is provider: | ||
| continue | ||
| if fallback_id and fallback_id in seen_provider_ids: | ||
| continue | ||
| self.fallback_providers.append(fallback_provider) | ||
| if fallback_id: | ||
| seen_provider_ids.add(fallback_id) | ||
| self.final_llm_resp = None | ||
| self._state = AgentState.IDLE | ||
| self.tool_executor = tool_executor | ||
|
|
@@ -166,23 +178,97 @@ async def reset( | |
| self.stats = AgentStats() | ||
| self.stats.start_time = time.time() | ||
|
|
||
| async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: | ||
| async def _iter_llm_responses( | ||
| self, *, include_model: bool = True | ||
| ) -> T.AsyncGenerator[LLMResponse, None]: | ||
| """Yields chunks *and* a final LLMResponse.""" | ||
| payload = { | ||
| "contexts": self.run_context.messages, # list[Message] | ||
| "func_tool": self.req.func_tool, | ||
| "model": self.req.model, # NOTE: in fact, this arg is None in most cases | ||
| "session_id": self.req.session_id, | ||
| "extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart] | ||
| } | ||
|
|
||
| if include_model: | ||
| # For primary provider we keep explicit model selection if provided. | ||
| payload["model"] = self.req.model | ||
| if self.streaming: | ||
| stream = self.provider.text_chat_stream(**payload) | ||
| async for resp in stream: # type: ignore | ||
| yield resp | ||
| else: | ||
| yield await self.provider.text_chat(**payload) | ||
|
|
||
| async def _iter_llm_responses_with_fallback( | ||
| self, | ||
| ) -> T.AsyncGenerator[LLMResponse, None]: | ||
| """Wrap _iter_llm_responses with provider fallback handling.""" | ||
| candidates = [self.provider, *self.fallback_providers] | ||
| total_candidates = len(candidates) | ||
| last_exception: Exception | None = None | ||
| last_err_response: LLMResponse | None = None | ||
|
|
||
| for idx, candidate in enumerate(candidates): | ||
| candidate_id = candidate.provider_config.get("id", "<unknown>") | ||
| is_last_candidate = idx == total_candidates - 1 | ||
| if idx > 0: | ||
| logger.warning( | ||
| "Switched from %s to fallback chat provider: %s", | ||
| self.provider.provider_config.get("id", "<unknown>"), | ||
| candidate_id, | ||
| ) | ||
| self.provider = candidate | ||
| has_stream_output = False | ||
| try: | ||
| async for resp in self._iter_llm_responses(include_model=idx == 0): | ||
| if resp.is_chunk: | ||
| has_stream_output = True | ||
| yield resp | ||
| continue | ||
|
|
||
| if ( | ||
| resp.role == "err" | ||
|
Comment on lines
+220
to
+229
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): 当前在开始流式输出后发生的异常仍然会触发 fallback,这可能导致来自不同 provider 的输出被混合。 由于 在
这样可以让异常路径与当前逻辑保持一致:只有在没有产生任何流式输出的情况下才进行 fallback。 Original comment in Englishissue (bug_risk): Exceptions after streaming has started currently trigger fallback, which may mix outputs from different providers. Because In the
That would align the exception path with the existing logic that only falls back when no streamed output was produced. |
||
| and not has_stream_output | ||
| and (not is_last_candidate) | ||
| ): | ||
| last_err_response = resp | ||
| logger.warning( | ||
| "Chat Model %s returns error response, trying fallback to next provider.", | ||
| candidate_id, | ||
| ) | ||
| break | ||
|
|
||
| yield resp | ||
| return | ||
|
|
||
| if has_stream_output: | ||
| return | ||
| except Exception as exc: # noqa: BLE001 | ||
| last_exception = exc | ||
| logger.warning( | ||
| "Chat Model %s request error: %s", | ||
| candidate_id, | ||
| exc, | ||
| exc_info=True, | ||
| ) | ||
| continue | ||
|
|
||
| if last_err_response: | ||
| yield last_err_response | ||
| return | ||
| if last_exception: | ||
| yield LLMResponse( | ||
| role="err", | ||
| completion_text=( | ||
| "All chat models failed: " | ||
| f"{type(last_exception).__name__}: {last_exception}" | ||
| ), | ||
| ) | ||
| return | ||
| yield LLMResponse( | ||
| role="err", | ||
| completion_text="All available chat models are unavailable.", | ||
| ) | ||
|
|
||
| def _simple_print_message_role(self, tag: str = ""): | ||
| roles = [] | ||
| for message in self.run_context.messages: | ||
|
|
@@ -215,7 +301,7 @@ async def step(self): | |
| ) | ||
| self._simple_print_message_role("[AftCompact]") | ||
|
|
||
| async for llm_response in self._iter_llm_responses(): | ||
| async for llm_response in self._iter_llm_responses_with_fallback(): | ||
| if llm_response.is_chunk: | ||
| # update ttft | ||
| if self.stats.time_to_first_token == 0: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): 建议将 fallback 循环拆分为更小的辅助函数,把模型选择逻辑从
_iter_llm_responses中移出,并集中处理 provider 去重,以简化控制流并减少状态跟踪的复杂度。_iter_llm_responses_with_fallback拆分为更小的辅助方法,以避免多 flag 组合的复杂度当前迭代器在单个循环中同时维护
has_stream_output、is_last_candidate、last_exception和last_err_response,并使用break/continue/return控制流程。可以把“单个 provider 的运行逻辑”移动到一个专门的 helper 中,让它返回清晰的状态;然后在外层用一个更简单的 fallback 循环来根据状态做决策。例如:
然后 fallback 包装器就可以主要负责在 providers 之间循环,并根据状态做出反应,而不需要在一个循环里维护多个内部标志:
这可以保持当前的行为(在输出任何 chunk 之前发生错误 → 尝试下一个 provider;一旦开始流式输出,就坚持使用该 provider),同时避免在同一个循环里对多个标志进行复杂推理。
_iter_llm_responses解耦include_model标志仅用于区分主 provider 与 fallback provider。可以让_iter_llm_responses专注于“为当前 provider 迭代 LLM 响应”,而把模型选择逻辑移到调用方。例如:
然后由调用方决定是传入
self.req.model(主 provider),还是传入None(fallback):你已经在
__init__里做了基于 id 的去重,而且reset/_get_fallback_chat_providers似乎也有类似逻辑。为了避免在两个地方分别维护唯一性,可以把这部分提取到一个小的 helper 中,并在各处复用:在
__init__和任何reset/_get_fallback_chat_providers逻辑中使用它:这样可以去除重复的 id 检查逻辑,并将唯一性规则集中在一个更易审查的位置。
Original comment in English
issue (complexity): Consider extracting the fallback loop into smaller helpers, moving model-selection out of
_iter_llm_responses, and centralizing provider de-duplication to simplify control flow and reduce state tracking complexity._iter_llm_responses_with_fallbackinto smaller helpers to avoid flag jugglingRight now the iterator is tracking
has_stream_output,is_last_candidate,last_exception, andlast_err_responsein a single loop withbreak/continue/return. You can move the “single provider” logic into a focused helper that reports a clear status, and keep the high‑level fallback loop simple.For example:
Then the fallback wrapper becomes mostly about looping providers and reacting to the status, without internal flags:
This preserves the current behavior (error before any chunk → try next provider; once streaming starts, stick with that provider) but removes the need to reason about multiple flags in the same loop.
_iter_llm_responsesThe
include_modelflag exists solely to distinguish primary vs fallback providers. You can keep_iter_llm_responsesfocused on “iterate LLM responses for the current provider” by moving model selection into the caller.For example:
Then callers decide whether to pass
self.req.model(primary) orNone(fallback), instead of_iter_llm_responsesknowing about primary vs fallback:You’re already doing id-based deduplication in
__init__, and it sounds likereset/_get_fallback_chat_providershave similar logic. To avoid two places enforcing uniqueness, factor this into a small helper used everywhere:Use it in
__init__and anyreset/_get_fallback_chat_providerslogic:This removes duplicated id-checking and keeps the uniqueness rules in a single, easy-to-audit place.