Skip to content
This repository was archived by the owner on Feb 11, 2026. It is now read-only.

Commit d6cbc69

Browse files
committed
финальные (!) фиксы сокета (надеюсь)
1 parent 65f9843 commit d6cbc69

4 files changed

Lines changed: 58 additions & 87 deletions

File tree

src/pymax/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def __init__(
140140
self._file_upload_waiters: dict[int, asyncio.Future[dict[str, Any]]] = {}
141141
self._background_tasks: set[asyncio.Task[Any]] = set()
142142
self._stop_event = asyncio.Event()
143+
self._sock_lock = asyncio.Lock()
144+
self._read_buffer = bytearray()
143145

144146
self._seq: int = 0
145147
self._error_count: int = 0

src/pymax/mixins/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ async def _login(self) -> None:
254254
raise RuntimeError("Account requires registration")
255255

256256
if password_challenge and not login_attrs:
257-
token = await self._two_factor_auth(password_challenge)
257+
token = await self._two_factor_auth(password_challenge, None)
258258
else:
259259
token = login_attrs.get("token")
260260

src/pymax/mixins/socket.py

Lines changed: 53 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,6 @@
2626

2727

2828
class SocketMixin(BaseTransport):
29-
def _close_socket_safely(self) -> None:
30-
if self._socket is not None:
31-
sock = self._socket
32-
self._socket = None
33-
with contextlib.suppress(ssl.SSLError, Exception):
34-
sock.close()
35-
3629
@property
3730
def sock(self) -> socket.socket:
3831
if self._socket is None or not self.is_connected:
@@ -127,7 +120,7 @@ def _create_socket_with_proxy(self, proxy: str) -> socket.socket:
127120
else:
128121
sock.sendall(b"\x05\x01\x00")
129122

130-
response = self._recv_exactly(sock, 2)
123+
response = self._recv_exactly_plain(sock, 2)
131124
if response[0] != 0x05:
132125
sock.close()
133126
raise ConnectionError("Invalid SOCKS5 proxy response")
@@ -152,7 +145,7 @@ def _create_socket_with_proxy(self, proxy: str) -> socket.socket:
152145
auth_req = b"\x01" + bytes([len(u)]) + u + bytes([len(p)]) + p
153146
sock.sendall(auth_req)
154147

155-
auth_resp = self._recv_exactly(sock, 2)
148+
auth_resp = self._ssl_read_exactly(sock, 2)
156149
if auth_resp != b"\x01\x00":
157150
sock.close()
158151
raise ConnectionError("SOCKS5 authentication failed")
@@ -166,34 +159,41 @@ def _create_socket_with_proxy(self, proxy: str) -> socket.socket:
166159
)
167160
sock.sendall(connect_req)
168161

169-
resp = self._recv_exactly(sock, 4)
162+
resp = self._recv_exactly_plain(sock, 4)
170163
if resp[0] != 0x05 or resp[1] != 0x00:
171164
sock.close()
172165
raise ConnectionError(f"SOCKS5 connect failed, code={resp[1]}")
173166

174167
atyp = resp[3]
175168
if atyp == 0x01:
176-
self._recv_exactly(sock, 4 + 2)
169+
self._recv_exactly_plain(sock, 4 + 2)
177170
elif atyp == 0x03:
178-
domain_len = self._recv_exactly(sock, 1)[0]
179-
self._recv_exactly(sock, domain_len + 2)
171+
domain_len = self._recv_exactly_plain(sock, 1)[0]
172+
self._recv_exactly_plain(sock, domain_len + 2)
180173
elif atyp == 0x04:
181-
self._recv_exactly(sock, 16 + 2)
174+
self._recv_exactly_plain(sock, 16 + 2)
182175
else:
183176
sock.close()
184177
raise ConnectionError(f"Unknown ATYP: {atyp}")
185178

186179
return sock
187180

