-
Notifications
You must be signed in to change notification settings - Fork 86
Expand file tree
/
Copy pathclient_v2.py
More file actions
97 lines (87 loc) · 3.57 KB
/
client_v2.py
File metadata and controls
97 lines (87 loc) · 3.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import typing
from concurrent.futures import ThreadPoolExecutor
import aiohttp
import httpx
from .client import AsyncClient, Client
from .environment import ClientEnvironment
from .v2.client import AsyncRawV2Client, AsyncV2Client, RawV2Client, V2Client
class _CombinedRawClient:
"""Proxy that combines v1 and v2 raw clients.
V2Client and Client both assign to self._raw_client in __init__,
causing a collision when combined in ClientV2/AsyncClientV2.
This proxy delegates to v2 first, falling back to v1 for
legacy methods like generate_stream.
"""
def __init__(self, v1_raw_client: typing.Any, v2_raw_client: typing.Any):
self._v1 = v1_raw_client
self._v2 = v2_raw_client
def __getattr__(self, name: str) -> typing.Any:
try:
return getattr(self._v2, name)
except AttributeError:
return getattr(self._v1, name)
class ClientV2(V2Client, Client): # type: ignore
def __init__(
self,
api_key: typing.Optional[typing.Union[str,
typing.Callable[[], str]]] = None,
*,
base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
client_name: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
httpx_client: typing.Optional[httpx.Client] = None,
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
log_warning_experimental_features: bool = True,
):
Client.__init__(
self,
api_key=api_key,
base_url=base_url,
environment=environment,
client_name=client_name,
timeout=timeout,
httpx_client=httpx_client,
thread_pool_executor=thread_pool_executor,
log_warning_experimental_features=log_warning_experimental_features,
)
v1_raw = self._raw_client
V2Client.__init__(
self,
client_wrapper=self._client_wrapper
)
self._raw_client = typing.cast(RawV2Client, _CombinedRawClient(v1_raw, self._raw_client))
class AsyncClientV2(AsyncV2Client, AsyncClient): # type: ignore
def __init__(
self,
api_key: typing.Optional[typing.Union[str,
typing.Callable[[], str]]] = None,
*,
base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
client_name: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
aiohttp_session: typing.Optional["aiohttp.ClientSession"] = None,
httpx_client: typing.Optional[httpx.AsyncClient] = None, # Deprecated
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
log_warning_experimental_features: bool = True,
):
AsyncClient.__init__(
self,
api_key=api_key,
base_url=base_url,
environment=environment,
client_name=client_name,
timeout=timeout,
aiohttp_session=aiohttp_session,
httpx_client=httpx_client,
thread_pool_executor=thread_pool_executor,
log_warning_experimental_features=log_warning_experimental_features,
)
v1_raw = self._raw_client
AsyncV2Client.__init__(
self,
client_wrapper=self._client_wrapper
)
self._raw_client = typing.cast(AsyncRawV2Client, _CombinedRawClient(v1_raw, self._raw_client))