diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index b19f3460dd..3285c0de03 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -438,11 +438,6 @@ async def _fallback_to_text_only_and_retry( image_fallback_used, ) - def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None: - """创建带代理的 HTTP 客户端""" - proxy = provider_config.get("proxy", "") - return create_proxy_client("OpenAI", proxy) - def __init__(self, provider_config, provider_settings) -> None: super().__init__(provider_config, provider_settings) self.chosen_api_key = None @@ -450,6 +445,8 @@ def __init__(self, provider_config, provider_settings) -> None: self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None self.timeout = provider_config.get("timeout", 120) self.custom_headers = provider_config.get("custom_headers", {}) + self.client: AsyncOpenAI | AsyncAzureOpenAI | None = None + self._client_alive = False if isinstance(self.timeout, str): self.timeout = int(self.timeout) @@ -459,34 +456,59 @@ def __init__(self, provider_config, provider_settings) -> None: for key in self.custom_headers: self.custom_headers[key] = str(self.custom_headers[key]) - if "api_version" in provider_config: + self.client = self._create_openai_client() + self._client_alive = True + + self.default_params = inspect.signature( + self.client.chat.completions.create, + ).parameters.keys() + + model = provider_config.get("model", "unknown") + self.set_model(model) + + self.reasoning_key = "reasoning_content" + + def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None: + """创建带代理的 HTTP 客户端""" + proxy = provider_config.get("proxy", "") + + return create_proxy_client("OpenAI", proxy) + + def _create_openai_client( + self, + api_key: str | None = None, + ) -> AsyncOpenAI | AsyncAzureOpenAI: + """创建 OpenAI/Azure 客户端实例,将初始化逻辑解耦以便复用。""" + api_key = api_key or self.chosen_api_key + if "api_version" in self.provider_config: # Using Azure OpenAI API - self.client = AsyncAzureOpenAI( - api_key=self.chosen_api_key, - api_version=provider_config.get("api_version", None), + return AsyncAzureOpenAI( + api_key=api_key, + api_version=self.provider_config.get("api_version", None), default_headers=self.custom_headers, - base_url=provider_config.get("api_base", ""), + base_url=self.provider_config.get("api_base", ""), timeout=self.timeout, - http_client=self._create_http_client(provider_config), + http_client=self._create_http_client(self.provider_config), ) else: # Using OpenAI Official API - self.client = AsyncOpenAI( - api_key=self.chosen_api_key, - base_url=provider_config.get("api_base", None), + return AsyncOpenAI( + api_key=api_key, + base_url=self.provider_config.get("api_base", None), default_headers=self.custom_headers, timeout=self.timeout, - http_client=self._create_http_client(provider_config), + http_client=self._create_http_client(self.provider_config), ) - self.default_params = inspect.signature( - self.client.chat.completions.create, - ).parameters.keys() - - model = provider_config.get("model", "unknown") - self.set_model(model) - - self.reasoning_key = "reasoning_content" + def _ensure_client(self) -> None: + """确保 client 可用,仅在真实 API 调用前按需重建。""" + if self.client is None or not self._client_alive: + logger.warning("检测到 OpenAI client 已关闭或未初始化,正在重新创建...") + self.client = self._create_openai_client() + self._client_alive = True + self.default_params = inspect.signature( + self.client.chat.completions.create, + ).parameters.keys() def _ollama_disable_thinking_enabled(self) -> bool: value = self.provider_config.get("ollama_disable_thinking", False) @@ -509,6 +531,7 @@ def _apply_provider_specific_extra_body_overrides( extra_body["reasoning_effort"] = "none" async def get_models(self): + self._ensure_client() try: models_str = [] models = await self.client.models.list() @@ -520,6 +543,7 @@ async def get_models(self): raise Exception(f"获取模型列表失败:{e}") async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: + self._ensure_client() if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -592,6 +616,7 @@ async def _query_stream( tools: ToolSet | None, ) -> AsyncGenerator[LLMResponse, None]: """流式查询API,逐步返回结果""" + self._ensure_client() if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -1145,7 +1170,10 @@ async def text_chat( retry_cnt = 0 for retry_cnt in range(max_retries): try: - self.client.api_key = chosen_key + self.chosen_api_key = chosen_key + self._ensure_client() + if self.client is not None: + self.client.api_key = chosen_key llm_response = await self._query(payloads, func_tool) break except Exception as e: @@ -1216,7 +1244,10 @@ async def text_chat_stream( retry_cnt = 0 for retry_cnt in range(max_retries): try: - self.client.api_key = chosen_key + self.chosen_api_key = chosen_key + self._ensure_client() + if self.client is not None: + self.client.api_key = chosen_key async for response in self._query_stream(payloads, func_tool): yield response break @@ -1270,13 +1301,15 @@ async def _remove_image_from_context(self, contexts: list): return new_contexts def get_current_key(self) -> str: - return self.client.api_key + return self.chosen_api_key def get_keys(self) -> list[str]: return self.api_keys def set_key(self, key) -> None: - self.client.api_key = key + self.chosen_api_key = key + if self.client is not None: + self.client.api_key = key async def assemble_context( self, @@ -1355,5 +1388,12 @@ async def encode_image_bs64(self, image_url: str) -> str: return image_data async def terminate(self): + """关闭 client 并将引用置为 None,确保后续仅在真实调用时重建。""" if self.client: - await self.client.close() + try: + await self.client.close() + except Exception as e: + logger.warning(f"关闭 OpenAI client 时出错: {e}") + finally: + self.client = None + self._client_alive = False