diff --git a/pkg/http/handler/hijack.go b/pkg/http/handler/hijack.go index 2559302639f7..c7e6d95f45e0 100644 --- a/pkg/http/handler/hijack.go +++ b/pkg/http/handler/hijack.go @@ -17,11 +17,15 @@ limitations under the License. package handler import ( + "bufio" "cmp" "context" + "net" "net/http" "sync/atomic" "time" + + "knative.dev/pkg/websocket" ) // HijackTracker is used to track Websocket Connections @@ -39,6 +43,17 @@ type HijackTracker struct { inflight atomic.Int64 } +type hijackTrackerResponseWriter struct { + http.ResponseWriter + tracker *HijackTracker +} + +type trackedConn struct { + net.Conn + tracker *HijackTracker + closed atomic.Bool +} + // Drain should be called after http.Server:Shutdown returns func (s *HijackTracker) Drain(ctx context.Context) error { pollInterval := cmp.Or(s.PollInterval, time.Second) @@ -62,5 +77,37 @@ func (s *HijackTracker) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.inflight.Add(1) defer s.inflight.Add(-1) - s.Handler.ServeHTTP(w, r) + s.Handler.ServeHTTP(&hijackTrackerResponseWriter{ + ResponseWriter: w, + tracker: s, + }, r) +} + +func (w *hijackTrackerResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +func (w *hijackTrackerResponseWriter) Flush() { + w.ResponseWriter.(http.Flusher).Flush() +} + +func (w *hijackTrackerResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + conn, rw, err := websocket.HijackIfPossible(w.ResponseWriter) + if err != nil { + return nil, nil, err + } + + w.tracker.inflight.Add(1) + return &trackedConn{ + Conn: conn, + tracker: w.tracker, + }, rw, nil +} + +func (c *trackedConn) Close() error { + err := c.Conn.Close() + if c.closed.CompareAndSwap(false, true) { + c.tracker.inflight.Add(-1) + } + return err } diff --git a/pkg/http/handler/hijack_test.go b/pkg/http/handler/hijack_test.go index 18bf84dd4568..00e366155f43 100644 --- a/pkg/http/handler/hijack_test.go +++ b/pkg/http/handler/hijack_test.go @@ -17,14 +17,36 @@ limitations under the License. package handler import ( + "bufio" "context" "errors" + "net" "net/http" "net/http/httptest" "testing" "time" ) +type hijackableResponseWriter struct { + *httptest.ResponseRecorder + conn net.Conn +} + +func (w *hijackableResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + client, server := net.Pipe() + w.conn = client + return server, bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server)), nil +} + +type flushableResponseWriter struct { + *hijackableResponseWriter + flushed bool +} + +func (w *flushableResponseWriter) Flush() { + w.flushed = true +} + func TestHijackTrackerNoHijack(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "http://somehost.com", nil) @@ -43,20 +65,32 @@ func TestHijackTrackerNoHijack(t *testing.T) { } func TestHijackTrackerConnectionHijacked(t *testing.T) { - w := httptest.NewRecorder() + w := &hijackableResponseWriter{ResponseRecorder: httptest.NewRecorder()} r := httptest.NewRequest(http.MethodGet, "http://somehost.com", nil) inHandler := make(chan struct{}) - handlerWait := make(chan struct{}) + connClosed := make(chan struct{}) drainResult := make(chan error, 1) h := &HijackTracker{ PollInterval: 10 * time.Millisecond, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { close(inHandler) - <-handlerWait + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Fatalf("Hijack() = %v", err) + } + go func() { + <-connClosed + conn.Close() + }() }), } + defer func() { + if w.conn != nil { + w.conn.Close() + } + }() go func() { h.ServeHTTP(w, r) @@ -75,15 +109,15 @@ func TestHijackTrackerConnectionHijacked(t *testing.T) { select { case <-time.After(250 * time.Millisecond): case <-drainResult: - t.Fatal("drain returned befoce handler was finished") + t.Fatal("drain returned before hijacked connection was closed") } - close(handlerWait) + close(connClosed) var err error select { case <-time.After(1 * time.Second): - t.Fatal("Drain was not unblocked when the handler returned") + t.Fatal("Drain was not unblocked when the hijacked connection closed") case err = <-drainResult: } @@ -93,11 +127,10 @@ func TestHijackTrackerConnectionHijacked(t *testing.T) { } func TestHijackTrackerConnectionHijackedTimeout(t *testing.T) { - w := httptest.NewRecorder() + w := &hijackableResponseWriter{ResponseRecorder: httptest.NewRecorder()} r := httptest.NewRequest(http.MethodGet, "http://somehost.com", nil) inHandler := make(chan struct{}) - handlerWait := make(chan struct{}) drainStarted := make(chan struct{}) drainResult := make(chan error, 1) @@ -105,9 +138,18 @@ func TestHijackTrackerConnectionHijackedTimeout(t *testing.T) { PollInterval: 10 * time.Millisecond, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { close(inHandler) - <-handlerWait + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Fatalf("Hijack() = %v", err) + } + _ = conn }), } + defer func() { + if w.conn != nil { + w.conn.Close() + } + }() go func() { h.ServeHTTP(w, r) @@ -122,9 +164,6 @@ func TestHijackTrackerConnectionHijackedTimeout(t *testing.T) { }() <-drainStarted - // note: this is deferred to unblock the go-routine - // to clean up the test - defer close(handlerWait) var err error select { @@ -137,3 +176,27 @@ func TestHijackTrackerConnectionHijackedTimeout(t *testing.T) { t.Fatal("unexpected error draining", err) } } + +func TestHijackTrackerResponseWriterPreservesOptionalInterfaces(t *testing.T) { + w := &flushableResponseWriter{ + hijackableResponseWriter: &hijackableResponseWriter{ + ResponseRecorder: httptest.NewRecorder(), + }, + } + + h := &HijackTracker{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.(http.Flusher).Flush() + got := w.(interface{ Unwrap() http.ResponseWriter }).Unwrap() + if got != w.(*hijackTrackerResponseWriter).ResponseWriter { + t.Fatal("Unwrap() did not return the underlying writer") + } + }), + } + + h.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "http://somehost.com", nil)) + + if !w.flushed { + t.Fatal("Flush() was not forwarded to the underlying writer") + } +}