diff --git a/docs/middleware.md b/docs/middleware.md index 03962e9d9e..c3cc4eb1f4 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -136,7 +136,8 @@ sequenceDiagram - Extract the token session ID (`tsid`) claim from the ToolHive JWT - Look up the stored upstream IdP tokens associated with that session - Inject the upstream access token into the request (replacing Authorization header or using a custom header) -- Gracefully proceed without modification if tokens are unavailable or expired +- Return 401 Unauthorized with WWW-Authenticate header when tokens are expired or not found +- Gracefully proceed without modification if identity, session ID, or storage is unavailable **Configuration**: @@ -147,9 +148,11 @@ sequenceDiagram **Behavior**: - **Automatic activation**: Enabled whenever the embedded auth server is configured, even without explicit `UpstreamSwapConfig` -- **Expired tokens**: Logs a warning but continues with the expired token (backend will reject if necessary) -- **Missing tokens**: Proceeds without modification (logs debug message) +- **Expired tokens**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required +- **Tokens not found**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required +- **Missing identity/tsid**: Proceeds without modification (logs debug message) - **Storage unavailable**: Proceeds without modification (logs warning) +- **Other storage errors**: Returns 503 Service Unavailable to fail closed (logs warning) **Context Data Used**: - Identity from Authentication middleware (specifically the `tsid` claim) diff --git a/pkg/auth/upstreamswap/middleware.go b/pkg/auth/upstreamswap/middleware.go index c049325717..5af02e099c 100644 --- a/pkg/auth/upstreamswap/middleware.go +++ b/pkg/auth/upstreamswap/middleware.go @@ -7,6 +7,7 @@ package upstreamswap import ( "encoding/json" + "errors" "fmt" "log/slog" "net/http" @@ -113,6 +114,15 @@ func validateConfig(cfg *Config) error { return nil } +// writeUpstreamAuthRequired writes a 401 response with a WWW-Authenticate Bearer +// challenge per RFC 6750 Section 3.1, signalling that the caller must re-authenticate +// with the upstream IdP. +func writeUpstreamAuthRequired(w http.ResponseWriter) { + w.Header().Set("WWW-Authenticate", + `Bearer error="invalid_token", error_description="upstream token is no longer valid; re-authentication required"`) + http.Error(w, "upstream authentication required", http.StatusUnauthorized) +} + // injectionFunc is a function that injects a token into an HTTP request. type injectionFunc func(*http.Request, string) @@ -183,15 +193,28 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle if err != nil { slog.Warn("Failed to get upstream tokens", "middleware", "upstreamswap", "error", err) - next.ServeHTTP(w, r) + // Token is expired, was not found, or failed binding validation + // (e.g., subject/client mismatch). All three are client-attributable + // errors that require the caller to re-authenticate with the upstream IdP. + if errors.Is(err, storage.ErrExpired) || + errors.Is(err, storage.ErrNotFound) || + errors.Is(err, storage.ErrInvalidBinding) { + writeUpstreamAuthRequired(w) + return + } + // Other storage errors: fail closed to avoid bypassing the token swap + http.Error(w, "authentication service temporarily unavailable", http.StatusServiceUnavailable) return } - // 5. Check if expired (MVP: just log warning, continue with token) + // 5. Check if expired + // Defense in depth: some storage implementations may return tokens + // without checking expiry (the interface does not require it). if tokens.IsExpired(time.Now()) { slog.Warn("Upstream tokens expired", "middleware", "upstreamswap") - // Continue with expired token - backend will reject if needed + writeUpstreamAuthRequired(w) + return } // 6. Inject access token diff --git a/pkg/auth/upstreamswap/middleware_test.go b/pkg/auth/upstreamswap/middleware_test.go index f77029ece8..85f66c31c4 100644 --- a/pkg/auth/upstreamswap/middleware_test.go +++ b/pkg/auth/upstreamswap/middleware_test.go @@ -193,46 +193,63 @@ func TestMiddleware_StorageUnavailable(t *testing.T) { assert.True(t, nextCalled, "next handler should be called") } -func TestMiddleware_TokensNotFound(t *testing.T) { +func TestMiddleware_ClientAttributableStorageErrors_Returns401(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() + tests := []struct { + name string + err error + }{ + {"not found", storage.ErrNotFound}, + {"expired", storage.ErrExpired}, + {"invalid binding", storage.ErrInvalidBinding}, + } - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123"). - Return(nil, storage.ErrNotFound) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() - storageGetter := func() storage.UpstreamTokenStorage { - return mockStorage - } + ctrl := gomock.NewController(t) + defer ctrl.Finish() - cfg := &Config{} - middleware := createMiddlewareFunc(cfg, storageGetter) + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockStorage.EXPECT(). + GetUpstreamTokens(gomock.Any(), "session-123"). + Return(nil, tt.err) - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true - }) + storageGetter := func() storage.UpstreamTokenStorage { + return mockStorage + } - handler := middleware(nextHandler) + cfg := &Config{} + middleware := createMiddlewareFunc(cfg, storageGetter) - req := httptest.NewRequest(http.MethodGet, "/test", nil) - identity := &auth.Identity{ - Subject: "user123", - Claims: map[string]any{ - "sub": "user123", - session.TokenSessionIDClaimKey: "session-123", - }, - } - ctx := auth.WithIdentity(req.Context(), identity) - req = req.WithContext(ctx) + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + handler := middleware(nextHandler) - assert.True(t, nextCalled, "next handler should be called") + req := httptest.NewRequest(http.MethodGet, "/test", nil) + identity := &auth.Identity{ + Subject: "user123", + Claims: map[string]any{ + "sub": "user123", + session.TokenSessionIDClaimKey: "session-123", + }, + } + ctx := auth.WithIdentity(req.Context(), identity) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.False(t, nextCalled, "next handler should NOT be called") + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) + }) + } } func TestMiddleware_StorageError(t *testing.T) { @@ -274,7 +291,8 @@ func TestMiddleware_StorageError(t *testing.T) { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - assert.True(t, nextCalled, "next handler should be called despite error") + assert.False(t, nextCalled, "next handler should NOT be called on storage error") + assert.Equal(t, http.StatusServiceUnavailable, rr.Code) } func TestMiddleware_SuccessfulSwap_AccessToken(t *testing.T) { @@ -382,7 +400,7 @@ func TestMiddleware_CustomHeader(t *testing.T) { assert.Equal(t, "Bearer original-token", capturedAuthHeader) } -func TestMiddleware_ExpiredTokens_ContinuesWithWarning(t *testing.T) { +func TestMiddleware_ExpiredTokens_Returns401(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -406,9 +424,9 @@ func TestMiddleware_ExpiredTokens_ContinuesWithWarning(t *testing.T) { cfg := &Config{} middleware := createMiddlewareFunc(cfg, storageGetter) - var capturedAuthHeader string - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - capturedAuthHeader = r.Header.Get("Authorization") + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true }) handler := middleware(nextHandler) @@ -427,8 +445,9 @@ func TestMiddleware_ExpiredTokens_ContinuesWithWarning(t *testing.T) { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - // MVP: Should continue with expired token - assert.Equal(t, "Bearer expired-upstream-token", capturedAuthHeader) + assert.False(t, nextCalled, "next handler should NOT be called for expired tokens") + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) } func TestMiddleware_EmptySelectedToken(t *testing.T) {