diff --git a/lib/go/thrift/server_socket.go b/lib/go/thrift/server_socket.go index 164221e92b..8b38c0bfd1 100644 --- a/lib/go/thrift/server_socket.go +++ b/lib/go/thrift/server_socket.go @@ -17,6 +17,7 @@ * under the License. */ + package thrift import ( @@ -26,6 +27,8 @@ import ( ) type TServerSocket struct { + // TServerSocketListenerFactory abstracts how listeners are created. + factory func(net.Addr) (net.Listener, error) addr net.Addr clientTimeout time.Duration @@ -44,28 +47,61 @@ func NewTServerSocketTimeout(listenAddr string, clientTimeout time.Duration) (*T if err != nil { return nil, err } - return &TServerSocket{addr: addr, clientTimeout: clientTimeout}, nil + + return NewTServerSocketFromAddrTimeout(addr, clientTimeout), nil } +// NewTServerSocketFromAddrTimeout returns TServerSocket // Creates a TServerSocket from a net.Addr func NewTServerSocketFromAddrTimeout(addr net.Addr, clientTimeout time.Duration) *TServerSocket { - return &TServerSocket{addr: addr, clientTimeout: clientTimeout} + factory := func(addr net.Addr) (net.Listener, error) { + return net.Listen(addr.Network(), addr.String()) + } + + return NewTServerSocketFromFactoryTimeout(factory, addr, clientTimeout) } -func (p *TServerSocket) Listen() error { +// NewTServerSocketFromFactoryTimeout returns TServerSocket +// Allows full customization (TLS, mocks, unix sockets, windows named pipes, etc.) +func NewTServerSocketFromFactoryTimeout(factory func(addr net.Addr) (listener net.Listener, err error), addr net.Addr, clientTimeout time.Duration) *TServerSocket { + return &TServerSocket{ + factory: factory, + addr: addr, + clientTimeout: clientTimeout, + } +} + +func (p *TServerSocket) try_listen(raise bool) error { p.mu.Lock() defer p.mu.Unlock() - if p.IsListening() { + + if p.listener != nil { + if (raise) { + return NewTTransportException(ALREADY_OPEN, "Server socket already open") + } return nil } - l, err := net.Listen(p.addr.Network(), p.addr.String()) + + l, err := p.factory(p.addr) if err != nil { return err } + p.listener = l + p.interrupted = false return nil } +// Open does try to listen and return on failure +// Connects the socket, creating a new socket object if necessary. +func (p *TServerSocket) Open() error { + return p.try_listen(true /* raise error if listening */) +} + +func (p *TServerSocket) Listen() error { + return p.try_listen(false /* do not raise error if listening */) +} + func (p *TServerSocket) Accept() (TTransport, error) { p.mu.RLock() interrupted := p.interrupted @@ -87,51 +123,43 @@ func (p *TServerSocket) Accept() (TTransport, error) { return NewTSocketFromConnTimeout(conn, p.clientTimeout), nil } +// IsListening returns listener != nil // Checks whether the socket is listening. func (p *TServerSocket) IsListening() bool { + p.mu.RLock() + defer p.mu.RUnlock() return p.listener != nil } -// Connects the socket, creating a new socket object if necessary. -func (p *TServerSocket) Open() error { - p.mu.Lock() - defer p.mu.Unlock() - if p.IsListening() { - return NewTTransportException(ALREADY_OPEN, "Server socket already open") - } - if l, err := net.Listen(p.addr.Network(), p.addr.String()); err != nil { - return err - } else { - p.listener = l - } - return nil -} - func (p *TServerSocket) Addr() net.Addr { p.mu.RLock() defer p.mu.RUnlock() - if p.IsListening() { + + if p.listener != nil { return p.listener.Addr() } return p.addr } -func (p *TServerSocket) Close() error { - var err error +func (p *TServerSocket) try_close(interrupt bool) error { p.mu.Lock() - if p.IsListening() { + defer p.mu.Unlock() + if (interrupt){ + p.interrupted = true + } + + var err error = nil + if p.listener != nil { err = p.listener.Close() p.listener = nil } - p.mu.Unlock() return err } -func (p *TServerSocket) Interrupt() error { - p.mu.Lock() - p.interrupted = true - p.mu.Unlock() - p.Close() +func (p *TServerSocket) Close() error { + return p.try_close(false /* do not set interrupted flag */) +} - return nil +func (p *TServerSocket) Interrupt() error { + return p.try_close(true /* set interrupted flag */) }