Skip to content

Commit 6dd9673

Browse files
committed
fix: prevent duplicate reconnect tasks and socket leaks on auth failure
_trigger_reconnect now sets _reconnecting=True before scheduling the task, so two failures in the same event loop tick can't both spawn reconnect tasks. The redundant guard inside _reconnect is removed. _open_connection now closes the socket on any handshake failure (not just wrong response type), preventing leaked connections from recv() or json.loads() errors.
1 parent 43d059d commit 6dd9673

1 file changed

Lines changed: 23 additions & 18 deletions

File tree

getstream/ws/client.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,22 +115,29 @@ async def _open_connection(self) -> dict:
115115
close_timeout=1.0,
116116
)
117117

118-
auth_payload = {
119-
"token": self._ensure_token(),
120-
"user_details": self._user_details,
121-
}
122-
await self._websocket.send(json.dumps(auth_payload))
123-
124-
raw = await self._websocket.recv()
125-
message = json.loads(raw)
126-
127-
msg_type = message.get("type")
128-
if msg_type != "connection.ok":
129-
await self._websocket.close()
118+
try:
119+
auth_payload = {
120+
"token": self._ensure_token(),
121+
"user_details": self._user_details,
122+
}
123+
await self._websocket.send(json.dumps(auth_payload))
124+
125+
raw = await self._websocket.recv()
126+
message = json.loads(raw)
127+
128+
msg_type = message.get("type")
129+
if msg_type != "connection.ok":
130+
raise StreamWSAuthError(
131+
f"Expected connection.ok, got {msg_type}: {message}",
132+
response=message,
133+
)
134+
except Exception:
135+
try:
136+
await self._websocket.close()
137+
except Exception:
138+
pass
130139
self._websocket = None
131-
raise StreamWSAuthError(
132-
f"Expected connection.ok, got {msg_type}: {message}", response=message
133-
)
140+
raise
134141

135142
self._connection_id = message.get("connection_id")
136143
self._last_received = time.monotonic()
@@ -159,6 +166,7 @@ def _start_tasks(self) -> None:
159166

160167
def _trigger_reconnect(self, reason: str) -> None:
161168
if self._connected and not self._reconnecting:
169+
self._reconnecting = True
162170
self._reconnect_task = asyncio.create_task(self._reconnect(reason))
163171

164172
async def _reader_loop(self) -> None:
@@ -209,9 +217,6 @@ async def _heartbeat_loop(self) -> None:
209217
break
210218

211219
async def _reconnect(self, reason: str) -> None:
212-
if self._reconnecting:
213-
return
214-
self._reconnecting = True
215220
logger.info("Reconnecting: %s", reason)
216221

217222
try:

0 commit comments

Comments
 (0)