-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsession.py
More file actions
406 lines (361 loc) · 14.3 KB
/
session.py
File metadata and controls
406 lines (361 loc) · 14.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
import asyncio
import logging
from typing import Any, Awaitable, Callable, Coroutine, TypeAlias
import nanoid
import websockets
from aiochannel import Channel
from opentelemetry.trace import Span, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from replit_river.common_session import (
SendMessage,
SessionState,
TerminalStates,
)
from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError
from replit_river.messages import (
FailedSendingMessageException,
WebsocketClosedException,
send_transport_message,
)
from replit_river.seq_manager import (
SeqManager,
)
from replit_river.task_manager import BackgroundTaskManager
from replit_river.transport_options import TransportOptions
from replit_river.websocket_wrapper import WebsocketWrapper, WsState
from .rpc import (
ACK_BIT,
TransportMessage,
TransportMessageTracingSetter,
)
logger = logging.getLogger(__name__)
trace_propagator = TraceContextTextMapPropagator()
trace_setter = TransportMessageTracingSetter()
CloseSessionCallback: TypeAlias = Callable[["Session"], Coroutine[Any, Any, Any]]
RetryConnectionCallback: TypeAlias = Callable[
[],
Coroutine[Any, Any, Any],
]
class Session:
"""Common functionality shared between client_session and server_session"""
_transport_id: str
_to_id: str
session_id: str
_transport_options: TransportOptions
# session state
_state: SessionState
_state_lock: asyncio.Lock
_close_session_callback: CloseSessionCallback
_close_session_after_time_secs: float | None
# ws state
_ws_lock: asyncio.Lock
_ws_wrapper: WebsocketWrapper
_heartbeat_misses: int
_retry_connection_callback: RetryConnectionCallback | None
# stream for tasks
_streams: dict[str, Channel[Any]]
# book keeping
_seq_manager: SeqManager
_buffer: MessageBuffer
_task_manager: BackgroundTaskManager
def __init__(
self,
transport_id: str,
to_id: str,
session_id: str,
websocket: websockets.WebSocketCommonProtocol,
transport_options: TransportOptions,
close_session_callback: CloseSessionCallback,
retry_connection_callback: RetryConnectionCallback | None = None,
) -> None:
self._transport_id = transport_id
self._to_id = to_id
self.session_id = session_id
self._transport_options = transport_options
# session state
self._state = SessionState.ACTIVE
self._state_lock = asyncio.Lock()
self._close_session_callback = close_session_callback
self._close_session_after_time_secs: float | None = None
# ws state
self._ws_lock = asyncio.Lock()
self._ws_wrapper = WebsocketWrapper(websocket)
self._heartbeat_misses = 0
self._retry_connection_callback = retry_connection_callback
# stream for tasks
self._streams: dict[str, Channel[Any]] = {}
# book keeping
self._seq_manager = SeqManager()
self._buffer = MessageBuffer(self._transport_options.buffer_size)
self._task_manager = BackgroundTaskManager()
def _setup_heartbeats_task(
self,
do_close_websocket: Callable[[], Awaitable[None]],
) -> None:
def increment_and_get_heartbeat_misses() -> int:
self._heartbeat_misses += 1
return self._heartbeat_misses
self._task_manager.create_task(
setup_heartbeat(
self.session_id,
self._transport_options.heartbeat_ms,
self._transport_options.heartbeats_until_dead,
lambda: (
self._state
if self._ws_wrapper.ws_state == WsState.OPEN
else SessionState.CONNECTING
),
lambda: self._close_session_after_time_secs,
close_websocket=do_close_websocket,
send_message=self.send_message,
increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses,
)
)
self._task_manager.create_task(
check_to_close_session(
self._transport_id,
self._transport_options.close_session_check_interval_ms,
lambda: self._state,
self._get_current_time,
lambda: self._close_session_after_time_secs,
self.close,
)
)
async def is_session_open(self) -> bool:
async with self._state_lock:
return self._state == SessionState.ACTIVE
async def is_websocket_open(self) -> bool:
async with self._ws_lock:
return self._ws_wrapper.is_open()
async def _begin_close_session_countdown(self) -> None:
"""Begin the countdown to close session, this should be called when
websocket is closed.
"""
# calculate the value now before establishing it so that there are no
# await points between the check and the assignment to avoid a TOCTOU
# race.
grace_period_ms = self._transport_options.session_disconnect_grace_ms
close_session_after_time_secs = (
await self._get_current_time() + grace_period_ms / 1000
)
if self._close_session_after_time_secs is not None:
# already in grace period, no need to set again
return
logger.info(
"websocket closed from %s to %s begin grace period",
self._transport_id,
self._to_id,
)
self._close_session_after_time_secs = close_session_after_time_secs
async def replace_with_new_websocket(
self, new_ws: websockets.WebSocketCommonProtocol
) -> None:
async with self._ws_lock:
old_wrapper = self._ws_wrapper
old_ws_id = old_wrapper.ws.id
if new_ws.id != old_ws_id:
await old_wrapper.close()
self._ws_wrapper = WebsocketWrapper(new_ws)
# Send buffered messages to the new ws
buffered_messages = list(self._buffer.buffer)
for msg in buffered_messages:
try:
await send_transport_message(
msg,
new_ws,
self._begin_close_session_countdown,
)
except WebsocketClosedException:
logger.info(
"Connection closed while sending buffered messages", exc_info=True
)
break
except FailedSendingMessageException:
logger.exception("Error while sending buffered messages")
break
async def _get_current_time(self) -> float:
return asyncio.get_event_loop().time()
def _reset_session_close_countdown(self) -> None:
self._heartbeat_misses = 0
self._close_session_after_time_secs = None
async def get_next_expected_seq(self) -> int:
"""Get the next expected sequence number from the server."""
return self._seq_manager.get_ack()
async def get_next_sent_seq(self) -> int:
"""Get the next sequence number that the client will send."""
return self._buffer.get_next_sent_seq() or self._seq_manager.get_seq()
async def get_next_expected_ack(self) -> int:
"""Get the next expected ack that the client expects."""
return self._seq_manager.get_seq()
async def send_message(
self,
stream_id: str,
payload: dict[Any, Any] | str,
control_flags: int = 0,
service_name: str | None = None,
procedure_name: str | None = None,
span: Span | None = None,
) -> None:
"""Send serialized messages to the websockets."""
# if the session is not active, we should not do anything
if self._state != SessionState.ACTIVE:
return
await self._buffer.has_capacity()
# Start of critical section. No await between here and buffer.put()!
msg = TransportMessage(
streamId=stream_id,
id=nanoid.generate(),
from_=self._transport_id,
to=self._to_id,
seq=self._seq_manager.get_seq_and_increment(),
ack=self._seq_manager.get_ack(),
controlFlags=control_flags,
payload=payload,
serviceName=service_name,
procedureName=procedure_name,
)
if span:
with use_span(span):
trace_propagator.inject(msg, None, trace_setter)
try:
try:
self._buffer.put(msg)
except MessageBufferClosedError:
# The session is closed and is no longer accepting new messages.
return
async with self._ws_lock:
if not self._ws_wrapper.is_open():
# If the websocket is closed, we should not send the message
# and wait for the retry from the buffer.
return
await send_transport_message(
msg, self._ws_wrapper.ws, self._begin_close_session_countdown
)
except WebsocketClosedException as e:
logger.debug(
"Connection closed while sending message %r, waiting for "
"retry from buffer",
type(e),
exc_info=e,
)
except FailedSendingMessageException:
logger.error(
"Failed sending message, waiting for retry from buffer", exc_info=True
)
async def close_websocket(
self, ws_wrapper: WebsocketWrapper, should_retry: bool
) -> None:
"""Mark the websocket as closed, close the websocket, and retry if needed."""
async with self._ws_lock:
# Already closed.
if not ws_wrapper.is_open():
return
await ws_wrapper.close()
if should_retry and self._retry_connection_callback:
self._task_manager.create_task(self._retry_connection_callback())
def _should_abort_streams_after_transport_failure(self) -> bool:
return not self._transport_options.transparent_reconnect
def _abort_all_streams(self) -> None:
"""Close all active stream channels, notifying any waiting consumers."""
if not self._streams:
return
for stream in self._streams.values():
stream.close()
self._streams.clear()
async def close(self) -> None:
"""Close the session and all associated streams."""
logger.info(
f"{self._transport_id} closing session "
f"to {self._to_id}, ws: {self._ws_wrapper.id}, "
f"current_state : {self._ws_wrapper.ws_state.name}"
)
async with self._state_lock:
if self._state != SessionState.ACTIVE:
# already closing
return
self._state = SessionState.CLOSING
self._reset_session_close_countdown()
await self._task_manager.cancel_all_tasks()
await self.close_websocket(self._ws_wrapper, should_retry=False)
await self._buffer.close()
# Clear the session in transports
await self._close_session_callback(self)
# TODO: unexpected_close should close stream differently here to
# throw exception correctly.
self._abort_all_streams()
self._state = SessionState.CLOSED
async def check_to_close_session(
transport_id: str,
close_session_check_interval_ms: float,
get_state: Callable[[], SessionState],
get_current_time: Callable[[], Awaitable[float]],
get_close_session_after_time_secs: Callable[[], float | None],
do_close: Callable[[], Awaitable[None]],
) -> None:
while True:
await asyncio.sleep(close_session_check_interval_ms / 1000)
if get_state() in TerminalStates:
# already closing
return
# calculate the value now before comparing it so that there are no
# await points between the check and the comparison to avoid a TOCTOU
# race.
current_time = await get_current_time()
close_session_after_time_secs = get_close_session_after_time_secs()
if not close_session_after_time_secs:
continue
if current_time > close_session_after_time_secs:
logger.info("Grace period ended for %s, closing session", transport_id)
await do_close()
return
async def setup_heartbeat(
session_id: str,
heartbeat_ms: float,
heartbeats_until_dead: int,
get_state: Callable[[], SessionState],
get_closing_grace_period: Callable[[], float | None],
close_websocket: Callable[[], Awaitable[None]],
send_message: SendMessage[None],
increment_and_get_heartbeat_misses: Callable[[], int],
) -> None:
while True:
await asyncio.sleep(heartbeat_ms / 1000)
state = get_state()
if state != SessionState.ACTIVE:
logger.debug("Websocket is not connected, not sending heartbeat")
continue
if state in TerminalStates:
logger.debug(
"Session is closed, no need to send heartbeat, state : "
"%r close_session_after_this: %r",
{state},
{get_closing_grace_period()},
)
# session is closing / closed, no need to send heartbeat anymore
return
try:
await send_message(
stream_id="heartbeat",
# TODO: make this a message class
# https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
payload={
"ack": 0,
},
control_flags=ACK_BIT,
procedure_name=None,
service_name=None,
span=None,
)
if increment_and_get_heartbeat_misses() > heartbeats_until_dead:
if get_closing_grace_period() is not None:
# already in grace period, no need to set again
continue
logger.info(
"%r closing websocket because of heartbeat misses",
session_id,
)
await close_websocket()
continue
except FailedSendingMessageException:
# this is expected during websocket closed period
continue