diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index dcb6773f..b9692529 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -52,6 +52,15 @@ const ( DefaultConnBufferSize = 32 * 1024 ) +// normalizeConnBufferSize keeps 0 as "use the default", so every caller in the +// packet/TLS stack derives buffer sizes from the same effective value. +func normalizeConnBufferSize(bufferSize int) int { + if bufferSize == 0 { + return DefaultConnBufferSize + } + return bufferSize +} + type rwStatus int const ( @@ -116,9 +125,7 @@ func getPooledWriter(conn net.Conn, size int) *bufio.Writer { } func newBasicReadWriter(conn net.Conn, bufferSize int) *basicReadWriter { - if bufferSize == 0 { - bufferSize = DefaultConnBufferSize - } + bufferSize = normalizeConnBufferSize(bufferSize) return &basicReadWriter{ Conn: conn, ReadWriter: bufio.NewReadWriter(getPooledReader(conn, bufferSize), getPooledWriter(conn, bufferSize)), @@ -274,7 +281,11 @@ type PacketIO interface { // PacketIO is a helper to read and write sql and proxy protocol. type packetIO struct { - lastKeepAlive config.KeepAlive + lastKeepAlive config.KeepAlive + // TLS allocates another buffered layer after the handshake. Keep the + // normalized base connection buffer size here so the TLS layer can scale + // from the caller's setting instead of falling back to unrelated constants. + connBufferSize int rawConn net.Conn readWriter packetReadWriter limitReader io.LimitedReader // reuse memory to reduce allocation @@ -288,10 +299,12 @@ type packetIO struct { } func NewPacketIO(conn net.Conn, lg *zap.Logger, bufferSize int, opts ...PacketIOption) *packetIO { + bufferSize = normalizeConnBufferSize(bufferSize) p := &packetIO{ - rawConn: conn, - logger: lg, - readWriter: newBasicReadWriter(conn, bufferSize), + connBufferSize: bufferSize, + rawConn: conn, + logger: lg, + readWriter: newBasicReadWriter(conn, bufferSize), } p.ApplyOpts(opts...) return p diff --git a/pkg/proxy/net/tls.go b/pkg/proxy/net/tls.go index ff79d570..8a1bc248 100644 --- a/pkg/proxy/net/tls.go +++ b/pkg/proxy/net/tls.go @@ -11,6 +11,39 @@ import ( "github.com/pingcap/tiproxy/pkg/util/bufio" ) +const ( + // The TLS layer keeps its own post-handshake buffers, so follow the caller's + // connBufferSize proportionally but cap the sizes to avoid recreating the same + // memory pressure as the base connection buffers. + minTLSReadBufferSize = 1 * 1024 + maxTLSReadBufferSize = 4 * 1024 + minTLSWriteBufferSize = 1 * 1024 + maxTLSWriteBufferSize = 16 * 1024 + tlsReadBufferDivisor = 4 + tlsWriteBufferDivisor = 2 +) + +func clampBufferSize(size, minSize, maxSize int) int { + if size < minSize { + return minSize + } + if size > maxSize { + return maxSize + } + return size +} + +// TLS reads mostly serve packet-header peeks after the handshake, so a smaller +// reader is usually enough. TLS writes are more sensitive to fragmentation, so +// keep the writer relatively larger. Clamp both ends so tiny custom connection +// buffers do not make TLS unusably small, and huge connection buffers do not +// reintroduce excessive per-connection TLS memory. +func tlsBufferSizes(connBufferSize int) (readSize int, writeSize int) { + connBufferSize = normalizeConnBufferSize(connBufferSize) + return clampBufferSize(connBufferSize/tlsReadBufferDivisor, minTLSReadBufferSize, maxTLSReadBufferSize), + clampBufferSize(connBufferSize/tlsWriteBufferDivisor, minTLSWriteBufferSize, maxTLSWriteBufferSize) +} + // tlsHandshakeConn is only used as the underlying connection in tls.Conn. // TLS handshake must read from the buffered reader because the handshake data may be already buffered in the reader. // TLS handshake can not use the buffered writer directly because it assumes the data will be flushed automatically, @@ -30,7 +63,7 @@ func (p *packetIO) ServerTLSHandshake(tlsConfig *tls.Config) (tls.ConnectionStat if err := tlsConn.Handshake(); err != nil { return tls.ConnectionState{}, p.wrapErr(errors.Wrap(errors.WithStack(err), ErrHandshakeTLS)) } - p.readWriter = newTLSReadWriter(p.readWriter, tlsConn) + p.readWriter = newTLSReadWriter(p.readWriter, tlsConn, p.connBufferSize) return tlsConn.ConnectionState(), nil } @@ -41,7 +74,7 @@ func (p *packetIO) ClientTLSHandshake(tlsConfig *tls.Config) error { if err := tlsConn.Handshake(); err != nil { return p.wrapErr(errors.Wrap(errors.WithStack(err), ErrHandshakeTLS)) } - p.readWriter = newTLSReadWriter(p.readWriter, tlsConn) + p.readWriter = newTLSReadWriter(p.readWriter, tlsConn, p.connBufferSize) return nil } @@ -57,10 +90,15 @@ type tlsReadWriter struct { conn *tls.Conn } -func newTLSReadWriter(rw packetReadWriter, tlsConn *tls.Conn) *tlsReadWriter { +func newTLSReadWriter(rw packetReadWriter, tlsConn *tls.Conn, connBufferSize int) *tlsReadWriter { // Can not modify rw and reuse it because tlsConn is using rw internally. - // We must create another buffer. - buf := bufio.NewReadWriter(bufio.NewReaderSize(tlsConn, DefaultConnBufferSize), bufio.NewWriterSize(tlsConn, DefaultConnBufferSize)) + // We must create another buffer. Size it from the base connection buffer so + // custom connBufferSize values keep a consistent memory profile after TLS. + readBufferSize, writeBufferSize := tlsBufferSizes(connBufferSize) + buf := bufio.NewReadWriter( + bufio.NewReaderSize(tlsConn, readBufferSize), + bufio.NewWriterSize(tlsConn, writeBufferSize), + ) return &tlsReadWriter{ packetReadWriter: rw, buf: buf, diff --git a/pkg/proxy/net/tls_test.go b/pkg/proxy/net/tls_test.go index 49dd6151..7c3b7ade 100644 --- a/pkg/proxy/net/tls_test.go +++ b/pkg/proxy/net/tls_test.go @@ -5,6 +5,7 @@ package net import ( "crypto/tls" + "fmt" "io" "net" "testing" @@ -25,7 +26,10 @@ func TestTLSReadWrite(t *testing.T) { conn := &tlsInternalConn{brw} tlsConn := tls.Client(conn, ctls) require.NoError(t, tlsConn.Handshake()) - trw := newTLSReadWriter(brw, tlsConn) + trw := newTLSReadWriter(brw, tlsConn, DefaultConnBufferSize) + readBufferSize, writeBufferSize := tlsBufferSizes(DefaultConnBufferSize) + require.Equal(t, readBufferSize, trw.buf.Reader.Size()) + require.Equal(t, writeBufferSize, trw.buf.Writer.Size()) // check tls connection state require.True(t, trw.TLSConnectionState().HandshakeComplete) // check out bytes @@ -49,7 +53,10 @@ func TestTLSReadWrite(t *testing.T) { conn := &tlsInternalConn{brw} tlsConn := tls.Server(conn, stls) require.NoError(t, tlsConn.Handshake()) - trw := newTLSReadWriter(brw, tlsConn) + trw := newTLSReadWriter(brw, tlsConn, DefaultConnBufferSize) + readBufferSize, writeBufferSize := tlsBufferSizes(DefaultConnBufferSize) + require.Equal(t, readBufferSize, trw.buf.Reader.Size()) + require.Equal(t, writeBufferSize, trw.buf.Writer.Size()) // check tls connection state require.True(t, trw.TLSConnectionState().HandshakeComplete) // check in bytes @@ -82,3 +89,58 @@ func TestTLSReadWrite(t *testing.T) { require.Equal(t, message[1:], data) }, 1) } + +func TestTLSBufferSizes(t *testing.T) { + cases := []struct { + connBufferSize int + readBufferSize int + writeBufferSize int + }{ + {connBufferSize: 0, readBufferSize: 4 * 1024, writeBufferSize: 16 * 1024}, + {connBufferSize: 1 * 1024, readBufferSize: 1 * 1024, writeBufferSize: 1 * 1024}, + {connBufferSize: 4 * 1024, readBufferSize: 1 * 1024, writeBufferSize: 2 * 1024}, + {connBufferSize: DefaultConnBufferSize, readBufferSize: 4 * 1024, writeBufferSize: 16 * 1024}, + {connBufferSize: DefaultConnBufferSize * 2, readBufferSize: 4 * 1024, writeBufferSize: 16 * 1024}, + } + + for _, tc := range cases { + t.Run(fmt.Sprintf("conn-%d", tc.connBufferSize), func(t *testing.T) { + readBufferSize, writeBufferSize := tlsBufferSizes(tc.connBufferSize) + require.Equal(t, tc.readBufferSize, readBufferSize) + require.Equal(t, tc.writeBufferSize, writeBufferSize) + }) + } +} + +func TestPacketIOTLSBufferSizes(t *testing.T) { + stls, ctls, err := security.CreateTLSConfigForTest() + require.NoError(t, err) + + for _, connBufferSize := range []int{1 * 1024, 4 * 1024, DefaultConnBufferSize, DefaultConnBufferSize * 2} { + t.Run(fmt.Sprintf("conn-%d", connBufferSize), func(t *testing.T) { + readBufferSize, writeBufferSize := tlsBufferSizes(connBufferSize) + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + cli := NewPacketIO(c, nil, connBufferSize) + require.NoError(t, cli.ClientTLSHandshake(ctls)) + trw, ok := cli.readWriter.(*tlsReadWriter) + require.True(t, ok) + require.Equal(t, readBufferSize, trw.buf.Reader.Size()) + require.Equal(t, writeBufferSize, trw.buf.Writer.Size()) + require.NoError(t, cli.Close()) + }, + func(t *testing.T, c net.Conn) { + srv := NewPacketIO(c, nil, connBufferSize) + _, err := srv.ServerTLSHandshake(stls) + require.NoError(t, err) + trw, ok := srv.readWriter.(*tlsReadWriter) + require.True(t, ok) + require.Equal(t, readBufferSize, trw.buf.Reader.Size()) + require.Equal(t, writeBufferSize, trw.buf.Writer.Size()) + require.NoError(t, srv.Close()) + }, + 1, + ) + }) + } +}