2626
2727
2828class SocketMixin (BaseTransport ):
29- def _close_socket_safely (self ) -> None :
30- if self ._socket is not None :
31- sock = self ._socket
32- self ._socket = None
33- with contextlib .suppress (ssl .SSLError , Exception ):
34- sock .close ()
35-
3629 @property
3730 def sock (self ) -> socket .socket :
3831 if self ._socket is None or not self .is_connected :
@@ -127,7 +120,7 @@ def _create_socket_with_proxy(self, proxy: str) -> socket.socket:
127120 else :
128121 sock .sendall (b"\x05 \x01 \x00 " )
129122
130- response = self ._recv_exactly (sock , 2 )
123+ response = self ._recv_exactly_plain (sock , 2 )
131124 if response [0 ] != 0x05 :
132125 sock .close ()
133126 raise ConnectionError ("Invalid SOCKS5 proxy response" )
@@ -152,7 +145,7 @@ def _create_socket_with_proxy(self, proxy: str) -> socket.socket:
152145 auth_req = b"\x01 " + bytes ([len (u )]) + u + bytes ([len (p )]) + p
153146 sock .sendall (auth_req )
154147
155- auth_resp = self ._recv_exactly (sock , 2 )
148+ auth_resp = self ._ssl_read_exactly (sock , 2 )
156149 if auth_resp != b"\x01 \x00 " :
157150 sock .close ()
158151 raise ConnectionError ("SOCKS5 authentication failed" )
@@ -166,34 +159,41 @@ def _create_socket_with_proxy(self, proxy: str) -> socket.socket:
166159 )
167160 sock .sendall (connect_req )
168161
169- resp = self ._recv_exactly (sock , 4 )
162+ resp = self ._recv_exactly_plain (sock , 4 )
170163 if resp [0 ] != 0x05 or resp [1 ] != 0x00 :
171164 sock .close ()
172165 raise ConnectionError (f"SOCKS5 connect failed, code={ resp [1 ]} " )
173166
174167 atyp = resp [3 ]
175168 if atyp == 0x01 :
176- self ._recv_exactly (sock , 4 + 2 )
169+ self ._recv_exactly_plain (sock , 4 + 2 )
177170 elif atyp == 0x03 :
178- domain_len = self ._recv_exactly (sock , 1 )[0 ]
179- self ._recv_exactly (sock , domain_len + 2 )
171+ domain_len = self ._recv_exactly_plain (sock , 1 )[0 ]
172+ self ._recv_exactly_plain (sock , domain_len + 2 )
180173 elif atyp == 0x04 :
181- self ._recv_exactly (sock , 16 + 2 )
174+ self ._recv_exactly_plain (sock , 16 + 2 )
182175 else :
183176 sock .close ()
184177 raise ConnectionError (f"Unknown ATYP: { atyp } " )
185178
186179 return sock
187180
181+ def _recv_exactly_plain (self , sock : socket .socket , n : int ) -> bytes :
182+ buf = bytearray ()
183+ while len (buf ) < n :
184+ chunk = sock .recv (n - len (buf ))
185+ if not chunk :
186+ raise ConnectionError ("Socket closed during SOCKS handshake" )
187+ buf .extend (chunk )
188+ return bytes (buf )
189+
188190 def _perform_ssl_handshake (self , raw_sock : socket .socket ) -> socket .socket :
189191 """
190192 Выполняет SSL handshake с сервером.
191193
192194 :param raw_sock: Обычный сокет
193195 :return: SSL сокет
194196 """
195- raw_sock .setblocking (True )
196-
197197 try :
198198 raw_sock .settimeout (10.0 )
199199 wrapped = self ._ssl_context .wrap_socket (
@@ -203,7 +203,7 @@ def _perform_ssl_handshake(self, raw_sock: socket.socket) -> socket.socket:
203203 suppress_ragged_eofs = True ,
204204 )
205205 wrapped .setsockopt (socket .SOL_SOCKET , socket .SO_KEEPALIVE , 1 )
206- wrapped . setblocking ( True )
206+
207207 return wrapped
208208 except ssl .SSLError as e :
209209 self .logger .error ("SSL handshake failed: %s" , e )
@@ -282,83 +282,45 @@ async def connect(self, user_agent: UserAgentPayload | None = None) -> dict[str,
282282 self .logger .debug ("Handshake location: %s" , data .get ("payload" , {}).get ("location" ))
283283 return data
284284
285- def _recv_exactly (self , sock : socket .socket , n : int ) -> bytes :
286- """
287- Получает ровно n байт из сокета. Обрабатывает SSL ошибки корректно.
288- """
289- buf = bytearray ()
290- try :
291- while len (buf ) < n :
292- try :
293- chunk = sock .recv (n - len (buf ))
294- except ssl .SSLWantReadError :
295- continue
296- except ssl .SSLWantWriteError :
297- continue
298-
299- if not chunk :
300- break
301- buf .extend (chunk )
302- return bytes (buf )
303- except (ssl .SSLError , ConnectionError , BrokenPipeError ) as e :
304- self .logger .debug ("SSL/Connection error in _recv_exactly: %s" , e )
305- raise
285+ def _ssl_read_exactly (self , sock : socket .socket , n : int ) -> bytes :
286+ while len (self ._read_buffer ) < n :
287+ chunk = sock .recv (8192 )
288+ if not chunk :
289+ raise ConnectionResetError ("SSL socket closed" )
290+ self ._read_buffer .extend (chunk )
306291
307- async def _parse_header (
308- self , loop : asyncio .AbstractEventLoop , sock : socket .socket
309- ) -> bytes | None :
310- header = await loop .run_in_executor (None , lambda : self ._recv_exactly (sock = sock , n = 10 ))
311- if not header or len (header ) < 10 :
312- self .logger .error (
313- "Socket connection closed (incomplete header: %d bytes received)" ,
314- len (header ) if header else 0 ,
315- )
316- self .is_connected = False
317- raise ConnectionResetError ("Socket closed while reading header" )
292+ data = self ._read_buffer [:n ]
293+ del self ._read_buffer [:n ]
294+ return bytes (data )
318295
296+ async def _parse_header (self , loop , sock ):
297+ async with self ._sock_lock :
298+ header = await loop .run_in_executor (None , lambda : self ._ssl_read_exactly (sock , 10 ))
319299 return header
320300
321- async def _recv_data (
322- self , loop : asyncio .AbstractEventLoop , header : bytes , sock : socket .socket
323- ) -> list [dict [str , Any ]] | None :
324- packed_len = int .from_bytes (header [6 :10 ], "big" , signed = False )
301+ async def _recv_data (self , loop , header , sock ):
302+ packed_len = int .from_bytes (header [6 :10 ], "big" )
325303 payload_length = packed_len & 0xFFFFFF
326- remaining = payload_length
327- payload = bytearray ()
328-
329- while remaining > 0 :
330- min_read = min (remaining , 8192 )
331- chunk = await loop .run_in_executor (None , lambda : self ._recv_exactly (sock , min_read ))
332- if not chunk :
333- self .logger .error ("Connection closed while reading payload" )
334- break
335- payload .extend (chunk )
336- remaining -= len (chunk )
337-
338- if remaining > 0 :
339- self .logger .error ("Incomplete payload received; skipping packet" )
340- return None
341304
342- raw = header + payload
343- if len (raw ) < 10 + payload_length :
344- self .logger .error (
345- "Incomplete packet: expected %d bytes, got %d" ,
346- 10 + payload_length ,
347- len (raw ),
348- )
349- await asyncio .sleep (RECV_LOOP_BACKOFF_DELAY )
350- return None
305+ if payload_length == 0 :
306+ raw = header
307+ else :
308+ async with self ._sock_lock :
309+ payload = await loop .run_in_executor (
310+ None , lambda : self ._ssl_read_exactly (sock , payload_length )
311+ )
312+ raw = header + payload
351313
352314 data = self ._unpack_packet (raw )
353315 if not data :
354- self .logger .warning ("Failed to unpack packet, skipping " )
316+ self .logger .warning ("Failed to unpack packet" )
355317 return None
356318
357319 payload_objs = data .get ("payload" )
358320 return (
359- [{** data , "payload" : obj } for obj in payload_objs ]
360- if isinstance (payload_objs , list )
361- else [data ]
321+ [{** data , "payload" : payload_objs } ]
322+ if not isinstance (payload_objs , list )
323+ else [{ ** data , "payload" : obj } for obj in payload_objs ]
362324 )
363325
364326 async def _recv_loop (self ) -> None :
@@ -422,7 +384,6 @@ async def _recv_loop(self) -> None:
422384 )
423385 self .is_connected = False
424386
425- self ._close_socket_safely ()
426387 self ._socket = None
427388
428389 if self .reconnect and consecutive_errors < max_consecutive_errors :
@@ -433,12 +394,14 @@ async def _recv_loop(self) -> None:
433394 else :
434395 self .logger .warning (...)
435396 break
397+ except socket .timeout :
398+ self .logger .debug ("Socket timeout, continuing recv loop" )
399+ continue
436400 except Exception as e :
437401 consecutive_errors += 1
438402 self .logger .exception ("Error in recv_loop: %s" , e )
439403 self .is_connected = False
440404
441- self ._close_socket_safely ()
442405 self ._socket = None
443406
444407 if self .reconnect and consecutive_errors < max_consecutive_errors :
@@ -461,7 +424,6 @@ async def _send_and_wait(
461424 if not self .is_connected or self ._socket is None :
462425 raise SocketNotConnectedError
463426
464- sock = self .sock
465427 msg = self ._make_message (opcode , payload , cmd )
466428 loop = asyncio .get_running_loop ()
467429 fut : asyncio .Future [dict [str , Any ]] = loop .create_future ()
@@ -484,7 +446,13 @@ async def _send_and_wait(
484446 msg ["opcode" ],
485447 msg ["payload" ],
486448 )
487- await loop .run_in_executor (None , lambda : sock .sendall (packet ))
449+ async with self ._sock_lock :
450+ if not self .is_connected or self ._socket is None :
451+ raise SocketNotConnectedError
452+
453+ sock = self ._socket
454+ await loop .run_in_executor (None , lambda : sock .sendall (packet ))
455+
488456 data = await asyncio .wait_for (fut , timeout = timeout )
489457 self .logger .debug (
490458 "Received frame for seq=%s opcode=%s" ,
@@ -497,7 +465,6 @@ async def _send_and_wait(
497465 self .logger .warning ("Connection lost: %s, attempting reconnect..." , conn_err )
498466 self .is_connected = False
499467
500- self ._close_socket_safely ()
501468 except asyncio .TimeoutError :
502469 self .logger .exception ("Send and wait failed (opcode=%s, seq=%s)" , opcode , msg ["seq" ])
503470 raise SocketSendError from None
0 commit comments