diff --git a/README.md b/README.md index 4641d503..36c98677 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ type Settings struct { PublicHost string // Public IP to expose (only an IP address is accepted at this stage) PublicIPResolver PublicIPResolver // (Optional) To fetch a public IP lookup PassiveTransferPortRange PasvPortGetter // (Optional) Port Range for data connections. Random if not specified + PassiveTransferPortMultiplexing bool // Allow different client IPs to share passive listener ports ActiveTransferPortNon20 bool // Do not impose the port 20 for active data transfer (#88, RFC 1579) IdleTimeout int // Maximum inactivity time before disconnecting (#58) ConnectionTimeout int // Maximum time to establish passive or active transfer connections diff --git a/driver.go b/driver.go index 63443623..54562563 100644 --- a/driver.go +++ b/driver.go @@ -303,28 +303,31 @@ const ( // Settings defines all the server settings type Settings struct { - Listener net.Listener // (Optional) To provide an already initialized listener - ListenAddr string // Listening address - PublicHost string // Public IP to expose (only an IP address is accepted at this stage) - Banner string // Banner to use in server status response - PassiveTransferPortRange PasvPortGetter // (Optional) Port Mapping for data connections. Random if not specified - PublicIPResolver PublicIPResolver // (Optional) To fetch a public IP lookup - IdleTimeout int // Maximum inactivity time before disconnecting (#58) - ConnectionTimeout int // Maximum time to establish passive or active transfer connections - ActiveTransferPortNon20 bool // Do not impose the port 20 for active data transfer (#88, RFC 1579) - DisableMLSD bool // Disable MLSD support - DisableMLST bool // Disable MLST support - DisableMFMT bool // Disable MFMT support (modify file mtime) - TLSRequired TLSRequirement // defines the TLS mode - DisableLISTArgs bool // Disable ls like options (-a,-la etc.) for directory listing - DisableSite bool // Disable SITE command - DisableActiveMode bool // Disable Active FTP - EnableHASH bool // Enable support for calculating hash value of files - DisableSTAT bool // Disable Server STATUS, STAT on files and directories will still work - DisableSYST bool // Disable SYST - EnableCOMB bool // Enable COMB support - DeflateCompressionLevel int // Deflate compression level (0-9). 0 means disabled - DefaultTransferType TransferType // Transfer type to use if the client don't send the TYPE command + Listener net.Listener // (Optional) To provide an already initialized listener + ListenAddr string // Listening address + PublicHost string // Public IP to expose (only an IP address is accepted at this stage) + Banner string // Banner to use in server status response + // PassiveTransferPortRange is the optional port mapping for passive data connections. + PassiveTransferPortRange PasvPortGetter + PassiveTransferPortMultiplexing bool // Allow different client IPs to share passive listener ports + PublicIPResolver PublicIPResolver // (Optional) To fetch a public IP lookup + IdleTimeout int // Maximum inactivity time before disconnecting (#58) + ConnectionTimeout int // Maximum time to establish passive or active transfer connections + ActiveTransferPortNon20 bool // Do not impose the port 20 for active data transfer (#88, RFC 1579) + DisableMLSD bool // Disable MLSD support + DisableMLST bool // Disable MLST support + DisableMFMT bool // Disable MFMT support (modify file mtime) + TLSRequired TLSRequirement // defines the TLS mode + DisableLISTArgs bool // Disable ls like options (-a,-la etc.) for directory listing + DisableSite bool // Disable SITE command + DisableActiveMode bool // Disable Active FTP + EnableHASH bool // Enable support for calculating hash value of files + // DisableSTAT disables Server STATUS. STAT on files and directories still works. + DisableSTAT bool + DisableSYST bool // Disable SYST + EnableCOMB bool // Enable COMB support + DeflateCompressionLevel int // Deflate compression level (0-9). 0 means disabled + DefaultTransferType TransferType // Transfer type to use if the client don't send the TYPE command // ActiveConnectionsCheck defines the security requirements for active connections ActiveConnectionsCheck DataConnectionRequirement // PasvConnectionsCheck defines the security requirements for passive connections diff --git a/passive_multiplexer.go b/passive_multiplexer.go new file mode 100644 index 00000000..759bad7d --- /dev/null +++ b/passive_multiplexer.go @@ -0,0 +1,415 @@ +package ftpserver + +import ( + "errors" + "fmt" + "log/slog" + "net" + "sync" + "time" +) + +var errPassiveListenerReservedForIP = errors.New("passive listener already reserved for client ip") + +type passiveDeadlineSetter interface { + SetDeadline(deadline time.Time) error +} + +type passivePortCandidate struct { + exposedPort int + listenedPort int +} + +type passiveListenersManager struct { + logger *slog.Logger + mu sync.Mutex + listeners map[int]*sharedPassiveListener + closed bool +} + +func newPassiveListenersManager(logger *slog.Logger) *passiveListenersManager { + return &passiveListenersManager{ + logger: logger, + listeners: make(map[int]*sharedPassiveListener), + } +} + +func (m *passiveListenersManager) reserve( + remoteIP net.IP, + portRange PasvPortGetter, +) (int, net.Listener, passiveDeadlineSetter, error) { + for _, candidate := range getPassivePortCandidates(portRange) { + listener, err := m.getOrCreate(candidate.listenedPort) + if err != nil { + continue + } + + reservation, err := listener.reserve(remoteIP) + if err == nil { + return candidate.exposedPort, reservation, reservation, nil + } + if errors.Is(err, errPassiveListenerReservedForIP) { + continue + } + } + + return 0, nil, nil, ErrNoAvailableListeningPort +} + +func (m *passiveListenersManager) close() error { + m.mu.Lock() + if m.closed { + m.mu.Unlock() + + return nil + } + + m.closed = true + listeners := make([]*sharedPassiveListener, 0, len(m.listeners)) + for _, listener := range m.listeners { + listeners = append(listeners, listener) + } + m.mu.Unlock() + + var closeErr error + for _, listener := range listeners { + if err := listener.close(); err != nil && closeErr == nil { + closeErr = err + } + } + + return closeErr +} + +func (m *passiveListenersManager) getOrCreate(port int) (*sharedPassiveListener, error) { + m.mu.Lock() + if m.closed { + m.mu.Unlock() + + return nil, net.ErrClosed + } + if listener, ok := m.listeners[port]; ok { + m.mu.Unlock() + + return listener, nil + } + m.mu.Unlock() + + listener, err := newSharedPassiveListener(port, m.logger.With("passivePort", port)) + if err != nil { + return nil, err + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed { + _ = listener.close() + + return nil, net.ErrClosed + } + + if existing, ok := m.listeners[port]; ok { + _ = listener.close() + + return existing, nil + } + + m.listeners[port] = listener + + return listener, nil +} + +type sharedPassiveListener struct { + logger *slog.Logger + listener *net.TCPListener + mu sync.Mutex + reservations map[string]*passiveReservationListener + closed bool +} + +func newSharedPassiveListener(port int, logger *slog.Logger) (*sharedPassiveListener, error) { + laddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("0.0.0.0:%d", port)) + if err != nil { + return nil, newNetworkError(fmt.Sprintf("could not resolve port %d", port), err) + } + + tcpListener, err := net.ListenTCP("tcp", laddr) + if err != nil { + return nil, err + } + + result := &sharedPassiveListener{ + logger: logger, + listener: tcpListener, + reservations: make(map[string]*passiveReservationListener), + } + + go result.serve() + + return result, nil +} + +func (l *sharedPassiveListener) serve() { + for { + conn, err := l.listener.Accept() + if err != nil { + if isClosedListenerError(err) { + l.failAll(net.ErrClosed) + + return + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Temporary() { //nolint:staticcheck + l.logger.Warn("Temporary passive accept error", "err", err) + + continue + } + + l.failAll(err) + + return + } + + l.dispatch(conn) + } +} + +func (l *sharedPassiveListener) reserve(remoteIP net.IP) (*passiveReservationListener, error) { + key := remoteIP.String() + + l.mu.Lock() + defer l.mu.Unlock() + + if l.closed { + return nil, net.ErrClosed + } + if _, ok := l.reservations[key]; ok { + return nil, errPassiveListenerReservedForIP + } + + reservation := &passiveReservationListener{ + parent: l, + remoteIP: key, + connCh: make(chan net.Conn, 1), + closedCh: make(chan struct{}), + } + l.reservations[key] = reservation + + return reservation, nil +} + +func (l *sharedPassiveListener) dispatch(conn net.Conn) { + ipAddress, err := getIPFromRemoteAddr(conn.RemoteAddr()) + if err != nil { + l.logger.Warn("Could not parse passive data connection IP", "err", err) + _ = conn.Close() + + return + } + + key := ipAddress.String() + + l.mu.Lock() + reservation := l.reservations[key] + if reservation != nil { + delete(l.reservations, key) + } + l.mu.Unlock() + + if reservation == nil || !reservation.deliver(conn) { + _ = conn.Close() + } +} + +func (l *sharedPassiveListener) release(remoteIP string) { + l.mu.Lock() + defer l.mu.Unlock() + + if reservation, ok := l.reservations[remoteIP]; ok { + delete(l.reservations, remoteIP) + reservation.markReleased() + } +} + +func (l *sharedPassiveListener) failAll(err error) { + l.mu.Lock() + if l.closed { + l.mu.Unlock() + + return + } + + l.closed = true + reservations := make([]*passiveReservationListener, 0, len(l.reservations)) + for _, reservation := range l.reservations { + reservations = append(reservations, reservation) + } + l.reservations = nil + l.mu.Unlock() + + for _, reservation := range reservations { + reservation.fail(err) + } +} + +func (l *sharedPassiveListener) close() error { + err := l.listener.Close() + l.failAll(net.ErrClosed) + + return err +} + +type passiveReservationListener struct { + parent *sharedPassiveListener + remoteIP string + connCh chan net.Conn + closedCh chan struct{} + closeOnce sync.Once + stateMu sync.Mutex + deadline time.Time + released bool + failureErr error +} + +func (l *passiveReservationListener) Accept() (net.Conn, error) { + timeout := l.getDeadline() + var timerCh <-chan time.Time + var timer *time.Timer + + if !timeout.IsZero() { + wait := time.Until(timeout) + if wait <= 0 { + return nil, newPassiveAcceptTimeoutError() + } + + timer = time.NewTimer(wait) + timerCh = timer.C + defer timer.Stop() + } + + select { + case conn := <-l.connCh: + if conn == nil { + return nil, l.getFailure() + } + + return conn, nil + case <-l.closedCh: + return nil, l.getFailure() + case <-timerCh: + return nil, newPassiveAcceptTimeoutError() + } +} + +func (l *passiveReservationListener) Close() error { + l.closeOnce.Do(func() { + l.parent.release(l.remoteIP) + l.markReleased() + + select { + case conn := <-l.connCh: + if conn != nil { + _ = conn.Close() + } + default: + } + + close(l.closedCh) + }) + + return nil +} + +func (l *passiveReservationListener) Addr() net.Addr { + return l.parent.listener.Addr() +} + +func (l *passiveReservationListener) SetDeadline(deadline time.Time) error { + l.stateMu.Lock() + defer l.stateMu.Unlock() + + l.deadline = deadline + + return nil +} + +func (l *passiveReservationListener) deliver(conn net.Conn) bool { + select { + case <-l.closedCh: + return false + default: + } + + select { + case l.connCh <- conn: + l.markReleased() + + return true + case <-l.closedCh: + return false + } +} + +func (l *passiveReservationListener) fail(err error) { + l.stateMu.Lock() + if l.failureErr == nil { + l.failureErr = err + } + l.stateMu.Unlock() + + _ = l.Close() +} + +func (l *passiveReservationListener) markReleased() { + l.stateMu.Lock() + defer l.stateMu.Unlock() + + l.released = true +} + +func (l *passiveReservationListener) getDeadline() time.Time { + l.stateMu.Lock() + defer l.stateMu.Unlock() + + return l.deadline +} + +func (l *passiveReservationListener) getFailure() error { + l.stateMu.Lock() + defer l.stateMu.Unlock() + + if l.failureErr != nil { + return l.failureErr + } + + return net.ErrClosed +} + +type passiveAcceptTimeoutError struct{} + +func (passiveAcceptTimeoutError) Error() string { return "i/o timeout" } +func (passiveAcceptTimeoutError) Timeout() bool { return true } +func (passiveAcceptTimeoutError) Temporary() bool { return true } + +func newPassiveAcceptTimeoutError() error { + return &net.OpError{ + Op: "accept", + Net: "tcp", + Err: passiveAcceptTimeoutError{}, + } +} + +func isClosedListenerError(err error) bool { + if errors.Is(err, net.ErrClosed) { + return true + } + + errOp := &net.OpError{} + if errors.As(err, &errOp) && errOp.Err != nil { + return errOp.Err.Error() == "use of closed network connection" + } + + return false +} diff --git a/passive_multiplexer_test.go b/passive_multiplexer_test.go new file mode 100644 index 00000000..bbf4b9ae --- /dev/null +++ b/passive_multiplexer_test.go @@ -0,0 +1,261 @@ +package ftpserver + +import ( + "errors" + "io" + "log/slog" + "net" + "testing" + "time" + + "github.com/secsy/goftp" + "github.com/stretchr/testify/require" +) + +type trackingTestConn struct { + testNetConn + closed bool +} + +var ( + errTestClosedListener = errors.New("use of closed network connection") + errTestDifferent = errors.New("different error") + errTestPassiveFailure = errors.New("boom") +) + +func (c *trackingTestConn) Close() error { + c.closed = true + + return nil +} + +func getFreePassivePort(t *testing.T) int { + t.Helper() + + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + defer func() { + require.NoError(t, listener.Close()) + }() + + addr, ok := listener.Addr().(*net.TCPAddr) + require.True(t, ok) + + return addr.Port +} + +func TestPassiveListenersManagerMultiplexesByClientIP(t *testing.T) { + req := require.New(t) + port := getFreePassivePort(t) + manager := newPassiveListenersManager(slog.New(slog.NewTextHandler(io.Discard, nil))) //nolint:sloglint + defer func() { + req.NoError(manager.close()) + }() + + portRange := &PortRange{Start: port, End: port} + ip1 := net.ParseIP("127.0.0.2") + ip2 := net.ParseIP("127.0.0.3") + + exposedPort1, listener1, _, err := manager.reserve(ip1, portRange) + req.NoError(err) + req.Equal(port, exposedPort1) + + exposedPort2, listener2, _, err := manager.reserve(ip2, portRange) + req.NoError(err) + req.Equal(port, exposedPort2) + + req.Len(manager.listeners, 1) + sharedListener := manager.listeners[port] + + conn1 := &testNetConn{remoteAddr: &net.TCPAddr{IP: ip1, Port: 40001}} + sharedListener.dispatch(conn1) + accepted1, err := listener1.Accept() + req.NoError(err) + req.Same(conn1, accepted1) + + conn2 := &testNetConn{remoteAddr: &net.TCPAddr{IP: ip2, Port: 40002}} + sharedListener.dispatch(conn2) + accepted2, err := listener2.Accept() + req.NoError(err) + req.Same(conn2, accepted2) +} + +func TestPassiveListenersManagerRejectsSameIPForSamePort(t *testing.T) { + req := require.New(t) + port := getFreePassivePort(t) + manager := newPassiveListenersManager(slog.New(slog.NewTextHandler(io.Discard, nil))) //nolint:sloglint + defer func() { + req.NoError(manager.close()) + }() + + portRange := &PortRange{Start: port, End: port} + clientIP := net.ParseIP("127.0.0.2") + + exposedPort, listener, deadlineSetter, err := manager.reserve(clientIP, portRange) + req.Equal(port, exposedPort) + req.NoError(err) + req.NotNil(listener) + req.NotNil(deadlineSetter) + + exposedPort, listener, deadlineSetter, err = manager.reserve(clientIP, portRange) + req.ErrorIs(err, ErrNoAvailableListeningPort) + req.Zero(exposedPort) + req.Nil(listener) + req.Nil(deadlineSetter) +} + +func TestPassiveListenersManagerCloseReleasesReservation(t *testing.T) { + req := require.New(t) + port := getFreePassivePort(t) + manager := newPassiveListenersManager(slog.New(slog.NewTextHandler(io.Discard, nil))) //nolint:sloglint + defer func() { + req.NoError(manager.close()) + }() + + portRange := &PortRange{Start: port, End: port} + clientIP := net.ParseIP("127.0.0.2") + + exposedPort, listener, deadlineSetter, err := manager.reserve(clientIP, portRange) + req.NoError(err) + req.Equal(port, exposedPort) + req.NotNil(deadlineSetter) + req.NoError(listener.Close()) + + exposedPort, listener, deadlineSetter, err = manager.reserve(clientIP, portRange) + req.NoError(err) + req.Equal(port, exposedPort) + req.NotNil(listener) + req.NotNil(deadlineSetter) +} + +func TestPassivePortMultiplexingSameClientExhaustion(t *testing.T) { + req := require.New(t) + port := getFreePassivePort(t) + driver := &TestServerDriver{ + Settings: &Settings{ + ListenAddr: "127.0.0.1:0", + DefaultTransferType: TransferTypeBinary, + PassiveTransferPortRange: &PortRange{Start: port, End: port}, + PassiveTransferPortMultiplexing: true, + }, + } + server := NewTestServerWithTestDriver(t, driver) + + client, err := goftp.DialConfig(goftp.Config{ + User: authUser, + Password: authPass, + }, server.Addr()) + req.NoError(err) + defer func() { panicOnError(client.Close()) }() + + raw, err := client.OpenRawConn() + req.NoError(err) + defer func() { req.NoError(raw.Close()) }() + + returnCode, message, err := raw.SendCommand("PASV") + req.NoError(err) + req.Equal(StatusEnteringPASV, returnCode, message) + + returnCode, message, err = raw.SendCommand("PASV") + req.NoError(err) + req.Equal(StatusServiceNotAvailable, returnCode, message) + req.Contains(message, ErrNoAvailableListeningPort.Error()) +} + +func TestPassiveReservationListenerTimeoutAndHelpers(t *testing.T) { + req := require.New(t) + port := getFreePassivePort(t) + listener, err := newSharedPassiveListener(port, slog.New(slog.NewTextHandler(io.Discard, nil))) //nolint:sloglint + req.NoError(err) + defer func() { + req.NoError(listener.close()) + }() + + reservation, err := listener.reserve(net.ParseIP("127.0.0.2")) + req.NoError(err) + req.Equal(listener.listener.Addr(), reservation.Addr()) + + req.NoError(reservation.SetDeadline(time.Now().Add(-time.Second))) + + _, err = reservation.Accept() + req.Error(err) + + var opErr *net.OpError + req.ErrorAs(err, &opErr) + + var timeoutErr passiveAcceptTimeoutError + req.ErrorAs(err, &timeoutErr) + req.Equal("i/o timeout", timeoutErr.Error()) + req.True(timeoutErr.Timeout()) + + req.True(isClosedListenerError(net.ErrClosed)) + req.True(isClosedListenerError(&net.OpError{Err: errTestClosedListener})) + req.False(isClosedListenerError(errTestDifferent)) +} + +func TestPassiveReservationListenerCloseAndFailures(t *testing.T) { + req := require.New(t) + port := getFreePassivePort(t) + listener, err := newSharedPassiveListener(port, slog.New(slog.NewTextHandler(io.Discard, nil))) //nolint:sloglint + req.NoError(err) + defer func() { + req.NoError(listener.close()) + }() + + reservation, err := listener.reserve(net.ParseIP("127.0.0.2")) + req.NoError(err) + + expectedErr := errTestPassiveFailure + reservation.stateMu.Lock() + reservation.failureErr = expectedErr + reservation.stateMu.Unlock() + reservation.connCh <- nil + + _, err = reservation.Accept() + req.ErrorIs(err, expectedErr) + + reservation2, err := listener.reserve(net.ParseIP("127.0.0.3")) + req.NoError(err) + + conn := &trackingTestConn{} + reservation2.connCh <- conn + req.NoError(reservation2.Close()) + req.True(conn.closed) + req.False(reservation2.deliver(&trackingTestConn{})) + + _, err = reservation2.Accept() + req.ErrorIs(err, net.ErrClosed) +} + +func TestSharedPassiveListenerDispatchRejectionsAndClosedManager(t *testing.T) { + req := require.New(t) + port := getFreePassivePort(t) + manager := newPassiveListenersManager(slog.New(slog.NewTextHandler(io.Discard, nil))) //nolint:sloglint + req.NoError(manager.close()) + req.NoError(manager.close()) + + _, err := manager.getOrCreate(port) + req.ErrorIs(err, net.ErrClosed) + + listener, err := newSharedPassiveListener(port, slog.New(slog.NewTextHandler(io.Discard, nil))) //nolint:sloglint + req.NoError(err) + + unknownConn := &trackingTestConn{ + testNetConn: testNetConn{remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.9"), Port: 40003}}, + } + listener.dispatch(unknownConn) + req.True(unknownConn.closed) + + invalidConn := &trackingTestConn{ + testNetConn: testNetConn{remoteAddr: &net.UnixAddr{Name: "sock", Net: "unix"}}, + } + listener.dispatch(invalidConn) + req.True(invalidConn.closed) + + req.NoError(listener.close()) + _, err = listener.reserve(net.ParseIP("127.0.0.4")) + req.ErrorIs(err, net.ErrClosed) + + _, err = newSharedPassiveListener(-1, slog.New(slog.NewTextHandler(io.Discard, nil))) //nolint:sloglint + req.Error(err) +} diff --git a/server.go b/server.go index 942c6fa3..d3f9b0bd 100644 --- a/server.go +++ b/server.go @@ -126,11 +126,12 @@ var specialAttentionCommands = []string{"ABOR", "STAT", "QUIT"} //nolint:gocheck // FtpServer is where everything is stored // We want to keep it as simple as possible type FtpServer struct { - Logger *slog.Logger // Structured logger (log/slog) - settings *Settings // General settings - listener net.Listener // listener used to receive files - clientCounter uint32 // Clients counter - driver MainDriver // Driver to handle the client authentication and the file access driver selection + Logger *slog.Logger // Structured logger (log/slog) + settings *Settings // General settings + listener net.Listener // listener used to receive files + passiveListeners *passiveListenersManager + clientCounter uint32 // Clients counter + driver MainDriver // Driver to handle the client authentication and the file access driver selection } func (server *FtpServer) loadSettings() error { @@ -165,6 +166,7 @@ func (server *FtpServer) loadSettings() error { } server.settings = settings + server.passiveListeners = newPassiveListenersManager(server.Logger) return nil } @@ -351,6 +353,12 @@ func (server *FtpServer) Stop() error { return newNetworkError("couln't close listener", err) } + if server.passiveListeners != nil { + if err := server.passiveListeners.close(); err != nil && !errors.Is(err, net.ErrClosed) { + server.Logger.Warn("Could not close passive listeners", "err", err) + } + } + return nil } diff --git a/transfer_pasv.go b/transfer_pasv.go index 3044d5a3..6038164c 100644 --- a/transfer_pasv.go +++ b/transfer_pasv.go @@ -29,13 +29,13 @@ var _ transferHandler = (*passiveTransferHandler)(nil) // Passive connection type passiveTransferHandler struct { - listener net.Listener // TCP or SSL Listener - tcpListener *net.TCPListener // TCP Listener (only keeping it to define a deadline during the accept) - Port int // TCP Port we are listening on - connection net.Conn // TCP Connection established - settings *Settings // Settings - info string // transfer info - logger *slog.Logger // Logger + listener net.Listener // TCP or SSL Listener + deadlineSetter passiveDeadlineSetter // Listener used to set accept deadlines + Port int // TCP Port we are listening on + connection net.Conn // TCP Connection established + settings *Settings // Settings + info string // transfer info + logger *slog.Logger // Logger // data connection requirement checker checkDataConn func(dataConnIP net.IP, channelType DataChannel) error } @@ -85,38 +85,63 @@ const ( portSearchMaxAttempts = 1000 ) -func (c *clientHandler) findListenerWithinPortRange(portMapping PasvPortGetter) (int, *net.TCPListener, error) { +func getPassivePortCandidates(portMapping PasvPortGetter) []passivePortCandidate { nbAttempts := portMapping.NumberAttempts() - // Making sure we trying a reasonable amount of ports before giving up if nbAttempts < portSearchMinAttempts { nbAttempts = portSearchMinAttempts } else if nbAttempts > portSearchMaxAttempts { nbAttempts = portSearchMaxAttempts } - for i := 0; i < nbAttempts; i++ { + maxFetches := nbAttempts * 4 + if maxFetches < nbAttempts { + maxFetches = nbAttempts + } + + result := make([]passivePortCandidate, 0, nbAttempts) + tried := make(map[int]struct{}, nbAttempts) + + for i := 0; len(result) < nbAttempts && i < maxFetches; i++ { exposedPort, listenedPort, ok := portMapping.FetchNext() if !ok { break } - laddr, errResolve := net.ResolveTCPAddr("tcp", fmt.Sprintf("0.0.0.0:%d", listenedPort)) + if _, ok := tried[listenedPort]; ok { + continue + } + + tried[listenedPort] = struct{}{} + result = append(result, passivePortCandidate{ + exposedPort: exposedPort, + listenedPort: listenedPort, + }) + } + + return result +} + +func (c *clientHandler) findListenerWithinPortRange(portMapping PasvPortGetter) (int, *net.TCPListener, error) { + candidates := getPassivePortCandidates(portMapping) + + for _, candidate := range candidates { + laddr, errResolve := net.ResolveTCPAddr("tcp", fmt.Sprintf("0.0.0.0:%d", candidate.listenedPort)) if errResolve != nil { - c.logger.Error("Problem resolving local port", "err", errResolve, "port", listenedPort) + c.logger.Error("Problem resolving local port", "err", errResolve, "port", candidate.listenedPort) - return 0, nil, newNetworkError(fmt.Sprintf("could not resolve port %d", listenedPort), errResolve) + return 0, nil, newNetworkError(fmt.Sprintf("could not resolve port %d", candidate.listenedPort), errResolve) } tcpListener, errListen := net.ListenTCP("tcp", laddr) if errListen == nil { - return exposedPort, tcpListener, nil + return candidate.exposedPort, tcpListener, nil } } c.logger.Warn( "Could not find any free port", - "nbAttempts", nbAttempts, + "nbAttempts", len(candidates), ) return 0, nil, ErrNoAvailableListeningPort @@ -124,17 +149,7 @@ func (c *clientHandler) findListenerWithinPortRange(portMapping PasvPortGetter) func (c *clientHandler) handlePASV(_ string) error { command := c.GetLastCommand() - addr, _ := net.ResolveTCPAddr("tcp", ":0") - var tcpListener *net.TCPListener - var err error - portMapping := c.server.settings.PassiveTransferPortRange - exposedPort := 0 - - if portMapping != nil { - exposedPort, tcpListener, err = c.findListenerWithinPortRange(portMapping) - } else { - tcpListener, err = net.ListenTCP("tcp", addr) - } + exposedPort, listener, deadlineSetter, err := c.getPassiveListener() if err != nil { c.logger.Error("Could not listen for passive connection", "err", err) @@ -142,9 +157,6 @@ func (c *clientHandler) handlePASV(_ string) error { return nil } - // The listener will either be plain TCP or TLS - var listener net.Listener - listener = tcpListener if wrapper, ok := c.server.driver.(MainDriverExtensionPassiveWrapper); ok { listener, err = wrapper.WrapPassiveListener(listener) @@ -167,17 +179,17 @@ func (c *clientHandler) handlePASV(_ string) error { } if exposedPort == 0 { - if tcpAddr, ok := tcpListener.Addr().(*net.TCPAddr); ok { + if tcpAddr, ok := listener.Addr().(*net.TCPAddr); ok { exposedPort = tcpAddr.Port } } transferHandler := &passiveTransferHandler{ - tcpListener: tcpListener, - listener: listener, - Port: exposedPort, - settings: c.server.settings, - logger: c.logger, - checkDataConn: c.checkDataConnectionRequirement, + listener: listener, + deadlineSetter: deadlineSetter, + Port: exposedPort, + settings: c.server.settings, + logger: c.logger, + checkDataConn: c.checkDataConnectionRequirement, } // We should rewrite this part @@ -201,6 +213,36 @@ func (c *clientHandler) handlePASV(_ string) error { return nil } +func (c *clientHandler) getPassiveListener() (int, net.Listener, passiveDeadlineSetter, error) { + portMapping := c.server.settings.PassiveTransferPortRange + if c.server.settings.PassiveTransferPortMultiplexing && portMapping != nil { + controlConnIP, err := getIPFromRemoteAddr(c.RemoteAddr()) + if err != nil { + return 0, nil, nil, err + } + + return c.server.passiveListeners.reserve(controlConnIP, portMapping) + } + + addr, _ := net.ResolveTCPAddr("tcp", ":0") + var ( + exposedPort int + tcpListener *net.TCPListener + err error + ) + + if portMapping != nil { + exposedPort, tcpListener, err = c.findListenerWithinPortRange(portMapping) + } else { + tcpListener, err = net.ListenTCP("tcp", addr) + } + if err != nil { + return 0, nil, nil, err + } + + return exposedPort, tcpListener, tcpListener, nil +} + func (c *clientHandler) handlePassivePASV(transferHandler *passiveTransferHandler) bool { portByte1 := transferHandler.Port / 256 portByte2 := transferHandler.Port - (portByte1 * 256) @@ -225,30 +267,35 @@ func (c *clientHandler) handlePassivePASV(transferHandler *passiveTransferHandle } func (p *passiveTransferHandler) ConnectionWait(wait time.Duration) (net.Conn, error) { - if p.connection == nil { - var err error - if err = p.tcpListener.SetDeadline(time.Now().Add(wait)); err != nil { - return nil, fmt.Errorf("failed to set deadline: %w", err) - } + if p.connection != nil { + return p.connection, nil + } - p.connection, err = p.listener.Accept() - if err != nil { - return nil, fmt.Errorf("failed to accept passive transfer connection: %w", err) - } + var err error + if p.deadlineSetter != nil { + err = p.deadlineSetter.SetDeadline(time.Now().Add(wait)) + } + if err != nil { + return nil, fmt.Errorf("failed to set deadline: %w", err) + } - ipAddress, err := getIPFromRemoteAddr(p.connection.RemoteAddr()) - if err != nil { - p.logger.Warn("Could get remote passive IP address", "err", err) + p.connection, err = p.listener.Accept() + if err != nil { + return nil, fmt.Errorf("failed to accept passive transfer connection: %w", err) + } - return nil, err - } + ipAddress, err := getIPFromRemoteAddr(p.connection.RemoteAddr()) + if err != nil { + p.logger.Warn("Could get remote passive IP address", "err", err) - if err := p.checkDataConn(ipAddress, DataChannelPassive); err != nil { - // we don't want to expose the full error to the client, we just log it - p.logger.Warn("Could not validate passive data connection requirement", "err", err) + return nil, err + } - return nil, &ipValidationError{error: "data connection security requirements not met"} - } + if err := p.checkDataConn(ipAddress, DataChannelPassive); err != nil { + // we don't want to expose the full error to the client, we just log it + p.logger.Warn("Could not validate passive data connection requirement", "err", err) + + return nil, &ipValidationError{error: "data connection security requirements not met"} } return p.connection, nil @@ -270,8 +317,8 @@ func (p *passiveTransferHandler) Open() (net.Conn, error) { // Closing only the client connection is not supported at that time func (p *passiveTransferHandler) Close() error { - if p.tcpListener != nil { - if err := p.tcpListener.Close(); err != nil { + if p.listener != nil { + if err := p.listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { p.logger.Warn("Problem closing passive listener", "err", err) } } diff --git a/transfer_test.go b/transfer_test.go index 8557cb3a..edfb9d0f 100644 --- a/transfer_test.go +++ b/transfer_test.go @@ -1117,11 +1117,11 @@ func TestPASVConnectionWait(t *testing.T) { remoteAddr: &net.TCPAddr{IP: nil, Port: 21}, // invalid IP }, }, - tcpListener: tcpListener, - Port: tcpListener.Addr().(*net.TCPAddr).Port, - settings: cltHandler.server.settings, - logger: slog.New(slog.NewTextHandler(io.Discard, nil)), //nolint:sloglint // DiscardHandler requires Go 1.23+ - checkDataConn: cltHandler.checkDataConnectionRequirement, + deadlineSetter: tcpListener, + Port: tcpListener.Addr().(*net.TCPAddr).Port, + settings: cltHandler.server.settings, + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), //nolint:sloglint // DiscardHandler requires Go 1.23+ + checkDataConn: cltHandler.checkDataConnectionRequirement, } defer func() {