-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathconfig_client.py
More file actions
291 lines (249 loc) · 10.4 KB
/
config_client.py
File metadata and controls
291 lines (249 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
from __future__ import annotations
from ._internal_logging import InternalLogger
import threading
import time
from typing import Optional
import prefab_pb2 as Prefab
import os
from ._count_down_latch import CountDownLatch
from ._requests import ApiClient, UnauthorizedException
from ._sse_connection_manager import SSEConnectionManager
from .config_client_interface import ConfigClientInterface
from .config_loader import ConfigLoader
from .config_resolver import ConfigResolver
from .config_value_unwrapper import ConfigValueUnwrapper
from .context import Context
from .config_resolver import Evaluation
from .constants import NoDefaultProvided, ConfigValueType
from google.protobuf.json_format import MessageToJson, Parse
STALE_CACHE_WARN_HOURS = 5
logger = InternalLogger(__name__)
class InitializationTimeoutException(Exception):
def __init__(self, timeout_seconds, key):
super().__init__(
f"Prefab couldn't initialize in {timeout_seconds} second timeout. Trying to fetch key `{key}`."
)
class MissingDefaultException(Exception):
def __init__(self, key):
super().__init__(
f"""No value found for key '{key}' and no default was provided.
If you'd prefer returning `None` rather than raising when this occurs, modify the `on_no_default` value you provide in your Options."""
)
class ConfigClient(ConfigClientInterface):
def __init__(self, base_client):
self.is_initialized = threading.Event()
self.checkpointing_thread = None
self.streaming_thread = None
self.sse_client = None
logger.info("Initializing ConfigClient")
self.base_client = base_client
self._options = base_client.options
self.init_latch = CountDownLatch()
self.unauthorized_event = threading.Event()
self.finish_init_mutex = threading.Lock()
self.checkpoint_freq_secs = 60
self.config_loader = ConfigLoader(base_client)
self.config_resolver = ConfigResolver(base_client, self.config_loader)
self._cache_path = None
self.set_cache_path()
self.api_client = ApiClient(self.options)
self.sse_connection_manager = SSEConnectionManager(
self.api_client, self, self.options.prefab_stream_urls
)
if self.options.is_local_only():
self.finish_init("local only")
elif self.options.has_datafile():
self.load_json_file(self.options.datafile)
else:
# don't load checkpoint here, that'll block the caller. let the thread do it
self.start_checkpointing_thread()
def get(
self,
key,
default=NoDefaultProvided,
context: Optional[dict | Context] = None,
) -> ConfigValueType:
evaluation_result = self.__get(key, None, {}, context=context)
if evaluation_result is not None:
self.base_client.telemetry_manager.record_evaluation(evaluation_result)
if evaluation_result.config:
return evaluation_result.unwrapped_value()
return self.handle_default(key, default)
def __get(
self,
key,
lookup_key,
properties,
context: Optional[dict | Context] = None,
) -> None | Evaluation:
ok_to_proceed = self.init_latch.wait(
timeout=self.options.connection_timeout_seconds
)
if self.unauthorized_event.is_set():
raise UnauthorizedException(self.options.api_key)
if not ok_to_proceed:
if self.options.on_connection_failure == "RAISE":
raise InitializationTimeoutException(
self.options.connection_timeout_seconds, key
)
logger.warning(
f"Couldn't initialize in {self.options.connection_timeout_seconds}. Key {key}. Returning what we have.",
)
return self.config_resolver.get(key, context=context)
@property
def options(self):
return self._options
def handle_default(self, key, default):
if default != NoDefaultProvided:
return default
if self.options.on_no_default == "RAISE":
raise MissingDefaultException(key)
return None
def load_checkpoint(self):
try:
if self.load_checkpoint_from_api_cdn():
return
if self.load_cache():
return
logger.warning("No success loading checkpoints")
except UnauthorizedException:
self.handle_unauthorized_response()
def start_checkpointing_thread(self):
self.checkpointing_thread = threading.Thread(
target=self.load_checkpoint, daemon=True
)
self.checkpointing_thread.start()
def start_streaming(self):
self.streaming_thread = threading.Thread(
target=self.sse_connection_manager.streaming_loop, daemon=True
)
self.streaming_thread.start()
def is_shutting_down(self):
return self.base_client.shutdown_flag.is_set()
def continue_connection_processing(self):
return not self.is_shutting_down() and not self.unauthorized_event.is_set()
def highwater_mark(self) -> int:
return self.config_loader.highwater_mark
def load_initial_data(self):
try:
self.load_checkpoint()
except UnauthorizedException:
self.handle_unauthorized_response()
def load_checkpoint_from_api_cdn(self):
try:
hwm = self.config_loader.highwater_mark
response = self.api_client.resilient_request(
"/api/v1/configs/" + str(hwm),
auth=("authuser", self.options.api_key),
timeout=4,
allow_cache=True,
)
if response.ok:
configs = Prefab.Configs.FromString(response.content)
self.load_configs(configs, "remote_api_cdn")
return True
else:
logger.info(
"Checkpoint remote_cdn_api failed to load",
)
return False
except UnauthorizedException:
self.handle_unauthorized_response()
def load_configs(self, configs: Prefab.Configs, source: str) -> None:
project_id = configs.config_service_pointer.project_id
project_env_id = configs.config_service_pointer.project_env_id
self.config_resolver.project_env_id = project_env_id
starting_highwater_mark = self.config_loader.highwater_mark
default_contexts = {}
if configs.default_context and configs.default_context.contexts is not None:
for context in configs.default_context.contexts:
values = {}
for k, v in context.values.items():
values[k] = ConfigValueUnwrapper(v, self.config_resolver).unwrap()
default_contexts[context.type] = values
self.config_resolver.default_context = default_contexts
for config in configs.configs:
self.config_loader.set(config, source)
if self.config_loader.highwater_mark > starting_highwater_mark:
logger.info(
f"Found new checkpoint with highwater id {self.config_loader.highwater_mark} from {source} in project {project_id} environment: {project_env_id}",
)
else:
logger.debug(
f"Checkpoint with highwater id {self.config_loader.highwater_mark} from {source}. No changes.",
)
self.config_resolver.update()
self.finish_init(source)
def cache_configs(self, configs):
if not self.options.use_local_cache:
return
if not self.cache_path:
return
with open(self.cache_path, "w") as f:
f.write(MessageToJson(configs))
logger.debug(f"Cached configs to {self.cache_path}")
def load_cache(self):
if not self.options.use_local_cache:
return False
if not self.cache_path:
return False
try:
with open(self.cache_path, "r") as f:
configs = Parse(f.read(), Prefab.Configs())
self.load_configs(configs, "cache")
hours_old = round(
(time.mktime(time.localtime()) - os.path.getmtime(self.cache_path))
/ 3600,
2,
)
if hours_old > STALE_CACHE_WARN_HOURS:
logger.info(f"Stale Cache Load: {hours_old} hours old")
return True
except OSError as e:
logger.info("error loading from cache", e)
return False
def load_json_file(self, datafile):
with open(datafile) as f:
configs = Parse(f.read(), Prefab.Configs())
self.load_configs(configs, "datafile")
def finish_init(self, source):
with self.finish_init_mutex:
was_initialized = self.is_initialized.is_set()
self.is_initialized.set()
self.init_latch.count_down()
if not was_initialized:
logger.warning(f"Unlocked config via {source}")
if self.options.is_loading_from_api():
self.start_streaming()
if self.options.on_ready_callback:
threading.Thread(
target=self.options.on_ready_callback, daemon=True
).start()
def set_cache_path(self):
home_dir_cache_path = None
home_dir = os.environ.get("HOME")
if home_dir:
home_dir_cache_path = os.path.join(home_dir, ".cache")
cache_path = os.environ.get("XDG_CACHE_HOME", home_dir_cache_path)
if cache_path:
file_name = f"prefab.cache.{self.base_client.options.api_key_id}.json"
self.cache_path = os.path.join(cache_path, file_name)
@property
def cache_path(self):
if self._cache_path:
os.makedirs(os.path.dirname(self._cache_path), exist_ok=True)
return self._cache_path
@cache_path.setter
def cache_path(self, path):
self._cache_path = path
def record_log(self, path, severity):
self.base_client.record_log(path, severity)
def is_ready(self) -> bool:
return self.is_initialized.is_set()
def handle_unauthorized_response(self):
logger.error("Received unauthorized response")
self.unauthorized_event.set()
self.init_latch.count_down()
def close(self) -> None:
if self.sse_client:
self.sse_client.close()