Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 78 additions & 15 deletions getstream/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,38 @@ def __init__(
timeout=timeout,
user_agent=user_agent,
)
self.client = httpx.Client(
base_url=self.base_url or "",
headers=self.headers,
params=self.params,
timeout=httpx.Timeout(self.timeout),
)
http_client = getattr(self, "_http_client", None)
if http_client is not None:
if not isinstance(http_client, httpx.Client):
raise TypeError(
f"http_client must be an httpx.Client instance, "
f"got {type(http_client).__name__}"
)
http_client.headers.update(self.headers)
http_client.params = http_client.params.merge(self.params)
http_client.base_url = self.base_url or ""
if self.timeout is not None:
http_client.timeout = httpx.Timeout(self.timeout)
self.client = http_client
self._owns_http_client = False
else:
transport = getattr(self, "_transport", None)
if transport is not None:
self.client = httpx.Client(
base_url=self.base_url or "",
headers=self.headers,
params=self.params,
timeout=httpx.Timeout(self.timeout),
transport=transport,
)
else:
self.client = httpx.Client(
base_url=self.base_url or "",
headers=self.headers,
params=self.params,
timeout=httpx.Timeout(self.timeout),
)
self._owns_http_client = True

def __enter__(self):
return self
Expand Down Expand Up @@ -348,8 +374,13 @@ def _upload_multipart(
def close(self):
"""
Close HTTPX client.

If the client was provided externally via ``http_client``, this is a
no-op — the caller that created the client is responsible for closing
it.
"""
self.client.close()
if getattr(self, "_owns_http_client", True):
self.client.close()


class AsyncBaseClient(TelemetryEndpointMixin, BaseConfig, ResponseParserMixin, ABC):
Expand All @@ -368,12 +399,38 @@ def __init__(
timeout=timeout,
user_agent=user_agent,
)
self.client = httpx.AsyncClient(
base_url=self.base_url or "",
headers=self.headers,
params=self.params,
timeout=httpx.Timeout(self.timeout),
)
http_client = getattr(self, "_http_client", None)
if http_client is not None:
if not isinstance(http_client, httpx.AsyncClient):
raise TypeError(
f"http_client must be an httpx.AsyncClient instance, "
f"got {type(http_client).__name__}"
)
http_client.headers.update(self.headers)
http_client.params = http_client.params.merge(self.params)
http_client.base_url = self.base_url or ""
if self.timeout is not None:
http_client.timeout = httpx.Timeout(self.timeout)
self.client = http_client
self._owns_http_client = False
else:
transport = getattr(self, "_transport", None)
if transport is not None:
self.client = httpx.AsyncClient(
base_url=self.base_url or "",
headers=self.headers,
params=self.params,
timeout=httpx.Timeout(self.timeout),
transport=transport,
)
else:
self.client = httpx.AsyncClient(
base_url=self.base_url or "",
headers=self.headers,
params=self.params,
timeout=httpx.Timeout(self.timeout),
)
self._owns_http_client = True

async def __aenter__(self):
return self
Expand All @@ -382,8 +439,14 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()

async def aclose(self):
"""Close HTTPX async client (closes pools/keep-alives)."""
await self.client.aclose()
"""Close HTTPX async client (closes pools/keep-alives).

If the client was provided externally via ``http_client``, this is a
no-op — the caller that created the client is responsible for closing
it.
"""
if getattr(self, "_owns_http_client", True):
await self.client.aclose()

async def _upload_multipart(
self,
Expand Down
137 changes: 88 additions & 49 deletions getstream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import List, Optional
from uuid import uuid4

import httpx
import jwt
from pydantic_settings import BaseSettings, SettingsConfigDict

Expand Down Expand Up @@ -47,7 +48,12 @@ def __init__(
timeout: Optional[float] = 6.0,
base_url: Optional[str] = BASE_URL,
user_agent: Optional[str] = None,
transport=None,
http_client=None,
):
if transport is not None and http_client is not None:
raise ValueError("Cannot specify both 'transport' and 'http_client'")

if None in (api_key, api_secret, timeout, base_url):
s = Settings() # loads from env and optional .env
api_key = api_key or s.api_key
Expand All @@ -68,10 +74,29 @@ def __init__(

self.base_url = validate_and_clean_url(base_url)
self.user_agent = user_agent
self._transport = transport
self._http_client = http_client
self.token = self._create_token()
super().__init__(
self.api_key, self.base_url, self.token, self.timeout, self.user_agent
)
# After super().__init__(), self.client is fully built and configured.
# When the user provided custom HTTP config, sub-clients share this
# client instead of each building their own.
if transport is not None or http_client is not None:
self._shared_client = self.client
else:
self._shared_client = None

def _apply_shared_client(self, sub_client):
"""Replace a sub-client's auto-created httpx client with the shared
one built from user-provided transport/http_client config."""
if self._shared_client is not None:
if isinstance(sub_client.client, httpx.Client):
sub_client.client.close()
sub_client.client = self._shared_client
sub_client._owns_http_client = False
return sub_client

def create_token(
self,
Expand Down Expand Up @@ -169,13 +194,15 @@ def video(self) -> AsyncVideoClient:
Video stream client.

"""
return AsyncVideoClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
return self._apply_shared_client(
AsyncVideoClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
)
)

@cached_property
Expand All @@ -184,13 +211,15 @@ def chat(self) -> AsyncChatClient:
Chat stream client.

"""
return AsyncChatClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
return self._apply_shared_client(
AsyncChatClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
)
)

@cached_property
Expand All @@ -199,13 +228,15 @@ def moderation(self) -> AsyncModerationClient:
Moderation stream client.

"""
return AsyncModerationClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
return self._apply_shared_client(
AsyncModerationClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
)
)

async def aclose(self):
Expand Down Expand Up @@ -291,13 +322,15 @@ def video(self) -> VideoClient:
Video stream client.

"""
return VideoClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
return self._apply_shared_client(
VideoClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
)
)

@cached_property
Expand All @@ -306,13 +339,15 @@ def chat(self) -> ChatClient:
Chat stream client.

"""
return ChatClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
return self._apply_shared_client(
ChatClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
)
)

@cached_property
Expand All @@ -321,13 +356,15 @@ def moderation(self) -> ModerationClient:
Moderation stream client.

"""
return ModerationClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
return self._apply_shared_client(
ModerationClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
)
)

@cached_property
Expand All @@ -336,13 +373,15 @@ def feeds(self) -> FeedsClient:
Feeds stream client.

"""
return FeedsClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
return self._apply_shared_client(
FeedsClient(
api_key=self.api_key,
base_url=self.base_url,
token=self.token,
timeout=self.timeout,
stream=self,
user_agent=self.user_agent,
)
)

@telemetry.operation_name("getstream.api.common.create_user")
Expand Down
Loading
Loading