Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions docs/middleware.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ sequenceDiagram
- Extract the token session ID (`tsid`) claim from the ToolHive JWT
- Look up the stored upstream IdP tokens associated with that session
- Inject the upstream access token into the request (replacing Authorization header or using a custom header)
- Gracefully proceed without modification if tokens are unavailable or expired
- Return 401 Unauthorized with WWW-Authenticate header when tokens are expired or not found
- Gracefully proceed without modification if identity, session ID, or storage is unavailable

**Configuration**:

Expand All @@ -147,9 +148,11 @@ sequenceDiagram

**Behavior**:
- **Automatic activation**: Enabled whenever the embedded auth server is configured, even without explicit `UpstreamSwapConfig`
- **Expired tokens**: Logs a warning but continues with the expired token (backend will reject if necessary)
- **Missing tokens**: Proceeds without modification (logs debug message)
- **Expired tokens**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required
- **Tokens not found**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required
- **Missing identity/tsid**: Proceeds without modification (logs debug message)
- **Storage unavailable**: Proceeds without modification (logs warning)
- **Other storage errors**: Returns 503 Service Unavailable to fail closed (logs warning)

**Context Data Used**:
- Identity from Authentication middleware (specifically the `tsid` claim)
Expand Down
29 changes: 26 additions & 3 deletions pkg/auth/upstreamswap/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package upstreamswap

import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
Expand Down Expand Up @@ -113,6 +114,15 @@ func validateConfig(cfg *Config) error {
return nil
}

// writeUpstreamAuthRequired writes a 401 response with a WWW-Authenticate Bearer
// challenge per RFC 6750 Section 3.1, signalling that the caller must re-authenticate
// with the upstream IdP.
func writeUpstreamAuthRequired(w http.ResponseWriter) {
w.Header().Set("WWW-Authenticate",
`Bearer error="invalid_token", error_description="upstream token is no longer valid; re-authentication required"`)
http.Error(w, "upstream authentication required", http.StatusUnauthorized)
}

// injectionFunc is a function that injects a token into an HTTP request.
type injectionFunc func(*http.Request, string)

Expand Down Expand Up @@ -183,15 +193,28 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle
if err != nil {
slog.Warn("Failed to get upstream tokens",
"middleware", "upstreamswap", "error", err)
next.ServeHTTP(w, r)
// Token is expired, was not found, or failed binding validation
// (e.g., subject/client mismatch). All three are client-attributable
// errors that require the caller to re-authenticate with the upstream IdP.
if errors.Is(err, storage.ErrExpired) ||
errors.Is(err, storage.ErrNotFound) ||
errors.Is(err, storage.ErrInvalidBinding) {
writeUpstreamAuthRequired(w)
return
}
// Other storage errors: fail closed to avoid bypassing the token swap
http.Error(w, "authentication service temporarily unavailable", http.StatusServiceUnavailable)
return
}

// 5. Check if expired (MVP: just log warning, continue with token)
// 5. Check if expired
// Defense in depth: some storage implementations may return tokens
// without checking expiry (the interface does not require it).
if tokens.IsExpired(time.Now()) {
slog.Warn("Upstream tokens expired",
"middleware", "upstreamswap")
// Continue with expired token - backend will reject if needed
writeUpstreamAuthRequired(w)
return
}

// 6. Inject access token
Expand Down
93 changes: 56 additions & 37 deletions pkg/auth/upstreamswap/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,46 +193,63 @@ func TestMiddleware_StorageUnavailable(t *testing.T) {
assert.True(t, nextCalled, "next handler should be called")
}

func TestMiddleware_TokensNotFound(t *testing.T) {
func TestMiddleware_ClientAttributableStorageErrors_Returns401(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
defer ctrl.Finish()
tests := []struct {
name string
err error
}{
{"not found", storage.ErrNotFound},
{"expired", storage.ErrExpired},
{"invalid binding", storage.ErrInvalidBinding},
}

mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl)
mockStorage.EXPECT().
GetUpstreamTokens(gomock.Any(), "session-123").
Return(nil, storage.ErrNotFound)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

storageGetter := func() storage.UpstreamTokenStorage {
return mockStorage
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()

cfg := &Config{}
middleware := createMiddlewareFunc(cfg, storageGetter)
mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl)
mockStorage.EXPECT().
GetUpstreamTokens(gomock.Any(), "session-123").
Return(nil, tt.err)

var nextCalled bool
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
nextCalled = true
})
storageGetter := func() storage.UpstreamTokenStorage {
return mockStorage
}

handler := middleware(nextHandler)
cfg := &Config{}
middleware := createMiddlewareFunc(cfg, storageGetter)

req := httptest.NewRequest(http.MethodGet, "/test", nil)
identity := &auth.Identity{
Subject: "user123",
Claims: map[string]any{
"sub": "user123",
session.TokenSessionIDClaimKey: "session-123",
},
}
ctx := auth.WithIdentity(req.Context(), identity)
req = req.WithContext(ctx)
var nextCalled bool
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
nextCalled = true
})

rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
handler := middleware(nextHandler)

assert.True(t, nextCalled, "next handler should be called")
req := httptest.NewRequest(http.MethodGet, "/test", nil)
identity := &auth.Identity{
Subject: "user123",
Claims: map[string]any{
"sub": "user123",
session.TokenSessionIDClaimKey: "session-123",
},
}
ctx := auth.WithIdentity(req.Context(), identity)
req = req.WithContext(ctx)

rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

assert.False(t, nextCalled, "next handler should NOT be called")
assert.Equal(t, http.StatusUnauthorized, rr.Code)
assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`)
})
}
}

func TestMiddleware_StorageError(t *testing.T) {
Expand Down Expand Up @@ -274,7 +291,8 @@ func TestMiddleware_StorageError(t *testing.T) {
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

assert.True(t, nextCalled, "next handler should be called despite error")
assert.False(t, nextCalled, "next handler should NOT be called on storage error")
assert.Equal(t, http.StatusServiceUnavailable, rr.Code)
}

func TestMiddleware_SuccessfulSwap_AccessToken(t *testing.T) {
Expand Down Expand Up @@ -382,7 +400,7 @@ func TestMiddleware_CustomHeader(t *testing.T) {
assert.Equal(t, "Bearer original-token", capturedAuthHeader)
}

func TestMiddleware_ExpiredTokens_ContinuesWithWarning(t *testing.T) {
func TestMiddleware_ExpiredTokens_Returns401(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
Expand All @@ -406,9 +424,9 @@ func TestMiddleware_ExpiredTokens_ContinuesWithWarning(t *testing.T) {
cfg := &Config{}
middleware := createMiddlewareFunc(cfg, storageGetter)

var capturedAuthHeader string
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
capturedAuthHeader = r.Header.Get("Authorization")
var nextCalled bool
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
nextCalled = true
})

handler := middleware(nextHandler)
Expand All @@ -427,8 +445,9 @@ func TestMiddleware_ExpiredTokens_ContinuesWithWarning(t *testing.T) {
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

// MVP: Should continue with expired token
assert.Equal(t, "Bearer expired-upstream-token", capturedAuthHeader)
assert.False(t, nextCalled, "next handler should NOT be called for expired tokens")
assert.Equal(t, http.StatusUnauthorized, rr.Code)
assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`)
}

func TestMiddleware_EmptySelectedToken(t *testing.T) {
Expand Down
Loading