-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_sse_connection_manager.py
More file actions
157 lines (140 loc) · 6.52 KB
/
_sse_connection_manager.py
File metadata and controls
157 lines (140 loc) · 6.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import base64
import time
from typing import Optional, Callable, TYPE_CHECKING
import sseclient # type: ignore
from requests import Response
from requests.exceptions import HTTPError
from sdk_reforge._internal_logging import InternalLogger
from sdk_reforge._requests import ApiClient
from sdk_reforge._sse_watchdog import WatchdogResponseWrapper
import prefab_pb2 as Prefab
from sdk_reforge.config_sdk_interface import ConfigSDKInterface
if TYPE_CHECKING:
from sdk_reforge._sse_watchdog import SSEWatchdog
SHORT_CONNECTION_THRESHOLD = 2 # seconds
CONSECUTIVE_SHORT_CONNECTION_LIMIT = 2 # times
MIN_BACKOFF_TIME = 1 # seconds
MAX_BACKOFF_TIME = 30 # seconds
class TooQuickConnectionException(Exception):
pass
logger = InternalLogger(__name__)
class SSEConnectionManager:
def __init__(
self,
api_client: ApiClient,
config_client: ConfigSDKInterface,
urls: list[str],
watchdog: Optional["SSEWatchdog"] = None,
):
self.api_client = api_client
self.config_client = config_client
self.sse_client: Optional[sseclient.SSEClient] = None
self.timing = Timing()
self.urls = urls
self.watchdog = watchdog
def streaming_loop(self) -> None:
too_short_connection_count = 0
backoff_time = MIN_BACKOFF_TIME
try:
while self.config_client.continue_connection_processing():
try:
logger.debug("Starting streaming connection")
headers = {
"Last-Event-ID": f"{self.config_client.highwater_mark()}",
"accept": "text/event-stream",
}
response = self.api_client.resilient_request(
"/api/v2/sse/config",
headers=headers,
stream=True,
auth=("authuser", self.config_client.options.api_key),
timeout=(5, 60),
hosts=self.urls,
)
response.raise_for_status()
if response.ok:
elapsed_time = self.timing.time_execution(
lambda: self.process_response(response)
)
if elapsed_time < SHORT_CONNECTION_THRESHOLD:
too_short_connection_count += 1
if (
too_short_connection_count
>= CONSECUTIVE_SHORT_CONNECTION_LIMIT
):
raise TooQuickConnectionException()
else:
too_short_connection_count = 0
backoff_time = MIN_BACKOFF_TIME
time.sleep(backoff_time)
except TooQuickConnectionException as e:
logger.debug(
f"Connection ended quickly: {str(e)}. Will apply backoff."
)
backoff_time = min(backoff_time * 2, MAX_BACKOFF_TIME)
time.sleep(backoff_time)
except HTTPError as e:
# Check for unauthorized (401/403) responses
if e.response is not None and e.response.status_code in (401, 403):
logger.warning(
f"Received {e.response.status_code} response, stopping SSE"
)
self.config_client.handle_unauthorized_response()
else:
if not self.config_client.is_shutting_down():
backoff_time = min(backoff_time * 2, MAX_BACKOFF_TIME)
logger.warning(
f"Streaming connection error ({type(e).__name__}): {str(e)}, "
f"Will retry in {backoff_time} seconds"
)
time.sleep(backoff_time)
except BaseException as e:
# Re-raise system exceptions that should terminate the thread
if isinstance(e, (KeyboardInterrupt, SystemExit)):
raise
if not self.config_client.is_shutting_down():
backoff_time = min(backoff_time * 2, MAX_BACKOFF_TIME)
logger.warning(
f"Streaming connection error ({type(e).__name__}): {str(e)}, "
f"Will retry in {backoff_time} seconds"
)
time.sleep(backoff_time)
finally:
logger.info(
f"Streaming loop exited "
f"(shutdown={self.config_client.is_shutting_down()})"
)
def process_response(self, response: Response) -> None:
"""Hand off a successful response here for processing."""
# Wrap response to track data received for watchdog
if self.watchdog:
wrapped_response = WatchdogResponseWrapper(response, self.watchdog.touch)
self.sse_client = sseclient.SSEClient(wrapped_response)
else:
self.sse_client = sseclient.SSEClient(response)
if self.sse_client is not None:
for event in self.sse_client.events():
if self.config_client.is_shutting_down():
logger.info("Client is shutting down, exiting SSE event loop")
return
if event.data:
decoded_data = base64.b64decode(event.data)
if not decoded_data or len(decoded_data) == 0:
logger.warning(
"Received zero-byte config payload from SSE stream, treating as connection error"
)
# Return early to trigger reconnection logic
return
configs = Prefab.Configs.FromString(decoded_data)
self.config_client.load_configs(configs, "sse_streaming")
self.sse_client.close()
self.sse_client = None
class Timing:
def time_execution(self, func: Callable[[], None]) -> float:
"""Executes the given function and returns the time it took to execute."""
start_time = self.now()
func() # Execute the block of code
return self.now() - start_time
def now(self) -> float:
"""Get the current time. This can be mocked in tests."""
return time.time()