From dcab3db5f1b9cc43729e4ad83d879c55e6f1318f Mon Sep 17 00:00:00 2001 From: djshow832 Date: Thu, 2 Apr 2026 22:11:20 +0800 Subject: [PATCH] reject conns --- lib/config/proxy.go | 14 ++++- lib/config/proxy_test.go | 31 +++++++++-- pkg/manager/memory/memory.go | 92 ++++++++++++++++++++++++++----- pkg/manager/memory/memory_test.go | 48 ++++++++++++++++ pkg/metrics/metrics.go | 1 + pkg/metrics/server.go | 8 +++ pkg/proxy/proxy.go | 34 +++++++++++- pkg/proxy/proxy_test.go | 85 ++++++++++++++++++++++++---- pkg/server/server.go | 2 +- 9 files changed, 280 insertions(+), 35 deletions(-) diff --git a/lib/config/proxy.go b/lib/config/proxy.go index cb0cc3209..fc4349065 100644 --- a/lib/config/proxy.go +++ b/lib/config/proxy.go @@ -50,9 +50,10 @@ type KeepAlive struct { } type ProxyServerOnline struct { - MaxConnections uint64 `yaml:"max-connections,omitempty" toml:"max-connections,omitempty" json:"max-connections,omitempty" reloadable:"true"` - ConnBufferSize int `yaml:"conn-buffer-size,omitempty" toml:"conn-buffer-size,omitempty" json:"conn-buffer-size,omitempty" reloadable:"true"` - FrontendKeepalive KeepAlive `yaml:"frontend-keepalive" toml:"frontend-keepalive" json:"frontend-keepalive"` + MaxConnections uint64 `yaml:"max-connections,omitempty" toml:"max-connections,omitempty" json:"max-connections,omitempty" reloadable:"true"` + HighMemoryUsageRejectThreshold float64 `yaml:"high-memory-usage-reject-threshold,omitempty" toml:"high-memory-usage-reject-threshold,omitempty" json:"high-memory-usage-reject-threshold,omitempty" reloadable:"true"` + ConnBufferSize int `yaml:"conn-buffer-size,omitempty" toml:"conn-buffer-size,omitempty" json:"conn-buffer-size,omitempty" reloadable:"true"` + FrontendKeepalive KeepAlive `yaml:"frontend-keepalive" toml:"frontend-keepalive" json:"frontend-keepalive"` // BackendHealthyKeepalive applies when the observer treats the backend as healthy. // The config values should be conservative to save CPU and tolerate network fluctuation. BackendHealthyKeepalive KeepAlive `yaml:"backend-healthy-keepalive" toml:"backend-healthy-keepalive" json:"backend-healthy-keepalive"` @@ -132,6 +133,7 @@ func NewConfig() *Config { cfg.Proxy.Addr = "0.0.0.0:6000" cfg.Proxy.FrontendKeepalive, cfg.Proxy.BackendHealthyKeepalive, cfg.Proxy.BackendUnhealthyKeepalive = DefaultKeepAlive() + cfg.Proxy.HighMemoryUsageRejectThreshold = 0.9 cfg.Proxy.PDAddrs = "127.0.0.1:2379" cfg.Proxy.GracefulCloseConnTimeout = 15 @@ -255,6 +257,12 @@ func (cfg *Config) GetBackendClusters() []BackendCluster { } func (ps *ProxyServer) Check() error { + if ps.HighMemoryUsageRejectThreshold < 0 || ps.HighMemoryUsageRejectThreshold > 1 { + return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.high-memory-usage-reject-threshold") + } + if ps.HighMemoryUsageRejectThreshold > 0 && ps.HighMemoryUsageRejectThreshold < 0.5 { + ps.HighMemoryUsageRejectThreshold = 0.5 + } if _, err := ps.GetSQLAddrs(); err != nil { return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.addr or proxy.port-range: %s", err.Error()) } diff --git a/lib/config/proxy_test.go b/lib/config/proxy_test.go index da9f3ecaa..a74346da7 100644 --- a/lib/config/proxy_test.go +++ b/lib/config/proxy_test.go @@ -22,11 +22,12 @@ var testProxyConfig = Config{ Addr: "0.0.0.0:4000", PDAddrs: "127.0.0.1:4089", ProxyServerOnline: ProxyServerOnline{ - MaxConnections: 1, - FrontendKeepalive: KeepAlive{Enabled: true}, - ProxyProtocol: "v2", - GracefulWaitBeforeShutdown: 10, - ConnBufferSize: 32 * 1024, + MaxConnections: 1, + HighMemoryUsageRejectThreshold: 0.9, + FrontendKeepalive: KeepAlive{Enabled: true}, + ProxyProtocol: "v2", + GracefulWaitBeforeShutdown: 10, + ConnBufferSize: 32 * 1024, BackendClusters: []BackendCluster{ { Name: "cluster-a", @@ -114,6 +115,26 @@ func TestProxyCheck(t *testing.T) { post func(*testing.T, *Config) err error }{ + { + pre: func(t *testing.T, c *Config) { + c.Proxy.HighMemoryUsageRejectThreshold = -0.1 + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.HighMemoryUsageRejectThreshold = 1.1 + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.HighMemoryUsageRejectThreshold = 0.4 + }, + post: func(t *testing.T, c *Config) { + require.Equal(t, 0.5, c.Proxy.HighMemoryUsageRejectThreshold) + }, + }, { pre: func(t *testing.T, c *Config) { c.Workdir = "" diff --git a/pkg/manager/memory/memory.go b/pkg/manager/memory/memory.go index 135bed4fe..efea1854c 100644 --- a/pkg/manager/memory/memory.go +++ b/pkg/manager/memory/memory.go @@ -9,6 +9,7 @@ import ( "path/filepath" "runtime" "runtime/pprof" + "sync/atomic" "time" "github.com/pingcap/tidb/pkg/util/memory" @@ -18,16 +19,26 @@ import ( ) const ( - // Check the memory usage every 30 seconds. - checkInterval = 30 * time.Second + // Check the memory usage every 5 seconds. + checkInterval = 5 * time.Second // No need to record too frequently. recordMinInterval = 5 * time.Minute // Record the profiles when the memory usage is higher than 60%. alarmThreshold = 0.6 // Remove the oldest profiles when the number of profiles exceeds this limit. maxSavedProfiles = 20 + // Fail open if the latest sampled usage is too old. + snapshotExpireInterval = 2 * checkInterval ) +type UsageSnapshot struct { + Used uint64 + Limit uint64 + Usage float64 + UpdateTime time.Time + Valid bool +} + // MemManager is a manager for memory usage. // Although the continuous profiling collects profiles periodically, when TiProxy runs in the replayer mode, // the profiles are not collected. @@ -41,17 +52,22 @@ type MemManager struct { checkInterval time.Duration // used for test recordMinInterval time.Duration // used for test maxSavedProfiles int // used for test + snapshotExpire time.Duration // used for test memoryLimit uint64 + latestUsage atomic.Value } func NewMemManager(lg *zap.Logger, cfgGetter config.ConfigGetter) *MemManager { - return &MemManager{ + mgr := &MemManager{ lg: lg, cfgGetter: cfgGetter, checkInterval: checkInterval, recordMinInterval: recordMinInterval, maxSavedProfiles: maxSavedProfiles, + snapshotExpire: snapshotExpireInterval, } + mgr.latestUsage.Store(UsageSnapshot{}) + return mgr } func (m *MemManager) Start(ctx context.Context) { @@ -62,6 +78,9 @@ func (m *MemManager) Start(ctx context.Context) { return } m.memoryLimit = limit + if _, err = m.refreshUsage(); err != nil { + return + } childCtx, cancel := context.WithCancel(ctx) m.cancel = cancel m.wg.RunWithRecover(func() { @@ -83,32 +102,77 @@ func (m *MemManager) alarmLoop(ctx context.Context) { } func (m *MemManager) checkAndAlarm() { + snapshot, err := m.refreshUsage() + if err != nil || !snapshot.Valid { + return + } + if snapshot.Usage < alarmThreshold { + return + } if time.Since(m.lastRecordTime) < m.recordMinInterval { return } // The filename is hot-reloadable. - logPath := m.cfgGetter.GetConfig().Log.LogFile.Filename + cfg := m.cfgGetter.GetConfig() + if cfg == nil { + return + } + logPath := cfg.Log.LogFile.Filename if logPath == "" { return } recordDir := filepath.Dir(logPath) + m.lastRecordTime = snapshot.UpdateTime + m.lg.Warn("memory usage alarm", zap.Uint64("limit", snapshot.Limit), zap.Uint64("used", snapshot.Used), zap.Float64("usage", snapshot.Usage)) + now := time.Now().Format(time.RFC3339) + m.recordHeap(filepath.Join(recordDir, "heap_"+now)) + m.recordGoroutine(filepath.Join(recordDir, "goroutine_"+now)) + m.rmExpiredProfiles() +} + +func (m *MemManager) refreshUsage() (UsageSnapshot, error) { + if m.memoryLimit == 0 { + return UsageSnapshot{}, nil + } used, err := memory.MemUsed() if err != nil || used == 0 { m.lg.Error("get used memory failed", zap.Uint64("used", used), zap.Error(err)) - return + return UsageSnapshot{}, err } - memoryUsage := float64(used) / float64(m.memoryLimit) - if memoryUsage < alarmThreshold { - return + snapshot := UsageSnapshot{ + Used: used, + Limit: m.memoryLimit, + Usage: float64(used) / float64(m.memoryLimit), + UpdateTime: time.Now(), + Valid: true, } + m.latestUsage.Store(snapshot) + return snapshot, nil +} - m.lastRecordTime = time.Now() - m.lg.Warn("memory usage alarm", zap.Uint64("limit", m.memoryLimit), zap.Uint64("used", used), zap.Float64("usage", memoryUsage)) - now := time.Now().Format(time.RFC3339) - m.recordHeap(filepath.Join(recordDir, "heap_"+now)) - m.recordGoroutine(filepath.Join(recordDir, "goroutine_"+now)) - m.rmExpiredProfiles() +func (m *MemManager) LatestUsage() UsageSnapshot { + snapshot, _ := m.latestUsage.Load().(UsageSnapshot) + return snapshot +} + +func (m *MemManager) ShouldRejectNewConn() (bool, UsageSnapshot, float64) { + if m == nil || m.cfgGetter == nil { + return false, UsageSnapshot{}, 0 + } + cfg := m.cfgGetter.GetConfig() + if cfg == nil { + return false, UsageSnapshot{}, 0 + } + threshold := cfg.Proxy.HighMemoryUsageRejectThreshold + if threshold == 0 { + return false, UsageSnapshot{}, 0 + } + snapshot := m.LatestUsage() + if !snapshot.Valid || time.Since(snapshot.UpdateTime) > m.snapshotExpire { + return false, snapshot, threshold + } + return snapshot.Usage >= threshold, snapshot, threshold } func (m *MemManager) recordHeap(fileName string) { diff --git a/pkg/manager/memory/memory_test.go b/pkg/manager/memory/memory_test.go index 9d0d8fee8..31897db05 100644 --- a/pkg/manager/memory/memory_test.go +++ b/pkg/manager/memory/memory_test.go @@ -26,6 +26,12 @@ func (c *mockCfgGetter) GetConfig() *config.Config { } func TestRecordProfile(t *testing.T) { + oldMemUsed, oldMemTotal := memory.MemUsed, memory.MemTotal + defer func() { + memory.MemUsed = oldMemUsed + memory.MemTotal = oldMemTotal + }() + dir := t.TempDir() cfg := &config.Config{} cfg.Log.LogFile.Filename = path.Join(dir, "proxy.log") @@ -75,3 +81,45 @@ func TestRecordProfile(t *testing.T) { require.NoError(t, err) require.Len(t, entries, m.maxSavedProfiles) } + +func TestShouldRejectNewConn(t *testing.T) { + oldMemUsed, oldMemTotal := memory.MemUsed, memory.MemTotal + defer func() { + memory.MemUsed = oldMemUsed + memory.MemTotal = oldMemTotal + }() + + cfg := config.NewConfig() + cfg.Proxy.HighMemoryUsageRejectThreshold = 0.9 + cfgGetter := mockCfgGetter{cfg: cfg} + memory.MemUsed = func() (uint64, error) { + return 9 * (1 << 30), nil + } + memory.MemTotal = func() (uint64, error) { + return 10 * (1 << 30), nil + } + m := NewMemManager(zap.NewNop(), &cfgGetter) + m.checkInterval = 50 * time.Millisecond + m.snapshotExpire = 200 * time.Millisecond + m.Start(context.Background()) + defer m.Close() + + require.Eventually(t, func() bool { + reject, snapshot, threshold := m.ShouldRejectNewConn() + return reject && snapshot.Valid && threshold == 0.9 + }, time.Second, 10*time.Millisecond) + m.Close() + + cfg.Proxy.HighMemoryUsageRejectThreshold = 0 + reject, _, threshold := m.ShouldRejectNewConn() + require.False(t, reject) + require.Zero(t, threshold) + + staleSnapshot := m.LatestUsage() + staleSnapshot.UpdateTime = time.Now().Add(-m.snapshotExpire - time.Second) + m.latestUsage.Store(staleSnapshot) + cfg.Proxy.HighMemoryUsageRejectThreshold = 0.9 + reject, _, threshold = m.ShouldRejectNewConn() + require.False(t, reject) + require.Equal(t, 0.9, threshold) +} diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 39c5eca1b..8501721e5 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -98,6 +98,7 @@ func init() { colls = []prometheus.Collector{ ConnGauge, CreateConnCounter, + RejectConnCounter, DisConnCounter, MaxProcsGauge, OwnerGauge, diff --git a/pkg/metrics/server.go b/pkg/metrics/server.go index b569a9910..53a288216 100644 --- a/pkg/metrics/server.go +++ b/pkg/metrics/server.go @@ -34,6 +34,14 @@ var ( Help: "Number of create connections.", }) + RejectConnCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: ModuleProxy, + Subsystem: LabelServer, + Name: "reject_connection_total", + Help: "Number of rejected connections.", + }, []string{LblType}) + DisConnCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: ModuleProxy, diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 59c89fcde..2cf9a39c7 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -14,6 +14,7 @@ import ( "github.com/pingcap/tiproxy/lib/util/errors" "github.com/pingcap/tiproxy/pkg/manager/cert" "github.com/pingcap/tiproxy/pkg/manager/id" + mgrmem "github.com/pingcap/tiproxy/pkg/manager/memory" "github.com/pingcap/tiproxy/pkg/metrics" "github.com/pingcap/tiproxy/pkg/proxy/backend" "github.com/pingcap/tiproxy/pkg/proxy/client" @@ -46,6 +47,7 @@ type SQLServer struct { logger *zap.Logger certMgr *cert.CertManager idMgr *id.IDManager + memUsage memoryStateProvider hsHandler backend.HandshakeHandler cpt capture.Capture meter backend.Meter @@ -55,14 +57,19 @@ type SQLServer struct { mu serverState } +type memoryStateProvider interface { + ShouldRejectNewConn() (bool, mgrmem.UsageSnapshot, float64) +} + // NewSQLServer creates a new SQLServer. func NewSQLServer(logger *zap.Logger, cfg *config.Config, certMgr *cert.CertManager, idMgr *id.IDManager, cpt capture.Capture, - meter backend.Meter, hsHandler backend.HandshakeHandler) (*SQLServer, error) { + meter backend.Meter, hsHandler backend.HandshakeHandler, memUsage memoryStateProvider) (*SQLServer, error) { var err error s := &SQLServer{ logger: logger, certMgr: certMgr, idMgr: idMgr, + memUsage: memUsage, hsHandler: hsHandler, cpt: cpt, meter: meter, @@ -153,6 +160,10 @@ func (s *SQLServer) Run(ctx context.Context, cfgch <-chan *config.Config) { } func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { + if s.rejectConnByMemory(conn) { + return + } + tcpKeepAlive, logger, connID, clientConn := func() (bool, *zap.Logger, uint64, *client.ClientConnection) { s.mu.Lock() defer s.mu.Unlock() @@ -161,6 +172,7 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { maxConns := s.mu.maxConnections // 'maxConns == 0' => unlimited connections if maxConns != 0 && conns >= maxConns { + metrics.RejectConnCounter.WithLabelValues("max_connections").Inc() s.logger.Warn("too many connections", zap.Uint64("max connections", maxConns), zap.Stringer("client_addr", conn.RemoteAddr()), zap.Error(conn.Close())) return false, nil, 0, nil } @@ -209,6 +221,26 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { clientConn.Run(ctx) } +func (s *SQLServer) rejectConnByMemory(conn net.Conn) bool { + if s.memUsage == nil { + return false + } + reject, snapshot, threshold := s.memUsage.ShouldRejectNewConn() + if !reject { + return false + } + metrics.RejectConnCounter.WithLabelValues("memory").Inc() + s.logger.Warn("reject connection due to high memory usage", + zap.Stringer("client_addr", conn.RemoteAddr()), + zap.Float64("threshold", threshold), + zap.Float64("usage", snapshot.Usage), + zap.Uint64("used", snapshot.Used), + zap.Uint64("limit", snapshot.Limit), + zap.Time("last_update", snapshot.UpdateTime), + zap.Error(conn.Close())) + return true +} + func (s *SQLServer) fromPublicEndpoint(addr net.Addr) bool { if addr == nil || reflect.ValueOf(addr).IsNil() { return false diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go index cd75c4c1d..106b310f4 100644 --- a/pkg/proxy/proxy_test.go +++ b/pkg/proxy/proxy_test.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tiproxy/pkg/balance/router" "github.com/pingcap/tiproxy/pkg/manager/cert" "github.com/pingcap/tiproxy/pkg/manager/id" + mgrmem "github.com/pingcap/tiproxy/pkg/manager/memory" "github.com/pingcap/tiproxy/pkg/metrics" "github.com/pingcap/tiproxy/pkg/proxy/backend" "github.com/pingcap/tiproxy/pkg/proxy/client" @@ -35,7 +36,7 @@ func TestCreateConn(t *testing.T) { cfg := &config.Config{} certManager := cert.NewCertManager() require.NoError(t, certManager.Init(cfg, lg, nil)) - server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}) + server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, nil) require.NoError(t, err) server.Run(context.Background(), nil) defer func() { @@ -72,6 +73,58 @@ func TestCreateConn(t *testing.T) { checkMetrics(0, 2) } +func TestRejectConnByMemory(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + certManager := cert.NewCertManager() + require.NoError(t, certManager.Init(&config.Config{}, lg, nil)) + server, err := NewSQLServer(lg, &config.Config{}, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, &mockMemUsageProvider{ + reject: true, + snapshot: mgrmem.UsageSnapshot{ + Used: 9 * (1 << 30), + Limit: 10 * (1 << 30), + Usage: 0.9, + UpdateTime: time.Now(), + Valid: true, + }, + threshold: 0.9, + }) + require.NoError(t, err) + server.Run(context.Background(), nil) + defer func() { + server.PreClose() + require.NoError(t, server.Close()) + certManager.Close() + }() + + rejectBefore, err := metrics.ReadCounter(metrics.RejectConnCounter.WithLabelValues("memory")) + require.NoError(t, err) + createBefore, err := metrics.ReadCounter(metrics.CreateConnCounter) + require.NoError(t, err) + connGaugeBefore, err := metrics.ReadGauge(metrics.ConnGauge) + require.NoError(t, err) + + conn, err := net.Dial("tcp", server.listeners[0].Addr().String()) + require.NoError(t, err) + defer func() { _ = conn.Close() }() + + require.Eventually(t, func() bool { + _ = conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + var buf [1]byte + _, err := conn.Read(buf[:]) + return err != nil + }, time.Second, 10*time.Millisecond) + + require.Eventually(t, func() bool { + rejectAfter, err := metrics.ReadCounter(metrics.RejectConnCounter.WithLabelValues("memory")) + require.NoError(t, err) + createAfter, err := metrics.ReadCounter(metrics.CreateConnCounter) + require.NoError(t, err) + connGaugeAfter, err := metrics.ReadGauge(metrics.ConnGauge) + require.NoError(t, err) + return rejectAfter == rejectBefore+1 && createAfter == createBefore && connGaugeAfter == connGaugeBefore + }, time.Second, 10*time.Millisecond) +} + func TestGracefulCloseConn(t *testing.T) { // Graceful shutdown finishes immediately if there's no connection. lg, _ := logger.CreateLoggerForTest(t) @@ -83,7 +136,7 @@ func TestGracefulCloseConn(t *testing.T) { }, }, } - server, err := NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, nil, hsHandler) + server, err := NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, nil, hsHandler, nil) require.NoError(t, err) finish := make(chan struct{}) go func() { @@ -113,7 +166,7 @@ func TestGracefulCloseConn(t *testing.T) { } // Graceful shutdown will be blocked if there are alive connections. - server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, nil, hsHandler) + server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, nil, hsHandler, nil) require.NoError(t, err) clientConn := createClientConn() go func() { @@ -139,7 +192,7 @@ func TestGracefulCloseConn(t *testing.T) { // Graceful shutdown will shut down after GracefulCloseConnTimeout. cfg.Proxy.GracefulCloseConnTimeout = 1 - server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, nil, hsHandler) + server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, nil, hsHandler, nil) require.NoError(t, err) createClientConn() go func() { @@ -167,7 +220,7 @@ func TestGracefulShutDown(t *testing.T) { }, }, } - server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}) + server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, nil) require.NoError(t, err) server.Run(context.Background(), nil) @@ -205,7 +258,7 @@ func TestMultiAddr(t *testing.T) { Proxy: config.ProxyServer{ Addr: "0.0.0.0:0,0.0.0.0:0", }, - }, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}) + }, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, nil) require.NoError(t, err) server.Run(context.Background(), nil) @@ -232,7 +285,7 @@ func TestPortRange(t *testing.T) { Addr: fmt.Sprintf("127.0.0.1:%d", start), PortRange: []int{start, end}, }, - }, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}) + }, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, nil) require.NoError(t, err) server.Run(context.Background(), nil) @@ -276,7 +329,7 @@ func TestConnAddrUsesActualListenerAddr(t *testing.T) { Proxy: config.ProxyServer{ Addr: "127.0.0.1:0", }, - }, certManager, id.NewIDManager(), nil, nil, handler) + }, certManager, id.NewIDManager(), nil, nil, handler, nil) require.NoError(t, err) server.Run(context.Background(), nil) defer func() { @@ -303,7 +356,7 @@ func TestWatchCfg(t *testing.T) { lg, _ := logger.CreateLoggerForTest(t) hsHandler := backend.NewDefaultHandshakeHandler(nil) cfgch := make(chan *config.Config) - server, err := NewSQLServer(lg, &config.Config{}, nil, id.NewIDManager(), nil, nil, hsHandler) + server, err := NewSQLServer(lg, &config.Config{}, nil, id.NewIDManager(), nil, nil, hsHandler, nil) require.NoError(t, err) server.Run(context.Background(), cfgch) cfg := &config.Config{ @@ -374,7 +427,7 @@ func TestRecoverPanic(t *testing.T) { } return nil }, - }) + }, nil) require.NoError(t, err) server.Run(context.Background(), nil) @@ -417,7 +470,7 @@ func TestPublicEndpoint(t *testing.T) { }, } - server, err := NewSQLServer(zap.NewNop(), &config.Config{}, nil, id.NewIDManager(), nil, nil, backend.NewDefaultHandshakeHandler(nil)) + server, err := NewSQLServer(zap.NewNop(), &config.Config{}, nil, id.NewIDManager(), nil, nil, backend.NewDefaultHandshakeHandler(nil), nil) require.NoError(t, err) for i, test := range tests { cfg := &config.Config{} @@ -439,6 +492,16 @@ type mockHsHandler struct { getRouter func(ctx backend.ConnContext, _ *pnet.HandshakeResp) (router.Router, error) } +type mockMemUsageProvider struct { + reject bool + snapshot mgrmem.UsageSnapshot + threshold float64 +} + +func (m *mockMemUsageProvider) ShouldRejectNewConn() (bool, mgrmem.UsageSnapshot, float64) { + return m.reject, m.snapshot, m.threshold +} + // HandleHandshakeResp only panics for the first connections. func (handler *mockHsHandler) HandleHandshakeResp(ctx backend.ConnContext, resp *pnet.HandshakeResp) error { if handler.handshakeResp != nil { diff --git a/pkg/server/server.go b/pkg/server/server.go index 2e1cc836e..8ee74d64e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -184,7 +184,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) // setup proxy server { - srv.proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg, srv.certManager, idMgr, srv.replay.GetCapture(), srv.meter, hsHandler) + srv.proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg, srv.certManager, idMgr, srv.replay.GetCapture(), srv.meter, hsHandler, srv.memManager) if err != nil { return }