Skip to content

Commit 3f07de2

Browse files
committed
fix ssl
1 parent 1f133bd commit 3f07de2

3 files changed

Lines changed: 217 additions & 10 deletions

File tree

tormysql/connections.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def finish(future):
221221
else:
222222
child_gr.switch(future.result())
223223

224-
future = self._sock.start_tls(False, self.ctx, server_hostname=self.host)
224+
future = self._sock.start_tls(False, self.ctx, server_hostname=self.host, connect_timeout=self.connect_timeout)
225225
future.add_done_callback(finish)
226226
self._rfile = self._sock = main.switch()
227227

tormysql/platform/asyncio.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, address, bind_address):
2626
self._transport = None
2727
self._close_callback = None
2828
self._connect_future = None
29+
self._connect_ssl_future = None
2930
self._read_future = None
3031
self._read_bytes = 0
3132
self._closed = False
@@ -47,6 +48,13 @@ def on_closed(self, exc_info = False):
4748
self._connect_future.set_exception(StreamClosedError(None))
4849
self._connect_future = None
4950

51+
if self._connect_ssl_future:
52+
if exc_info:
53+
self._connect_ssl_future.set_exception(exc_info[1] if isinstance(exc_info, tuple) else exc_info)
54+
else:
55+
self._connect_ssl_future.set_exception(StreamClosedError(None))
56+
self._connect_ssl_future = None
57+
5058
if self._read_future:
5159
if exc_info:
5260
self._read_future.set_exception(exc_info[1] if isinstance(exc_info, tuple) else exc_info)
@@ -82,12 +90,12 @@ def connect(self, address, connect_timeout = 0, server_hostname = None):
8290
self._loop = current_ioloop()
8391
future = self._connect_future = Future(loop=self._loop)
8492
if connect_timeout:
85-
def timeout():
93+
def on_timeout():
8694
self._loop_connect_timeout = None
8795
if self._connect_future:
8896
self.close((None, IOError("Connect timeout"), None))
8997

90-
self._loop_connect_timeout = self._loop.call_later(connect_timeout, connect_timeout)
98+
self._loop_connect_timeout = self._loop.call_later(connect_timeout, on_timeout)
9199

92100
def connected(connect_future):
93101
if self._loop_connect_timeout:
@@ -96,18 +104,18 @@ def connected(connect_future):
96104

97105
if connect_future._exception is not None:
98106
self.on_closed(connect_future.exception())
107+
self._connect_future = None
99108
else:
100109
self._connect_future = None
101110
future.set_result(connect_future.result())
102-
self._connect_future = None
103111

104112
connect_future = ensure_future(self._connect(address, server_hostname))
105113
connect_future.add_done_callback(connected)
106114
return self._connect_future
107115

108116
def connection_made(self, transport):
109117
self._transport = transport
110-
if self._connect_future is None:
118+
if self._connect_future is None and self._connect_ssl_future is None:
111119
transport.close()
112120
else:
113121
self._transport.set_write_buffer_limits(1024 * 1024 * 1024)
@@ -151,4 +159,40 @@ def write(self, data):
151159
if self._closed:
152160
raise StreamClosedError(IOError('Already Closed'))
153161

154-
self._transport.write(data)
162+
self._transport.write(data)
163+
164+
def start_tls(self, server_side, ssl_options=None, server_hostname=None, connect_timeout=None):
165+
if not self._transport or self._read_future:
166+
raise ValueError("IOStream is not idle; cannot convert to SSL")
167+
168+
self._connect_ssl_future = connect_ssl_future = Future(loop=self._loop)
169+
waiter = Future(loop=self._loop)
170+
171+
def on_connected(future):
172+
if self._loop_connect_timeout:
173+
self._loop_connect_timeout.cancel()
174+
self._loop_connect_timeout = None
175+
176+
if connect_ssl_future._exception is not None:
177+
self.on_closed(future.exception())
178+
self._connect_ssl_future = None
179+
else:
180+
self._connect_ssl_future = None
181+
connect_ssl_future.set_result(self)
182+
waiter.add_done_callback(on_connected)
183+
184+
if connect_timeout:
185+
def on_timeout():
186+
self._loop_connect_timeout = None
187+
if not waiter.done():
188+
self.close((None, IOError("Connect timeout"), None))
189+
190+
self._loop_connect_timeout = self._loop.call_later(connect_timeout, on_timeout)
191+
192+
self._transport.pause_reading()
193+
sock, self._transport._sock = self._transport._sock, None
194+
self._transport = self._loop._make_ssl_transport(
195+
sock, self, ssl_options, waiter,
196+
server_side=False, server_hostname=server_hostname)
197+
198+
return connect_ssl_future

