Skip to content

Commit 21248d6

Browse files
Merge pull request #17 from Deadpool2000/fix-4
Fix #4 - Handle thread safety for shared clients (Celery fixes)
2 parents 7869447 + 81cd2a8 commit 21248d6

3 files changed

Lines changed: 108 additions & 2 deletions

File tree

openapi_python_sdk/client.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import threading
23
from typing import Any, Dict
34

45
import httpx
@@ -15,13 +16,33 @@ class Client:
1516
"""
1617

1718
def __init__(self, token: str, client: Any = None, timeout: float = 30.0):
18-
self.client = client if client is not None else httpx.Client(timeout=timeout)
19+
self._client = client
20+
self._thread_local = threading.local()
21+
self.timeout = timeout
1922
self.auth_header: str = f"Bearer {token}"
2023
self.headers: Dict[str, str] = {
2124
"Authorization": self.auth_header,
2225
"Content-Type": "application/json",
2326
}
2427

28+
@property
29+
def client(self) -> Any:
30+
"""
31+
Thread-safe access to the underlying HTTP client.
32+
If a custom client was provided at initialization, it is returned.
33+
Otherwise, a thread-local httpx.Client is created and returned.
34+
"""
35+
if self._client is not None:
36+
return self._client
37+
38+
if not hasattr(self._thread_local, "client"):
39+
self._thread_local.client = httpx.Client(timeout=self.timeout)
40+
return self._thread_local.client
41+
42+
@client.setter
43+
def client(self, value: Any):
44+
self._client = value
45+
2546
def __enter__(self):
2647
"""Enable use as a synchronous context manager."""
2748
return self

openapi_python_sdk/oauth_client.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
import threading
23
from typing import Any, Dict, List
34

45
import httpx
@@ -13,7 +14,9 @@ class OauthClient:
1314
"""
1415

1516
def __init__(self, username: str, apikey: str, test: bool = False, client: Any = None, timeout: float = 30.0):
16-
self.client = client if client is not None else httpx.Client(timeout=timeout)
17+
self._client = client
18+
self._thread_local = threading.local()
19+
self.timeout = timeout
1720
self.url: str = TEST_OAUTH_BASE_URL if test else OAUTH_BASE_URL
1821
self.auth_header: str = (
1922
"Basic " + base64.b64encode(f"{username}:{apikey}".encode("utf-8")).decode()
@@ -23,6 +26,23 @@ def __init__(self, username: str, apikey: str, test: bool = False, client: Any =
2326
"Content-Type": "application/json",
2427
}
2528

29+
@property
30+
def client(self) -> Any:
31+
"""
32+
Thread-safe access to the underlying HTTP client.
33+
If a custom client was provided at initialization, it is returned.
34+
Otherwise, a thread-local httpx.Client is created and returned.
35+
"""
36+
if self._client is not None:
37+
return self._client
38+
if not hasattr(self._thread_local, "client"):
39+
self._thread_local.client = httpx.Client(timeout=self.timeout)
40+
return self._thread_local.client
41+
42+
@client.setter
43+
def client(self, value: Any):
44+
self._client = value
45+
2646
def __enter__(self):
2747
"""Enable use as a synchronous context manager."""
2848
return self

tests/test_thread_safety.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import threading
2+
import unittest
3+
4+
import httpx
5+
6+
from openapi_python_sdk import Client, OauthClient
7+
8+
9+
class TestThreadSafety(unittest.TestCase):
10+
def test_oauth_client_thread_safety(self):
11+
oauth = OauthClient(username="user", apikey="key")
12+
13+
clients = []
14+
def get_client():
15+
clients.append(oauth.client)
16+
17+
threads = [threading.Thread(target=get_client) for _ in range(5)]
18+
for t in threads:
19+
t.start()
20+
for t in threads:
21+
t.join()
22+
23+
# Each thread should have gotten a unique client instance
24+
self.assertEqual(len(clients), 5)
25+
self.assertEqual(len(set(id(c) for c in clients)), 5)
26+
27+
def test_client_thread_safety(self):
28+
client = Client(token="tok")
29+
30+
clients = []
31+
def get_client():
32+
clients.append(client.client)
33+
34+
threads = [threading.Thread(target=get_client) for _ in range(5)]
35+
for t in threads:
36+
t.start()
37+
for t in threads:
38+
t.join()
39+
40+
# Each thread should have gotten a unique client instance
41+
self.assertEqual(len(clients), 5)
42+
self.assertEqual(len(set(id(c) for c in clients)), 5)
43+
44+
def test_shared_client_injection_still_works(self):
45+
# If we explicitly pass a client, it SHOULD be shared (backward compatibility)
46+
shared_engine = httpx.Client()
47+
oauth = OauthClient(username="user", apikey="key", client=shared_engine)
48+
49+
clients = []
50+
def get_client():
51+
clients.append(oauth.client)
52+
53+
threads = [threading.Thread(target=get_client) for _ in range(5)]
54+
for t in threads:
55+
t.start()
56+
for t in threads:
57+
t.join()
58+
59+
# All threads should have the SAME instance because it was injected
60+
self.assertEqual(len(clients), 5)
61+
self.assertEqual(len(set(id(c) for c in clients)), 1)
62+
self.assertEqual(id(clients[0]), id(shared_engine))
63+
64+
if __name__ == "__main__":
65+
unittest.main()

0 commit comments

Comments
 (0)