-
-
Notifications
You must be signed in to change notification settings - Fork 200
Expand file tree
/
Copy pathdecorator.py
More file actions
230 lines (191 loc) · 7.82 KB
/
decorator.py
File metadata and controls
230 lines (191 loc) · 7.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import logging
import sys
from functools import wraps
from inspect import Parameter, Signature, isawaitable, iscoroutinefunction
from typing import (
Awaitable,
Callable,
List,
Optional,
Type,
TypeVar,
Union,
cast,
)
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
from fastapi.concurrency import run_in_threadpool
from fastapi.dependencies.utils import (
get_typed_return_annotation,
get_typed_signature,
)
from starlette.requests import Request
from starlette.responses import Response
from starlette.status import HTTP_304_NOT_MODIFIED
from fastapi_cache import FastAPICache
from fastapi_cache.coder import Coder
from fastapi_cache.types import KeyBuilder
logger: logging.Logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
P = ParamSpec("P")
R = TypeVar("R")
def _augment_signature(signature: Signature, *extra: Parameter) -> Signature:
if not extra:
return signature
parameters = list(signature.parameters.values())
variadic_keyword_params: List[Parameter] = []
while parameters and parameters[-1].kind is Parameter.VAR_KEYWORD:
variadic_keyword_params.append(parameters.pop())
return signature.replace(parameters=[*parameters, *extra, *variadic_keyword_params])
def _locate_param(
sig: Signature, dep: Parameter, to_inject: List[Parameter]
) -> Parameter:
"""Locate an existing parameter in the decorated endpoint
If not found, returns the injectable parameter, and adds it to the to_inject list.
"""
param = next(
(p for p in sig.parameters.values() if p.annotation is dep.annotation), None
)
if param is None:
to_inject.append(dep)
param = dep
return param
def _uncacheable(request: Optional[Request]) -> bool:
"""Determine if this request should not be cached
Returns true if:
- Caching has been disabled globally
- This is not a GET request
- The request has a Cache-Control header with a value of "no-store" or "no-cache"
"""
if not FastAPICache.get_enable():
return True
if request is None:
return False
if request.method != "GET":
return True
return request.headers.get("Cache-Control") in ("no-store", "no-cache")
def cache(
expire: Optional[int] = None,
coder: Optional[Type[Coder]] = None,
key_builder: Optional[KeyBuilder] = None,
namespace: str = "",
injected_dependency_namespace: str = "__fastapi_cache",
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
"""
cache all function
:param namespace:
:param expire:
:param coder:
:param key_builder:
:return:
"""
injected_request = Parameter(
name=f"{injected_dependency_namespace}_request",
annotation=Request,
kind=Parameter.KEYWORD_ONLY,
)
injected_response = Parameter(
name=f"{injected_dependency_namespace}_response",
annotation=Response,
kind=Parameter.KEYWORD_ONLY,
)
def wrapper(
func: Callable[P, Awaitable[R]]
) -> Callable[P, Awaitable[R]]:
# get_typed_signature ensures that any forward references are resolved first
wrapped_signature = get_typed_signature(func)
to_inject: List[Parameter] = []
request_param = _locate_param(wrapped_signature, injected_request, to_inject)
response_param = _locate_param(wrapped_signature, injected_response, to_inject)
return_type = get_typed_return_annotation(func)
@wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> Union[R, Response]:
nonlocal coder
nonlocal expire
nonlocal key_builder
async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
"""Run cached sync functions in thread pool just like FastAPI."""
# if the wrapped function does NOT have request or response in
# its function signature, make sure we don't pass them in as
# keyword arguments
kwargs.pop(injected_request.name, None)
kwargs.pop(injected_response.name, None)
if iscoroutinefunction(func):
# async, return as is.
# unintuitively, we have to await once here, so that caller
# does not have to await twice. See
# https://stackoverflow.com/a/59268198/532513
return await func(*args, **kwargs)
else:
# sync, wrap in thread and return async
# see above why we have to await even although caller also awaits.
return await run_in_threadpool(func, *args, **kwargs) # type: ignore[arg-type]
copy_kwargs = kwargs.copy()
request: Optional[Request] = copy_kwargs.pop(request_param.name, None) # type: ignore[assignment]
response: Optional[Response] = copy_kwargs.pop(response_param.name, None) # type: ignore[assignment]
if _uncacheable(request):
return await ensure_async_func(*args, **kwargs)
prefix = FastAPICache.get_prefix()
coder = coder or FastAPICache.get_coder()
expire = expire or FastAPICache.get_expire()
key_builder = key_builder or FastAPICache.get_key_builder()
backend = FastAPICache.get_backend()
cache_status_header = FastAPICache.get_cache_status_header()
cache_key = key_builder(
func,
f"{prefix}:{namespace}",
request=request,
response=response,
args=args,
kwargs=copy_kwargs,
)
if isawaitable(cache_key):
cache_key = await cache_key
assert isinstance(cache_key, str) # noqa: S101 # assertion is a type guard
try:
ttl, cached = await backend.get_with_ttl(cache_key)
except Exception:
logger.warning(
f"Error retrieving cache key '{cache_key}' from backend:",
exc_info=True,
)
ttl, cached = 0, None
if cached is None: # cache miss
result = await ensure_async_func(*args, **kwargs)
to_cache = coder.encode(result)
try:
await backend.set(cache_key, to_cache, expire)
except Exception:
logger.warning(
f"Error setting cache key '{cache_key}' in backend:",
exc_info=True,
)
if response:
response.headers.update(
{
"Cache-Control": f"max-age={expire}",
"ETag": f"W/{hash(to_cache)}",
cache_status_header: "MISS",
}
)
else: # cache hit
if response:
etag = f"W/{hash(cached)}"
response.headers.update(
{
"Cache-Control": f"max-age={ttl}",
"ETag": etag,
cache_status_header: "HIT",
}
)
if_none_match = request and request.headers.get("if-none-match")
if if_none_match == etag:
response.status_code = HTTP_304_NOT_MODIFIED
return response
result = cast(R, coder.decode_as_type(cached, type_=return_type))
return result
inner.__signature__ = _augment_signature(wrapped_signature, *to_inject) # type: ignore[attr-defined]
return inner
return wrapper