1- import json
21import logging
3- from datetime import datetime , timedelta , timezone
4- from pathlib import Path
5- from typing import Generator , Optional , Tuple , Type , Callable , Union
2+ from typing import Generator , Optional
63
74import dateutil .parser
85import httpx
9- from platformdirs import user_cache_dir
106
117from ..config import API_KEY , CONFIG , APIKey
128from . import USER_AGENT , RetryableHTTPClient , encode_json
13-
9+ from . cache import ACCESS_TOKEN_CACHE_CLS_TYPE , FileCachedToken , get_access_token_cache_cache
1410logger = logging .getLogger (__name__ )
1511
1612
17- class CachedToken :
18- def __init__ (self , token_refresh_before_expiry = timedelta (seconds = 30 )):
19- self .token_refresh_before_expiry = token_refresh_before_expiry
20-
21- cache = self ._get_cached_token ()
22-
23- if cache :
24- self ._access_token , self ._access_token_expires_at_utc = cache
25- else :
26- self ._access_token : Optional [str ] = None
27- self ._access_token_expires_at_utc : Optional [datetime ] = None
28-
29- @property
30- def access_token (self ) -> Optional [str ]:
31- return self ._access_token if self .is_access_token_valid else None
32-
33- @property
34- def is_access_token_valid (self ) -> bool :
35- if not self ._access_token :
36- return False
37-
38- if self ._access_token_expires_at_utc :
39- return datetime .now (tz = timezone .utc ) < (
40- self ._access_token_expires_at_utc - self .token_refresh_before_expiry
41- )
42-
43- return False
44-
45- def set_cached_token (self , access_token : str , expires_at_utc : datetime ) -> None :
46- self ._access_token = access_token
47- self ._access_token_expires_at_utc = expires_at_utc
48- self ._set_cached_token ()
49-
50- def _set_cached_token (self ) -> None :
51- raise NotImplementedError
52-
53- def _get_cached_token (self ) -> Optional [Tuple [str , datetime ]]:
54- raise NotImplementedError
55-
56-
57- class FileCachedToken (CachedToken ):
58- def __init__ (self ):
59- self ._cache_file = Path (user_cache_dir ("syncsparkpy" )) / "auth.json"
60-
61- super ().__init__ ()
62-
63- def _get_cached_token (self ) -> Optional [Tuple [str , datetime ]]:
64- # Cache is optional, we can fail to read it and not worry
65- if self ._cache_file .exists ():
66- try :
67- cached_token = json .loads (self ._cache_file .read_text ())
68- cached_access_token = cached_token ["access_token" ]
69- cached_expiry = datetime .fromisoformat (cached_token ["expires_at_utc" ])
70- return cached_access_token , cached_expiry
71- except Exception as e :
72- logger .warning (
73- f"Failed to read cached access token @ { self ._cache_file } " , exc_info = e
74- )
75-
76- return None
77-
78- def _set_cached_token (self ) -> None :
79- # Cache is optional, we can fail to read it and not worry
80- try :
81- self ._cache_file .parent .mkdir (parents = True , exist_ok = True )
82- self ._cache_file .write_text (
83- json .dumps (
84- {
85- "access_token" : self ._access_token ,
86- "expires_at_utc" : self ._access_token_expires_at_utc .isoformat (),
87- }
88- )
89- )
90- except Exception as e :
91- logger .warning (
92- f"Failed to write cached access token @ { self ._cache_file } " , exc_info = e
93- )
94-
95-
96- # Putting this here instead of config.py because circular imports and typing.
97- _access_token_cache_cls = FileCachedToken # Default to local file caching.
98- ACCESS_TOKEN_CACHE_CLS_TYPE = Union [Type [CachedToken ], Callable [[], CachedToken ]]
99-
100-
101- def set_access_token_cache_cls (access_token_cache_cls : ACCESS_TOKEN_CACHE_CLS_TYPE ) -> None :
102- global _access_token_cache_cls
103- _access_token_cache_cls = access_token_cache_cls
104-
105-
10613class SyncAuth (httpx .Auth ):
10714 requires_response_body = True
10815
@@ -425,7 +332,7 @@ async def _send(self, request: httpx.Request) -> dict:
425332 return {"error" : {"code" : "Sync API Error" , "message" : "Transaction failure" }}
426333
427334
428- _sync_client : SyncClient = None
335+ _sync_client : Optional [ SyncClient ] = None
429336
430337
431338def get_default_client () -> SyncClient :
@@ -434,12 +341,12 @@ def get_default_client() -> SyncClient:
434341 _sync_client = SyncClient (
435342 CONFIG .api_url ,
436343 API_KEY ,
437- access_token_cache_cls = _access_token_cache_cls
344+ access_token_cache_cls = get_access_token_cache_cache ()
438345 )
439346 return _sync_client
440347
441348
442- _async_sync_client : ASyncClient = None
349+ _async_sync_client : Optional [ ASyncClient ] = None
443350
444351
445352def get_default_async_client () -> ASyncClient :
@@ -448,6 +355,6 @@ def get_default_async_client() -> ASyncClient:
448355 _async_sync_client = ASyncClient (
449356 CONFIG .api_url ,
450357 API_KEY ,
451- access_token_cache_cls = _access_token_cache_cls
358+ access_token_cache_cls = get_access_token_cache_cache ()
452359 )
453360 return _async_sync_client
0 commit comments