181+
def _recv_exactly_plain(self, sock: socket.socket, n: int) -> bytes:
182+
buf = bytearray()
183+
while len(buf) < n:
184+
chunk = sock.recv(n - len(buf))
185+
if not chunk:
186+
raise ConnectionError("Socket closed during SOCKS handshake")
187+
buf.extend(chunk)
188+
return bytes(buf)
189+
188190
def _perform_ssl_handshake(self, raw_sock: socket.socket) -> socket.socket:
189191
"""
190192
Выполняет SSL handshake с сервером.
191193
192194
:param raw_sock: Обычный сокет
193195
:return: SSL сокет
194196
"""
195-
raw_sock.setblocking(True)
196-
197197
try:
198198
raw_sock.settimeout(10.0)
199199
wrapped = self._ssl_context.wrap_socket(
@@ -203,7 +203,7 @@ def _perform_ssl_handshake(self, raw_sock: socket.socket) -> socket.socket:
203203
suppress_ragged_eofs=True,
204204
)
205205
wrapped.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
206-
wrapped.setblocking(True)
206+
207207
return wrapped
208208
except ssl.SSLError as e:
209209
self.logger.error("SSL handshake failed: %s", e)
@@ -282,83 +282,45 @@ async def connect(self, user_agent: UserAgentPayload | None = None) -> dict[str,
282282
self.logger.debug("Handshake location: %s", data.get("payload", {}).get("location"))
283283
return data
284284

285-
def _recv_exactly(self, sock: socket.socket, n: int) -> bytes:
286-
"""
287-
Получает ровно n байт из сокета. Обрабатывает SSL ошибки корректно.
288-
"""
289-
buf = bytearray()
290-
try:
291-
while len(buf) < n:
292-
try:
293-
chunk = sock.recv(n - len(buf))
294-
except ssl.SSLWantReadError:
295-
continue
296-
except ssl.SSLWantWriteError:
297-
continue
298-
299-
if not chunk:
300-
break
301-
buf.extend(chunk)
302-
return bytes(buf)
303-
except (ssl.SSLError, ConnectionError, BrokenPipeError) as e:
304-
self.logger.debug("SSL/Connection error in _recv_exactly: %s", e)
305-
raise
285+
def _ssl_read_exactly(self, sock: socket.socket, n: int) -> bytes:
286+
while len(self._read_buffer) < n:
287+
chunk = sock.recv(8192)
288+
if not chunk:
289+
raise ConnectionResetError("SSL socket closed")
290+
self._read_buffer.extend(chunk)
306291

307-
async def _parse_header(
308-
self, loop: asyncio.AbstractEventLoop, sock: socket.socket
309-
) -> bytes | None:
310-
header = await loop.run_in_executor(None, lambda: self._recv_exactly(sock=sock, n=10))
311-
if not header or len(header) < 10:
312-
self.logger.error(
313-
"Socket connection closed (incomplete header: %d bytes received)",
314-
len(header) if header else 0,
315-
)
316-
self.is_connected = False
317-
raise ConnectionResetError("Socket closed while reading header")
292+
data = self._read_buffer[:n]
293+
del self._read_buffer[:n]
294+
return bytes(data)
318295

296+
async def _parse_header(self, loop, sock):
297+
async with self._sock_lock:
298+
header = await loop.run_in_executor(None, lambda: self._ssl_read_exactly(sock, 10))
319299
return header
320300

321-
async def _recv_data(
322-
self, loop: asyncio.AbstractEventLoop, header: bytes, sock: socket.socket
323-
) -> list[dict[str, Any]] | None:
324-
packed_len = int.from_bytes(header[6:10], "big", signed=False)
301+
async def _recv_data(self, loop, header, sock):
302+
packed_len = int.from_bytes(header[6:10], "big")
325303
payload_length = packed_len & 0xFFFFFF
326-
remaining = payload_length
327-
payload = bytearray()
328-
329-
while remaining > 0:
330-
min_read = min(remaining, 8192)
331-
chunk = await loop.run_in_executor(None, lambda: self._recv_exactly(sock, min_read))
332-
if not chunk:
333-
self.logger.error("Connection closed while reading payload")
334-
break
335-
payload.extend(chunk)
336-
remaining -= len(chunk)
337-
338-
if remaining > 0:
339-
self.logger.error("Incomplete payload received; skipping packet")
340-
return None
341304

342-
raw = header + payload
343-
if len(raw) < 10 + payload_length:
344-
self.logger.error(
345-
"Incomplete packet: expected %d bytes, got %d",
346-
10 + payload_length,
347-
len(raw),
348-
)
349-
await asyncio.sleep(RECV_LOOP_BACKOFF_DELAY)
350-
return None
305+
if payload_length == 0:
306+
raw = header
307+
else:
308+
async with self._sock_lock:
309+
payload = await loop.run_in_executor(
310+
None, lambda: self._ssl_read_exactly(sock, payload_length)
311+
)
312+
raw = header + payload
351313

352314
data = self._unpack_packet(raw)
353315
if not data:
354-
self.logger.warning("Failed to unpack packet, skipping")
316+
self.logger.warning("Failed to unpack packet")
355317
return None
356318

357319
payload_objs = data.get("payload")
358320
return (
359-
[{**data, "payload": obj} for obj in payload_objs]
360-
if isinstance(payload_objs, list)
361-
else [data]
321+
[{**data, "payload": payload_objs}]
322+
if not isinstance(payload_objs, list)
323+
else [{**data, "payload": obj} for obj in payload_objs]
362324
)
363325

364326
async def _recv_loop(self) -> None:
@@ -422,7 +384,6 @@ async def _recv_loop(self) -> None:
422384
)
423385
self.is_connected = False
424386

