Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion server/cmd/api/api/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/kernel/kernel-images/server/lib/logger"
oapi "github.com/kernel/kernel-images/server/lib/oapi"
"github.com/kernel/kernel-images/server/lib/ptyio"
"github.com/kernel/kernel-images/server/lib/wsdrain"
openapi_types "github.com/oapi-codegen/runtime/types"
)

Expand Down Expand Up @@ -659,7 +660,7 @@ func writeJSON(w http.ResponseWriter, status int, body string) {
// - Server sends TextMessage with JSON for events (e.g., exit code)
//
// This endpoint is intentionally not defined in OpenAPI.
func (s *ApiService) HandleProcessAttachWS(w http.ResponseWriter, r *http.Request, id string) {
func (s *ApiService) HandleProcessAttachWS(w http.ResponseWriter, r *http.Request, id string, reg *wsdrain.Registry) {
ctx := r.Context()
log := logger.FromContext(ctx)

Expand Down Expand Up @@ -702,6 +703,9 @@ func (s *ApiService) HandleProcessAttachWS(w http.ResponseWriter, r *http.Reques
}
defer wsConn.CloseNow()

untrack := reg.Track(wsConn)
defer untrack()

// Set a generous read limit for PTY data
wsConn.SetReadLimit(1024 * 1024) // 1MB

Expand Down
16 changes: 14 additions & 2 deletions server/cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"syscall"
"time"

"github.com/coder/websocket"
"github.com/ghodss/yaml"
"github.com/go-chi/chi/v5"
chiMiddleware "github.com/go-chi/chi/v5/middleware"
Expand All @@ -32,6 +33,7 @@ import (
"github.com/kernel/kernel-images/server/lib/scaletozero"
"github.com/kernel/kernel-images/server/lib/sysmon"
"github.com/kernel/kernel-images/server/lib/telemetry"
"github.com/kernel/kernel-images/server/lib/wsdrain"
)

func main() {
Expand Down Expand Up @@ -80,6 +82,9 @@ func main() {
os.Exit(1)
}

// ws conn tracker
wsRegistry := wsdrain.New()

// DevTools WebSocket upstream manager: tail Chromium supervisord log
const chromiumLogPath = "/var/log/supervisord/chromium"
upstreamMgr := devtoolsproxy.NewUpstreamManager(chromiumLogPath, slogger)
Expand Down Expand Up @@ -166,7 +171,7 @@ func main() {
// Uses WebSocket for bidirectional streaming, which works well through proxies.
r.Get("/process/{process_id}/attach", func(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "process_id")
apiService.HandleProcessAttachWS(w, r, id)
apiService.HandleProcessAttachWS(w, r, id, wsRegistry)
})

// Serve extension files for Chrome policy-installed extensions
Expand Down Expand Up @@ -214,7 +219,7 @@ func main() {
rDevtools.Get("/json/list", jsonTargetHandler)
rDevtools.Get("/json/list/", jsonTargetHandler)
rDevtools.Get("/*", func(w http.ResponseWriter, r *http.Request) {
devtoolsproxy.WebSocketProxyHandler(upstreamMgr, slogger, config.LogCDPMessages, stz, telemetrySession.Publish).ServeHTTP(w, r)
devtoolsproxy.WebSocketProxyHandler(upstreamMgr, slogger, config.LogCDPMessages, stz, telemetrySession.Publish, wsRegistry).ServeHTTP(w, r)
})

srvDevtools := &http.Server{
Expand All @@ -240,6 +245,7 @@ func main() {
rChromeDriver.Handle("/*", chromedriverproxy.Handler(slogger, &chromedriverproxy.Options{
ChromeDriverUpstream: config.ChromeDriverUpstreamAddr,
DevToolsProxyAddr: config.DevToolsProxyAddr,
Registry: wsRegistry,
}))

srvChromeDriver := &http.Server{
Expand Down Expand Up @@ -285,6 +291,12 @@ func main() {
g.Go(func() error {
return apiService.Shutdown(shutdownCtx)
})
g.Go(func() error {
if n := wsRegistry.CloseAll(websocket.StatusGoingAway, "browser shutting down"); n > 0 {
slogger.Info("closed active websocket connections for shutdown", "count", n)
}
return nil
})
g.Go(func() error {
upstreamMgr.Stop()
return srvDevtools.Shutdown(shutdownCtx)
Expand Down
6 changes: 6 additions & 0 deletions server/lib/chromedriverproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"strings"

"github.com/coder/websocket"
"github.com/kernel/kernel-images/server/lib/wsdrain"
"github.com/kernel/kernel-images/server/lib/wsproxy"
)

Expand All @@ -25,6 +26,9 @@ const (
type Options struct {
ChromeDriverUpstream string
DevToolsProxyAddr string
// Registry, when set, tracks accepted WebDriver/BiDi connections so they
// are closed with a Going Away frame on server shutdown.
Registry *wsdrain.Registry
}

func resolveOptions(opts *Options) Options {
Expand All @@ -41,6 +45,7 @@ func resolveOptions(opts *Options) Options {
if opts.DevToolsProxyAddr != "" {
resolved.DevToolsProxyAddr = opts.DevToolsProxyAddr
}
resolved.Registry = opts.Registry
return resolved
}

Expand Down Expand Up @@ -313,6 +318,7 @@ func proxyWebSocket(w http.ResponseWriter, r *http.Request, logger *slog.Logger,
DialOptions: dialOpts,
Logger: logger,
Transform: transform,
Registry: cfg.Registry,
})
}

Expand Down
6 changes: 5 additions & 1 deletion server/lib/devtoolsproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/kernel/kernel-images/server/lib/events"
oapi "github.com/kernel/kernel-images/server/lib/oapi"
"github.com/kernel/kernel-images/server/lib/scaletozero"
"github.com/kernel/kernel-images/server/lib/wsdrain"
"github.com/kernel/kernel-images/server/lib/wsproxy"
)

Expand Down Expand Up @@ -309,7 +310,7 @@ type EventPublisher func(ev events.Event) (events.Envelope, bool)
// If logCDPMessages is true, all CDP messages will be logged with their direction.
// publish is invoked on accept (cdp_connect) and on teardown (cdp_disconnect); pass
// nil to disable emission.
func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMessages bool, ctrl scaletozero.Controller, publish EventPublisher) http.Handler {
func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMessages bool, ctrl scaletozero.Controller, publish EventPublisher, reg *wsdrain.Registry) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Counts every relayed message so cdp_disconnect can report message_count.
var msgCount atomic.Int64
Expand Down Expand Up @@ -349,6 +350,9 @@ func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMess
}
clientConn.SetReadLimit(100 * 1024 * 1024)

untrack := reg.Track(clientConn)
defer untrack()

publishCdpConnect(publish)
connectedAt := time.Now()

Expand Down
69 changes: 64 additions & 5 deletions server/lib/devtoolsproxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/kernel/kernel-images/server/lib/events"
oapi "github.com/kernel/kernel-images/server/lib/oapi"
"github.com/kernel/kernel-images/server/lib/scaletozero"
"github.com/kernel/kernel-images/server/lib/wsdrain"
"github.com/kernel/kernel-images/server/lib/wsproxy"
)

Expand Down Expand Up @@ -132,7 +133,7 @@ func TestWebSocketProxyHandler_ProxiesEcho(t *testing.T) {
// seed current upstream to echo server including path/query (bypass tailing)
mgr.setCurrent((&url.URL{Scheme: u.Scheme, Host: u.Host, Path: u.Path, RawQuery: u.RawQuery}).String())

proxy := WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), nil)
proxy := WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), nil, nil)
proxySrv := httptest.NewServer(proxy)
defer proxySrv.Close()

Expand Down Expand Up @@ -164,6 +165,64 @@ func TestWebSocketProxyHandler_ProxiesEcho(t *testing.T) {
}
}

func TestWebSocketProxyHandler_RegistryClosesClientWithGoingAway(t *testing.T) {
echoSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{OriginPatterns: []string{"*"}})
if err != nil {
return
}
defer c.Close(websocket.StatusNormalClosure, "")
ctx := r.Context()
for {
mt, msg, err := c.Read(ctx)
if err != nil {
return
}
if err := c.Write(ctx, mt, msg); err != nil {
return
}
}
}))
defer echoSrv.Close()

