Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ const (
DefaultConnBufferSize = 32 * 1024
)

// normalizeConnBufferSize keeps 0 as "use the default", so every caller in the
// packet/TLS stack derives buffer sizes from the same effective value.
func normalizeConnBufferSize(bufferSize int) int {
if bufferSize == 0 {
return DefaultConnBufferSize
}
return bufferSize
}

type rwStatus int

const (
Expand Down Expand Up @@ -116,9 +125,7 @@ func getPooledWriter(conn net.Conn, size int) *bufio.Writer {
}

func newBasicReadWriter(conn net.Conn, bufferSize int) *basicReadWriter {
if bufferSize == 0 {
bufferSize = DefaultConnBufferSize
}
bufferSize = normalizeConnBufferSize(bufferSize)
return &basicReadWriter{
Conn: conn,
ReadWriter: bufio.NewReadWriter(getPooledReader(conn, bufferSize), getPooledWriter(conn, bufferSize)),
Expand Down Expand Up @@ -274,7 +281,11 @@ type PacketIO interface {

// PacketIO is a helper to read and write sql and proxy protocol.
type packetIO struct {
lastKeepAlive config.KeepAlive
lastKeepAlive config.KeepAlive
// TLS allocates another buffered layer after the handshake. Keep the
// normalized base connection buffer size here so the TLS layer can scale
// from the caller's setting instead of falling back to unrelated constants.
connBufferSize int
rawConn net.Conn
readWriter packetReadWriter
limitReader io.LimitedReader // reuse memory to reduce allocation
Expand All @@ -288,10 +299,12 @@ type packetIO struct {
}

func NewPacketIO(conn net.Conn, lg *zap.Logger, bufferSize int, opts ...PacketIOption) *packetIO {
bufferSize = normalizeConnBufferSize(bufferSize)
p := &packetIO{
rawConn: conn,
logger: lg,
readWriter: newBasicReadWriter(conn, bufferSize),
connBufferSize: bufferSize,
rawConn: conn,
logger: lg,
readWriter: newBasicReadWriter(conn, bufferSize),
}
p.ApplyOpts(opts...)
return p
Expand Down
48 changes: 43 additions & 5 deletions pkg/proxy/net/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,39 @@ import (
"github.com/pingcap/tiproxy/pkg/util/bufio"
)

const (
// The TLS layer keeps its own post-handshake buffers, so follow the caller's
// connBufferSize proportionally but cap the sizes to avoid recreating the same
// memory pressure as the base connection buffers.
minTLSReadBufferSize = 1 * 1024
maxTLSReadBufferSize = 4 * 1024
minTLSWriteBufferSize = 1 * 1024
maxTLSWriteBufferSize = 16 * 1024
tlsReadBufferDivisor = 4
tlsWriteBufferDivisor = 2
)

func clampBufferSize(size, minSize, maxSize int) int {
if size < minSize {
return minSize
}
if size > maxSize {
return maxSize
}
return size
}

// TLS reads mostly serve packet-header peeks after the handshake, so a smaller
// reader is usually enough. TLS writes are more sensitive to fragmentation, so
// keep the writer relatively larger. Clamp both ends so tiny custom connection
// buffers do not make TLS unusably small, and huge connection buffers do not
// reintroduce excessive per-connection TLS memory.
func tlsBufferSizes(connBufferSize int) (readSize int, writeSize int) {
connBufferSize = normalizeConnBufferSize(connBufferSize)
return clampBufferSize(connBufferSize/tlsReadBufferDivisor, minTLSReadBufferSize, maxTLSReadBufferSize),
clampBufferSize(connBufferSize/tlsWriteBufferDivisor, minTLSWriteBufferSize, maxTLSWriteBufferSize)
}

// tlsHandshakeConn is only used as the underlying connection in tls.Conn.
// TLS handshake must read from the buffered reader because the handshake data may be already buffered in the reader.
// TLS handshake can not use the buffered writer directly because it assumes the data will be flushed automatically,
Expand All @@ -30,7 +63,7 @@ func (p *packetIO) ServerTLSHandshake(tlsConfig *tls.Config) (tls.ConnectionStat
if err := tlsConn.Handshake(); err != nil {
return tls.ConnectionState{}, p.wrapErr(errors.Wrap(errors.WithStack(err), ErrHandshakeTLS))
}
p.readWriter = newTLSReadWriter(p.readWriter, tlsConn)
p.readWriter = newTLSReadWriter(p.readWriter, tlsConn, p.connBufferSize)
return tlsConn.ConnectionState(), nil
}

Expand All @@ -41,7 +74,7 @@ func (p *packetIO) ClientTLSHandshake(tlsConfig *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return p.wrapErr(errors.Wrap(errors.WithStack(err), ErrHandshakeTLS))
}
p.readWriter = newTLSReadWriter(p.readWriter, tlsConn)
p.readWriter = newTLSReadWriter(p.readWriter, tlsConn, p.connBufferSize)
return nil
}

Expand All @@ -57,10 +90,15 @@ type tlsReadWriter struct {
conn *tls.Conn
}

func newTLSReadWriter(rw packetReadWriter, tlsConn *tls.Conn) *tlsReadWriter {
func newTLSReadWriter(rw packetReadWriter, tlsConn *tls.Conn, connBufferSize int) *tlsReadWriter {
// Can not modify rw and reuse it because tlsConn is using rw internally.
// We must create another buffer.
buf := bufio.NewReadWriter(bufio.NewReaderSize(tlsConn, DefaultConnBufferSize), bufio.NewWriterSize(tlsConn, DefaultConnBufferSize))
// We must create another buffer. Size it from the base connection buffer so
// custom connBufferSize values keep a consistent memory profile after TLS.
readBufferSize, writeBufferSize := tlsBufferSizes(connBufferSize)
buf := bufio.NewReadWriter(
bufio.NewReaderSize(tlsConn, readBufferSize),
bufio.NewWriterSize(tlsConn, writeBufferSize),
)
return &tlsReadWriter{
packetReadWriter: rw,
buf: buf,
Expand Down
66 changes: 64 additions & 2 deletions pkg/proxy/net/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package net

import (
"crypto/tls"
"fmt"
"io"
"net"
"testing"
Expand All @@ -25,7 +26,10 @@ func TestTLSReadWrite(t *testing.T) {
conn := &tlsInternalConn{brw}
tlsConn := tls.Client(conn, ctls)
require.NoError(t, tlsConn.Handshake())
trw := newTLSReadWriter(brw, tlsConn)
trw := newTLSReadWriter(brw, tlsConn, DefaultConnBufferSize)
readBufferSize, writeBufferSize := tlsBufferSizes(DefaultConnBufferSize)
require.Equal(t, readBufferSize, trw.buf.Reader.Size())
require.Equal(t, writeBufferSize, trw.buf.Writer.Size())
// check tls connection state
require.True(t, trw.TLSConnectionState().HandshakeComplete)
// check out bytes
Expand All @@ -49,7 +53,10 @@ func TestTLSReadWrite(t *testing.T) {
conn := &tlsInternalConn{brw}
tlsConn := tls.Server(conn, stls)
require.NoError(t, tlsConn.Handshake())
trw := newTLSReadWriter(brw, tlsConn)
trw := newTLSReadWriter(brw, tlsConn, DefaultConnBufferSize)
readBufferSize, writeBufferSize := tlsBufferSizes(DefaultConnBufferSize)
require.Equal(t, readBufferSize, trw.buf.Reader.Size())
require.Equal(t, writeBufferSize, trw.buf.Writer.Size())
// check tls connection state
require.True(t, trw.TLSConnectionState().HandshakeComplete)
// check in bytes
Expand Down Expand Up @@ -82,3 +89,58 @@ func TestTLSReadWrite(t *testing.T) {
require.Equal(t, message[1:], data)
}, 1)
}

func TestTLSBufferSizes(t *testing.T) {
cases := []struct {
connBufferSize int
readBufferSize int
writeBufferSize int
}{
{connBufferSize: 0, readBufferSize: 4 * 1024, writeBufferSize: 16 * 1024},
{connBufferSize: 1 * 1024, readBufferSize: 1 * 1024, writeBufferSize: 1 * 1024},
{connBufferSize: 4 * 1024, readBufferSize: 1 * 1024, writeBufferSize: 2 * 1024},
{connBufferSize: DefaultConnBufferSize, readBufferSize: 4 * 1024, writeBufferSize: 16 * 1024},
{connBufferSize: DefaultConnBufferSize * 2, readBufferSize: 4 * 1024, writeBufferSize: 16 * 1024},
}

for _, tc := range cases {
t.Run(fmt.Sprintf("conn-%d", tc.connBufferSize), func(t *testing.T) {
readBufferSize, writeBufferSize := tlsBufferSizes(tc.connBufferSize)
require.Equal(t, tc.readBufferSize, readBufferSize)
require.Equal(t, tc.writeBufferSize, writeBufferSize)
})
}
}

func TestPacketIOTLSBufferSizes(t *testing.T) {
stls, ctls, err := security.CreateTLSConfigForTest()
require.NoError(t, err)

for _, connBufferSize := range []int{1 * 1024, 4 * 1024, DefaultConnBufferSize, DefaultConnBufferSize * 2} {
t.Run(fmt.Sprintf("conn-%d", connBufferSize), func(t *testing.T) {
readBufferSize, writeBufferSize := tlsBufferSizes(connBufferSize)
testkit.TestTCPConn(t,
func(t *testing.T, c net.Conn) {
cli := NewPacketIO(c, nil, connBufferSize)
require.NoError(t, cli.ClientTLSHandshake(ctls))
trw, ok := cli.readWriter.(*tlsReadWriter)
require.True(t, ok)
require.Equal(t, readBufferSize, trw.buf.Reader.Size())
require.Equal(t, writeBufferSize, trw.buf.Writer.Size())
require.NoError(t, cli.Close())
},
func(t *testing.T, c net.Conn) {
srv := NewPacketIO(c, nil, connBufferSize)
_, err := srv.ServerTLSHandshake(stls)
require.NoError(t, err)
trw, ok := srv.readWriter.(*tlsReadWriter)
require.True(t, ok)
require.Equal(t, readBufferSize, trw.buf.Reader.Size())
require.Equal(t, writeBufferSize, trw.buf.Writer.Size())
require.NoError(t, srv.Close())
},
1,
)
})
}
}