diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 09fdd9507e..13302dbeb9 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -325,6 +325,8 @@ def _get_azure_ad_token(self) -> str | None: @override def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + self._refresh_api_key() + headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {} options = model_copy(options) @@ -612,6 +614,8 @@ async def _get_azure_ad_token(self) -> str | None: @override async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + await self._refresh_api_key() + headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {} options = model_copy(options) diff --git a/tests/lib/test_azure.py b/tests/lib/test_azure.py index 52c24eba27..4b7ddf036e 100644 --- a/tests/lib/test_azure.py +++ b/tests/lib/test_azure.py @@ -114,6 +114,44 @@ def token_provider() -> str: assert calls[1].request.headers.get("Authorization") == "Bearer second" +@pytest.mark.respx() +def test_client_api_key_provider_refresh_sync(respx_mock: MockRouter) -> None: + respx_mock.post( + "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01" + ).mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + def api_key_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = AzureOpenAI( + api_version="2024-02-01", + api_key=api_key_provider, + azure_endpoint="https://example-resource.azure.openai.com", + ) + client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + + assert len(calls) == 2 + + assert calls[0].request.headers.get("api-key") == "first" + assert calls[1].request.headers.get("api-key") == "second" + + @pytest.mark.asyncio @pytest.mark.respx() async def test_client_token_provider_refresh_async(respx_mock: MockRouter) -> None: @@ -154,6 +192,46 @@ def token_provider() -> str: assert calls[1].request.headers.get("Authorization") == "Bearer second" +@pytest.mark.asyncio +@pytest.mark.respx() +async def test_client_api_key_provider_refresh_async(respx_mock: MockRouter) -> None: + respx_mock.post( + "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01" + ).mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + async def api_key_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = AsyncAzureOpenAI( + api_version="2024-02-01", + api_key=api_key_provider, + azure_endpoint="https://example-resource.azure.openai.com", + ) + + await client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + + assert len(calls) == 2 + + assert calls[0].request.headers.get("api-key") == "first" + assert calls[1].request.headers.get("api-key") == "second" + + class TestAzureLogging: @pytest.fixture(autouse=True) def logger_with_filter(self) -> logging.Logger: