77import sys
88import socket
99import errno
10- from tornado .iostream import IOStream as BaseIOStream , StreamClosedError , _ERRNO_WOULDBLOCK
10+ from tornado .iostream import IOStream as BaseIOStream , SSLIOStream as BaseSSLIOStream , StreamClosedError , _ERRNO_WOULDBLOCK , ssl , ssl_wrap_socket , _client_ssl_defaults
1111from tornado .concurrent import Future
1212from tornado .gen import coroutine
1313from tornado .ioloop import IOLoop
@@ -24,8 +24,9 @@ def current_ioloop():
2424
2525
2626class IOStream (BaseIOStream ):
27- def __init__ (self , address , bind_address , * args , ** kwargs ):
28- socket = self .init_socket (address , bind_address )
27+ def __init__ (self , address , bind_address , socket = None , * args , ** kwargs ):
28+ if socket is None :
29+ socket = self .init_socket (address , bind_address )
2930
3031 super (IOStream , self ).__init__ (socket , * args , ** kwargs )
3132
@@ -192,4 +193,166 @@ def write(self, data):
192193 if self ._write_buffer_size :
193194 if not self ._state & self .io_loop .WRITE :
194195 self ._state = self ._state | self .io_loop .WRITE
195- self .io_loop .update_handler (self .fileno (), self ._state )
196+ self .io_loop .update_handler (self .fileno (), self ._state )
197+
198+ def start_tls (self , server_side , ssl_options = None , server_hostname = None , connect_timeout = None ):
199+ if (self ._read_callback or self ._read_future or
200+ self ._write_callback or self ._write_futures or
201+ self ._connect_callback or self ._connect_future or
202+ self ._pending_callbacks or self ._closed or
203+ self ._read_buffer or self ._write_buffer ):
204+ raise ValueError ("IOStream is not idle; cannot convert to SSL" )
205+
206+ if ssl_options is None :
207+ ssl_options = _client_ssl_defaults
208+
209+ socket = self .socket
210+ self .io_loop .remove_handler (socket )
211+ self .socket = None
212+ socket = ssl_wrap_socket (socket , ssl_options ,
213+ server_hostname = server_hostname ,
214+ server_side = server_side ,
215+ do_handshake_on_connect = False )
216+ orig_close_callback = self ._close_callback
217+ self ._close_callback = None
218+
219+ future = Future ()
220+ ssl_stream = SSLIOStream (socket , ssl_options = ssl_options )
221+
222+ # Wrap the original close callback so we can fail our Future as well.
223+ # If we had an "unwrap" counterpart to this method we would need
224+ # to restore the original callback after our Future resolves
225+ # so that repeated wrap/unwrap calls don't build up layers.
226+
227+ def close_callback ():
228+ if not future .done ():
229+ # Note that unlike most Futures returned by IOStream,
230+ # this one passes the underlying error through directly
231+ # instead of wrapping everything in a StreamClosedError
232+ # with a real_error attribute. This is because once the
233+ # connection is established it's more helpful to raise
234+ # the SSLError directly than to hide it behind a
235+ # StreamClosedError (and the client is expecting SSL
236+ # issues rather than network issues since this method is
237+ # named start_tls).
238+ future .set_exception (ssl_stream .error or StreamClosedError ())
239+ if orig_close_callback is not None :
240+ orig_close_callback ()
241+
242+ if connect_timeout :
243+ def timeout ():
244+ ssl_stream ._loop_connect_timeout = None
245+ if not future .done ():
246+ ssl_stream .close ((None , IOError ("Connect timeout" ), None ))
247+
248+ ssl_stream ._loop_connect_timeout = self .io_loop .call_later (connect_timeout , timeout )
249+
250+ ssl_stream .set_close_callback (close_callback )
251+ ssl_stream ._ssl_connect_callback = lambda : future .set_result (ssl_stream )
252+ ssl_stream .max_buffer_size = self .max_buffer_size
253+ ssl_stream .read_chunk_size = self .read_chunk_size
254+ return future
255+
256+ class SSLIOStream (IOStream , BaseSSLIOStream ):
257+ def __init__ (self , socket , * args , ** kwargs ):
258+ self ._ssl_options = kwargs .pop ('ssl_options' , _client_ssl_defaults )
259+ IOStream .__init__ (self , None , None , socket , * args , ** kwargs )
260+
261+ self ._ssl_accepting = True
262+ self ._handshake_reading = False
263+ self ._handshake_writing = False
264+ self ._ssl_connect_callback = None
265+ self ._loop_connect_timeout = None
266+ self ._server_hostname = None
267+
268+ # If the socket is already connected, attempt to start the handshake.
269+ try :
270+ self .socket .getpeername ()
271+ except socket .error :
272+ pass
273+ else :
274+ # Indirectly start the handshake, which will run on the next
275+ # IOLoop iteration and then the real IO state will be set in
276+ # _handle_events.
277+ self ._add_io_state (self .io_loop .WRITE )
278+
279+ def _handle_read (self ):
280+ if self ._ssl_accepting :
281+ self ._do_ssl_handshake ()
282+ return
283+
284+ chunk = True
285+
286+ while True :
287+ try :
288+ chunk = self .socket .recv (self .read_chunk_size )
289+ if not chunk :
290+ break
291+ if self ._read_buffer_size :
292+ self ._read_buffer += chunk
293+ else :
294+ self ._read_buffer = bytearray (chunk )
295+ self ._read_buffer_size += len (chunk )
296+ except ssl .SSLError as e :
297+ if e .args [0 ] == ssl .SSL_ERROR_WANT_READ :
298+ break
299+
300+ self .close (exc_info = True )
301+ return
302+ except (socket .error , IOError , OSError ) as e :
303+ en = e .errno if hasattr (e , 'errno' ) else e .args [0 ]
304+ if en in _ERRNO_WOULDBLOCK :
305+ break
306+
307+ if en == errno .EINTR :
308+ continue
309+
310+ self .close (exc_info = True )
311+ return
312+
313+ if self ._read_future is not None and self ._read_buffer_size >= self ._read_bytes :
314+ future , self ._read_future = self ._read_future , None
315+ self ._read_buffer , data = bytearray (), self ._read_buffer
316+ self ._read_buffer_size = 0
317+ self ._read_bytes = 0
318+ future .set_result (data )
319+
320+ if not chunk :
321+ self .close ()
322+ return
323+
324+ def _handle_write (self ):
325+ if self ._ssl_accepting :
326+ self ._do_ssl_handshake ()
327+ return
328+
329+ try :
330+ num_bytes = self .socket .send (memoryview (self ._write_buffer )[
331+ self ._write_buffer_pos : self ._write_buffer_pos + self ._write_buffer_size ])
332+ self ._write_buffer_pos += num_bytes
333+ self ._write_buffer_size -= num_bytes
334+ except ssl .SSLError as e :
335+ if e .args [0 ] != ssl .SSL_ERROR_WANT_WRITE :
336+ self .close (exc_info = True )
337+ return
338+ except (socket .error , IOError , OSError ) as e :
339+ en = e .errno if hasattr (e , 'errno' ) else e .args [0 ]
340+ if en not in _ERRNO_WOULDBLOCK :
341+ self .close (exc_info = True )
342+ return
343+
344+ if not self ._write_buffer_size :
345+ if self ._write_buffer_pos > 0 :
346+ self ._write_buffer = bytearray ()
347+ self ._write_buffer_pos = 0
348+
349+ if self ._state & self .io_loop .WRITE :
350+ self ._state = self ._state & ~ self .io_loop .WRITE
351+ self .io_loop .update_handler (self .fileno (), self ._state )
352+
353+ def _run_ssl_connect_callback (self ):
354+ if self ._state & self .io_loop .WRITE :
355+ self ._state = self ._state & ~ self .io_loop .WRITE
356+ self .io_loop .update_handler (self .fileno (), self ._state )
357+
358+ BaseSSLIOStream ._run_ssl_connect_callback (self )
0 commit comments