tormysql/platform/tornado.py

Lines changed: 167 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import socket
99
import errno
10-
from tornado.iostream import IOStream as BaseIOStream, StreamClosedError, _ERRNO_WOULDBLOCK
10+
from tornado.iostream import IOStream as BaseIOStream, SSLIOStream as BaseSSLIOStream, StreamClosedError, _ERRNO_WOULDBLOCK, ssl, ssl_wrap_socket, _client_ssl_defaults
1111
from tornado.concurrent import Future
1212
from tornado.gen import coroutine
1313
from tornado.ioloop import IOLoop
@@ -24,8 +24,9 @@ def current_ioloop():
2424

2525

2626
class IOStream(BaseIOStream):
27-
def __init__(self, address, bind_address, *args, **kwargs):
28-
socket = self.init_socket(address, bind_address)
27+
def __init__(self, address, bind_address, socket = None, *args, **kwargs):
28+
if socket is None:
29+
socket = self.init_socket(address, bind_address)
2930

3031
super(IOStream, self).__init__(socket, *args, **kwargs)
3132

@@ -192,4 +193,166 @@ def write(self, data):
192193
if self._write_buffer_size:
193194
if not self._state & self.io_loop.WRITE:
194195
self._state = self._state | self.io_loop.WRITE
195-
self.io_loop.update_handler(self.fileno(), self._state)
196+
self.io_loop.update_handler(self.fileno(), self._state)
197+
198+
def start_tls(self, server_side, ssl_options=None, server_hostname=None, connect_timeout = None):
199+
if (self._read_callback or self._read_future or
200+
self._write_callback or self._write_futures or
201+
self._connect_callback or self._connect_future or
202+
self._pending_callbacks or self._closed or
203+
self._read_buffer or self._write_buffer):
204+
raise ValueError("IOStream is not idle; cannot convert to SSL")
205+
206+
if ssl_options is None:
207+
ssl_options = _client_ssl_defaults
208+
209+
socket = self.socket
210+
self.io_loop.remove_handler(socket)
211+
self.socket = None
212+
socket = ssl_wrap_socket(socket, ssl_options,
213+
server_hostname=server_hostname,
214+
server_side=server_side,
215+
do_handshake_on_connect=False)
216+
orig_close_callback = self._close_callback
217+
self._close_callback = None
218+
219+
future = Future()
220+
ssl_stream = SSLIOStream(socket, ssl_options=ssl_options)
221+
222+
# Wrap the original close callback so we can fail our Future as well.
223+
# If we had an "unwrap" counterpart to this method we would need
224+
# to restore the original callback after our Future resolves
225+
# so that repeated wrap/unwrap calls don't build up layers.
226+
227+
def close_callback():
228+
if not future.done():
229+
# Note that unlike most Futures returned by IOStream,
230+
# this one passes the underlying error through directly
231+
# instead of wrapping everything in a StreamClosedError
232+
# with a real_error attribute. This is because once the
233+
# connection is established it's more helpful to raise
234+
# the SSLError directly than to hide it behind a
235+
# StreamClosedError (and the client is expecting SSL
236+
# issues rather than network issues since this method is
237+
# named start_tls).
238+
future.set_exception(ssl_stream.error or StreamClosedError())
239+
if orig_close_callback is not None:
240+
orig_close_callback()
241+
242+
if connect_timeout:
243+
def timeout():
244+
ssl_stream._loop_connect_timeout = None
245+
if not future.done():
246+
ssl_stream.close((None, IOError("Connect timeout"), None))
247+
248+
ssl_stream._loop_connect_timeout = self.io_loop.call_later(connect_timeout, timeout)
249+
250+
ssl_stream.set_close_callback(close_callback)
251+
ssl_stream._ssl_connect_callback = lambda: future.set_result(ssl_stream)
252+
ssl_stream.max_buffer_size = self.max_buffer_size
253+
ssl_stream.read_chunk_size = self.read_chunk_size
254+
return future
255+
256+
class SSLIOStream(IOStream, BaseSSLIOStream):
257+
def __init__(self, socket, *args, **kwargs):
258+
self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults)
259+
IOStream.__init__(self, None, None, socket, *args, **kwargs)
260+
261+
self._ssl_accepting = True
262+
self._handshake_reading = False
263+
self._handshake_writing = False
264+
self._ssl_connect_callback = None
265+
self._loop_connect_timeout = None
266+
self._server_hostname = None
267+
268+
# If the socket is already connected, attempt to start the handshake.
269+
try:
270+
self.socket.getpeername()
271+
except socket.error:
272+
pass
273+
else:
274+
# Indirectly start the handshake, which will run on the next
275+
# IOLoop iteration and then the real IO state will be set in
276+
# _handle_events.
277+
self._add_io_state(self.io_loop.WRITE)
278+
279+
def _handle_read(self):
280+
if self._ssl_accepting:
281+
self._do_ssl_handshake()
282+
return
283+
284+
chunk = True
285+
286+
while True:
287+
try:
288+
chunk = self.socket.recv(self.read_chunk_size)
289+
if not chunk:
290+
break
291+
if self._read_buffer_size:
292+
self._read_buffer += chunk
293+
else:
294+
self._read_buffer = bytearray(chunk)
295+
self._read_buffer_size += len(chunk)
296+
except ssl.SSLError as e:
297+
if e.args[0] == ssl.SSL_ERROR_WANT_READ:
298+
break
299+
300+
self.close(exc_info=True)
301+
return
302+
except (socket.error, IOError, OSError) as e:
303+
en = e.errno if hasattr(e, 'errno') else e.args[0]
304+
if en in _ERRNO_WOULDBLOCK:
305+
break
306+
307+
if en == errno.EINTR:
308+
continue
309+
310+
self.close(exc_info=True)
311+
return
312+
313+
if self._read_future is not None and self._read_buffer_size >= self._read_bytes:
314+
future, self._read_future = self._read_future, None
315+
self._read_buffer, data = bytearray(), self._read_buffer
316+
self._read_buffer_size = 0
317+
self._read_bytes = 0
318+
future.set_result(data)
319+
320+
if not chunk:
321+
self.close()
322+
return
323+
324+
def _handle_write(self):
325+
if self._ssl_accepting:
326+
self._do_ssl_handshake()
327+
return
328+
329+
try:
330+
num_bytes = self.socket.send(memoryview(self._write_buffer)[
331+
self._write_buffer_pos: self._write_buffer_pos + self._write_buffer_size])
332+
self._write_buffer_pos += num_bytes
333+
self._write_buffer_size -= num_bytes
334+
except ssl.SSLError as e:
335+
if e.args[0] != ssl.SSL_ERROR_WANT_WRITE:
336+
self.close(exc_info=True)
337+
return
338+
except (socket.error, IOError, OSError) as e:
339+
en = e.errno if hasattr(e, 'errno') else e.args[0]
340+
if en not in _ERRNO_WOULDBLOCK:
341+
self.close(exc_info=True)
342+
return
343+
344+
if not self._write_buffer_size:
345+
if self._write_buffer_pos > 0:
346+
self._write_buffer = bytearray()
347+
self._write_buffer_pos = 0
348+
349+
if self._state & self.io_loop.WRITE:
350+
self._state = self._state & ~self.io_loop.WRITE
351+
self.io_loop.update_handler(self.fileno(), self._state)
352+
353+
def _run_ssl_connect_callback(self):
354+
if self._state & self.io_loop.WRITE:
355+
self._state = self._state & ~self.io_loop.WRITE
356+
self.io_loop.update_handler(self.fileno(), self._state)
357+
358+
BaseSSLIOStream._run_ssl_connect_callback(self)

0 commit comments

Comments
 (0)