425-
self._close_socket_safely()
426387
self._socket = None
427388

428389
if self.reconnect and consecutive_errors < max_consecutive_errors:
@@ -433,12 +394,14 @@ async def _recv_loop(self) -> None:
433394
else:
434395
self.logger.warning(...)
435396
break
397+
except socket.timeout:
398+
self.logger.debug("Socket timeout, continuing recv loop")
399+
continue
436400
except Exception as e:
437401
consecutive_errors += 1
438402
self.logger.exception("Error in recv_loop: %s", e)
439403
self.is_connected = False
440404

441-
self._close_socket_safely()
442405
self._socket = None
443406

444407
if self.reconnect and consecutive_errors < max_consecutive_errors:
@@ -461,7 +424,6 @@ async def _send_and_wait(
461424
if not self.is_connected or self._socket is None:
462425
raise SocketNotConnectedError
463426

464-
sock = self.sock
465427
msg = self._make_message(opcode, payload, cmd)
466428
loop = asyncio.get_running_loop()
467429
fut: asyncio.Future[dict[str, Any]] = loop.create_future()
@@ -484,7 +446,13 @@ async def _send_and_wait(
484446
msg["opcode"],
485447
msg["payload"],
486448
)
487-
await loop.run_in_executor(None, lambda: sock.sendall(packet))
449+
async with self._sock_lock:
450+
if not self.is_connected or self._socket is None:
451+
raise SocketNotConnectedError
452+
453+
sock = self._socket
454+
await loop.run_in_executor(None, lambda: sock.sendall(packet))
455+
488456
data = await asyncio.wait_for(fut, timeout=timeout)
489457
self.logger.debug(
490458
"Received frame for seq=%s opcode=%s",
@@ -497,7 +465,6 @@ async def _send_and_wait(
497465
self.logger.warning("Connection lost: %s, attempting reconnect...", conn_err)
498466
self.is_connected = False
499467

500-
self._close_socket_safely()
501468
except asyncio.TimeoutError:
502469
self.logger.exception("Send and wait failed (opcode=%s, seq=%s)", opcode, msg["seq"])
503470
raise SocketSendError from None

src/pymax/protocols.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def __init__(self, logger: Logger) -> None:
7070
self._outgoing: asyncio.Queue[dict[str, Any]] | None = None
7171
self._outgoing_task: asyncio.Task[Any] | None = None
7272
self._error_count: int = 0
73+
self._sock_lock: asyncio.Lock = asyncio.Lock()
74+
self._read_buffer: bytearray = bytearray()
7375
self._circuit_breaker: bool = False
7476
self._last_error_time: float = 0.0
7577
self._session_id: int

0 commit comments

Comments
 (0)