Skip to content

Commit 2f2e5ab

Browse files
committed
apply cloud connection fixes
1 parent 461adc5 commit 2f2e5ab

3 files changed

Lines changed: 86 additions & 34 deletions

File tree

scratchattach/cloud/_base.py

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class EventStream(SupportsRead[Iterator[dict[str, Any]]], SupportsClose):
3939
"""
4040
Allows you to stream events
4141
"""
42+
timeout: Optional[Union[float, int]] = None
4243

4344
class AnyCloud(ABC, Generic[T]):
4445
"""
@@ -129,38 +130,57 @@ def __init__(self, cloud: BaseCloud):
129130
warnings.warn("Initial cloud connection attempt failed, retrying...", exceptions.UnexpectedWebsocketEventWarning)
130131
self.packets_left = []
131132

132-
def receive_new(self, non_blocking: bool = False):
133+
def receive_new(self, non_blocking: bool = False, timeout: Optional[float] = 0):
134+
timeout = None if timeout is None else max(timeout, 0)
135+
timeout_value = self.timeout if timeout is None else timeout
133136
if non_blocking:
134-
self.source_cloud.websocket.settimeout(0)
135-
try:
136-
received = self.source_cloud.websocket.recv().splitlines()
137-
self.packets_left.extend(received)
138-
except Exception:
139-
pass
137+
timeout_value = 0
138+
if self.source_cloud.websocket.gettimeout() != timeout_value:
139+
self.source_cloud.websocket.settimeout(timeout_value)
140+
# print("Receiving...")
141+
try:
142+
received = self.source_cloud.websocket.recv().splitlines()
143+
except websocket.WebSocketTimeoutException:
140144
return
141-
self.source_cloud.websocket.settimeout(None)
142-
received = self.source_cloud.websocket.recv().splitlines()
145+
# print(f"{received=}")
143146
self.packets_left.extend(received)
144-
147+
145148
def read(self, amount: int = -1) -> Iterator[dict[str, Any]]:
149+
# print("Reading...")
146150
i = 0
151+
recv_once = amount == -1
152+
recv_at_least = max(amount, 0)
153+
start_time = time.time()
154+
if self.timeout is not None:
155+
has_timeout = True
156+
timeout_end = start_time + self.timeout
157+
else:
158+
has_timeout = False
159+
timeout_end = 0.0
160+
done = False
161+
# print("Getting data...")
147162
with self.reading:
148-
try:
149-
self.receive_new(amount != -1)
150-
while (self.packets_left and amount == -1) or (amount != -1 and i < amount):
151-
if not self.packets_left and amount != -1:
152-
self.receive_new()
153-
yield json.loads(self.packets_left.pop(0))
154-
i += 1
155-
except Exception:
156-
self.source_cloud.reconnect()
157-
self.receive_new(amount != -1)
158-
while (self.packets_left and amount == -1) or (amount != -1 and i < amount):
159-
if not self.packets_left and amount != -1:
160-
self.receive_new()
161-
yield json.loads(self.packets_left.pop(0))
162-
i += 1
163+
# print("Getting data...", end_time is None, end_time > time.time(), end_time is None or end_time > time.time())
164+
while not done:
165+
# print("Getting data...")
166+
try:
167+
self.receive_new(not recv_once, timeout = timeout_end - time.time() if has_timeout else None)
168+
while ((not has_timeout or time.time() < timeout_end)
169+
and ((recv_once and self.packets_left) or (not recv_once and i < recv_at_least))):
170+
if not self.packets_left and not recv_once:
171+
self.receive_new(timeout = timeout_end - time.time() if has_timeout else None)
172+
if not self.packets_left:
173+
continue
174+
yield json.loads(self.packets_left.pop(0))
175+
i += 1
176+
done = True
177+
except Exception:
178+
# traceback.print_exc()
179+
self.source_cloud.reconnect()
163180

181+
def __del__(self):
182+
self.close()
183+
164184
def close(self) -> None:
165185
self.source_cloud.disconnect()
166186

@@ -305,6 +325,10 @@ def _handshake(self):
305325
self._send_packet(packet)
306326

307327
def connect(self):
328+
if self.websocket:
329+
self.websocket.close()
330+
if self.event_stream:
331+
self.event_stream = None
308332
self.websocket = websocket.WebSocket(sslopt={"cert_reqs": ssl.CERT_NONE})
309333
self.websocket.connect(
310334
self.cloud_host,
@@ -329,6 +353,8 @@ def disconnect(self):
329353
self.websocket.close()
330354
except Exception:
331355
pass
356+
if self.event_stream:
357+
self.event_stream = None
332358

333359
def _assert_valid_value(self, value):
334360
if not (value in [True, False, float('inf'), -float('inf')]):
@@ -445,6 +471,9 @@ def create_event_stream(self):
445471
raise ValueError("Cloud already has an event stream.")
446472
self.event_stream = WebSocketEventStream(self)
447473
return self.event_stream
474+
475+
def __del__(self):
476+
self.disconnect()
448477

449478
class LogCloudMeta(ABCMeta):
450479
def __instancecheck__(cls, instance) -> bool:

scratchattach/eventhandlers/_base.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,21 @@
55
from threading import Thread
66
from collections.abc import Callable
77
import traceback
8+
from typing import Optional
89
from scratchattach.utils.requests import requests
910
from scratchattach.utils import exceptions
1011

1112
class BaseEventHandler(ABC):
1213
_events: defaultdict[str, list[Callable]]
1314
_threaded_events: defaultdict[str, list[Callable]]
15+
running: bool
16+
_thread: Optional[Thread]
17+
_call_threads: list[Thread]
1418

1519
def __init__(self):
1620
self._thread = None
1721
self.running = False
22+
self._call_threads = []
1823
self._events = defaultdict(list)
1924
self._threaded_events = defaultdict(list)
2025

@@ -35,14 +40,18 @@ def start(self, *, thread=True, ignore_exceptions=True):
3540
else:
3641
self._thread = None
3742
self._updater()
38-
43+
3944
def call_event(self, event_name, args : list = []):
4045
try:
46+
# print(f"Calling for {event_name}...")
4147
if event_name in self._threaded_events:
4248
for func in self._threaded_events[event_name]:
43-
Thread(target=func, args=args).start()
49+
thread = Thread(target=func, args=args)
50+
self._call_threads.append(thread)
51+
thread.start()
4452
if event_name in self._events:
4553
for func in self._events[event_name]:
54+
# print(f"Called {func}.")
4655
func(*args)
4756
except Exception as e:
4857
if self.ignore_exceptions:
@@ -54,31 +63,44 @@ def call_event(self, event_name, args : list = []):
5463
except Exception:
5564
print(e)
5665
else:
57-
raise(e)
66+
raise e
5867

5968
@abstractmethod
6069
def _updater(self):
6170
pass
6271

63-
def stop(self):
72+
def __del__(self):
73+
self.stop()
74+
75+
def stop(self, wait_call_threads: bool = True):
6476
"""
6577
Permanently stops the event handler.
6678
"""
79+
# print("Stopping event handler...")
6780
self.running = False
68-
if self._thread is not None:
81+
thread = self._thread
82+
if thread is not None:
83+
thread.join()
6984
self._thread = None
85+
if not wait_call_threads:
86+
return
87+
for thread in self._call_threads:
88+
thread.join()
7089

7190
def pause(self):
7291
"""
7392
Pauses the event handler.
7493
"""
7594
self.running = False
95+
thread = self._thread
96+
if thread is not None:
97+
thread.join()
7698

7799
def resume(self):
78100
"""
79101
Resumes the event handler.
80102
"""
81-
if self.running is False:
103+
if not self.running:
82104
self.start()
83105

84106
def event(self, function=None, *, thread=False):

scratchattach/eventhandlers/cloud_events.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ def _updater(self):
2929

3030
self.call_event("on_ready")
3131

32-
if self.running is False:
32+
if not self.running:
3333
return
34-
while True:
34+
while self.running:
3535
try:
36-
while True:
36+
while self.running:
37+
self.source_stream.timeout = 1
3738
for data in self.source_stream.read():
3839
try:
3940
_a = cloud_activity.CloudActivity(timestamp=time.time()*1000, _session=self._session, cloud=self.cloud)

0 commit comments

Comments
 (0)