u, _ := url.Parse(echoSrv.URL)
logger := silentLogger()
mgr := NewUpstreamManager("/dev/null", logger)
mgr.setCurrent((&url.URL{Scheme: "ws", Host: u.Host, Path: "/echo"}).String())

reg := wsdrain.New()
proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), nil, reg))
defer proxySrv.Close()

pu, _ := url.Parse(proxySrv.URL)
pu.Scheme = "ws"
ctx := context.Background()
conn, _, err := websocket.Dial(ctx, pu.String(), nil)
if err != nil {
t.Fatalf("dial proxy failed: %v", err)
}
defer conn.Close(websocket.StatusInternalError, "")

// Round-trip so the proxy session is fully established and registered.
if err := conn.Write(ctx, websocket.MessageText, []byte("ping")); err != nil {
t.Fatalf("write failed: %v", err)
}
if _, _, err := conn.Read(ctx); err != nil {
t.Fatalf("read failed: %v", err)
}

if n := reg.CloseAll(websocket.StatusGoingAway, "shutting down"); n != 1 {
t.Fatalf("CloseAll closed %d conns, want 1", n)
}

// The client should observe a 1001 Going Away, not the 1000 the proxy's
// own cleanup would otherwise send.
_, _, err = conn.Read(ctx)
if got := websocket.CloseStatus(err); got != websocket.StatusGoingAway {
t.Fatalf("client close status = %v (err %v), want StatusGoingAway", got, err)
}
}

func TestDialUpstreamWithRetry_RechecksCurrentAfterMissedUpdate(t *testing.T) {
// Start a working websocket upstream.
upstreamSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -461,7 +520,7 @@ func TestWebSocketProxyHandler_EmitsConnectAndDisconnect(t *testing.T) {
mgr.setCurrent(u.String())

rp := &recordingPublisher{}
proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish))
proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish, nil))
defer proxySrv.Close()

