diff --git a/server/cmd/api/api/process.go b/server/cmd/api/api/process.go index 12eb448a..38a458f4 100644 --- a/server/cmd/api/api/process.go +++ b/server/cmd/api/api/process.go @@ -26,6 +26,7 @@ import ( "github.com/kernel/kernel-images/server/lib/logger" oapi "github.com/kernel/kernel-images/server/lib/oapi" "github.com/kernel/kernel-images/server/lib/ptyio" + "github.com/kernel/kernel-images/server/lib/wsdrain" openapi_types "github.com/oapi-codegen/runtime/types" ) @@ -659,7 +660,7 @@ func writeJSON(w http.ResponseWriter, status int, body string) { // - Server sends TextMessage with JSON for events (e.g., exit code) // // This endpoint is intentionally not defined in OpenAPI. -func (s *ApiService) HandleProcessAttachWS(w http.ResponseWriter, r *http.Request, id string) { +func (s *ApiService) HandleProcessAttachWS(w http.ResponseWriter, r *http.Request, id string, reg *wsdrain.Registry) { ctx := r.Context() log := logger.FromContext(ctx) @@ -702,6 +703,9 @@ func (s *ApiService) HandleProcessAttachWS(w http.ResponseWriter, r *http.Reques } defer wsConn.CloseNow() + untrack := reg.Track(wsConn) + defer untrack() + // Set a generous read limit for PTY data wsConn.SetReadLimit(1024 * 1024) // 1MB diff --git a/server/cmd/api/main.go b/server/cmd/api/main.go index b3500a52..d839f4fb 100644 --- a/server/cmd/api/main.go +++ b/server/cmd/api/main.go @@ -14,6 +14,7 @@ import ( "syscall" "time" + "github.com/coder/websocket" "github.com/ghodss/yaml" "github.com/go-chi/chi/v5" chiMiddleware "github.com/go-chi/chi/v5/middleware" @@ -32,6 +33,7 @@ import ( "github.com/kernel/kernel-images/server/lib/scaletozero" "github.com/kernel/kernel-images/server/lib/sysmon" "github.com/kernel/kernel-images/server/lib/telemetry" + "github.com/kernel/kernel-images/server/lib/wsdrain" ) func main() { @@ -80,6 +82,9 @@ func main() { os.Exit(1) } + // ws conn tracker + wsRegistry := wsdrain.New() + // DevTools WebSocket upstream manager: tail Chromium supervisord log const chromiumLogPath = "/var/log/supervisord/chromium" upstreamMgr := devtoolsproxy.NewUpstreamManager(chromiumLogPath, slogger) @@ -166,7 +171,7 @@ func main() { // Uses WebSocket for bidirectional streaming, which works well through proxies. r.Get("/process/{process_id}/attach", func(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "process_id") - apiService.HandleProcessAttachWS(w, r, id) + apiService.HandleProcessAttachWS(w, r, id, wsRegistry) }) // Serve extension files for Chrome policy-installed extensions @@ -214,7 +219,7 @@ func main() { rDevtools.Get("/json/list", jsonTargetHandler) rDevtools.Get("/json/list/", jsonTargetHandler) rDevtools.Get("/*", func(w http.ResponseWriter, r *http.Request) { - devtoolsproxy.WebSocketProxyHandler(upstreamMgr, slogger, config.LogCDPMessages, stz, telemetrySession.Publish).ServeHTTP(w, r) + devtoolsproxy.WebSocketProxyHandler(upstreamMgr, slogger, config.LogCDPMessages, stz, telemetrySession.Publish, wsRegistry).ServeHTTP(w, r) }) srvDevtools := &http.Server{ @@ -240,6 +245,7 @@ func main() { rChromeDriver.Handle("/*", chromedriverproxy.Handler(slogger, &chromedriverproxy.Options{ ChromeDriverUpstream: config.ChromeDriverUpstreamAddr, DevToolsProxyAddr: config.DevToolsProxyAddr, + Registry: wsRegistry, })) srvChromeDriver := &http.Server{ @@ -285,6 +291,12 @@ func main() { g.Go(func() error { return apiService.Shutdown(shutdownCtx) }) + g.Go(func() error { + if n := wsRegistry.CloseAll(websocket.StatusGoingAway, "browser shutting down"); n > 0 { + slogger.Info("closed active websocket connections for shutdown", "count", n) + } + return nil + }) g.Go(func() error { upstreamMgr.Stop() return srvDevtools.Shutdown(shutdownCtx) diff --git a/server/lib/chromedriverproxy/proxy.go b/server/lib/chromedriverproxy/proxy.go index 0b83a5c7..612f34b6 100644 --- a/server/lib/chromedriverproxy/proxy.go +++ b/server/lib/chromedriverproxy/proxy.go @@ -12,6 +12,7 @@ import ( "strings" "github.com/coder/websocket" + "github.com/kernel/kernel-images/server/lib/wsdrain" "github.com/kernel/kernel-images/server/lib/wsproxy" ) @@ -25,6 +26,9 @@ const ( type Options struct { ChromeDriverUpstream string DevToolsProxyAddr string + // Registry, when set, tracks accepted WebDriver/BiDi connections so they + // are closed with a Going Away frame on server shutdown. + Registry *wsdrain.Registry } func resolveOptions(opts *Options) Options { @@ -41,6 +45,7 @@ func resolveOptions(opts *Options) Options { if opts.DevToolsProxyAddr != "" { resolved.DevToolsProxyAddr = opts.DevToolsProxyAddr } + resolved.Registry = opts.Registry return resolved } @@ -313,6 +318,7 @@ func proxyWebSocket(w http.ResponseWriter, r *http.Request, logger *slog.Logger, DialOptions: dialOpts, Logger: logger, Transform: transform, + Registry: cfg.Registry, }) } diff --git a/server/lib/devtoolsproxy/proxy.go b/server/lib/devtoolsproxy/proxy.go index aded45a8..df8b47eb 100644 --- a/server/lib/devtoolsproxy/proxy.go +++ b/server/lib/devtoolsproxy/proxy.go @@ -22,6 +22,7 @@ import ( "github.com/kernel/kernel-images/server/lib/events" oapi "github.com/kernel/kernel-images/server/lib/oapi" "github.com/kernel/kernel-images/server/lib/scaletozero" + "github.com/kernel/kernel-images/server/lib/wsdrain" "github.com/kernel/kernel-images/server/lib/wsproxy" ) @@ -309,7 +310,7 @@ type EventPublisher func(ev events.Event) (events.Envelope, bool) // If logCDPMessages is true, all CDP messages will be logged with their direction. // publish is invoked on accept (cdp_connect) and on teardown (cdp_disconnect); pass // nil to disable emission. -func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMessages bool, ctrl scaletozero.Controller, publish EventPublisher) http.Handler { +func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMessages bool, ctrl scaletozero.Controller, publish EventPublisher, reg *wsdrain.Registry) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Counts every relayed message so cdp_disconnect can report message_count. var msgCount atomic.Int64 @@ -349,6 +350,9 @@ func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMess } clientConn.SetReadLimit(100 * 1024 * 1024) + untrack := reg.Track(clientConn) + defer untrack() + publishCdpConnect(publish) connectedAt := time.Now() diff --git a/server/lib/devtoolsproxy/proxy_test.go b/server/lib/devtoolsproxy/proxy_test.go index c3de042e..5956a555 100644 --- a/server/lib/devtoolsproxy/proxy_test.go +++ b/server/lib/devtoolsproxy/proxy_test.go @@ -23,6 +23,7 @@ import ( "github.com/kernel/kernel-images/server/lib/events" oapi "github.com/kernel/kernel-images/server/lib/oapi" "github.com/kernel/kernel-images/server/lib/scaletozero" + "github.com/kernel/kernel-images/server/lib/wsdrain" "github.com/kernel/kernel-images/server/lib/wsproxy" ) @@ -132,7 +133,7 @@ func TestWebSocketProxyHandler_ProxiesEcho(t *testing.T) { // seed current upstream to echo server including path/query (bypass tailing) mgr.setCurrent((&url.URL{Scheme: u.Scheme, Host: u.Host, Path: u.Path, RawQuery: u.RawQuery}).String()) - proxy := WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), nil) + proxy := WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), nil, nil) proxySrv := httptest.NewServer(proxy) defer proxySrv.Close() @@ -164,6 +165,64 @@ func TestWebSocketProxyHandler_ProxiesEcho(t *testing.T) { } } +func TestWebSocketProxyHandler_RegistryClosesClientWithGoingAway(t *testing.T) { + echoSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{OriginPatterns: []string{"*"}}) + if err != nil { + return + } + defer c.Close(websocket.StatusNormalClosure, "") + ctx := r.Context() + for { + mt, msg, err := c.Read(ctx) + if err != nil { + return + } + if err := c.Write(ctx, mt, msg); err != nil { + return + } + } + })) + defer echoSrv.Close() + + u, _ := url.Parse(echoSrv.URL) + logger := silentLogger() + mgr := NewUpstreamManager("/dev/null", logger) + mgr.setCurrent((&url.URL{Scheme: "ws", Host: u.Host, Path: "/echo"}).String()) + + reg := wsdrain.New() + proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), nil, reg)) + defer proxySrv.Close() + + pu, _ := url.Parse(proxySrv.URL) + pu.Scheme = "ws" + ctx := context.Background() + conn, _, err := websocket.Dial(ctx, pu.String(), nil) + if err != nil { + t.Fatalf("dial proxy failed: %v", err) + } + defer conn.Close(websocket.StatusInternalError, "") + + // Round-trip so the proxy session is fully established and registered. + if err := conn.Write(ctx, websocket.MessageText, []byte("ping")); err != nil { + t.Fatalf("write failed: %v", err) + } + if _, _, err := conn.Read(ctx); err != nil { + t.Fatalf("read failed: %v", err) + } + + if n := reg.CloseAll(websocket.StatusGoingAway, "shutting down"); n != 1 { + t.Fatalf("CloseAll closed %d conns, want 1", n) + } + + // The client should observe a 1001 Going Away, not the 1000 the proxy's + // own cleanup would otherwise send. + _, _, err = conn.Read(ctx) + if got := websocket.CloseStatus(err); got != websocket.StatusGoingAway { + t.Fatalf("client close status = %v (err %v), want StatusGoingAway", got, err) + } +} + func TestDialUpstreamWithRetry_RechecksCurrentAfterMissedUpdate(t *testing.T) { // Start a working websocket upstream. upstreamSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -461,7 +520,7 @@ func TestWebSocketProxyHandler_EmitsConnectAndDisconnect(t *testing.T) { mgr.setCurrent(u.String()) rp := &recordingPublisher{} - proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish)) + proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish, nil)) defer proxySrv.Close() pu, _ := url.Parse(proxySrv.URL) @@ -636,7 +695,7 @@ func TestWebSocketProxyHandler_EmitsUpstreamChangedOnMidStreamRestart(t *testing mgr.setCurrent(urlA.String()) rp := &recordingPublisher{} - proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish)) + proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish, nil)) defer proxySrv.Close() pu, _ := url.Parse(proxySrv.URL) @@ -720,7 +779,7 @@ func TestWebSocketProxyHandler_KicksClientOffStaleUpstreamOnURLChange(t *testing mgr.setCurrent(urlA.String()) rp := &recordingPublisher{} - proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish)) + proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish, nil)) defer proxySrv.Close() pu, _ := url.Parse(proxySrv.URL) @@ -772,7 +831,7 @@ func TestWebSocketProxyHandler_EmitsUpstreamErrorOnDialFailure(t *testing.T) { mgr.setCurrent(deadURL) rp := &recordingPublisher{} - proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish)) + proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish, nil)) defer proxySrv.Close() pu, _ := url.Parse(proxySrv.URL) diff --git a/server/lib/wsdrain/wsdrain.go b/server/lib/wsdrain/wsdrain.go new file mode 100644 index 00000000..0df8a093 --- /dev/null +++ b/server/lib/wsdrain/wsdrain.go @@ -0,0 +1,68 @@ +// Package wsdrain tracks live WebSocket connections so they can be closed with +// a single status code when the server shuts down. +package wsdrain + +import ( + "sync" + + "github.com/coder/websocket" +) + +// Conn is the subset of *websocket.Conn the registry needs. +type Conn interface { + Close(code websocket.StatusCode, reason string) error +} + +// Registry tracks active connections. Construct one with New. All methods are +// safe for concurrent use and tolerate a nil receiver, so callers may pass a +// nil *Registry to disable tracking. +type Registry struct { + mu sync.Mutex + conns map[Conn]struct{} +} + +func New() *Registry { + return &Registry{conns: make(map[Conn]struct{})} +} + +// Track registers conn and returns a function that removes it. The returned +// function is idempotent; call it (e.g. via defer) when the connection ends. +func (r *Registry) Track(conn Conn) func() { + if r == nil || conn == nil { + return func() {} + } + r.mu.Lock() + r.conns[conn] = struct{}{} + r.mu.Unlock() + + var once sync.Once + return func() { + once.Do(func() { + r.mu.Lock() + delete(r.conns, conn) + r.mu.Unlock() + }) + } +} + +// CloseAll closes every tracked connection with the given code and reason and +// returns how many it closed. Connections are snapshotted under the lock and +// closed outside it. Close errors are ignored: the connection is being +// discarded regardless, and the first Close wins, so a later normal-closure +// from the connection's own teardown does not override the code sent here. +func (r *Registry) CloseAll(code websocket.StatusCode, reason string) int { + if r == nil { + return 0 + } + r.mu.Lock() + conns := make([]Conn, 0, len(r.conns)) + for c := range r.conns { + conns = append(conns, c) + } + r.mu.Unlock() + + for _, c := range conns { + _ = c.Close(code, reason) + } + return len(conns) +} diff --git a/server/lib/wsdrain/wsdrain_test.go b/server/lib/wsdrain/wsdrain_test.go new file mode 100644 index 00000000..d48b4bdb --- /dev/null +++ b/server/lib/wsdrain/wsdrain_test.go @@ -0,0 +1,56 @@ +package wsdrain + +import ( + "testing" + + "github.com/coder/websocket" +) + +type fakeConn struct { + closes []websocket.StatusCode +} + +func (c *fakeConn) Close(code websocket.StatusCode, _ string) error { + c.closes = append(c.closes, code) + return nil +} + +func TestCloseAllClosesTrackedConns(t *testing.T) { + r := New() + a, b := &fakeConn{}, &fakeConn{} + r.Track(a) + r.Track(b) + + if n := r.CloseAll(websocket.StatusGoingAway, "bye"); n != 2 { + t.Fatalf("CloseAll returned %d, want 2", n) + } + for _, c := range []*fakeConn{a, b} { + if len(c.closes) != 1 || c.closes[0] != websocket.StatusGoingAway { + t.Fatalf("conn closed with %v, want one StatusGoingAway", c.closes) + } + } +} + +func TestUntrackRemovesConn(t *testing.T) { + r := New() + c := &fakeConn{} + untrack := r.Track(c) + untrack() + untrack() // idempotent + + if n := r.CloseAll(websocket.StatusGoingAway, "bye"); n != 0 { + t.Fatalf("CloseAll returned %d after untrack, want 0", n) + } + if len(c.closes) != 0 { + t.Fatalf("untracked conn was closed: %v", c.closes) + } +} + +func TestNilRegistryIsNoop(t *testing.T) { + var r *Registry + untrack := r.Track(&fakeConn{}) + untrack() + if n := r.CloseAll(websocket.StatusGoingAway, "bye"); n != 0 { + t.Fatalf("nil registry CloseAll returned %d, want 0", n) + } +} diff --git a/server/lib/wsproxy/wsproxy.go b/server/lib/wsproxy/wsproxy.go index 8a14e490..8a7e5940 100644 --- a/server/lib/wsproxy/wsproxy.go +++ b/server/lib/wsproxy/wsproxy.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/coder/websocket" + "github.com/kernel/kernel-images/server/lib/wsdrain" ) // Conn abstracts a WebSocket connection for testing and flexibility. @@ -28,6 +29,9 @@ type ProxyOptions struct { DialOptions *websocket.DialOptions Logger *slog.Logger Transform MessageTransform + // Registry, when set, tracks the accepted client connection so it is + // closed with a Going Away frame on server shutdown. + Registry *wsdrain.Registry } // PumpExitCause names which side caused Pump to return. Callers use this to @@ -121,6 +125,9 @@ func Proxy(w http.ResponseWriter, r *http.Request, upstreamURL string, opts Prox } clientConn.SetReadLimit(100 * 1024 * 1024) + untrack := opts.Registry.Track(clientConn) + defer untrack() + upstreamConn, _, err := websocket.Dial(r.Context(), upstreamURL, opts.DialOptions) if err != nil { logger.Error("dial upstream failed", slog.String("err", err.Error()), slog.String("url", upstreamURL))