Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def __init__(
Note: The client should not be shared across different asyncio event loops.
client_args: Arguments for the OpenAI client (legacy approach).
For a complete list of supported arguments, see https://pypi.org/project/openai/.
The ``http_client`` key accepts either an ``httpx.AsyncClient`` instance or a
zero-argument callable that returns one. When a callable (factory) is provided,
it is invoked on every request to produce a fresh client, avoiding the
"closed client" error that occurs when the same instance is reused.
**model_config: Configuration options for the OpenAI model.

Raises:
Expand Down Expand Up @@ -552,6 +556,10 @@ async def _get_client(self) -> AsyncIterator[Any]:
- Otherwise, creates a new AsyncOpenAI client from client_args and automatically
closes it when the context exits.

If ``http_client`` in *client_args* is a callable (factory), it is invoked on each
request to produce a fresh ``httpx.AsyncClient``, preventing the "closed client" error
that occurs when the same client instance is reused across ``async with`` blocks.

Note: We create a new client per request to avoid connection sharing in the underlying
httpx client, as the asyncio event loop does not allow connections to be shared.
For more details, see https://github.com/encode/httpx/discussions/2959.
Expand All @@ -567,7 +575,11 @@ async def _get_client(self) -> AsyncIterator[Any]:
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
# httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
# refer to https://github.com/encode/httpx/discussions/2959.
async with openai.AsyncOpenAI(**self.client_args) as client:
resolved_args = dict(self.client_args)
http_client = resolved_args.get("http_client")
if http_client is not None and callable(http_client) and not hasattr(http_client, "send"):
resolved_args["http_client"] = http_client()
async with openai.AsyncOpenAI(**resolved_args) as client:
yield client

@override
Expand Down
82 changes: 82 additions & 0 deletions tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,3 +1533,85 @@ def test_format_request_messages_multiple_tool_calls_with_images():
},
]
assert tru_result == exp_result


class TestGetClientHttpClientFactory:
"""Tests for http_client factory support in _get_client."""

@pytest.mark.asyncio
async def test_http_client_factory_called_on_each_request(self):
"""When http_client is a callable, it should be invoked on every _get_client call."""
mock_http_client_1 = unittest.mock.MagicMock()
mock_http_client_2 = unittest.mock.MagicMock()
factory = unittest.mock.MagicMock(
side_effect=[mock_http_client_1, mock_http_client_2],
spec=[], # No attributes — ensures no .send
)

with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_cls:
mock_client = unittest.mock.AsyncMock()
mock_client.__aenter__ = unittest.mock.AsyncMock(return_value=mock_client)
mock_client.__aexit__ = unittest.mock.AsyncMock(return_value=None)
mock_cls.return_value = mock_client

model = OpenAIModel(
client_args={"api_key": "test-key", "http_client": factory},
model_id="gpt-4",
)

async with model._get_client():
pass
async with model._get_client():
pass

assert factory.call_count == 2
calls = mock_cls.call_args_list
assert calls[0][1]["http_client"] == mock_http_client_1
assert calls[1][1]["http_client"] == mock_http_client_2

@pytest.mark.asyncio
async def test_http_client_instance_passed_through(self):
"""When http_client is a regular instance (not callable), it should be passed as-is."""
mock_http_client = unittest.mock.MagicMock()
mock_http_client.send = unittest.mock.MagicMock() # httpx clients have .send()

with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_cls:
mock_client = unittest.mock.AsyncMock()
mock_client.__aenter__ = unittest.mock.AsyncMock(return_value=mock_client)
mock_client.__aexit__ = unittest.mock.AsyncMock(return_value=None)
mock_cls.return_value = mock_client

model = OpenAIModel(
client_args={"api_key": "test-key", "http_client": mock_http_client},
model_id="gpt-4",
)

async with model._get_client():
pass

mock_cls.assert_called_once_with(api_key="test-key", http_client=mock_http_client)

@pytest.mark.asyncio
async def test_client_args_not_mutated_by_factory(self):
"""The original client_args dict should not be mutated when using a factory."""
factory = unittest.mock.MagicMock(
return_value=unittest.mock.MagicMock(),
spec=[], # No attributes — ensures no .send
)

with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_cls:
mock_client = unittest.mock.AsyncMock()
mock_client.__aenter__ = unittest.mock.AsyncMock(return_value=mock_client)
mock_client.__aexit__ = unittest.mock.AsyncMock(return_value=None)
mock_cls.return_value = mock_client

model = OpenAIModel(
client_args={"api_key": "test-key", "http_client": factory},
model_id="gpt-4",
)

async with model._get_client():
pass

# The original client_args should still have the factory, not the resolved instance
assert model.client_args["http_client"] is factory