pu, _ := url.Parse(proxySrv.URL)
Expand Down Expand Up @@ -636,7 +695,7 @@ func TestWebSocketProxyHandler_EmitsUpstreamChangedOnMidStreamRestart(t *testing
mgr.setCurrent(urlA.String())

rp := &recordingPublisher{}
proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish))
proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish, nil))
defer proxySrv.Close()

pu, _ := url.Parse(proxySrv.URL)
Expand Down Expand Up @@ -720,7 +779,7 @@ func TestWebSocketProxyHandler_KicksClientOffStaleUpstreamOnURLChange(t *testing
mgr.setCurrent(urlA.String())

rp := &recordingPublisher{}
proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish))
proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish, nil))
defer proxySrv.Close()

pu, _ := url.Parse(proxySrv.URL)
Expand Down Expand Up @@ -772,7 +831,7 @@ func TestWebSocketProxyHandler_EmitsUpstreamErrorOnDialFailure(t *testing.T) {
mgr.setCurrent(deadURL)

rp := &recordingPublisher{}
proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish))
proxySrv := httptest.NewServer(WebSocketProxyHandler(mgr, logger, false, scaletozero.NewNoopController(), rp.publish, nil))
defer proxySrv.Close()

pu, _ := url.Parse(proxySrv.URL)
Expand Down
68 changes: 68 additions & 0 deletions server/lib/wsdrain/wsdrain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Package wsdrain tracks live WebSocket connections so they can be closed with
// a single status code when the server shuts down.
package wsdrain

import (
"sync"

"github.com/coder/websocket"
)

// Conn is the subset of *websocket.Conn the registry needs.
type Conn interface {
Close(code websocket.StatusCode, reason string) error
}

// Registry tracks active connections. Construct one with New. All methods are
// safe for concurrent use and tolerate a nil receiver, so callers may pass a
// nil *Registry to disable tracking.
type Registry struct {
mu sync.Mutex
conns map[Conn]struct{}
}

func New() *Registry {
return &Registry{conns: make(map[Conn]struct{})}
}

// Track registers conn and returns a function that removes it. The returned
// function is idempotent; call it (e.g. via defer) when the connection ends.
func (r *Registry) Track(conn Conn) func() {
if r == nil || conn == nil {
return func() {}
}
r.mu.Lock()
r.conns[conn] = struct{}{}
r.mu.Unlock()

var once sync.Once
return func() {
once.Do(func() {
r.mu.Lock()
delete(r.conns, conn)
r.mu.Unlock()
})
}
}

// CloseAll closes every tracked connection with the given code and reason and
// returns how many it closed. Connections are snapshotted under the lock and
// closed outside it. Close errors are ignored: the connection is being
// discarded regardless, and the first Close wins, so a later normal-closure
// from the connection's own teardown does not override the code sent here.
func (r *Registry) CloseAll(code websocket.StatusCode, reason string) int {
if r == nil {
return 0
}
r.mu.Lock()
conns := make([]Conn, 0, len(r.conns))
for c := range r.conns {
conns = append(conns, c)
}
r.mu.Unlock()

for _, c := range conns {
_ = c.Close(code, reason)
}
return len(conns)
}
56 changes: 56 additions & 0 deletions server/lib/wsdrain/wsdrain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package wsdrain

import (
"testing"

"github.com/coder/websocket"
)

type fakeConn struct {
closes []websocket.StatusCode
}

func (c *fakeConn) Close(code websocket.StatusCode, _ string) error {
c.closes = append(c.closes, code)
return nil
}

func TestCloseAllClosesTrackedConns(t *testing.T) {
r := New()
a, b := &fakeConn{}, &fakeConn{}
r.Track(a)
r.Track(b)

if n := r.CloseAll(websocket.StatusGoingAway, "bye"); n != 2 {
t.Fatalf("CloseAll returned %d, want 2", n)
}
for _, c := range []*fakeConn{a, b} {
if len(c.closes) != 1 || c.closes[0] != websocket.StatusGoingAway {
t.Fatalf("conn closed with %v, want one StatusGoingAway", c.closes)
}
}
}

func TestUntrackRemovesConn(t *testing.T) {
r := New()
c := &fakeConn{}
untrack := r.Track(c)
untrack()
untrack() // idempotent

if n := r.CloseAll(websocket.StatusGoingAway, "bye"); n != 0 {
t.Fatalf("CloseAll returned %d after untrack, want 0", n)
}
if len(c.closes) != 0 {
t.Fatalf("untracked conn was closed: %v", c.closes)
}
}

func TestNilRegistryIsNoop(t *testing.T) {
var r *Registry
untrack := r.Track(&fakeConn{})
untrack()
if n := r.CloseAll(websocket.StatusGoingAway, "bye"); n != 0 {
t.Fatalf("nil registry CloseAll returned %d, want 0", n)
}
}
Loading
Loading