diff --git a/getstream/base.py b/getstream/base.py index 3eb6f3b1..95abd2fe 100644 --- a/getstream/base.py +++ b/getstream/base.py @@ -158,12 +158,38 @@ def __init__( timeout=timeout, user_agent=user_agent, ) - self.client = httpx.Client( - base_url=self.base_url or "", - headers=self.headers, - params=self.params, - timeout=httpx.Timeout(self.timeout), - ) + http_client = getattr(self, "_http_client", None) + if http_client is not None: + if not isinstance(http_client, httpx.Client): + raise TypeError( + f"http_client must be an httpx.Client instance, " + f"got {type(http_client).__name__}" + ) + http_client.headers.update(self.headers) + http_client.params = http_client.params.merge(self.params) + http_client.base_url = self.base_url or "" + if self.timeout is not None: + http_client.timeout = httpx.Timeout(self.timeout) + self.client = http_client + self._owns_http_client = False + else: + transport = getattr(self, "_transport", None) + if transport is not None: + self.client = httpx.Client( + base_url=self.base_url or "", + headers=self.headers, + params=self.params, + timeout=httpx.Timeout(self.timeout), + transport=transport, + ) + else: + self.client = httpx.Client( + base_url=self.base_url or "", + headers=self.headers, + params=self.params, + timeout=httpx.Timeout(self.timeout), + ) + self._owns_http_client = True def __enter__(self): return self @@ -348,8 +374,13 @@ def _upload_multipart( def close(self): """ Close HTTPX client. + + If the client was provided externally via ``http_client``, this is a + no-op — the caller that created the client is responsible for closing + it. """ - self.client.close() + if getattr(self, "_owns_http_client", True): + self.client.close() class AsyncBaseClient(TelemetryEndpointMixin, BaseConfig, ResponseParserMixin, ABC): @@ -368,12 +399,38 @@ def __init__( timeout=timeout, user_agent=user_agent, ) - self.client = httpx.AsyncClient( - base_url=self.base_url or "", - headers=self.headers, - params=self.params, - timeout=httpx.Timeout(self.timeout), - ) + http_client = getattr(self, "_http_client", None) + if http_client is not None: + if not isinstance(http_client, httpx.AsyncClient): + raise TypeError( + f"http_client must be an httpx.AsyncClient instance, " + f"got {type(http_client).__name__}" + ) + http_client.headers.update(self.headers) + http_client.params = http_client.params.merge(self.params) + http_client.base_url = self.base_url or "" + if self.timeout is not None: + http_client.timeout = httpx.Timeout(self.timeout) + self.client = http_client + self._owns_http_client = False + else: + transport = getattr(self, "_transport", None) + if transport is not None: + self.client = httpx.AsyncClient( + base_url=self.base_url or "", + headers=self.headers, + params=self.params, + timeout=httpx.Timeout(self.timeout), + transport=transport, + ) + else: + self.client = httpx.AsyncClient( + base_url=self.base_url or "", + headers=self.headers, + params=self.params, + timeout=httpx.Timeout(self.timeout), + ) + self._owns_http_client = True async def __aenter__(self): return self @@ -382,8 +439,14 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self.aclose() async def aclose(self): - """Close HTTPX async client (closes pools/keep-alives).""" - await self.client.aclose() + """Close HTTPX async client (closes pools/keep-alives). + + If the client was provided externally via ``http_client``, this is a + no-op — the caller that created the client is responsible for closing + it. + """ + if getattr(self, "_owns_http_client", True): + await self.client.aclose() async def _upload_multipart( self, diff --git a/getstream/stream.py b/getstream/stream.py index ec946b01..9abd9216 100644 --- a/getstream/stream.py +++ b/getstream/stream.py @@ -6,6 +6,7 @@ from typing import List, Optional from uuid import uuid4 +import httpx import jwt from pydantic_settings import BaseSettings, SettingsConfigDict @@ -47,7 +48,12 @@ def __init__( timeout: Optional[float] = 6.0, base_url: Optional[str] = BASE_URL, user_agent: Optional[str] = None, + transport=None, + http_client=None, ): + if transport is not None and http_client is not None: + raise ValueError("Cannot specify both 'transport' and 'http_client'") + if None in (api_key, api_secret, timeout, base_url): s = Settings() # loads from env and optional .env api_key = api_key or s.api_key @@ -68,10 +74,29 @@ def __init__( self.base_url = validate_and_clean_url(base_url) self.user_agent = user_agent + self._transport = transport + self._http_client = http_client self.token = self._create_token() super().__init__( self.api_key, self.base_url, self.token, self.timeout, self.user_agent ) + # After super().__init__(), self.client is fully built and configured. + # When the user provided custom HTTP config, sub-clients share this + # client instead of each building their own. + if transport is not None or http_client is not None: + self._shared_client = self.client + else: + self._shared_client = None + + def _apply_shared_client(self, sub_client): + """Replace a sub-client's auto-created httpx client with the shared + one built from user-provided transport/http_client config.""" + if self._shared_client is not None: + if isinstance(sub_client.client, httpx.Client): + sub_client.client.close() + sub_client.client = self._shared_client + sub_client._owns_http_client = False + return sub_client def create_token( self, @@ -169,13 +194,15 @@ def video(self) -> AsyncVideoClient: Video stream client. """ - return AsyncVideoClient( - api_key=self.api_key, - base_url=self.base_url, - token=self.token, - timeout=self.timeout, - stream=self, - user_agent=self.user_agent, + return self._apply_shared_client( + AsyncVideoClient( + api_key=self.api_key, + base_url=self.base_url, + token=self.token, + timeout=self.timeout, + stream=self, + user_agent=self.user_agent, + ) ) @cached_property @@ -184,13 +211,15 @@ def chat(self) -> AsyncChatClient: Chat stream client. """ - return AsyncChatClient( - api_key=self.api_key, - base_url=self.base_url, - token=self.token, - timeout=self.timeout, - stream=self, - user_agent=self.user_agent, + return self._apply_shared_client( + AsyncChatClient( + api_key=self.api_key, + base_url=self.base_url, + token=self.token, + timeout=self.timeout, + stream=self, + user_agent=self.user_agent, + ) ) @cached_property @@ -199,13 +228,15 @@ def moderation(self) -> AsyncModerationClient: Moderation stream client. """ - return AsyncModerationClient( - api_key=self.api_key, - base_url=self.base_url, - token=self.token, - timeout=self.timeout, - stream=self, - user_agent=self.user_agent, + return self._apply_shared_client( + AsyncModerationClient( + api_key=self.api_key, + base_url=self.base_url, + token=self.token, + timeout=self.timeout, + stream=self, + user_agent=self.user_agent, + ) ) async def aclose(self): @@ -291,13 +322,15 @@ def video(self) -> VideoClient: Video stream client. """ - return VideoClient( - api_key=self.api_key, - base_url=self.base_url, - token=self.token, - timeout=self.timeout, - stream=self, - user_agent=self.user_agent, + return self._apply_shared_client( + VideoClient( + api_key=self.api_key, + base_url=self.base_url, + token=self.token, + timeout=self.timeout, + stream=self, + user_agent=self.user_agent, + ) ) @cached_property @@ -306,13 +339,15 @@ def chat(self) -> ChatClient: Chat stream client. """ - return ChatClient( - api_key=self.api_key, - base_url=self.base_url, - token=self.token, - timeout=self.timeout, - stream=self, - user_agent=self.user_agent, + return self._apply_shared_client( + ChatClient( + api_key=self.api_key, + base_url=self.base_url, + token=self.token, + timeout=self.timeout, + stream=self, + user_agent=self.user_agent, + ) ) @cached_property @@ -321,13 +356,15 @@ def moderation(self) -> ModerationClient: Moderation stream client. """ - return ModerationClient( - api_key=self.api_key, - base_url=self.base_url, - token=self.token, - timeout=self.timeout, - stream=self, - user_agent=self.user_agent, + return self._apply_shared_client( + ModerationClient( + api_key=self.api_key, + base_url=self.base_url, + token=self.token, + timeout=self.timeout, + stream=self, + user_agent=self.user_agent, + ) ) @cached_property @@ -336,13 +373,15 @@ def feeds(self) -> FeedsClient: Feeds stream client. """ - return FeedsClient( - api_key=self.api_key, - base_url=self.base_url, - token=self.token, - timeout=self.timeout, - stream=self, - user_agent=self.user_agent, + return self._apply_shared_client( + FeedsClient( + api_key=self.api_key, + base_url=self.base_url, + token=self.token, + timeout=self.timeout, + stream=self, + user_agent=self.user_agent, + ) ) @telemetry.operation_name("getstream.api.common.create_user") diff --git a/tests/test_http_client.py b/tests/test_http_client.py new file mode 100644 index 00000000..b4c30ce8 --- /dev/null +++ b/tests/test_http_client.py @@ -0,0 +1,228 @@ +import httpx +import pytest + +from getstream import Stream, AsyncStream + + +def _mock_transport(status=200, body=None): + body = body or {} + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(status, json=body, request=request) + + return httpx.MockTransport(handler) + + +def _capture_transport(): + captured = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={}, request=request) + + return httpx.MockTransport(handler), captured + + +# ── transport (primary API) ────────────────────────────────────────── + + +class TestSyncTransport: + def test_transport_used_for_requests(self): + transport, captured = _capture_transport() + client = Stream( + api_key="k", api_secret="s", base_url="http://test", transport=transport + ) + client.get_app() + assert len(captured) == 1 + + def test_stream_headers_applied(self): + transport, captured = _capture_transport() + client = Stream( + api_key="k", api_secret="s", base_url="http://test", transport=transport + ) + client.get_app() + + req = captured[0] + assert "authorization" in req.headers + assert req.headers["stream-auth-type"] == "jwt" + assert "x-stream-client" in req.headers + assert req.url.params["api_key"] == "k" + + def test_sub_clients_share_client(self): + transport = _mock_transport() + client = Stream( + api_key="k", api_secret="s", base_url="http://test", transport=transport + ) + shared = client.client + assert client.video.client is shared + assert client.chat.client is shared + assert client.moderation.client is shared + assert client.feeds.client is shared + + def test_close_closes_sdk_built_client(self): + transport = _mock_transport() + client = Stream( + api_key="k", api_secret="s", base_url="http://test", transport=transport + ) + inner = client.client + client.close() + assert inner.is_closed + + def test_sub_client_close_is_noop(self): + transport = _mock_transport() + client = Stream( + api_key="k", api_secret="s", base_url="http://test", transport=transport + ) + client.video.close() + assert not client.client.is_closed + + def test_custom_limits_propagated(self): + limits = httpx.Limits(max_connections=42, max_keepalive_connections=10) + transport = httpx.HTTPTransport(limits=limits) + client = Stream( + api_key="k", api_secret="s", base_url="http://test", transport=transport + ) + pool = client.client._transport._pool + assert pool._max_connections == 42 + assert pool._max_keepalive_connections == 10 + + def test_default_path_unchanged(self): + client = Stream(api_key="k", api_secret="s", base_url="http://test") + assert client._owns_http_client is True + assert isinstance(client.client, httpx.Client) + assert client._shared_client is None + + +@pytest.mark.asyncio +class TestAsyncTransport: + async def test_transport_used_for_requests(self): + transport, captured = _capture_transport() + client = AsyncStream( + api_key="k", api_secret="s", base_url="http://test", transport=transport + ) + await client.get_app() + assert len(captured) == 1 + + async def test_sub_clients_share_client(self): + transport = _mock_transport() + client = AsyncStream( + api_key="k", api_secret="s", base_url="http://test", transport=transport + ) + shared = client.client + assert client.video.client is shared + assert client.chat.client is shared + assert client.moderation.client is shared + + async def test_aclose_closes_sdk_built_client(self): + transport = _mock_transport() + client = AsyncStream( + api_key="k", api_secret="s", base_url="http://test", transport=transport + ) + inner = client.client + await client.aclose() + assert inner.is_closed + + +# ── http_client (escape hatch) ─────────────────────────────────────── + + +class TestSyncHttpClientEscapeHatch: + def test_custom_http_client_is_used(self): + transport, captured = _capture_transport() + custom = httpx.Client(transport=transport) + + client = Stream( + api_key="k", api_secret="s", base_url="http://test", http_client=custom + ) + assert client.client is custom + client.get_app() + assert len(captured) == 1 + + def test_sub_clients_share_custom_http_client(self): + custom = httpx.Client(transport=_mock_transport()) + client = Stream( + api_key="k", api_secret="s", base_url="http://test", http_client=custom + ) + assert client.video.client is custom + assert client.chat.client is custom + + def test_close_does_not_close_user_provided_client(self): + custom = httpx.Client(transport=_mock_transport()) + client = Stream( + api_key="k", api_secret="s", base_url="http://test", http_client=custom + ) + client.close() + assert not custom.is_closed + + def test_user_headers_preserved(self): + transport, captured = _capture_transport() + custom = httpx.Client(transport=transport, headers={"X-Custom": "val"}) + client = Stream( + api_key="k", api_secret="s", base_url="http://test", http_client=custom + ) + client.get_app() + + req = captured[0] + assert req.headers["x-custom"] == "val" + assert "authorization" in req.headers + + +@pytest.mark.asyncio +class TestAsyncHttpClientEscapeHatch: + async def test_custom_async_client_is_used(self): + transport = _mock_transport() + custom = httpx.AsyncClient(transport=transport) + client = AsyncStream( + api_key="k", api_secret="s", base_url="http://test", http_client=custom + ) + assert client.client is custom + + async def test_aclose_does_not_close_user_provided_client(self): + custom = httpx.AsyncClient(transport=_mock_transport()) + client = AsyncStream( + api_key="k", api_secret="s", base_url="http://test", http_client=custom + ) + await client.aclose() + assert not custom.is_closed + + +# ── validation ─────────────────────────────────────────────────────── + + +class TestValidation: + def test_transport_and_http_client_mutually_exclusive(self): + with pytest.raises(ValueError, match="Cannot specify both"): + Stream( + api_key="k", + api_secret="s", + base_url="http://test", + transport=_mock_transport(), + http_client=httpx.Client(), + ) + + def test_wrong_type_raises_type_error(self): + with pytest.raises(TypeError, match="httpx.Client"): + Stream( + api_key="k", + api_secret="s", + base_url="http://test", + http_client="not-a-client", + ) + + def test_async_client_on_sync_raises_type_error(self): + with pytest.raises(TypeError, match="httpx.Client"): + Stream( + api_key="k", + api_secret="s", + base_url="http://test", + http_client=httpx.AsyncClient(), + ) + + def test_sync_client_on_async_raises_type_error(self): + with pytest.raises(TypeError, match="httpx.AsyncClient"): + AsyncStream( + api_key="k", + api_secret="s", + base_url="http://test", + http_client=httpx.Client(), + )