From ce40a528486d820f25ad34dd5b08abd05be8ec8a Mon Sep 17 00:00:00 2001 From: djshow832 Date: Thu, 2 Apr 2026 15:43:26 +0800 Subject: [PATCH] evict backend --- conf/proxy.toml | 8 ++ lib/config/proxy.go | 24 ++++ lib/config/proxy_test.go | 22 +++ pkg/balance/router/group.go | 35 +++++ pkg/balance/router/mock_test.go | 10 ++ pkg/balance/router/router.go | 65 ++++++++- pkg/balance/router/router_score.go | 63 +++++++-- pkg/balance/router/router_score_test.go | 157 +++++++++++++++++++++ pkg/manager/config/config_test.go | 20 +++ pkg/proxy/backend/backend_conn_mgr.go | 19 +++ pkg/proxy/backend/backend_conn_mgr_test.go | 31 ++++ 11 files changed, 445 insertions(+), 9 deletions(-) diff --git a/conf/proxy.toml b/conf/proxy.toml index d05b6c9a7..d1fa70e57 100644 --- a/conf/proxy.toml +++ b/conf/proxy.toml @@ -23,6 +23,14 @@ graceful-close-conn-timeout = 15 +# fail-backend-list marks backend pod names or backend addresses as failed. TiProxy will stop routing new +# connections to them and migrate existing connections away. +# fail-backend-list = ["db-2033841436272623616-0f6e346b-tidb-0", "10.0.0.10:4000"] + +# failover-timeout is measured in seconds. If a failed backend still has remaining connections after the timeout, +# TiProxy will force close them. +# failover-timeout = 60 + # possible values: # "" => enable static routing. # "pd-addr:pd-port" => automatically tidb discovery. diff --git a/lib/config/proxy.go b/lib/config/proxy.go index 43fc0573e..dd467140a 100644 --- a/lib/config/proxy.go +++ b/lib/config/proxy.go @@ -68,6 +68,11 @@ type ProxyServerOnline struct { // BackendClusters represents multiple backend clusters that the proxy can route to. It can be reloaded // online. BackendClusters []BackendCluster `yaml:"backend-clusters,omitempty" toml:"backend-clusters,omitempty" json:"backend-clusters,omitempty" reloadable:"true"` + // FailBackendList contains backend pod names or backend addresses (IP:port) that should be drained immediately + // and excluded from new routing. + FailBackendList []string `yaml:"fail-backend-list,omitempty" toml:"fail-backend-list,omitempty" json:"fail-backend-list,omitempty" reloadable:"true"` + // FailoverTimeout is the grace period in seconds before force closing the remaining connections on failed backends. + FailoverTimeout int `yaml:"failover-timeout,omitempty" toml:"failover-timeout,omitempty" json:"failover-timeout,omitempty" reloadable:"true"` } type ProxyServer struct { @@ -134,6 +139,7 @@ func NewConfig() *Config { cfg.Proxy.FrontendKeepalive, cfg.Proxy.BackendHealthyKeepalive, cfg.Proxy.BackendUnhealthyKeepalive = DefaultKeepAlive() cfg.Proxy.PDAddrs = "127.0.0.1:2379" cfg.Proxy.GracefulCloseConnTimeout = 15 + cfg.Proxy.FailoverTimeout = 60 cfg.API.Addr = "0.0.0.0:3080" @@ -160,6 +166,7 @@ func (cfg *Config) Clone() *Config { newCfg.Labels = maps.Clone(cfg.Labels) newCfg.Proxy.PublicEndpoints = slices.Clone(cfg.Proxy.PublicEndpoints) newCfg.Proxy.BackendClusters = slices.Clone(cfg.Proxy.BackendClusters) + newCfg.Proxy.FailBackendList = slices.Clone(cfg.Proxy.FailBackendList) for i := range newCfg.Proxy.BackendClusters { newCfg.Proxy.BackendClusters[i].NSServers = slices.Clone(newCfg.Proxy.BackendClusters[i].NSServers) } @@ -279,6 +286,23 @@ func (ps *ProxyServer) Check() error { return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.backend-clusters.ns-servers: %s", err.Error()) } } + if ps.FailoverTimeout < 0 { + return errors.Wrapf(ErrInvalidConfigValue, "proxy.failover-timeout must be greater than or equal to 0") + } + failBackends := ps.FailBackendList[:0] + failBackendSet := make(map[string]struct{}, len(ps.FailBackendList)) + for i, backendName := range ps.FailBackendList { + backendName = strings.TrimSpace(backendName) + if backendName == "" { + return errors.Wrapf(ErrInvalidConfigValue, "proxy.fail-backend-list[%d] is empty", i) + } + if _, ok := failBackendSet[backendName]; ok { + return errors.Wrapf(ErrInvalidConfigValue, "duplicate proxy.fail-backend-list entry %s", backendName) + } + failBackendSet[backendName] = struct{}{} + failBackends = append(failBackends, backendName) + } + ps.FailBackendList = failBackends return nil } diff --git a/lib/config/proxy_test.go b/lib/config/proxy_test.go index da9f3ecaa..d7a061278 100644 --- a/lib/config/proxy_test.go +++ b/lib/config/proxy_test.go @@ -26,6 +26,8 @@ var testProxyConfig = Config{ FrontendKeepalive: KeepAlive{Enabled: true}, ProxyProtocol: "v2", GracefulWaitBeforeShutdown: 10, + FailBackendList: []string{"db-tidb-0", "db-tidb-1"}, + FailoverTimeout: 60, ConnBufferSize: 32 * 1024, BackendClusters: []BackendCluster{ { @@ -188,6 +190,24 @@ func TestProxyCheck(t *testing.T) { }, err: ErrInvalidConfigValue, }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.FailBackendList = []string{"db-tidb-0", " "} + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.FailBackendList = []string{"db-tidb-0", "db-tidb-0"} + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.FailoverTimeout = -1 + }, + err: ErrInvalidConfigValue, + }, } for _, tc := range testcases { cfg := testProxyConfig @@ -311,10 +331,12 @@ func TestCloneConfig(t *testing.T) { require.Equal(t, cfg, *clone) cfg.Labels["c"] = "d" cfg.Proxy.PublicEndpoints[0] = "2.2.2.0/24" + cfg.Proxy.FailBackendList[0] = "db-tidb-9" cfg.Proxy.BackendClusters[0].Name = "cluster-updated" cfg.Proxy.BackendClusters[0].NSServers[0] = "10.0.0.9" require.NotContains(t, clone.Labels, "c") require.Equal(t, []string{"1.1.1.0/24"}, clone.Proxy.PublicEndpoints) + require.Equal(t, []string{"db-tidb-0", "db-tidb-1"}, clone.Proxy.FailBackendList) require.Equal(t, "cluster-a", clone.Proxy.BackendClusters[0].Name) require.Equal(t, []string{"10.0.0.2", "10.0.0.3"}, clone.Proxy.BackendClusters[0].NSServers) } diff --git a/pkg/balance/router/group.go b/pkg/balance/router/group.go index 43341054b..c661d19a7 100644 --- a/pkg/balance/router/group.go +++ b/pkg/balance/router/group.go @@ -253,6 +253,9 @@ func (g *Group) Balance(ctx context.Context) { i := 0 for ele := fromBackend.connList.Front(); ele != nil && ctx.Err() == nil && i < count; ele = ele.Next() { conn := ele.Value + if conn.forceClosing { + continue + } switch conn.phase { case phaseRedirectNotify: // A connection cannot be redirected again when it has not finished redirecting. @@ -279,6 +282,7 @@ func (g *Group) onCreateConn(backendInst BackendInst, conn RedirectableConn, suc RedirectableConn: conn, createTime: time.Now(), phase: phaseNotRedirected, + forceClosing: false, } g.addConn(backend, connWrapper) conn.SetEventReceiver(g) @@ -287,6 +291,37 @@ func (g *Group) onCreateConn(backendInst BackendInst, conn RedirectableConn, suc } } +func (g *Group) CloseTimedOutFailoverConnections(now time.Time, timeout time.Duration) { + g.Lock() + defer g.Unlock() + for _, backend := range g.backends { + active, since := backend.Failover() + if !active { + continue + } + if timeout > 0 && since.Add(timeout).After(now) { + continue + } + for ele := backend.connList.Front(); ele != nil; ele = ele.Next() { + conn := ele.Value + if conn.phase == phaseClosed || conn.forceClosing { + continue + } + fields := []zap.Field{ + zap.Uint64("connID", conn.ConnectionID()), + zap.String("backend_addr", backend.addr), + zap.String("backend_pod", backend.PodName()), + zap.Duration("failover_timeout", timeout), + zap.Duration("failover_elapsed", now.Sub(since)), + } + if conn.ForceClose() { + conn.forceClosing = true + g.lg.Warn("force close connection on failover backend", fields...) + } + } + } +} + func (g *Group) removeConn(backend *backendWrapper, ce *glist.Element[*connWrapper]) { backend.connList.Remove(ce) setBackendConnMetrics(backend.addr, backend.connList.Len()) diff --git a/pkg/balance/router/mock_test.go b/pkg/balance/router/mock_test.go index d8eb98950..8ec358046 100644 --- a/pkg/balance/router/mock_test.go +++ b/pkg/balance/router/mock_test.go @@ -69,6 +69,16 @@ func (conn *mockRedirectableConn) Redirect(inst BackendInst) bool { return true } +func (conn *mockRedirectableConn) ForceClose() bool { + conn.Lock() + defer conn.Unlock() + if conn.closing { + return false + } + conn.closing = true + return true +} + func (conn *mockRedirectableConn) GetRedirectingAddr() string { conn.Lock() defer conn.Unlock() diff --git a/pkg/balance/router/router.go b/pkg/balance/router/router.go index ee844bf9e..9f31ad8f3 100644 --- a/pkg/balance/router/router.go +++ b/pkg/balance/router/router.go @@ -4,6 +4,7 @@ package router import ( + "net" "strings" "sync" "time" @@ -68,6 +69,8 @@ type RedirectableConn interface { Value(key any) any // Redirect returns false if the current conn is not redirectable. Redirect(backend BackendInst) bool + // ForceClose closes the connection immediately and returns false if it's already closing. + ForceClose() bool ConnectionID() uint64 ConnInfo() []zap.Field } @@ -85,8 +88,11 @@ type backendWrapper struct { mu struct { sync.RWMutex observer.BackendHealth + failoverActive bool + failoverSince time.Time } - addr string + addr string + podName string // connScore is used for calculating backend scores and check if the backend can be removed from the list. // connScore = connList.Len() + incoming connections - outgoing connections. connScore int @@ -100,6 +106,7 @@ type backendWrapper struct { func newBackendWrapper(addr string, health observer.BackendHealth) *backendWrapper { wrapper := &backendWrapper{ addr: addr, + podName: backendPodNameFromAddr(addr), connList: glist.New[*connWrapper](), } wrapper.setHealth(health) @@ -128,12 +135,50 @@ func (b *backendWrapper) Addr() string { } func (b *backendWrapper) Healthy() bool { + b.mu.RLock() + healthy := b.mu.Healthy && !b.mu.failoverActive + b.mu.RUnlock() + return healthy +} + +func (b *backendWrapper) ObservedHealthy() bool { b.mu.RLock() healthy := b.mu.Healthy b.mu.RUnlock() return healthy } +func (b *backendWrapper) PodName() string { + return b.podName +} + +func (b *backendWrapper) setFailover(active bool, since time.Time) (changed bool, failoverSince time.Time) { + b.mu.Lock() + defer b.mu.Unlock() + if active { + if b.mu.failoverActive { + return false, b.mu.failoverSince + } + b.mu.failoverActive = true + b.mu.failoverSince = since + return true, b.mu.failoverSince + } + if !b.mu.failoverActive { + return false, time.Time{} + } + b.mu.failoverActive = false + b.mu.failoverSince = time.Time{} + return true, time.Time{} +} + +func (b *backendWrapper) Failover() (active bool, since time.Time) { + b.mu.RLock() + active = b.mu.failoverActive + since = b.mu.failoverSince + b.mu.RUnlock() + return +} + func (b *backendWrapper) ServerVersion() string { b.mu.RLock() version := b.mu.ServerVersion @@ -213,4 +258,22 @@ type connWrapper struct { lastRedirect time.Time createTime time.Time phase connPhase + forceClosing bool +} + +func backendPodNameFromAddr(addr string) string { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + if host == "" { + return "" + } + if ip := net.ParseIP(host); ip != nil { + return host + } + if idx := strings.IndexByte(host, '.'); idx >= 0 { + return host[:idx] + } + return host } diff --git a/pkg/balance/router/router_score.go b/pkg/balance/router/router_score.go index d70afa4a9..9fc528b2f 100644 --- a/pkg/balance/router/router_score.go +++ b/pkg/balance/router/router_score.go @@ -48,14 +48,17 @@ type ScoreBasedRouter struct { serverVersion string // The backend supports redirection only when they have signing certs. supportRedirection bool + failoverBackends map[string]struct{} + failoverTimeout time.Duration } // NewScoreBasedRouter creates a ScoreBasedRouter. func NewScoreBasedRouter(logger *zap.Logger) *ScoreBasedRouter { return &ScoreBasedRouter{ - logger: logger, - backends: make(map[string]*backendWrapper), - groups: make([]*Group, 0), + logger: logger, + backends: make(map[string]*backendWrapper), + groups: make([]*Group, 0), + failoverBackends: make(map[string]struct{}), } } @@ -78,6 +81,9 @@ func (r *ScoreBasedRouter) Init(ctx context.Context, ob observer.BackendObserver default: r.logger.Error("unsupported routing rule, use the default rule", zap.String("rule", cfg.Balance.RoutingRule)) } + r.Lock() + r.setFailoverConfigLocked(cfg) + r.Unlock() childCtx, cancelFunc := context.WithCancel(ctx) r.cancelFunc = cancelFunc @@ -195,11 +201,14 @@ func (router *ScoreBasedRouter) updateBackendHealth(healthResults observer.Healt } var serverVersion string supportRedirection := true + now := time.Now() for addr, health := range backends { backend, ok := router.backends[addr] if !ok && health.Healthy { router.logger.Debug("add new backend to router", zap.String("addr", addr), zap.Stringer("health", health)) - router.backends[addr] = newBackendWrapper(addr, *health) + backend = newBackendWrapper(addr, *health) + router.backends[addr] = backend + router.setBackendFailoverLocked(backend, now) serverVersion = health.ServerVersion } else if ok { if !health.Equals(backend.getHealth()) { @@ -231,7 +240,7 @@ func (router *ScoreBasedRouter) updateGroups() { for _, backend := range router.backends { // If connList.Len() == 0, there won't be any outgoing connections. // And if also connScore == 0, there won't be any incoming connections. - if !backend.Healthy() && backend.connList.Len() == 0 && backend.connScore <= 0 { + if !backend.ObservedHealthy() && backend.connList.Len() == 0 && backend.connScore <= 0 { delete(router.backends, backend.addr) if backend.group != nil { backend.group.RemoveBackend(backend.addr) @@ -308,6 +317,7 @@ func (router *ScoreBasedRouter) rebalanceLoop(ctx context.Context) { func (router *ScoreBasedRouter) setConfig(cfg *config.Config) { router.Lock() defer router.Unlock() + router.setFailoverConfigLocked(cfg) for _, group := range router.groups { group.SetConfig(cfg) } @@ -320,12 +330,49 @@ func (router *ScoreBasedRouter) rebalance(ctx context.Context) { router.Lock() defer router.Unlock() - if !router.supportRedirection { - return + if router.supportRedirection { + for _, group := range router.groups { + group.Balance(ctx) + } } for _, group := range router.groups { - group.Balance(ctx) + group.CloseTimedOutFailoverConnections(time.Now(), router.failoverTimeout) + } +} + +func (router *ScoreBasedRouter) setFailoverConfigLocked(cfg *config.Config) { + failoverBackends := make(map[string]struct{}, len(cfg.Proxy.FailBackendList)) + for _, backend := range cfg.Proxy.FailBackendList { + failoverBackends[backend] = struct{}{} + } + router.failoverBackends = failoverBackends + router.failoverTimeout = time.Duration(cfg.Proxy.FailoverTimeout) * time.Second + now := time.Now() + for _, backend := range router.backends { + router.setBackendFailoverLocked(backend, now) + } +} + +func (router *ScoreBasedRouter) setBackendFailoverLocked(backend *backendWrapper, now time.Time) { + _, active := router.failoverBackends[backend.PodName()] + if !active { + _, active = router.failoverBackends[backend.Addr()] + } + changed, since := backend.setFailover(active, now) + if !changed { + return + } + fields := []zap.Field{ + zap.String("backend_addr", backend.Addr()), + zap.String("backend_pod", backend.PodName()), + zap.Duration("failover_timeout", router.failoverTimeout), + } + if active { + fields = append(fields, zap.Time("failover_since", since)) + router.logger.Warn("backend enters failover", fields...) + return } + router.logger.Info("backend exits failover", fields...) } func (router *ScoreBasedRouter) ConnCount() int { diff --git a/pkg/balance/router/router_score_test.go b/pkg/balance/router/router_score_test.go index 3dd8e4042..257af3036 100644 --- a/pkg/balance/router/router_score_test.go +++ b/pkg/balance/router/router_score_test.go @@ -128,6 +128,13 @@ func (tester *routerTester) updateBackendLocalityByAddr(addr string, local bool) tester.notifyHealth() } +func (tester *routerTester) updateBackendRedirectSupportByAddr(addr string, support bool) { + health, ok := tester.backends[addr] + require.True(tester.t, ok) + health.SupportRedirection = support + tester.notifyHealth() +} + func (tester *routerTester) getBackendByIndex(index int) *backendWrapper { addr := strconv.Itoa(index + 1) backend := tester.router.backends[addr] @@ -723,6 +730,94 @@ func TestSetBackendStatus(t *testing.T) { } } +func TestBackendPodNameFromAddr(t *testing.T) { + require.Equal(t, "db-2033841436272623616-0f6e346b-tidb-0", backendPodNameFromAddr("db-2033841436272623616-0f6e346b-tidb-0.peer.ns.svc:4000")) + require.Equal(t, "127.0.0.1", backendPodNameFromAddr("127.0.0.1:4000")) + require.Equal(t, "backend-host", backendPodNameFromAddr("backend-host")) +} + +func TestFailoverBackend(t *testing.T) { + tester := newRouterTester(t, nil) + tester.addBackends(2) + tester.addConnections(20) + + fromBackend := tester.getBackendByIndex(0) + toBackend := tester.getBackendByIndex(1) + tester.router.setConfig(&config.Config{ + Proxy: config.ProxyServer{ + ProxyServerOnline: config.ProxyServerOnline{ + FailBackendList: []string{fromBackend.PodName()}, + FailoverTimeout: 60, + }, + }, + }) + + require.False(t, fromBackend.Healthy()) + selector := tester.router.GetBackendSelector(ClientInfo{}) + backend, err := selector.Next() + require.NoError(t, err) + selector.Finish(nil, false) + require.NotNil(t, backend) + require.Equal(t, toBackend.Addr(), backend.Addr()) + + tester.rebalance(1) + require.Equal(t, 10, fromBackend.ConnCount()) + tester.checkRedirectingNum(10) + tester.redirectFinish(10, true) + require.Equal(t, 0, fromBackend.ConnCount()) + require.Equal(t, 20, toBackend.ConnCount()) + tester.checkBackendConnMetrics() +} + +func TestFailoverBackendByAddr(t *testing.T) { + tester := newRouterTester(t, nil) + tester.addBackends(2) + + fromBackend := tester.getBackendByIndex(0) + toBackend := tester.getBackendByIndex(1) + tester.router.setConfig(&config.Config{ + Proxy: config.ProxyServer{ + ProxyServerOnline: config.ProxyServerOnline{ + FailBackendList: []string{fromBackend.Addr()}, + FailoverTimeout: 60, + }, + }, + }) + + require.False(t, fromBackend.Healthy()) + require.True(t, toBackend.Healthy()) + selector := tester.router.GetBackendSelector(ClientInfo{}) + backend, err := selector.Next() + require.NoError(t, err) + selector.Finish(nil, false) + require.NotNil(t, backend) + require.Equal(t, toBackend.Addr(), backend.Addr()) +} + +func TestFailoverTimeoutForceClose(t *testing.T) { + tester := newRouterTester(t, nil) + tester.addBackends(1) + tester.addConnections(3) + + backend := tester.getBackendByIndex(0) + tester.updateBackendRedirectSupportByAddr(backend.Addr(), false) + tester.router.setConfig(&config.Config{ + Proxy: config.ProxyServer{ + ProxyServerOnline: config.ProxyServerOnline{ + FailBackendList: []string{backend.PodName()}, + FailoverTimeout: 0, + }, + }, + }) + + tester.rebalance(1) + for _, conn := range tester.conns { + require.True(t, conn.closing) + } + tester.closeConnections(3, false) + tester.checkBackendConnMetrics() +} + func TestGetServerVersion(t *testing.T) { lg, _ := logger.CreateLoggerForTest(t) rt := NewScoreBasedRouter(lg) @@ -840,6 +935,68 @@ func TestWatchConfig(t *testing.T) { }, 3*time.Second, 10*time.Millisecond) } +func TestWatchFailoverConfig(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + router := NewScoreBasedRouter(lg) + cfgCh := make(chan *config.Config) + addr := "db-2033841436272623616-0f6e346b-tidb-0.peer.ns.svc:4000" + cfgGetter := newMockConfigGetter(&config.Config{ + Proxy: config.ProxyServer{ + ProxyServerOnline: config.ProxyServerOnline{ + FailoverTimeout: 60, + }, + }, + }) + bo := newMockBackendObserver() + router.Init(context.Background(), bo, simpleBpCreator, cfgGetter, cfgCh) + t.Cleanup(router.Close) + + bo.addBackend(addr, nil) + bo.notify(nil) + require.Eventually(t, func() bool { + backend := router.backends[addr] + return backend != nil && backend.Healthy() + }, 3*time.Second, 10*time.Millisecond) + + cfgCh <- &config.Config{ + Proxy: config.ProxyServer{ + ProxyServerOnline: config.ProxyServerOnline{ + FailBackendList: []string{"db-2033841436272623616-0f6e346b-tidb-0"}, + FailoverTimeout: 60, + }, + }, + } + require.Eventually(t, func() bool { + backend := router.backends[addr] + return backend != nil && !backend.Healthy() + }, 3*time.Second, 10*time.Millisecond) + + cfgCh <- &config.Config{ + Proxy: config.ProxyServer{ + ProxyServerOnline: config.ProxyServerOnline{ + FailBackendList: []string{addr}, + FailoverTimeout: 60, + }, + }, + } + require.Eventually(t, func() bool { + backend := router.backends[addr] + return backend != nil && !backend.Healthy() + }, 3*time.Second, 10*time.Millisecond) + + cfgCh <- &config.Config{ + Proxy: config.ProxyServer{ + ProxyServerOnline: config.ProxyServerOnline{ + FailoverTimeout: 60, + }, + }, + } + require.Eventually(t, func() bool { + backend := router.backends[addr] + return backend != nil && backend.Healthy() + }, 3*time.Second, 10*time.Millisecond) +} + func TestControlSpeed(t *testing.T) { tests := []struct { balanceCount float64 diff --git a/pkg/manager/config/config_test.go b/pkg/manager/config/config_test.go index 2ea3557eb..8569553d4 100644 --- a/pkg/manager/config/config_test.go +++ b/pkg/manager/config/config_test.go @@ -88,6 +88,26 @@ func TestConfigReload(t *testing.T) { return c.API.Addr == "0.0.0.0:3081" }, }, + { + name: "failover override", + precfg: ` +proxy.fail-backend-list = ["db-tidb-0", "db-tidb-1"] +proxy.failover-timeout = 90 +`, + precheck: func(c *config.Config) bool { + return c.Proxy.FailoverTimeout == 90 && + len(c.Proxy.FailBackendList) == 2 && + c.Proxy.FailBackendList[0] == "db-tidb-0" && + c.Proxy.FailBackendList[1] == "db-tidb-1" + }, + postcfg: ` +proxy.fail-backend-list = [] +proxy.failover-timeout = 0 +`, + postcheck: func(c *config.Config) bool { + return c.Proxy.FailoverTimeout == 0 && len(c.Proxy.FailBackendList) == 0 + }, + }, { name: "non empty fields should not be override by empty fields", precfg: `proxy.addr = "gg"`, diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 9619d60cd..a66cb29be 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -697,6 +697,25 @@ func (mgr *BackendConnManager) Redirect(backendInst router.BackendInst) bool { return true } +func (mgr *BackendConnManager) ForceClose() bool { + for { + status := mgr.closeStatus.Load() + if status >= statusClosing { + return false + } + if mgr.closeStatus.CompareAndSwap(status, statusClosing) { + break + } + } + mgr.quitSource = SrcProxyQuit + if mgr.clientIO != nil { + if err := mgr.clientIO.Close(); err != nil && !pnet.IsDisconnectError(err) { + mgr.logger.Warn("force close client IO error", zap.Error(err)) + } + } + return true +} + func (mgr *BackendConnManager) notifyRedirectResult(ctx context.Context, rs *redirectResult) { _ = ctx if rs == nil { diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 05b95b46e..b5589859c 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -895,6 +895,37 @@ func TestGracefulCloseBeforeHandshake(t *testing.T) { ts.runTests(runners) } +func TestForceClose(t *testing.T) { + ts := newBackendMgrTester(t) + runners := []runner{ + // 1st handshake + { + client: ts.mc.authenticate, + proxy: ts.firstHandshake4Proxy, + backend: ts.handshake4Backend, + }, + // force close + { + proxy: func(_, _ pnet.PacketIO) error { + require.True(t, ts.mp.ForceClose()) + return nil + }, + }, + // really closed + { + proxy: ts.checkConnClosed4Proxy, + }, + { + proxy: func(clientIO, backendIO pnet.PacketIO) error { + require.Equal(t, SrcProxyQuit, ts.mp.QuitSource()) + require.False(t, ts.mp.ForceClose()) + return nil + }, + }, + } + ts.runTests(runners) +} + func TestHandlerReturnError(t *testing.T) { tests := []struct { cfg cfgOverrider