From 5a5ea0af9b840c7795cf601999bb17df165619ca Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Wed, 18 Mar 2026 02:33:49 +0530 Subject: [PATCH 1/8] feat: Add client secret expiry tracking and automatic renew for DCR Signed-off-by: Sanskarzz --- pkg/auth/discovery/discovery.go | 42 +++ pkg/auth/remote/config.go | 10 +- pkg/auth/remote/handler.go | 44 ++- pkg/auth/remote/persisting_token_source.go | 18 +- pkg/auth/remote/secret_renewal.go | 223 ++++++++++++ pkg/auth/remote/secret_renewal_test.go | 383 +++++++++++++++++++++ pkg/runner/runner.go | 37 +- 7 files changed, 750 insertions(+), 7 deletions(-) create mode 100644 pkg/auth/remote/secret_renewal.go create mode 100644 pkg/auth/remote/secret_renewal_test.go diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index ee28036c5d..e240d68ef9 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -512,6 +512,12 @@ type OAuthFlowConfig struct { Resource string // RFC 8707 resource indicator (optional) OAuthParams map[string]string ScopeParamName string // Override scope query parameter name (e.g., "user_scope" for Slack) + + // DCR renewal metadata — populated by handleDynamicRegistration and threaded + // into OAuthFlowResult so callers can persist the data for RFC 7592 operations. + SecretExpiry time.Time // zero means the secret never expires + RegistrationAccessToken string //nolint:gosec // G117: field legitimately holds sensitive data + RegistrationClientURI string } // OAuthFlowResult contains the result of an OAuth flow @@ -527,6 +533,14 @@ type OAuthFlowResult struct { // DCR client credentials for persistence (obtained during Dynamic Client Registration) ClientID string ClientSecret string //nolint:gosec // G117: field legitimately holds sensitive data + + // DCR renewal metadata (RFC 7591 §3.2.1 / RFC 7592). + // SecretExpiry is zero when the provider did not issue an expiring secret. + // RegistrationAccessToken and RegistrationClientURI are empty when the + // provider does not support RFC 7592 management operations. + SecretExpiry time.Time + RegistrationAccessToken string //nolint:gosec // G117: field legitimately holds sensitive data + RegistrationClientURI string } func shouldDynamicallyRegisterClient(config *OAuthFlowConfig) bool { @@ -660,6 +674,30 @@ func handleDynamicRegistration(ctx context.Context, issuer string, config *OAuth config.TokenURL = resolution.TokenEndpoint } + // Store DCR renewal metadata for RFC 7592 operations. + // client_secret_expires_at == 0 means the secret never expires (RFC 7591 §3.2.1). + if registrationResponse.ClientSecretExpiresAt > 0 { + config.SecretExpiry = time.Unix(registrationResponse.ClientSecretExpiresAt, 0) + } + config.RegistrationAccessToken = registrationResponse.RegistrationAccessToken + config.RegistrationClientURI = registrationResponse.RegistrationClientURI + + if registrationResponse.RegistrationAccessToken != "" { + slog.Debug("DCR response includes registration access token for RFC 7592 operations") + } + + // Store DCR renewal metadata for RFC 7592 operations. + // client_secret_expires_at == 0 means the secret never expires (RFC 7591 §3.2.1). + if registrationResponse.ClientSecretExpiresAt > 0 { + config.SecretExpiry = time.Unix(registrationResponse.ClientSecretExpiresAt, 0) + } + config.RegistrationAccessToken = registrationResponse.RegistrationAccessToken + config.RegistrationClientURI = registrationResponse.RegistrationClientURI + + if registrationResponse.RegistrationAccessToken != "" { + slog.Debug("DCR response includes registration access token for RFC 7592 operations") + } + return nil } @@ -892,6 +930,10 @@ func newOAuthFlow(ctx context.Context, oauthConfig *oauth.Config, config *OAuthF Expiry: tokenResult.Expiry, ClientID: oauthConfig.ClientID, ClientSecret: oauthConfig.ClientSecret, + // DCR renewal metadata — populated only when dynamic registration was performed. + SecretExpiry: config.SecretExpiry, + RegistrationAccessToken: config.RegistrationAccessToken, + RegistrationClientURI: config.RegistrationClientURI, }, nil } diff --git a/pkg/auth/remote/config.go b/pkg/auth/remote/config.go index 8675036dda..7b66b6649e 100644 --- a/pkg/auth/remote/config.go +++ b/pkg/auth/remote/config.go @@ -11,7 +11,7 @@ import ( "strings" "time" - "github.com/stacklok/toolhive-core/registry/types" + registry "github.com/stacklok/toolhive-core/registry/types" httpval "github.com/stacklok/toolhive-core/validation/http" ) @@ -68,7 +68,8 @@ type Config struct { // ClientSecretExpiresAt indicates when the client secret expires (if provided by the DCR server). // A zero value means the secret does not expire. CachedSecretExpiry time.Time `json:"cached_secret_expiry,omitempty" yaml:"cached_secret_expiry,omitempty"` - // RegistrationAccessToken is used to update/delete the client registration. + // CachedRegTokenRef is a secret manager reference to the registration_access_token + // returned in the DCR response. Used for RFC 7592 client update operations. // Stored as a secret reference since it's sensitive. CachedRegTokenRef string `json:"cached_reg_token_ref,omitempty" yaml:"cached_reg_token_ref,omitempty"` @@ -78,6 +79,10 @@ type Config struct { // rotation clears CachedClientID without touching the stable CIMD URL. // Read by resolveClientCredentials to send the correct client_id on token refresh. CachedCIMDClientID string `json:"cached_cimd_client_id,omitempty" yaml:"cached_cimd_client_id,omitempty"` + // CachedRegClientURI is the registration_client_uri from the DCR response. + // This is the endpoint used for RFC 7592 client read/update/delete operations. + // Stored as plain text since it is not sensitive. + CachedRegClientURI string `json:"cached_reg_client_uri,omitempty" yaml:"cached_reg_client_uri,omitempty"` } // BearerTokenEnvVarName is the environment variable name used for bearer token authentication. @@ -184,6 +189,7 @@ func (c *Config) ClearCachedClientCredentials() { c.CachedClientSecretRef = "" c.CachedSecretExpiry = time.Time{} c.CachedRegTokenRef = "" + c.CachedRegClientURI = "" } // LogContext returns the upstream issuer and resolved client_id for use as diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index ce7d271288..dd63792c9d 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -9,6 +9,7 @@ import ( "fmt" "log/slog" "strings" + "time" "golang.org/x/oauth2" @@ -207,7 +208,13 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. // CIMD client IDs (HTTPS URLs) are stable constants and are stored separately below. if h.clientCredentialsPersister != nil && result.ClientID != "" && !oauthproto.IsClientIDMetadataDocumentURL(result.ClientID) { - if err := h.clientCredentialsPersister(result.ClientID, result.ClientSecret); err != nil { + if err := h.clientCredentialsPersister( + result.ClientID, + result.ClientSecret, + result.SecretExpiry, + result.RegistrationAccessToken, + result.RegistrationClientURI, + ); err != nil { slog.Warn("Failed to persist DCR client credentials", "error", err) } else { slog.Debug("Successfully persisted DCR client credentials for future restarts") @@ -232,6 +239,8 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. // resolveClientCredentials returns the client ID and secret to use, preferring // cached DCR credentials over statically configured ones. +// If the cached client secret is expiring soon, it attempts renewal via RFC 7592 +// before returning the credentials. func (h *Handler) resolveClientCredentials(ctx context.Context) (clientID, clientSecret string) { // First try to use statically configured credentials clientID = h.config.ClientID @@ -252,6 +261,18 @@ func (h *Handler) resolveClientCredentials(ctx context.Context) (clientID, clien clientID = h.config.CachedClientID slog.Debug("Using cached DCR client credentials", "client_id", clientID) + // Proactively renew the client secret if it is expiring soon (RFC 7592) + if h.isSecretExpiredOrExpiringSoon() { + slog.Info("Cached client secret is expiring soon, attempting renewal", + "expiry", h.config.CachedSecretExpiry) + if renewErr := h.renewClientSecret(ctx); renewErr != nil { + slog.Warn("Failed to proactively renew client secret; continuing with existing secret", + "error", renewErr) + } else { + slog.Debug("Successfully renewed client secret ahead of expiry") + } + } + // Client secret is stored securely and may be empty for PKCE flows if h.config.CachedClientSecretRef != "" && h.secretProvider != nil { cachedClientSecret, err := h.secretProvider.GetSecret(ctx, h.config.CachedClientSecretRef) @@ -278,6 +299,27 @@ func (h *Handler) tryRestoreFromCachedTokens( return nil, fmt.Errorf("secret provider not configured, cannot restore cached tokens") } + // Check if the cached client secret is expired before attempting token refresh. + // If it has fully expired and renewal also fails we must force a fresh OAuth flow. + if h.isSecretExpiredOrExpiringSoon() { + slog.Info("Cached client secret is expiring or expired; attempting renewal before token restore", + "expiry", h.config.CachedSecretExpiry) + if renewErr := h.renewClientSecret(ctx); renewErr != nil { + slog.Warn("Client secret renewal failed", "error", renewErr) + // Hard-fail only when the secret is already past its expiry. + // If we are still in the buffer window the existing secret may work. + if !h.config.CachedSecretExpiry.IsZero() && time.Now().After(h.config.CachedSecretExpiry) { + return nil, fmt.Errorf( + "client secret expired at %v and renewal failed: %w", + h.config.CachedSecretExpiry, renewErr) + } + // Still within buffer — log and continue with the existing (still-valid) secret + slog.Warn("Proceeding with expiring client secret after failed renewal attempt") + } else { + slog.Debug("Successfully renewed client secret before token restore") + } + } + refreshToken, err := h.secretProvider.GetSecret(ctx, h.config.CachedRefreshTokenRef) if err != nil { return nil, fmt.Errorf("failed to retrieve cached refresh token: %w", err) diff --git a/pkg/auth/remote/persisting_token_source.go b/pkg/auth/remote/persisting_token_source.go index 3ba90b6253..5777ba6edf 100644 --- a/pkg/auth/remote/persisting_token_source.go +++ b/pkg/auth/remote/persisting_token_source.go @@ -21,8 +21,22 @@ import ( type TokenPersister func(refreshToken string, expiry time.Time) error // ClientCredentialsPersister is called when DCR client credentials need to be persisted. -// This is used to store client_id and client_secret obtained during Dynamic Client Registration. -type ClientCredentialsPersister func(clientID, clientSecret string) error +// This is used to store client_id, client_secret, and renewal metadata obtained during +// Dynamic Client Registration (RFC 7591) and needed for secret renewal (RFC 7592). +// +// Parameters: +// - clientID: the registered client ID (public, stored as plain text) +// - clientSecret: the registered client secret (sensitive, stored via secret manager) +// - secretExpiry: when the client secret expires; zero value means it never expires +// - registrationAccessToken: bearer token for RFC 7592 management operations (sensitive) +// - registrationClientURI: endpoint for RFC 7592 client update/read operations (plain text) +type ClientCredentialsPersister func( + clientID string, + clientSecret string, + secretExpiry time.Time, + registrationAccessToken string, + registrationClientURI string, +) error // PersistingTokenSource wraps an oauth2.TokenSource and persists tokens // whenever they are refreshed. This enables session restoration across diff --git a/pkg/auth/remote/secret_renewal.go b/pkg/auth/remote/secret_renewal.go new file mode 100644 index 0000000000..6815c82ad6 --- /dev/null +++ b/pkg/auth/remote/secret_renewal.go @@ -0,0 +1,223 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package remote + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "strings" + "time" + + "github.com/stacklok/toolhive/pkg/networking" +) + +// secretExpiryBuffer is the lead time before expiry at which we proactively +// renew the client secret (RFC 7592). Renewal is attempted when the secret +// expires within this window, not only after expiry. +const secretExpiryBuffer = 24 * time.Hour + +// clientUpdateRequest is the body sent in a RFC 7592 §2.2 PUT request. +// Per the spec, all client metadata fields that were provided during +// registration must be included in the update request body. +type clientUpdateRequest struct { + ClientID string `json:"client_id"` + ClientName string `json:"client_name,omitempty"` + RedirectURIs []string `json:"redirect_uris,omitempty"` + GrantTypes []string `json:"grant_types,omitempty"` + ResponseTypes []string `json:"response_types,omitempty"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` +} + +// clientUpdateResponse is the body returned by a RFC 7592 §2.1 response. +// The provider may rotate the registration_access_token; if present we must +// replace the stored one. +type clientUpdateResponse struct { + // Required fields mirrored from registration response + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` //nolint:gosec // G117: field holds sensitive data + + // Expiry fields + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + + // Management fields — registration_access_token may be rotated + RegistrationAccessToken string `json:"registration_access_token,omitempty"` //nolint:gosec + RegistrationClientURI string `json:"registration_client_uri,omitempty"` +} + +// isSecretExpiredOrExpiringSoon returns true when the cached client secret is +// already expired or will expire within secretExpiryBuffer. +// A zero CachedSecretExpiry means the secret never expires, so this returns false. +func (h *Handler) isSecretExpiredOrExpiringSoon() bool { + expiry := h.config.CachedSecretExpiry + if expiry.IsZero() { + return false // Non-expiring secret + } + return time.Now().After(expiry.Add(-secretExpiryBuffer)) +} + +// renewClientSecret attempts to renew the client secret using RFC 7592 §2.2. +// It retrieves the stored registration_access_token and sends a PUT request +// to the registration_client_uri with the current client metadata. +// +// On success the handler's config is updated with the new secret, expiry, and +// (if rotated) the new registration_access_token. +// +// Callers should log a warning and continue if renewal fails — the existing +// secret may still be valid for some time, or the provider may not support renewal. +func (h *Handler) renewClientSecret(ctx context.Context) error { + if h.config.CachedRegClientURI == "" { + return fmt.Errorf("registration_client_uri not available; cannot renew client secret (RFC 7592 not supported by this provider)") + } + + if h.config.CachedRegTokenRef == "" { + return fmt.Errorf("registration_access_token not available; cannot renew client secret (RFC 7592 not supported by this provider)") + } + + if h.secretProvider == nil { + return fmt.Errorf("secret provider not configured; cannot retrieve registration access token") + } + + // Retrieve the registration access token from the secret manager + regAccessToken, err := h.secretProvider.GetSecret(ctx, h.config.CachedRegTokenRef) + if err != nil { + return fmt.Errorf("failed to retrieve registration access token: %w", err) + } + + slog.Debug("Attempting RFC 7592 client secret renewal", + "registration_client_uri", h.config.CachedRegClientURI) + + // Validate the registration_client_uri before using it + if err := validateRegistrationClientURI(h.config.CachedRegClientURI); err != nil { + return fmt.Errorf("invalid registration_client_uri: %w", err) + } + + // Build the update request body with the current client metadata. + // Per RFC 7592 §2.2, the request MUST include all client metadata fields + // that were provided during the initial registration. + updateReq := clientUpdateRequest{ + ClientID: h.config.CachedClientID, + ClientName: "ToolHive MCP Client", + RedirectURIs: []string{fmt.Sprintf("http://localhost:%d/callback", h.config.CallbackPort)}, + GrantTypes: []string{"authorization_code", "refresh_token"}, + ResponseTypes: []string{"code"}, + TokenEndpointAuthMethod: "none", + } + + reqBody, err := json.Marshal(updateReq) + if err != nil { + return fmt.Errorf("failed to marshal client update request: %w", err) + } + + // Create PUT request per RFC 7592 §2.2 + req, err := http.NewRequestWithContext( + ctx, + http.MethodPut, + h.config.CachedRegClientURI, + strings.NewReader(string(reqBody)), + ) + if err != nil { + return fmt.Errorf("failed to create client update request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+regAccessToken) //nolint:gosec // G117 + + // Execute the request + httpClient := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + }, + } + + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("client update request failed: %w", err) + } + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + slog.Debug("Failed to close renewal response body", "error", closeErr) + } + }() + + if resp.StatusCode != http.StatusOK { + errorBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return fmt.Errorf("client update request returned HTTP %d: %s", resp.StatusCode, string(errorBody)) + } + + // Parse the renewal response + const maxResponseSize = 1024 * 1024 // 1 MB + var updateResp clientUpdateResponse + if err := json.NewDecoder(io.LimitReader(resp.Body, maxResponseSize)).Decode(&updateResp); err != nil { + return fmt.Errorf("failed to decode client update response: %w", err) + } + + if updateResp.ClientID == "" { + return fmt.Errorf("client update response missing client_id") + } + if updateResp.ClientSecret == "" { + return fmt.Errorf("client update response missing client_secret") + } + + // Persist the new secret via clientCredentialsPersister so it goes through + // the same secure-storage path as the initial registration. + if h.clientCredentialsPersister == nil { + return fmt.Errorf("client credentials persister not configured; cannot save renewed secret") + } + + var newExpiry time.Time + if updateResp.ClientSecretExpiresAt > 0 { + newExpiry = time.Unix(updateResp.ClientSecretExpiresAt, 0) + } + + // Use the rotated registration_access_token if provided; fall back to existing. + newRegToken := updateResp.RegistrationAccessToken + newRegURI := updateResp.RegistrationClientURI + if newRegURI == "" { + newRegURI = h.config.CachedRegClientURI + } + + if err := h.clientCredentialsPersister( + updateResp.ClientID, + updateResp.ClientSecret, + newExpiry, + newRegToken, + newRegURI, + ); err != nil { + return fmt.Errorf("failed to persist renewed client secret: %w", err) + } + + slog.Info("Successfully renewed client secret via RFC 7592", + "client_id", updateResp.ClientID, + "new_expiry_zero", newExpiry.IsZero(), + "reg_token_rotated", newRegToken != "") + + return nil +} + +// validateRegistrationClientURI validates that the registration_client_uri is +// a valid HTTPS URL (or localhost for development). +func validateRegistrationClientURI(registrationClientURI string) error { + if registrationClientURI == "" { + return fmt.Errorf("registration_client_uri is empty") + } + + parsedURL, err := url.Parse(registrationClientURI) + if err != nil { + return fmt.Errorf("invalid registration_client_uri URL: %w", err) + } + + if parsedURL.Scheme != "https" && !networking.IsLocalhost(parsedURL.Host) { + return fmt.Errorf("registration_client_uri must use HTTPS: %s", registrationClientURI) + } + + return nil +} diff --git a/pkg/auth/remote/secret_renewal_test.go b/pkg/auth/remote/secret_renewal_test.go new file mode 100644 index 0000000000..30756103e4 --- /dev/null +++ b/pkg/auth/remote/secret_renewal_test.go @@ -0,0 +1,383 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package remote + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/secrets" +) + +// mockSecretProvider is a simple in-memory secret store for tests. +// It implements the full secrets.Provider interface. +type mockSecretProvider struct { + secrets map[string]string +} + +func newMockSecretProvider(initial map[string]string) *mockSecretProvider { + if initial == nil { + initial = make(map[string]string) + } + return &mockSecretProvider{secrets: initial} +} + +func (m *mockSecretProvider) GetSecret(_ context.Context, name string) (string, error) { + v, ok := m.secrets[name] + if !ok { + return "", fmt.Errorf("secret %q not found", name) + } + return v, nil +} + +func (m *mockSecretProvider) SetSecret(_ context.Context, name, value string) error { + m.secrets[name] = value + return nil +} + +func (m *mockSecretProvider) DeleteSecret(_ context.Context, name string) error { + delete(m.secrets, name) + return nil +} + +func (m *mockSecretProvider) ListSecrets(_ context.Context) ([]secrets.SecretDescription, error) { + result := make([]secrets.SecretDescription, 0, len(m.secrets)) + for k := range m.secrets { + result = append(result, secrets.SecretDescription{Key: k}) + } + return result, nil +} + +func (m *mockSecretProvider) Cleanup() error { return nil } + +func (m *mockSecretProvider) Capabilities() secrets.ProviderCapabilities { + return secrets.ProviderCapabilities{ + CanRead: true, + CanWrite: true, + CanDelete: true, + CanList: true, + } +} + +// TestIsSecretExpiredOrExpiringSoon tests the expiry helper on various time scenarios. +func TestIsSecretExpiredOrExpiringSoon(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + expiry time.Time + wantExpired bool + }{ + { + name: "zero expiry means never expires", + expiry: time.Time{}, + wantExpired: false, + }, + { + name: "expiry far in the future — not expiring", + expiry: time.Now().Add(48 * time.Hour), + wantExpired: false, + }, + { + name: "expiry within 24h buffer — expiring soon", + expiry: time.Now().Add(12 * time.Hour), + wantExpired: true, + }, + { + name: "expiry in the past — already expired", + expiry: time.Now().Add(-1 * time.Hour), + wantExpired: true, + }, + { + name: "expiry exactly at buffer boundary — expiring soon", + expiry: time.Now().Add(secretExpiryBuffer - time.Minute), + wantExpired: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h := &Handler{ + config: &Config{ + CachedSecretExpiry: tt.expiry, + }, + } + assert.Equal(t, tt.wantExpired, h.isSecretExpiredOrExpiringSoon()) + }) + } +} + +// TestValidateRegistrationClientURI tests URI validation. +func TestValidateRegistrationClientURI(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + uri string + wantErr bool + }{ + { + name: "empty URI", + uri: "", + wantErr: true, + }, + { + name: "valid HTTPS URI", + uri: "https://example.com/oauth/register/client-id", + wantErr: false, + }, + { + name: "HTTP URI for non-localhost is rejected", + uri: "http://example.com/oauth/register/client-id", + wantErr: true, + }, + { + name: "localhost HTTP is allowed (development)", + uri: "http://localhost:8080/oauth/register/client-id", + wantErr: false, + }, + { + name: "127.0.0.1 HTTP is allowed (development)", + uri: "http://127.0.0.1:8080/oauth/register/client-id", + wantErr: false, + }, + { + name: "invalid URL", + uri: "://bad-url", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateRegistrationClientURI(tt.uri) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestRenewClientSecret_MissingConfig tests early-exit conditions. +func TestRenewClientSecret_MissingConfig(t *testing.T) { + t.Parallel() + + t.Run("missing registration_client_uri", func(t *testing.T) { + t.Parallel() + + h := &Handler{ + config: &Config{ + CachedRegClientURI: "", + CachedRegTokenRef: "some-ref", + }, + } + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "registration_client_uri not available") + }) + + t.Run("missing registration_token_ref", func(t *testing.T) { + t.Parallel() + + h := &Handler{ + config: &Config{ + CachedRegClientURI: "https://example.com/register/client-id", + CachedRegTokenRef: "", + }, + } + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "registration_access_token not available") + }) + + t.Run("missing secret provider", func(t *testing.T) { + t.Parallel() + + h := &Handler{ + config: &Config{ + CachedRegClientURI: "https://example.com/register/client-id", + CachedRegTokenRef: "some-ref", + }, + secretProvider: nil, // no provider + } + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "secret provider not configured") + }) +} + +// TestRenewClientSecret_Success tests the happy path with a mock RFC 7592 server. +func TestRenewClientSecret_Success(t *testing.T) { + t.Parallel() + + newSecret := "new-client-secret-xyz" + newExpiry := time.Now().Add(24 * time.Hour * 30).Unix() + newRegToken := "new-registration-access-token" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // RFC 7592 §2.2: must be PUT with Bearer auth + assert.Equal(t, http.MethodPut, r.Method) + assert.Contains(t, r.Header.Get("Authorization"), "Bearer reg-access-token") + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + + // Return the updated registration response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "client_id": "test-client-id", + "client_secret": newSecret, + "client_secret_expires_at": newExpiry, + "registration_access_token": newRegToken, + "registration_client_uri": "http://" + r.Host + r.URL.Path, + }) + })) + defer server.Close() + + // Set up persister capture + var persistedClientID, persistedSecret, persistedRegToken, persistedRegURI string + var persistedExpiry time.Time + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-secret-ref", + }, + secretProvider: newMockSecretProvider(map[string]string{ + "reg-token-secret-ref": "reg-access-token", + }), + clientCredentialsPersister: func( + clientID, secret string, + expiry time.Time, + regToken, regURI string, + ) error { + persistedClientID = clientID + persistedSecret = secret + persistedExpiry = expiry + persistedRegToken = regToken + persistedRegURI = regURI + return nil + }, + } + + err := h.renewClientSecret(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "test-client-id", persistedClientID) + assert.Equal(t, newSecret, persistedSecret) + assert.Equal(t, newRegToken, persistedRegToken) + assert.False(t, persistedExpiry.IsZero(), "expiry should be set") + assert.NotEmpty(t, persistedRegURI) +} + +// TestRenewClientSecret_ServerError tests error propagation when the server returns non-200. +func TestRenewClientSecret_ServerError(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + })) + defer server.Close() + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-ref", + }, + secretProvider: newMockSecretProvider(map[string]string{ + "reg-token-ref": "bad-token", + }), + clientCredentialsPersister: func(_, _ string, _ time.Time, _, _ string) error { + return nil + }, + } + + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "401") +} + +// TestRenewClientSecret_NoPersister tests failure when persister is not set. +func TestRenewClientSecret_NoPersister(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "client_id": "test-client-id", + "client_secret": "new-secret", + }) + })) + defer server.Close() + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-ref", + }, + secretProvider: newMockSecretProvider(map[string]string{ + "reg-token-ref": "some-token", + }), + clientCredentialsPersister: nil, // no persister + } + + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "client credentials persister not configured") +} + +// TestRenewClientSecret_ZeroExpiryInResponse tests that a zero client_secret_expires_at +// is correctly interpreted as a non-expiring secret. +func TestRenewClientSecret_ZeroExpiryInResponse(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "client_id": "test-client-id", + "client_secret": "new-secret", + "client_secret_expires_at": 0, // never expires + }) + })) + defer server.Close() + + var capturedExpiry time.Time + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-ref", + }, + secretProvider: newMockSecretProvider(map[string]string{ + "reg-token-ref": "some-token", + }), + clientCredentialsPersister: func(_, _ string, expiry time.Time, _, _ string) error { + capturedExpiry = expiry + return nil + }, + } + + err := h.renewClientSecret(context.Background()) + require.NoError(t, err) + assert.True(t, capturedExpiry.IsZero(), "zero client_secret_expires_at must produce zero time.Time") +} diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 079a7d88db..fd5a91e374 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -780,7 +780,13 @@ func (r *Runner) handleRemoteAuthentication(ctx context.Context) (oauth2.TokenSo }) // Set up client credentials persister for DCR (Dynamic Client Registration) - authHandler.SetClientCredentialsPersister(func(clientID, clientSecret string) error { + authHandler.SetClientCredentialsPersister(func( + clientID string, + clientSecret string, + secretExpiry time.Time, + registrationAccessToken string, + registrationClientURI string, + ) error { // Store client ID directly (it's public information) r.Config.RemoteAuthConfig.CachedClientID = clientID @@ -801,12 +807,39 @@ func (r *Runner) handleRemoteAuthentication(ctx context.Context) (oauth2.TokenSo r.Config.RemoteAuthConfig.CachedClientSecretRef = clientSecretSecretName } + // Store secret expiry (zero value = never expires per RFC 7591 §3.2.1) + r.Config.RemoteAuthConfig.CachedSecretExpiry = secretExpiry + + // Store registration_access_token securely (needed for RFC 7592 renewal) + if registrationAccessToken != "" { + regTokenSecretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( + r.Config.Name, + "OAUTH_REG_TOKEN_", + secretManager, + ) + if err != nil { + return fmt.Errorf("failed to generate registration token secret name: %w", err) + } + + if err := authsecrets.StoreSecretInManagerWithProvider(ctx, regTokenSecretName, registrationAccessToken, secretManager); err != nil { + return fmt.Errorf("failed to store registration access token: %w", err) + } + r.Config.RemoteAuthConfig.CachedRegTokenRef = regTokenSecretName + slog.Debug("Stored DCR registration access token for RFC 7592 operations") + } + + // Store registration_client_uri as plain text (not sensitive) + r.Config.RemoteAuthConfig.CachedRegClientURI = registrationClientURI + // Save the updated config to persist the credentials if err := r.Config.SaveState(ctx); err != nil { return fmt.Errorf("failed to save config with client credentials: %w", err) } - slog.Debug("Stored DCR client credentials", "client_id", clientID) + slog.Debug("Stored DCR client credentials", "client_id", clientID, + "has_expiry", !secretExpiry.IsZero(), + "has_reg_token", registrationAccessToken != "", + "has_reg_uri", registrationClientURI != "") return nil }) } From 3ed7b119753653b228996df428d57e643bfb088d Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Wed, 18 Mar 2026 19:47:55 +0530 Subject: [PATCH 2/8] fix: CI lint and docs Signed-off-by: Sanskarzz --- docs/server/docs.go | 44 ++++++- docs/server/swagger.json | 44 ++++++- docs/server/swagger.yaml | 9 +- pkg/auth/remote/secret_renewal.go | 31 +++-- pkg/auth/remote/secret_renewal_test.go | 12 +- pkg/runner/config_test.go | 28 +++- pkg/runner/runner.go | 169 +++++++++++++------------ 7 files changed, 227 insertions(+), 110 deletions(-) diff --git a/docs/server/docs.go b/docs/server/docs.go index fa76390328..5aa1247d7f 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -262,8 +262,12 @@ const docTemplate = `{ "description": "Cached OAuth token reference for persistence across restarts.\nThe refresh token is stored securely in the secret manager, and this field\ncontains the reference to retrieve it (e.g., \"OAUTH_REFRESH_TOKEN_workload\").\nThis enables session restoration without requiring a new browser-based login.", "type": "string" }, + "cached_reg_client_uri": { + "description": "CachedRegClientURI is the registration_client_uri from the DCR response.\nThis is the endpoint used for RFC 7592 client read/update/delete operations.\nStored as plain text since it is not sensitive.", + "type": "string" + }, "cached_reg_token_ref": { - "description": "RegistrationAccessToken is used to update/delete the client registration.\nStored as a secret reference since it's sensitive.", + "description": "CachedRegTokenRef is a secret manager reference to the registration_access_token\nreturned in the DCR response. Used for RFC 7592 client update operations.\nStored as a secret reference since it's sensitive.", "type": "string" }, "cached_secret_expiry": { @@ -2242,6 +2246,44 @@ const docTemplate = `{ }, "type": "object" }, + "github_com_stacklok_toolhive_pkg_transport_types.ProxyMode": { + "description": "ProxyMode is the proxy mode for stdio transport (\"sse\" or \"streamable-http\")\nNote: \"sse\" is deprecated; use \"streamable-http\" instead.", + "enum": [ + "sse", + "streamable-http", + "sse", + "streamable-http" + ], + "type": "string", + "x-enum-varnames": [ + "ProxyModeSSE", + "ProxyModeStreamableHTTP" + ] + }, + "github_com_stacklok_toolhive_pkg_transport_types.TransportType": { + "description": "Transport is the transport mode (stdio, sse, or streamable-http)", + "enum": [ + "stdio", + "sse", + "streamable-http", + "inspector", + "stdio", + "sse", + "streamable-http", + "inspector", + "stdio", + "sse", + "streamable-http", + "inspector" + ], + "type": "string", + "x-enum-varnames": [ + "TransportTypeStdio", + "TransportTypeSSE", + "TransportTypeStreamableHTTP", + "TransportTypeInspector" + ] + }, "permissions.InboundNetworkPermissions": { "description": "Inbound defines inbound network permissions", "properties": { diff --git a/docs/server/swagger.json b/docs/server/swagger.json index eb71c3efec..a1e29bc48c 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -255,8 +255,12 @@ "description": "Cached OAuth token reference for persistence across restarts.\nThe refresh token is stored securely in the secret manager, and this field\ncontains the reference to retrieve it (e.g., \"OAUTH_REFRESH_TOKEN_workload\").\nThis enables session restoration without requiring a new browser-based login.", "type": "string" }, + "cached_reg_client_uri": { + "description": "CachedRegClientURI is the registration_client_uri from the DCR response.\nThis is the endpoint used for RFC 7592 client read/update/delete operations.\nStored as plain text since it is not sensitive.", + "type": "string" + }, "cached_reg_token_ref": { - "description": "RegistrationAccessToken is used to update/delete the client registration.\nStored as a secret reference since it's sensitive.", + "description": "CachedRegTokenRef is a secret manager reference to the registration_access_token\nreturned in the DCR response. Used for RFC 7592 client update operations.\nStored as a secret reference since it's sensitive.", "type": "string" }, "cached_secret_expiry": { @@ -2235,6 +2239,44 @@ }, "type": "object" }, + "github_com_stacklok_toolhive_pkg_transport_types.ProxyMode": { + "description": "ProxyMode is the proxy mode for stdio transport (\"sse\" or \"streamable-http\")\nNote: \"sse\" is deprecated; use \"streamable-http\" instead.", + "enum": [ + "sse", + "streamable-http", + "sse", + "streamable-http" + ], + "type": "string", + "x-enum-varnames": [ + "ProxyModeSSE", + "ProxyModeStreamableHTTP" + ] + }, + "github_com_stacklok_toolhive_pkg_transport_types.TransportType": { + "description": "Transport is the transport mode (stdio, sse, or streamable-http)", + "enum": [ + "stdio", + "sse", + "streamable-http", + "inspector", + "stdio", + "sse", + "streamable-http", + "inspector", + "stdio", + "sse", + "streamable-http", + "inspector" + ], + "type": "string", + "x-enum-varnames": [ + "TransportTypeStdio", + "TransportTypeSSE", + "TransportTypeStreamableHTTP", + "TransportTypeInspector" + ] + }, "permissions.InboundNetworkPermissions": { "description": "Inbound defines inbound network permissions", "properties": { diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 6e9d446e50..efb3e5a9d6 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -274,9 +274,16 @@ components: contains the reference to retrieve it (e.g., "OAUTH_REFRESH_TOKEN_workload"). This enables session restoration without requiring a new browser-based login. type: string + cached_reg_client_uri: + description: |- + CachedRegClientURI is the registration_client_uri from the DCR response. + This is the endpoint used for RFC 7592 client read/update/delete operations. + Stored as plain text since it is not sensitive. + type: string cached_reg_token_ref: description: |- - RegistrationAccessToken is used to update/delete the client registration. + CachedRegTokenRef is a secret manager reference to the registration_access_token + returned in the DCR response. Used for RFC 7592 client update operations. Stored as a secret reference since it's sensitive. type: string cached_secret_expiry: diff --git a/pkg/auth/remote/secret_renewal.go b/pkg/auth/remote/secret_renewal.go index 6815c82ad6..01c1a58929 100644 --- a/pkg/auth/remote/secret_renewal.go +++ b/pkg/auth/remote/secret_renewal.go @@ -71,16 +71,8 @@ func (h *Handler) isSecretExpiredOrExpiringSoon() bool { // Callers should log a warning and continue if renewal fails — the existing // secret may still be valid for some time, or the provider may not support renewal. func (h *Handler) renewClientSecret(ctx context.Context) error { - if h.config.CachedRegClientURI == "" { - return fmt.Errorf("registration_client_uri not available; cannot renew client secret (RFC 7592 not supported by this provider)") - } - - if h.config.CachedRegTokenRef == "" { - return fmt.Errorf("registration_access_token not available; cannot renew client secret (RFC 7592 not supported by this provider)") - } - - if h.secretProvider == nil { - return fmt.Errorf("secret provider not configured; cannot retrieve registration access token") + if err := h.validateRenewalPrerequisites(); err != nil { + return err } // Retrieve the registration access token from the secret manager @@ -167,8 +159,23 @@ func (h *Handler) renewClientSecret(ctx context.Context) error { return fmt.Errorf("client update response missing client_secret") } - // Persist the new secret via clientCredentialsPersister so it goes through - // the same secure-storage path as the initial registration. + return h.persistRenewedSecret(updateResp) +} + +func (h *Handler) validateRenewalPrerequisites() error { + if h.config.CachedRegClientURI == "" { + return fmt.Errorf("registration_client_uri missing; cannot renew secret (RFC 7592 unsupported)") + } + if h.config.CachedRegTokenRef == "" { + return fmt.Errorf("registration_access_token missing; cannot renew secret (RFC 7592 unsupported)") + } + if h.secretProvider == nil { + return fmt.Errorf("secret provider not configured; cannot retrieve registration access token") + } + return nil +} + +func (h *Handler) persistRenewedSecret(updateResp clientUpdateResponse) error { if h.clientCredentialsPersister == nil { return fmt.Errorf("client credentials persister not configured; cannot save renewed secret") } diff --git a/pkg/auth/remote/secret_renewal_test.go b/pkg/auth/remote/secret_renewal_test.go index 30756103e4..9ad65e0bbf 100644 --- a/pkg/auth/remote/secret_renewal_test.go +++ b/pkg/auth/remote/secret_renewal_test.go @@ -57,12 +57,12 @@ func (m *mockSecretProvider) ListSecrets(_ context.Context) ([]secrets.SecretDes return result, nil } -func (m *mockSecretProvider) Cleanup() error { return nil } +func (*mockSecretProvider) Cleanup() error { return nil } -func (m *mockSecretProvider) Capabilities() secrets.ProviderCapabilities { +func (*mockSecretProvider) Capabilities() secrets.ProviderCapabilities { return secrets.ProviderCapabilities{ - CanRead: true, - CanWrite: true, + CanRead: true, + CanWrite: true, CanDelete: true, CanList: true, } @@ -353,8 +353,8 @@ func TestRenewClientSecret_ZeroExpiryInResponse(t *testing.T) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(map[string]interface{}{ - "client_id": "test-client-id", - "client_secret": "new-secret", + "client_id": "test-client-id", + "client_secret": "new-secret", "client_secret_expires_at": 0, // never expires }) })) diff --git a/pkg/runner/config_test.go b/pkg/runner/config_test.go index ad66252c9a..f933741236 100644 --- a/pkg/runner/config_test.go +++ b/pkg/runner/config_test.go @@ -174,6 +174,22 @@ func TestRunConfig_NormalizeProxyMode(t *testing.T) { // Note: This test uses actual port finding logic, so it may fail if ports are in use func TestRunConfig_WithPorts(t *testing.T) { t.Parallel() + // Find available ports dynamically to avoid flaky failures + port1 := networking.FindAvailable() + require.NotZero(t, port1, "should find an available proxy port for SSE") + targetPort1 := networking.FindAvailable() + require.NotZero(t, targetPort1, "should find an available target port for SSE") + + port2 := networking.FindAvailable() + require.NotZero(t, port2, "should find an available proxy port for HTTP") + targetPort2 := networking.FindAvailable() + require.NotZero(t, targetPort2, "should find an available target port for HTTP") + + port3 := networking.FindAvailable() + require.NotZero(t, port3, "should find an available proxy port for Stdio") + targetPort3 := networking.FindAvailable() + require.NotZero(t, targetPort3, "should find an available target port for Stdio") + testCases := []struct { name string config *RunConfig @@ -184,8 +200,8 @@ func TestRunConfig_WithPorts(t *testing.T) { { name: "SSE transport with specific ports", config: &RunConfig{Transport: types.TransportTypeSSE}, - port: 8001, - targetPort: 9001, + port: port1, + targetPort: targetPort1, expectError: false, }, { @@ -198,15 +214,15 @@ func TestRunConfig_WithPorts(t *testing.T) { { name: "Streamable HTTP transport with specific ports", config: &RunConfig{Transport: types.TransportTypeStreamableHTTP}, - port: 8002, - targetPort: 9002, + port: port2, + targetPort: targetPort2, expectError: false, }, { name: "Stdio transport with specific port", config: &RunConfig{Transport: types.TransportTypeStdio}, - port: 8003, - targetPort: 9003, // This should be ignored for stdio + port: port3, + targetPort: targetPort3, // This should be ignored for stdio expectError: false, }, } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index fd5a91e374..01c7440dcb 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -751,106 +751,109 @@ func (r *Runner) handleRemoteAuthentication(ctx context.Context) (oauth2.TokenSo // Set up token persister to save tokens across restarts if secretManager != nil { authHandler.SetTokenPersister(func(refreshToken string, expiry time.Time) error { - // Generate a unique secret name for this workload's refresh token - secretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( - r.Config.Name, - "OAUTH_REFRESH_TOKEN_", - secretManager, - ) - if err != nil { - return fmt.Errorf("failed to generate secret name: %w", err) - } - - // Store the refresh token in the secret manager - if err := authsecrets.StoreSecretInManagerWithProvider(ctx, secretName, refreshToken, secretManager); err != nil { - return fmt.Errorf("failed to store refresh token: %w", err) - } - - // Store the secret reference (not the actual token) in the config - r.Config.RemoteAuthConfig.CachedRefreshTokenRef = secretName - r.Config.RemoteAuthConfig.CachedTokenExpiry = expiry - - // Save the updated config to persist the reference - if err := r.Config.SaveState(ctx); err != nil { - return fmt.Errorf("failed to save config with token reference: %w", err) - } - - slog.Debug("Stored OAuth refresh token in secret manager", "secret_name", secretName) - return nil + return r.persistRefreshToken(ctx, secretManager, refreshToken, expiry) }) // Set up client credentials persister for DCR (Dynamic Client Registration) authHandler.SetClientCredentialsPersister(func( - clientID string, - clientSecret string, + clientID, clientSecret string, secretExpiry time.Time, - registrationAccessToken string, - registrationClientURI string, + regAccessToken, regClientURI string, ) error { - // Store client ID directly (it's public information) - r.Config.RemoteAuthConfig.CachedClientID = clientID - - // Only store client secret if it's non-empty (PKCE flows may not have one) - if clientSecret != "" { - clientSecretSecretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( - r.Config.Name, - "OAUTH_CLIENT_SECRET_", - secretManager, - ) - if err != nil { - return fmt.Errorf("failed to generate client secret secret name: %w", err) - } + return r.persistClientCredentials(ctx, secretManager, clientID, clientSecret, secretExpiry, regAccessToken, regClientURI) + }) + } - if err := authsecrets.StoreSecretInManagerWithProvider(ctx, clientSecretSecretName, clientSecret, secretManager); err != nil { - return fmt.Errorf("failed to store client secret: %w", err) - } - r.Config.RemoteAuthConfig.CachedClientSecretRef = clientSecretSecretName - } + // Perform authentication + tokenSource, err := authHandler.Authenticate(ctx, r.Config.RemoteURL) + if err != nil { + return nil, fmt.Errorf("remote authentication failed: %w", err) + } - // Store secret expiry (zero value = never expires per RFC 7591 §3.2.1) - r.Config.RemoteAuthConfig.CachedSecretExpiry = secretExpiry + return tokenSource, nil +} - // Store registration_access_token securely (needed for RFC 7592 renewal) - if registrationAccessToken != "" { - regTokenSecretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( - r.Config.Name, - "OAUTH_REG_TOKEN_", - secretManager, - ) - if err != nil { - return fmt.Errorf("failed to generate registration token secret name: %w", err) - } +func (r *Runner) persistRefreshToken( + ctx context.Context, + secretManager secrets.Provider, + refreshToken string, + expiry time.Time, +) error { + secretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( + r.Config.Name, + "OAUTH_REFRESH_TOKEN_", + secretManager, + ) + if err != nil { + return fmt.Errorf("failed to generate secret name: %w", err) + } - if err := authsecrets.StoreSecretInManagerWithProvider(ctx, regTokenSecretName, registrationAccessToken, secretManager); err != nil { - return fmt.Errorf("failed to store registration access token: %w", err) - } - r.Config.RemoteAuthConfig.CachedRegTokenRef = regTokenSecretName - slog.Debug("Stored DCR registration access token for RFC 7592 operations") - } + if err := authsecrets.StoreSecretInManagerWithProvider(ctx, secretName, refreshToken, secretManager); err != nil { + return fmt.Errorf("failed to store refresh token: %w", err) + } - // Store registration_client_uri as plain text (not sensitive) - r.Config.RemoteAuthConfig.CachedRegClientURI = registrationClientURI + r.Config.RemoteAuthConfig.CachedRefreshTokenRef = secretName + r.Config.RemoteAuthConfig.CachedTokenExpiry = expiry - // Save the updated config to persist the credentials - if err := r.Config.SaveState(ctx); err != nil { - return fmt.Errorf("failed to save config with client credentials: %w", err) - } + if err := r.Config.SaveState(ctx); err != nil { + return fmt.Errorf("failed to save config with token reference: %w", err) + } - slog.Debug("Stored DCR client credentials", "client_id", clientID, - "has_expiry", !secretExpiry.IsZero(), - "has_reg_token", registrationAccessToken != "", - "has_reg_uri", registrationClientURI != "") - return nil - }) + slog.Debug("Stored OAuth refresh token in secret manager", "secret_name", secretName) + return nil +} + +func (r *Runner) persistClientCredentials( + ctx context.Context, + secretManager secrets.Provider, + clientID, clientSecret string, + secretExpiry time.Time, + regAccessToken, regClientURI string, +) error { + r.Config.RemoteAuthConfig.CachedClientID = clientID + + if clientSecret != "" { + clientSecretSecretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( + r.Config.Name, + "OAUTH_CLIENT_SECRET_", + secretManager, + ) + if err != nil { + return fmt.Errorf("failed to generate client secret secret name: %w", err) + } + + if err := authsecrets.StoreSecretInManagerWithProvider(ctx, clientSecretSecretName, clientSecret, secretManager); err != nil { + return fmt.Errorf("failed to store client secret: %w", err) + } + r.Config.RemoteAuthConfig.CachedClientSecretRef = clientSecretSecretName } - // Perform authentication - tokenSource, err := authHandler.Authenticate(ctx, r.Config.RemoteURL) - if err != nil { - return nil, fmt.Errorf("remote authentication failed: %w", err) + r.Config.RemoteAuthConfig.CachedSecretExpiry = secretExpiry + + if regAccessToken != "" { + regTokenSecretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix(r.Config.Name, "OAUTH_REG_TOKEN_", secretManager) + if err != nil { + return fmt.Errorf("failed to generate registration token secret name: %w", err) + } + + if err := authsecrets.StoreSecretInManagerWithProvider(ctx, regTokenSecretName, regAccessToken, secretManager); err != nil { + return fmt.Errorf("failed to store registration access token: %w", err) + } + r.Config.RemoteAuthConfig.CachedRegTokenRef = regTokenSecretName + slog.Debug("Stored DCR registration access token for RFC 7592 operations") } - return tokenSource, nil + r.Config.RemoteAuthConfig.CachedRegClientURI = regClientURI + + if err := r.Config.SaveState(ctx); err != nil { + return fmt.Errorf("failed to save config with client credentials: %w", err) + } + + slog.Debug("Stored DCR client credentials", "client_id", clientID, + "has_expiry", !secretExpiry.IsZero(), + "has_reg_token", regAccessToken != "", + "has_reg_uri", regClientURI != "") + return nil } // Cleanup performs cleanup operations for the runner, including shutting down all middleware. From ed8e219e48f7dc47d34b51aa50bae19be3f5956f Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Thu, 19 Mar 2026 01:55:18 +0530 Subject: [PATCH 3/8] fix: test-coverage Signed-off-by: Sanskarzz --- pkg/auth/remote/secret_renewal_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/auth/remote/secret_renewal_test.go b/pkg/auth/remote/secret_renewal_test.go index 9ad65e0bbf..04ae124764 100644 --- a/pkg/auth/remote/secret_renewal_test.go +++ b/pkg/auth/remote/secret_renewal_test.go @@ -187,7 +187,7 @@ func TestRenewClientSecret_MissingConfig(t *testing.T) { } err := h.renewClientSecret(context.Background()) require.Error(t, err) - assert.Contains(t, err.Error(), "registration_client_uri not available") + assert.Contains(t, err.Error(), "registration_client_uri missing") }) t.Run("missing registration_token_ref", func(t *testing.T) { @@ -201,7 +201,7 @@ func TestRenewClientSecret_MissingConfig(t *testing.T) { } err := h.renewClientSecret(context.Background()) require.Error(t, err) - assert.Contains(t, err.Error(), "registration_access_token not available") + assert.Contains(t, err.Error(), "registration_access_token missing") }) t.Run("missing secret provider", func(t *testing.T) { From 66dc84a5ca2174dc2997cd9db3b9b90fb9224706 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Thu, 19 Mar 2026 02:57:51 +0530 Subject: [PATCH 4/8] fix: swag docs Signed-off-by: Sanskarzz --- docs/server/docs.go | 190 ++++++++++++++++++++------------- docs/server/swagger.json | 225 ++++++++++++++++++++++++++------------- docs/server/swagger.yaml | 165 +++++++++++++++++----------- 3 files changed, 373 insertions(+), 207 deletions(-) diff --git a/docs/server/docs.go b/docs/server/docs.go index 5aa1247d7f..df602489e3 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -8,6 +8,63 @@ const docTemplate = `{ "schemes": {{ marshal .Schemes }}, "components": { "schemas": { + "auth.TokenValidatorConfig": { + "description": "DEPRECATED: Middleware configuration.\nOIDCConfig contains OIDC configuration", + "properties": { + "allowPrivateIP": { + "description": "AllowPrivateIP allows JWKS/OIDC endpoints on private IP addresses", + "type": "boolean" + }, + "audience": { + "description": "Audience is the expected audience for the token", + "type": "string" + }, + "authTokenFile": { + "description": "AuthTokenFile is the path to file containing bearer token for authentication", + "type": "string" + }, + "cacertPath": { + "description": "CACertPath is the path to the CA certificate bundle for HTTPS requests", + "type": "string" + }, + "clientID": { + "description": "ClientID is the OIDC client ID", + "type": "string" + }, + "clientSecret": { + "description": "ClientSecret is the optional OIDC client secret for introspection", + "type": "string" + }, + "insecureAllowHTTP": { + "description": "InsecureAllowHTTP allows HTTP (non-HTTPS) OIDC issuers for development/testing\nWARNING: This is insecure and should NEVER be used in production", + "type": "boolean" + }, + "introspectionURL": { + "description": "IntrospectionURL is the optional introspection endpoint for validating tokens", + "type": "string" + }, + "issuer": { + "description": "Issuer is the OIDC issuer URL (e.g., https://accounts.google.com)", + "type": "string" + }, + "jwksurl": { + "description": "JWKSURL is the URL to fetch the JWKS from", + "type": "string" + }, + "resourceURL": { + "description": "ResourceURL is the explicit resource URL for OAuth discovery (RFC 9728)", + "type": "string" + }, + "scopes": { + "description": "Scopes is the list of OAuth scopes to advertise in the well-known endpoint (RFC 9728)\nIf empty, defaults to [\"openid\"]", + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + }, "github_com_stacklok_toolhive-core_registry_types.Registry": { "description": "Full registry data", "properties": { @@ -114,63 +171,6 @@ const docTemplate = `{ }, "type": "object" }, - "github_com_stacklok_toolhive_pkg_auth.TokenValidatorConfig": { - "description": "DEPRECATED: Middleware configuration.\nOIDCConfig contains OIDC configuration", - "properties": { - "allowPrivateIP": { - "description": "AllowPrivateIP allows JWKS/OIDC endpoints on private IP addresses", - "type": "boolean" - }, - "audience": { - "description": "Audience is the expected audience for the token", - "type": "string" - }, - "authTokenFile": { - "description": "AuthTokenFile is the path to file containing bearer token for authentication", - "type": "string" - }, - "cacertPath": { - "description": "CACertPath is the path to the CA certificate bundle for HTTPS requests", - "type": "string" - }, - "clientID": { - "description": "ClientID is the OIDC client ID", - "type": "string" - }, - "clientSecret": { - "description": "ClientSecret is the optional OIDC client secret for introspection", - "type": "string" - }, - "insecureAllowHTTP": { - "description": "InsecureAllowHTTP allows HTTP (non-HTTPS) OIDC issuers for development/testing\nWARNING: This is insecure and should NEVER be used in production", - "type": "boolean" - }, - "introspectionURL": { - "description": "IntrospectionURL is the optional introspection endpoint for validating tokens", - "type": "string" - }, - "issuer": { - "description": "Issuer is the OIDC issuer URL (e.g., https://accounts.google.com)", - "type": "string" - }, - "jwksurl": { - "description": "JWKSURL is the URL to fetch the JWKS from", - "type": "string" - }, - "resourceURL": { - "description": "ResourceURL is the explicit resource URL for OAuth discovery (RFC 9728)", - "type": "string" - }, - "scopes": { - "description": "Scopes is the list of OAuth scopes to advertise in the well-known endpoint (RFC 9728)\nIf empty, defaults to [\"openid\"]", - "items": { - "type": "string" - }, - "type": "array" - } - }, - "type": "object" - }, "github_com_stacklok_toolhive_pkg_auth_awssts.Config": { "description": "AWSStsConfig contains AWS STS token exchange configuration for accessing AWS services", "properties": { @@ -549,7 +549,7 @@ const docTemplate = `{ "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_authserver.SigningKeyRunConfig" }, "storage": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_authserver_storage.RunConfig" + "$ref": "#/components/schemas/storage.RunConfig" }, "token_lifespans": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_authserver.TokenLifespanRunConfig" @@ -1041,7 +1041,19 @@ const docTemplate = `{ "type": "string" }, "status": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_container_runtime.WorkloadStatus" + "description": "Status is the current status of the workload.", + "enum": [ + "running", + "stopped", + "error", + "starting", + "stopping", + "unhealthy", + "removing", + "unknown", + "unauthenticated" + ], + "type": "string" }, "status_context": { "description": "StatusContext provides additional context about the workload's status.\nThe exact meaning is determined by the status and the underlying runtime.", @@ -1056,7 +1068,14 @@ const docTemplate = `{ "uniqueItems": false }, "transport_type": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.TransportType" + "description": "TransportType is the type of transport used for this workload.", + "enum": [ + "stdio", + "sse", + "streamable-http", + "inspector" + ], + "type": "string" }, "url": { "description": "URL is the URL of the workload exposed by the ToolHive proxy.", @@ -1293,7 +1312,7 @@ const docTemplate = `{ "type": "string" }, "ignore_config": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_ignore.Config" + "$ref": "#/components/schemas/ignore.Config" }, "image": { "description": "Image is the Docker image to run", @@ -1318,7 +1337,7 @@ const docTemplate = `{ "middleware_configs": { "description": "MiddlewareConfigs contains the list of middleware to apply to the transport\nand the configuration for each middleware.", "items": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.MiddlewareConfig" + "$ref": "#/components/schemas/types.MiddlewareConfig" }, "type": "array", "uniqueItems": false @@ -1336,7 +1355,7 @@ const docTemplate = `{ "type": "string" }, "oidc_config": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_auth.TokenValidatorConfig" + "$ref": "#/components/schemas/auth.TokenValidatorConfig" }, "permission_profile_name_or_path": { "description": "PermissionProfileNameOrPath is the name or path of the permission profile", @@ -1347,7 +1366,12 @@ const docTemplate = `{ "type": "integer" }, "proxy_mode": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.ProxyMode" + "description": "ProxyMode is the proxy mode for stdio transport (\"sse\" or \"streamable-http\")\nNote: \"sse\" is deprecated; use \"streamable-http\" instead.", + "enum": [ + "sse", + "streamable-http" + ], + "type": "string" }, "publish": { "description": "Publish lists ports to publish to the host in format \"hostPort:containerPort\"", @@ -1444,7 +1468,14 @@ const docTemplate = `{ "type": "object" }, "transport": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.TransportType" + "description": "Transport is the transport mode (stdio, sse, or streamable-http)", + "enum": [ + "stdio", + "sse", + "streamable-http", + "inspector" + ], + "type": "string" }, "trust_proxy_headers": { "description": "TrustProxyHeaders indicates whether to trust X-Forwarded-* headers from reverse proxies", @@ -1836,15 +1867,16 @@ const docTemplate = `{ }, "type": "object" }, - "github_com_stacklok_toolhive_pkg_transport_types.MiddlewareConfig": { + "ignore.Config": { + "description": "IgnoreConfig contains configuration for ignore processing", "properties": { - "parameters": { - "description": "Parameters is a JSON object containing the middleware parameters.\nIt is stored as a raw message to allow flexible parameter types.", - "type": "object" + "loadGlobal": { + "description": "Whether to load global ignore patterns", + "type": "boolean" }, - "type": { - "description": "Type is a string representing the middleware type.", - "type": "string" + "printOverlays": { + "description": "Whether to print resolved overlay paths for debugging", + "type": "boolean" } }, "type": "object" @@ -3442,7 +3474,19 @@ const docTemplate = `{ "description": "Response containing workload status information", "properties": { "status": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_container_runtime.WorkloadStatus" + "description": "Current status of the workload", + "enum": [ + "running", + "stopped", + "error", + "starting", + "stopping", + "unhealthy", + "removing", + "unknown", + "unauthenticated" + ], + "type": "string" } }, "type": "object" diff --git a/docs/server/swagger.json b/docs/server/swagger.json index a1e29bc48c..f59a859b95 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -1,6 +1,63 @@ { "components": { "schemas": { + "auth.TokenValidatorConfig": { + "description": "DEPRECATED: Middleware configuration.\nOIDCConfig contains OIDC configuration", + "properties": { + "allowPrivateIP": { + "description": "AllowPrivateIP allows JWKS/OIDC endpoints on private IP addresses", + "type": "boolean" + }, + "audience": { + "description": "Audience is the expected audience for the token", + "type": "string" + }, + "authTokenFile": { + "description": "AuthTokenFile is the path to file containing bearer token for authentication", + "type": "string" + }, + "cacertPath": { + "description": "CACertPath is the path to the CA certificate bundle for HTTPS requests", + "type": "string" + }, + "clientID": { + "description": "ClientID is the OIDC client ID", + "type": "string" + }, + "clientSecret": { + "description": "ClientSecret is the optional OIDC client secret for introspection", + "type": "string" + }, + "insecureAllowHTTP": { + "description": "InsecureAllowHTTP allows HTTP (non-HTTPS) OIDC issuers for development/testing\nWARNING: This is insecure and should NEVER be used in production", + "type": "boolean" + }, + "introspectionURL": { + "description": "IntrospectionURL is the optional introspection endpoint for validating tokens", + "type": "string" + }, + "issuer": { + "description": "Issuer is the OIDC issuer URL (e.g., https://accounts.google.com)", + "type": "string" + }, + "jwksurl": { + "description": "JWKSURL is the URL to fetch the JWKS from", + "type": "string" + }, + "resourceURL": { + "description": "ResourceURL is the explicit resource URL for OAuth discovery (RFC 9728)", + "type": "string" + }, + "scopes": { + "description": "Scopes is the list of OAuth scopes to advertise in the well-known endpoint (RFC 9728)\nIf empty, defaults to [\"openid\"]", + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + }, "github_com_stacklok_toolhive-core_registry_types.Registry": { "description": "Full registry data", "properties": { @@ -107,63 +164,6 @@ }, "type": "object" }, - "github_com_stacklok_toolhive_pkg_auth.TokenValidatorConfig": { - "description": "DEPRECATED: Middleware configuration.\nOIDCConfig contains OIDC configuration", - "properties": { - "allowPrivateIP": { - "description": "AllowPrivateIP allows JWKS/OIDC endpoints on private IP addresses", - "type": "boolean" - }, - "audience": { - "description": "Audience is the expected audience for the token", - "type": "string" - }, - "authTokenFile": { - "description": "AuthTokenFile is the path to file containing bearer token for authentication", - "type": "string" - }, - "cacertPath": { - "description": "CACertPath is the path to the CA certificate bundle for HTTPS requests", - "type": "string" - }, - "clientID": { - "description": "ClientID is the OIDC client ID", - "type": "string" - }, - "clientSecret": { - "description": "ClientSecret is the optional OIDC client secret for introspection", - "type": "string" - }, - "insecureAllowHTTP": { - "description": "InsecureAllowHTTP allows HTTP (non-HTTPS) OIDC issuers for development/testing\nWARNING: This is insecure and should NEVER be used in production", - "type": "boolean" - }, - "introspectionURL": { - "description": "IntrospectionURL is the optional introspection endpoint for validating tokens", - "type": "string" - }, - "issuer": { - "description": "Issuer is the OIDC issuer URL (e.g., https://accounts.google.com)", - "type": "string" - }, - "jwksurl": { - "description": "JWKSURL is the URL to fetch the JWKS from", - "type": "string" - }, - "resourceURL": { - "description": "ResourceURL is the explicit resource URL for OAuth discovery (RFC 9728)", - "type": "string" - }, - "scopes": { - "description": "Scopes is the list of OAuth scopes to advertise in the well-known endpoint (RFC 9728)\nIf empty, defaults to [\"openid\"]", - "items": { - "type": "string" - }, - "type": "array" - } - }, - "type": "object" - }, "github_com_stacklok_toolhive_pkg_auth_awssts.Config": { "description": "AWSStsConfig contains AWS STS token exchange configuration for accessing AWS services", "properties": { @@ -542,7 +542,7 @@ "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_authserver.SigningKeyRunConfig" }, "storage": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_authserver_storage.RunConfig" + "$ref": "#/components/schemas/storage.RunConfig" }, "token_lifespans": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_authserver.TokenLifespanRunConfig" @@ -1034,7 +1034,19 @@ "type": "string" }, "status": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_container_runtime.WorkloadStatus" + "description": "Status is the current status of the workload.", + "enum": [ + "running", + "stopped", + "error", + "starting", + "stopping", + "unhealthy", + "removing", + "unknown", + "unauthenticated" + ], + "type": "string" }, "status_context": { "description": "StatusContext provides additional context about the workload's status.\nThe exact meaning is determined by the status and the underlying runtime.", @@ -1049,7 +1061,14 @@ "uniqueItems": false }, "transport_type": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.TransportType" + "description": "TransportType is the type of transport used for this workload.", + "enum": [ + "stdio", + "sse", + "streamable-http", + "inspector" + ], + "type": "string" }, "url": { "description": "URL is the URL of the workload exposed by the ToolHive proxy.", @@ -1286,7 +1305,7 @@ "type": "string" }, "ignore_config": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_ignore.Config" + "$ref": "#/components/schemas/ignore.Config" }, "image": { "description": "Image is the Docker image to run", @@ -1311,7 +1330,15 @@ "middleware_configs": { "description": "MiddlewareConfigs contains the list of middleware to apply to the transport\nand the configuration for each middleware.", "items": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.MiddlewareConfig" + "$ref": "#/components/schemas/types.MiddlewareConfig" + }, + "type": "array", + "uniqueItems": false + }, + "mutating_webhooks": { + "description": "MutatingWebhooks contains the configuration for mutating webhook middleware.\nMutating webhooks run before validating webhooks, per RFC THV-0017 ordering.", + "items": { + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.Config" }, "type": "array", "uniqueItems": false @@ -1329,7 +1356,7 @@ "type": "string" }, "oidc_config": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_auth.TokenValidatorConfig" + "$ref": "#/components/schemas/auth.TokenValidatorConfig" }, "permission_profile_name_or_path": { "description": "PermissionProfileNameOrPath is the name or path of the permission profile", @@ -1340,7 +1367,39 @@ "type": "integer" }, "proxy_mode": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.ProxyMode" + "description": "ProxyMode is the proxy mode for stdio transport (\"sse\" or \"streamable-http\")\nNote: \"sse\" is deprecated; use \"streamable-http\" instead.", + "enum": [ + "sse", + "streamable-http" + ], + "type": "string" + }, + "publish": { + "description": "Publish lists ports to publish to the host in format \"hostPort:containerPort\"", + "items": { + "type": "string" + }, + "type": "array", + "uniqueItems": false + }, + "rate_limit_config": { + "$ref": "#/components/schemas/github_com_stacklok_toolhive_cmd_thv-operator_api_v1beta1.RateLimitConfig" + }, + "rate_limit_namespace": { + "description": "RateLimitNamespace is the Kubernetes namespace for Redis key derivation.", + "type": "string" + }, + "registry_api_url": { + "description": "RegistryAPIURL is the registry API URL that served this server's metadata.\nEmpty when the server was not discovered via registry lookup.", + "type": "string" + }, + "registry_server_name": { + "description": "RegistryServerName is the registry entry name used to look up this server's metadata.\nEmpty when the server was not discovered via registry lookup.", + "type": "string" + }, + "registry_url": { + "description": "RegistryURL is the registry URL that served this server's metadata.\nEmpty when the server was not discovered via registry lookup.", + "type": "string" }, "publish": { "description": "Publish lists ports to publish to the host in format \"hostPort:containerPort\"", @@ -1437,7 +1496,14 @@ "type": "object" }, "transport": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.TransportType" + "description": "Transport is the transport mode (stdio, sse, or streamable-http)", + "enum": [ + "stdio", + "sse", + "streamable-http", + "inspector" + ], + "type": "string" }, "trust_proxy_headers": { "description": "TrustProxyHeaders indicates whether to trust X-Forwarded-* headers from reverse proxies", @@ -1829,15 +1895,16 @@ }, "type": "object" }, - "github_com_stacklok_toolhive_pkg_transport_types.MiddlewareConfig": { + "ignore.Config": { + "description": "IgnoreConfig contains configuration for ignore processing", "properties": { - "parameters": { - "description": "Parameters is a JSON object containing the middleware parameters.\nIt is stored as a raw message to allow flexible parameter types.", - "type": "object" + "loadGlobal": { + "description": "Whether to load global ignore patterns", + "type": "boolean" }, - "type": { - "description": "Type is a string representing the middleware type.", - "type": "string" + "printOverlays": { + "description": "Whether to print resolved overlay paths for debugging", + "type": "boolean" } }, "type": "object" @@ -3435,7 +3502,19 @@ "description": "Response containing workload status information", "properties": { "status": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_container_runtime.WorkloadStatus" + "description": "Current status of the workload", + "enum": [ + "running", + "stopped", + "error", + "starting", + "stopping", + "unhealthy", + "removing", + "unknown", + "unauthenticated" + ], + "type": "string" } }, "type": "object" diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index efb3e5a9d6..d6c1ddf142 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -1,5 +1,57 @@ components: schemas: + auth.TokenValidatorConfig: + description: |- + DEPRECATED: Middleware configuration. + OIDCConfig contains OIDC configuration + properties: + allowPrivateIP: + description: AllowPrivateIP allows JWKS/OIDC endpoints on private IP addresses + type: boolean + audience: + description: Audience is the expected audience for the token + type: string + authTokenFile: + description: AuthTokenFile is the path to file containing bearer token for + authentication + type: string + cacertPath: + description: CACertPath is the path to the CA certificate bundle for HTTPS + requests + type: string + clientID: + description: ClientID is the OIDC client ID + type: string + clientSecret: + description: ClientSecret is the optional OIDC client secret for introspection + type: string + insecureAllowHTTP: + description: |- + InsecureAllowHTTP allows HTTP (non-HTTPS) OIDC issuers for development/testing + WARNING: This is insecure and should NEVER be used in production + type: boolean + introspectionURL: + description: IntrospectionURL is the optional introspection endpoint for + validating tokens + type: string + issuer: + description: Issuer is the OIDC issuer URL (e.g., https://accounts.google.com) + type: string + jwksurl: + description: JWKSURL is the URL to fetch the JWKS from + type: string + resourceURL: + description: ResourceURL is the explicit resource URL for OAuth discovery + (RFC 9728) + type: string + scopes: + description: |- + Scopes is the list of OAuth scopes to advertise in the well-known endpoint (RFC 9728) + If empty, defaults to ["openid"] + items: + type: string + type: array + type: object github_com_stacklok_toolhive-core_registry_types.Registry: description: Full registry data properties: @@ -121,58 +173,6 @@ components: +optional type: integer type: object - github_com_stacklok_toolhive_pkg_auth.TokenValidatorConfig: - description: |- - DEPRECATED: Middleware configuration. - OIDCConfig contains OIDC configuration - properties: - allowPrivateIP: - description: AllowPrivateIP allows JWKS/OIDC endpoints on private IP addresses - type: boolean - audience: - description: Audience is the expected audience for the token - type: string - authTokenFile: - description: AuthTokenFile is the path to file containing bearer token for - authentication - type: string - cacertPath: - description: CACertPath is the path to the CA certificate bundle for HTTPS - requests - type: string - clientID: - description: ClientID is the OIDC client ID - type: string - clientSecret: - description: ClientSecret is the optional OIDC client secret for introspection - type: string - insecureAllowHTTP: - description: |- - InsecureAllowHTTP allows HTTP (non-HTTPS) OIDC issuers for development/testing - WARNING: This is insecure and should NEVER be used in production - type: boolean - introspectionURL: - description: IntrospectionURL is the optional introspection endpoint for - validating tokens - type: string - issuer: - description: Issuer is the OIDC issuer URL (e.g., https://accounts.google.com) - type: string - jwksurl: - description: JWKSURL is the URL to fetch the JWKS from - type: string - resourceURL: - description: ResourceURL is the explicit resource URL for OAuth discovery - (RFC 9728) - type: string - scopes: - description: |- - Scopes is the list of OAuth scopes to advertise in the well-known endpoint (RFC 9728) - If empty, defaults to ["openid"] - items: - type: string - type: array - type: object github_com_stacklok_toolhive_pkg_auth_awssts.Config: description: AWSStsConfig contains AWS STS token exchange configuration for accessing AWS services @@ -611,7 +611,7 @@ components: signing_key_config: $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_authserver.SigningKeyRunConfig' storage: - $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_authserver_storage.RunConfig' + $ref: '#/components/schemas/storage.RunConfig' token_lifespans: $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_authserver.TokenLifespanRunConfig' upstreams: @@ -1078,7 +1078,18 @@ components: restart) type: string status: - $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_container_runtime.WorkloadStatus' + description: Status is the current status of the workload. + enum: + - running + - stopped + - error + - starting + - stopping + - unhealthy + - removing + - unknown + - unauthenticated + type: string status_context: description: |- StatusContext provides additional context about the workload's status. @@ -1091,7 +1102,13 @@ components: type: array uniqueItems: false transport_type: - $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.TransportType' + description: TransportType is the type of transport used for this workload. + enum: + - stdio + - sse + - streamable-http + - inspector + type: string url: description: URL is the URL of the workload exposed by the ToolHive proxy. type: string @@ -1299,7 +1316,7 @@ components: description: Host is the host for the HTTP proxy type: string ignore_config: - $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_ignore.Config' + $ref: '#/components/schemas/ignore.Config' image: description: Image is the Docker image to run type: string @@ -1329,7 +1346,15 @@ components: MiddlewareConfigs contains the list of middleware to apply to the transport and the configuration for each middleware. items: - $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.MiddlewareConfig' + $ref: '#/components/schemas/types.MiddlewareConfig' + type: array + uniqueItems: false + mutating_webhooks: + description: |- + MutatingWebhooks contains the configuration for mutating webhook middleware. + Mutating webhooks run before validating webhooks, per RFC THV-0017 ordering. + items: + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.Config' type: array uniqueItems: false mutating_webhooks: @@ -1344,7 +1369,7 @@ components: description: Name is the name of the MCP server type: string oidc_config: - $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_auth.TokenValidatorConfig' + $ref: '#/components/schemas/auth.TokenValidatorConfig' permission_profile_name_or_path: description: PermissionProfileNameOrPath is the name or path of the permission profile @@ -1451,7 +1476,13 @@ components: ToolsOverride is a map from an actual tool to its overridden name and/or description type: object transport: - $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.TransportType' + description: Transport is the transport mode (stdio, sse, or streamable-http) + enum: + - stdio + - sse + - streamable-http + - inspector + type: string trust_proxy_headers: description: TrustProxyHeaders indicates whether to trust X-Forwarded-* headers from reverse proxies @@ -1808,7 +1839,8 @@ components: +optional type: boolean type: object - github_com_stacklok_toolhive_pkg_transport_types.MiddlewareConfig: + ignore.Config: + description: IgnoreConfig contains configuration for ignore processing properties: parameters: description: |- @@ -3006,7 +3038,18 @@ components: description: Response containing workload status information properties: status: - $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_container_runtime.WorkloadStatus' + description: Current status of the workload + enum: + - running + - stopped + - error + - starting + - stopping + - unhealthy + - removing + - unknown + - unauthenticated + type: string type: object registry.EnvVar: properties: From 89651035e98beb1fd854b3e84e3c430beb9218c4 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Sun, 22 Mar 2026 05:18:48 +0530 Subject: [PATCH 5/8] fix: add tokenEndpointauth Signed-off-by: Sanskarzz --- docs/server/docs.go | 4 + docs/server/swagger.json | 4 + docs/server/swagger.yaml | 5 + pkg/auth/discovery/discovery.go | 4 + pkg/auth/oauth/dynamic_registration_test.go | 748 ++++++++++++++++++++ pkg/auth/remote/config.go | 4 + pkg/auth/remote/handler.go | 1 + pkg/auth/remote/persisting_token_source.go | 2 + pkg/auth/remote/secret_renewal.go | 3 +- pkg/auth/remote/secret_renewal_test.go | 227 +++++- pkg/oauth/constants.go | 70 ++ pkg/runner/runner.go | 7 +- 12 files changed, 1074 insertions(+), 5 deletions(-) create mode 100644 pkg/auth/oauth/dynamic_registration_test.go create mode 100644 pkg/oauth/constants.go diff --git a/docs/server/docs.go b/docs/server/docs.go index df602489e3..57d72de1ff 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -274,6 +274,10 @@ const docTemplate = `{ "description": "ClientSecretExpiresAt indicates when the client secret expires (if provided by the DCR server).\nA zero value means the secret does not expire.", "type": "string" }, + "cached_token_auth_method": { + "description": "CachedTokenEndpointAuthMethod is the auth method used for the token endpoint\n(e.g., \"client_secret_basic\", \"none\"). Persisted for RFC 7592 updates.", + "type": "string" + }, "cached_token_expiry": { "type": "string" }, diff --git a/docs/server/swagger.json b/docs/server/swagger.json index f59a859b95..b6cff39120 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -267,6 +267,10 @@ "description": "ClientSecretExpiresAt indicates when the client secret expires (if provided by the DCR server).\nA zero value means the secret does not expire.", "type": "string" }, + "cached_token_auth_method": { + "description": "CachedTokenEndpointAuthMethod is the auth method used for the token endpoint\n(e.g., \"client_secret_basic\", \"none\"). Persisted for RFC 7592 updates.", + "type": "string" + }, "cached_token_expiry": { "type": "string" }, diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index d6c1ddf142..72d029647c 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -291,6 +291,11 @@ components: ClientSecretExpiresAt indicates when the client secret expires (if provided by the DCR server). A zero value means the secret does not expire. type: string + cached_token_auth_method: + description: |- + CachedTokenEndpointAuthMethod is the auth method used for the token endpoint + (e.g., "client_secret_basic", "none"). Persisted for RFC 7592 updates. + type: string cached_token_expiry: type: string callback_port: diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index e240d68ef9..32acddce02 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -518,6 +518,7 @@ type OAuthFlowConfig struct { SecretExpiry time.Time // zero means the secret never expires RegistrationAccessToken string //nolint:gosec // G117: field legitimately holds sensitive data RegistrationClientURI string + TokenEndpointAuthMethod string } // OAuthFlowResult contains the result of an OAuth flow @@ -541,6 +542,7 @@ type OAuthFlowResult struct { SecretExpiry time.Time RegistrationAccessToken string //nolint:gosec // G117: field legitimately holds sensitive data RegistrationClientURI string + TokenEndpointAuthMethod string } func shouldDynamicallyRegisterClient(config *OAuthFlowConfig) bool { @@ -693,6 +695,7 @@ func handleDynamicRegistration(ctx context.Context, issuer string, config *OAuth } config.RegistrationAccessToken = registrationResponse.RegistrationAccessToken config.RegistrationClientURI = registrationResponse.RegistrationClientURI + config.TokenEndpointAuthMethod = registrationResponse.TokenEndpointAuthMethod if registrationResponse.RegistrationAccessToken != "" { slog.Debug("DCR response includes registration access token for RFC 7592 operations") @@ -934,6 +937,7 @@ func newOAuthFlow(ctx context.Context, oauthConfig *oauth.Config, config *OAuthF SecretExpiry: config.SecretExpiry, RegistrationAccessToken: config.RegistrationAccessToken, RegistrationClientURI: config.RegistrationClientURI, + TokenEndpointAuthMethod: config.TokenEndpointAuthMethod, }, nil } diff --git a/pkg/auth/oauth/dynamic_registration_test.go b/pkg/auth/oauth/dynamic_registration_test.go new file mode 100644 index 0000000000..db195fe764 --- /dev/null +++ b/pkg/auth/oauth/dynamic_registration_test.go @@ -0,0 +1,748 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package oauth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + oauthproto "github.com/stacklok/toolhive/pkg/oauth" +) + +func TestDiscoverOIDCEndpointsWithRegistration(t *testing.T) { + t.Parallel() + tests := []struct { + name string + issuer string + response string + expectedError bool + expectedResult *oauthproto.OIDCDiscoveryDocument + }{ + { + name: "valid OIDC discovery with registration endpoint", + issuer: "https://example.com", + response: `{ + "issuer": "{{SERVER_URL}}", + "authorization_endpoint": "{{SERVER_URL}}/oauth/authorize", + "token_endpoint": "{{SERVER_URL}}/oauth/token", + "userinfo_endpoint": "{{SERVER_URL}}/oauth/userinfo", + "jwks_uri": "{{SERVER_URL}}/oauth/jwks", + "registration_endpoint": "{{SERVER_URL}}/oauth/register" + }`, + expectedError: false, + expectedResult: &oauthproto.OIDCDiscoveryDocument{ + AuthorizationServerMetadata: oauthproto.AuthorizationServerMetadata{ + Issuer: "https://example.com", + AuthorizationEndpoint: "https://example.com/oauth/authorize", + TokenEndpoint: "https://example.com/oauth/token", + UserinfoEndpoint: "https://example.com/oauth/userinfo", + JWKSURI: "https://example.com/oauth/jwks", + RegistrationEndpoint: "https://example.com/oauth/register", + }, + }, + }, + { + name: "valid OIDC discovery without registration endpoint", + issuer: "https://example.com", + response: `{ + "issuer": "{{SERVER_URL}}", + "authorization_endpoint": "{{SERVER_URL}}/oauth/authorize", + "token_endpoint": "{{SERVER_URL}}/oauth/token", + "userinfo_endpoint": "{{SERVER_URL}}/oauth/userinfo", + "jwks_uri": "{{SERVER_URL}}/oauth/jwks" + }`, + expectedError: false, + expectedResult: &oauthproto.OIDCDiscoveryDocument{ + AuthorizationServerMetadata: oauthproto.AuthorizationServerMetadata{ + Issuer: "https://example.com", + AuthorizationEndpoint: "https://example.com/oauth/authorize", + TokenEndpoint: "https://example.com/oauth/token", + UserinfoEndpoint: "https://example.com/oauth/userinfo", + JWKSURI: "https://example.com/oauth/jwks", + RegistrationEndpoint: "", + }, + }, + }, + { + name: "invalid issuer URL", + issuer: "not-a-url", + expectedError: true, + }, + { + name: "non-HTTPS issuer", + issuer: "http://example.com", + expectedError: true, + }, + { + name: "localhost HTTP allowed for development", + issuer: "http://localhost:8080", + response: `{ + "issuer": "{{SERVER_URL}}", + "authorization_endpoint": "{{SERVER_URL}}/oauth/authorize", + "token_endpoint": "{{SERVER_URL}}/oauth/token", + "userinfo_endpoint": "{{SERVER_URL}}/oauth/userinfo", + "jwks_uri": "{{SERVER_URL}}/oauth/jwks", + "registration_endpoint": "{{SERVER_URL}}/oauth/register" + }`, + expectedError: false, + expectedResult: &oauthproto.OIDCDiscoveryDocument{ + AuthorizationServerMetadata: oauthproto.AuthorizationServerMetadata{ + Issuer: "http://localhost:8080", + AuthorizationEndpoint: "http://localhost:8080/oauth/authorize", + TokenEndpoint: "http://localhost:8080/oauth/token", + UserinfoEndpoint: "http://localhost:8080/oauth/userinfo", + JWKSURI: "http://localhost:8080/oauth/jwks", + RegistrationEndpoint: "http://localhost:8080/oauth/register", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var server *httptest.Server + var responseTemplate string + + if tt.response != "" { + responseTemplate = tt.response + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle both OIDC and OAuth discovery endpoints + if r.URL.Path == oauthproto.WellKnownOIDCPath || + r.URL.Path == oauthproto.WellKnownOAuthServerPath { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + // Replace placeholder with actual server URL + response := strings.ReplaceAll(responseTemplate, "{{SERVER_URL}}", server.URL) + w.Write([]byte(response)) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + } + + issuer := tt.issuer + if server != nil { + // For test server, use the actual server URL + issuer = server.URL + } + + result, err := DiscoverOIDCEndpoints(context.Background(), issuer) + + if tt.expectedError { + assert.Error(t, err) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + assert.NotNil(t, result) + if server != nil { + // For test server, we can't predict the exact URLs, so just check structure + assert.NotEmpty(t, result.Issuer) + assert.NotEmpty(t, result.AuthorizationEndpoint) + assert.NotEmpty(t, result.TokenEndpoint) + if tt.expectedResult.RegistrationEndpoint != "" { + assert.NotEmpty(t, result.RegistrationEndpoint) + } + } else { + // For static tests, check exact values + assert.Equal(t, tt.expectedResult.Issuer, result.Issuer) + assert.Equal(t, tt.expectedResult.AuthorizationEndpoint, result.AuthorizationEndpoint) + assert.Equal(t, tt.expectedResult.TokenEndpoint, result.TokenEndpoint) + assert.Equal(t, tt.expectedResult.RegistrationEndpoint, result.RegistrationEndpoint) + } + } + }) + } +} + +func TestNewDynamicClientRegistrationRequest(t *testing.T) { + t.Parallel() + tests := []struct { + name string + scopes []string + callbackPort int + expected *DynamicClientRegistrationRequest + }{ + { + name: "basic request", + scopes: []string{"openid", "profile"}, + callbackPort: 8080, + expected: &DynamicClientRegistrationRequest{ + ClientName: "ToolHive MCP Client", + RedirectURIs: []string{"http://localhost:8080/callback"}, + TokenEndpointAuthMethod: "client_secret_post", + GrantTypes: []string{"authorization_code", "refresh_token"}, + ResponseTypes: []string{"code"}, + Scopes: []string{"openid", "profile"}, + }, + }, + { + name: "empty scopes", + scopes: []string{}, + callbackPort: 8666, + expected: &DynamicClientRegistrationRequest{ + ClientName: "ToolHive MCP Client", + RedirectURIs: []string{"http://localhost:8666/callback"}, + TokenEndpointAuthMethod: "client_secret_post", + GrantTypes: []string{"authorization_code", "refresh_token"}, + ResponseTypes: []string{"code"}, + Scopes: []string{}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := NewDynamicClientRegistrationRequest(tt.scopes, tt.callbackPort) + + assert.Equal(t, tt.expected.ClientName, result.ClientName) + assert.Equal(t, tt.expected.RedirectURIs, result.RedirectURIs) + assert.Equal(t, tt.expected.TokenEndpointAuthMethod, result.TokenEndpointAuthMethod) + assert.Equal(t, tt.expected.GrantTypes, result.GrantTypes) + assert.Equal(t, tt.expected.ResponseTypes, result.ResponseTypes) + assert.Equal(t, tt.expected.Scopes, result.Scopes) + }) + } +} + +func TestDynamicClientRegistrationRequest_ScopeSerialization(t *testing.T) { + t.Parallel() + + // This test verifies RFC 7591 Section 2 compliance for scope serialization. + // Per the spec, scopes MUST be serialized as a space-delimited string, not a JSON array. + // Empty/nil scopes should result in the scope field being omitted entirely (omitempty), + // which is RFC 7591 compliant since the scope parameter is optional. + + tests := []struct { + name string + scopes []string + shouldOmitScope bool + expectedScopeJSON string // Expected scope field in JSON, empty if omitted + }{ + { + name: "nil scopes should omit scope field entirely", + scopes: nil, + shouldOmitScope: true, + }, + { + name: "empty slice scopes should omit scope field entirely", + scopes: []string{}, + shouldOmitScope: true, + }, + { + name: "single scope should be space-delimited string per RFC 7591", + scopes: []string{"openid"}, + shouldOmitScope: false, + expectedScopeJSON: `"scope":"openid"`, + }, + { + name: "multiple scopes should be space-delimited string per RFC 7591", + scopes: []string{"openid", "profile"}, + shouldOmitScope: false, + expectedScopeJSON: `"scope":"openid profile"`, + }, + { + name: "three scopes should be space-delimited string per RFC 7591", + scopes: []string{"openid", "profile", "email"}, + shouldOmitScope: false, + expectedScopeJSON: `"scope":"openid profile email"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Create request with specified scopes + request := NewDynamicClientRegistrationRequest(tt.scopes, 8666) + + // Marshal to JSON + jsonBytes, err := json.Marshal(request) + require.NoError(t, err, "JSON marshaling should succeed") + + jsonStr := string(jsonBytes) + + // Verify scope field behavior + if tt.shouldOmitScope { + assert.NotContains(t, jsonStr, `"scope"`, + "JSON should NOT contain scope field when scopes are empty/nil (omitempty behavior)") + } else { + assert.Contains(t, jsonStr, tt.expectedScopeJSON, + "JSON should contain expected scope field") + } + + // Verify other required fields are always present + assert.Contains(t, jsonStr, `"redirect_uris"`, "redirect_uris should be present") + assert.Contains(t, jsonStr, `"client_name"`, "client_name should be present") + assert.Contains(t, jsonStr, `"grant_types"`, "grant_types should be present") + }) + } +} + +func TestRegisterClientDynamically(t *testing.T) { + t.Parallel() + tests := []struct { + name string + request *DynamicClientRegistrationRequest + response string + responseStatus int + expectedError bool + expectedResult *DynamicClientRegistrationResponse + }{ + { + name: "successful registration", + request: &DynamicClientRegistrationRequest{ + ClientName: "Test Client", + RedirectURIs: []string{"http://localhost:8080/callback"}, + TokenEndpointAuthMethod: "none", + GrantTypes: []string{"authorization_code"}, + ResponseTypes: []string{"code"}, + Scopes: []string{"openid", "profile"}, + }, + response: `{ + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "client_id_issued_at": 1234567890, + "client_secret_expires_at": 0, + "registration_access_token": "reg-token", + "registration_client_uri": "https://example.com/oauth/register/test-client-id" + }`, + responseStatus: http.StatusCreated, + expectedError: false, + expectedResult: &DynamicClientRegistrationResponse{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + ClientIDIssuedAt: 1234567890, + ClientSecretExpiresAt: 0, + RegistrationAccessToken: "reg-token", + RegistrationClientURI: "https://example.com/oauth/register/test-client-id", + }, + }, + { + name: "registration without client secret (PKCE flow)", + request: &DynamicClientRegistrationRequest{ + ClientName: "Test Client", + RedirectURIs: []string{"http://localhost:8080/callback"}, + TokenEndpointAuthMethod: "none", + GrantTypes: []string{"authorization_code"}, + ResponseTypes: []string{"code"}, + }, + response: `{ + "client_id": "test-client-id", + "client_id_issued_at": 1234567890 + }`, + responseStatus: http.StatusCreated, + expectedError: false, + expectedResult: &DynamicClientRegistrationResponse{ + ClientID: "test-client-id", + ClientIDIssuedAt: 1234567890, + }, + }, + { + name: "server error", + request: &DynamicClientRegistrationRequest{ + ClientName: "Test Client", + RedirectURIs: []string{"http://localhost:8080/callback"}, + }, + response: `{"error": "invalid_request", "error_description": "Invalid request"}`, + responseStatus: http.StatusBadRequest, + expectedError: true, + }, + { + name: "DCR not supported - 404 Not Found", + request: &DynamicClientRegistrationRequest{ + ClientName: "Test Client", + RedirectURIs: []string{"http://localhost:8080/callback"}, + }, + response: `{"error": "not_found"}`, + responseStatus: http.StatusNotFound, + expectedError: true, + }, + { + name: "DCR not supported - 405 Method Not Allowed", + request: &DynamicClientRegistrationRequest{ + ClientName: "Test Client", + RedirectURIs: []string{"http://localhost:8080/callback"}, + }, + response: `{"error": "method_not_allowed"}`, + responseStatus: http.StatusMethodNotAllowed, + expectedError: true, + }, + { + name: "DCR not supported - 501 Not Implemented", + request: &DynamicClientRegistrationRequest{ + ClientName: "Test Client", + RedirectURIs: []string{"http://localhost:8080/callback"}, + }, + response: `{"error": "not_implemented", "error_description": "Dynamic Client Registration is not supported"}`, + responseStatus: http.StatusNotImplemented, + expectedError: true, + }, + { + name: "invalid request - no redirect URIs", + request: &DynamicClientRegistrationRequest{ + ClientName: "Test Client", + }, + expectedError: true, + }, + { + name: "invalid request - scope with spaces", + request: &DynamicClientRegistrationRequest{ + ClientName: "Test Client", + RedirectURIs: []string{"http://localhost:8080/callback"}, + Scopes: []string{"openid", "profile email", "another"}, + }, + expectedError: true, + }, + { + name: "invalid request - scope with leading space", + request: &DynamicClientRegistrationRequest{ + ClientName: "Test Client", + RedirectURIs: []string{"http://localhost:8080/callback"}, + Scopes: []string{" openid"}, + }, + expectedError: true, + }, + { + name: "invalid request - scope with trailing space", + request: &DynamicClientRegistrationRequest{ + ClientName: "Test Client", + RedirectURIs: []string{"http://localhost:8080/callback"}, + Scopes: []string{"openid "}, + }, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var server *httptest.Server + if tt.response != "" { + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + assert.Equal(t, "application/json", r.Header.Get("Accept")) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(tt.responseStatus) + w.Write([]byte(tt.response)) + })) + defer server.Close() + } + + var registrationEndpoint string + if server != nil { + registrationEndpoint = server.URL + } else { + registrationEndpoint = "https://example.com/oauth/register" + } + + result, err := RegisterClientDynamically(context.Background(), registrationEndpoint, tt.request) + + if tt.expectedError { + assert.Error(t, err) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, tt.expectedResult.ClientID, result.ClientID) + assert.Equal(t, tt.expectedResult.ClientSecret, result.ClientSecret) + assert.Equal(t, tt.expectedResult.ClientIDIssuedAt, result.ClientIDIssuedAt) + assert.Equal(t, tt.expectedResult.RegistrationAccessToken, result.RegistrationAccessToken) + assert.Equal(t, tt.expectedResult.RegistrationClientURI, result.RegistrationClientURI) + } + }) + } +} + +func TestDynamicClientRegistrationRequest_Defaults(t *testing.T) { + t.Parallel() + // Test that default values are set correctly + request := &DynamicClientRegistrationRequest{ + RedirectURIs: []string{"http://localhost:8080/callback"}, + } + + // Serialize to JSON to verify defaults + data, err := json.Marshal(request) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + // Verify that required fields are present + assert.Contains(t, result, "redirect_uris") + assert.Equal(t, []interface{}{"http://localhost:8080/callback"}, result["redirect_uris"]) +} + +// TestDynamicClientRegistrationResponse_Validation tests that the response validation works correctly +func TestDynamicClientRegistrationResponse_Validation(t *testing.T) { + t.Parallel() + // Test that response validation works correctly + validResponse := &DynamicClientRegistrationResponse{ + ClientID: "test-client-id", + } + + // Serialize to JSON + data, err := json.Marshal(validResponse) + require.NoError(t, err) + + var result DynamicClientRegistrationResponse + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + assert.Equal(t, "test-client-id", result.ClientID) +} + +func TestDiscoverOIDCEndpointsWithRegistrationFallback(t *testing.T) { + t.Parallel() + + // Test case: OIDC well-known succeeds but lacks registration_endpoint, + // OAuth authorization server well-known has it + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + baseURL := "http://" + r.Host + switch r.URL.Path { + case oauthproto.WellKnownOIDCPath: + // OIDC discovery - no registration_endpoint + response := `{ + "issuer": "` + baseURL + `", + "authorization_endpoint": "` + baseURL + `/oauth/authorize", + "token_endpoint": "` + baseURL + `/oauth/token", + "userinfo_endpoint": "` + baseURL + `/oauth/userinfo", + "jwks_uri": "` + baseURL + `/oauth/jwks" + }` + w.WriteHeader(http.StatusOK) + w.Write([]byte(response)) + case oauthproto.WellKnownOAuthServerPath: + // OAuth authorization server - has registration_endpoint + response := `{ + "issuer": "` + baseURL + `", + "authorization_endpoint": "` + baseURL + `/oauth/authorize", + "token_endpoint": "` + baseURL + `/oauth/token", + "registration_endpoint": "` + baseURL + `/oauth/register" + }` + w.WriteHeader(http.StatusOK) + w.Write([]byte(response)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + result, err := DiscoverOIDCEndpoints(context.Background(), server.URL) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, server.URL, result.Issuer) + assert.NotEmpty(t, result.AuthorizationEndpoint) + assert.NotEmpty(t, result.TokenEndpoint) + // Registration endpoint should be found from OAuth authorization server well-known + assert.NotEmpty(t, result.RegistrationEndpoint, "registration_endpoint should be found via OAuth authorization server fallback") + assert.Equal(t, server.URL+"/oauth/register", result.RegistrationEndpoint) +} + +func TestDiscoverOIDCEndpointsWithRegistrationFallbackIssuerMismatch(t *testing.T) { + t.Parallel() + + // Test case: OIDC and OAuth have different issuers - should not merge + // Use DiscoverActualIssuer which doesn't validate issuer, allowing us to test the merge logic + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + baseURL := "http://" + r.Host + switch r.URL.Path { + case oauthproto.WellKnownOIDCPath: + // OIDC discovery - no registration_endpoint, different issuer + response := `{ + "issuer": "https://oidc.example.com", + "authorization_endpoint": "` + baseURL + `/oauth/authorize", + "token_endpoint": "` + baseURL + `/oauth/token", + "userinfo_endpoint": "` + baseURL + `/oauth/userinfo", + "jwks_uri": "` + baseURL + `/oauth/jwks" + }` + w.WriteHeader(http.StatusOK) + w.Write([]byte(response)) + case oauthproto.WellKnownOAuthServerPath: + // OAuth authorization server - has registration_endpoint but different issuer + response := `{ + "issuer": "https://oauth.example.com", + "authorization_endpoint": "` + baseURL + `/oauth/authorize", + "token_endpoint": "` + baseURL + `/oauth/token", + "registration_endpoint": "` + baseURL + `/oauth/register" + }` + w.WriteHeader(http.StatusOK) + w.Write([]byte(response)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + // Use DiscoverActualIssuer which doesn't validate issuer, allowing us to test merge logic + result, err := DiscoverActualIssuer(context.Background(), server.URL) + + require.NoError(t, err) + require.NotNil(t, result) + // Registration endpoint should NOT be merged due to issuer mismatch + assert.Empty(t, result.RegistrationEndpoint, "registration_endpoint should not be merged when issuers don't match") +} + +// TestIsLocalhost is already defined in oidc_test.go + +// TestScopeList_MarshalJSON tests that the ScopeList marshaling works correctly +// and produces RFC 7591 compliant space-delimited strings. +func TestScopeList_MarshalJSON(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + scopes ScopeList + wantJSON string + wantOmit bool // If true, expect omitempty to hide the field + }{ + { + name: "nil scopes => empty string (omitempty will hide at struct level)", + scopes: nil, + wantJSON: `""`, + wantOmit: true, + }, + { + name: "empty slice => empty string (omitempty will hide at struct level)", + scopes: ScopeList{}, + wantJSON: `""`, + wantOmit: true, + }, + { + name: "single scope => string", + scopes: ScopeList{"openid"}, + wantJSON: `"openid"`, + }, + { + name: "two scopes => space-delimited string", + scopes: ScopeList{"openid", "profile"}, + wantJSON: `"openid profile"`, + }, + { + name: "three scopes => space-delimited string", + scopes: ScopeList{"openid", "profile", "email"}, + wantJSON: `"openid profile email"`, + }, + { + name: "scopes with special characters", + scopes: ScopeList{"read:user", "write:repo"}, + wantJSON: `"read:user write:repo"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + jsonBytes, err := json.Marshal(tt.scopes) + require.NoError(t, err, "marshaling should succeed") + + jsonStr := string(jsonBytes) + assert.Equal(t, tt.wantJSON, jsonStr, "marshaled JSON should match expected format") + + // Verify omitempty behavior in a struct + // Note: omitempty checks the Go value (empty slice) before calling MarshalJSON, + // so empty slices are omitted regardless of what MarshalJSON returns. + if tt.wantOmit { + type testStruct struct { + Scope ScopeList `json:"scope,omitempty"` + } + s := testStruct{Scope: tt.scopes} + structJSON, err := json.Marshal(s) + require.NoError(t, err) + assert.Equal(t, "{}", string(structJSON), "omitempty should hide empty scope field") + } + }) + } +} + +// TestScopeList_UnmarshalJSON tests that the ScopeList unmarshaling works correctly. +func TestScopeList_UnmarshalJSON(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + jsonIn string + want []string + wantErr bool + }{ + { + name: "space-delimited string", + jsonIn: `"openid profile email"`, + want: []string{"openid", "profile", "email"}, + }, + { + name: "empty string => nil", + jsonIn: `""`, + want: nil, + }, + { + name: "string with extra spaces", + jsonIn: `" openid profile "`, + want: []string{"openid", "profile"}, + }, + { + name: "normal array", + jsonIn: `["openid","profile","email"]`, + want: []string{"openid", "profile", "email"}, + }, + { + name: "array with whitespace and empties", + jsonIn: `[" openid ",""," profile "]`, + want: []string{"openid", "profile"}, + }, + { + name: "all-empty array => nil", + jsonIn: `[""," "]`, + want: nil, + }, + { + name: "explicit null => nil", + jsonIn: `null`, + want: nil, + }, + { + name: "invalid type (number)", + jsonIn: `123`, + wantErr: true, + }, + { + name: "invalid type (object)", + jsonIn: `{"not":"valid"}`, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt // capture loop variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var s ScopeList + err := json.Unmarshal([]byte(tt.jsonIn), &s) + + if tt.wantErr { + assert.Error(t, err, "expected error but got none") + return + } + + assert.NoError(t, err, "unexpected unmarshal error") + assert.Equal(t, tt.want, []string(s)) + }) + } +} diff --git a/pkg/auth/remote/config.go b/pkg/auth/remote/config.go index 7b66b6649e..b88e93566a 100644 --- a/pkg/auth/remote/config.go +++ b/pkg/auth/remote/config.go @@ -83,6 +83,9 @@ type Config struct { // This is the endpoint used for RFC 7592 client read/update/delete operations. // Stored as plain text since it is not sensitive. CachedRegClientURI string `json:"cached_reg_client_uri,omitempty" yaml:"cached_reg_client_uri,omitempty"` + // CachedTokenEndpointAuthMethod is the auth method used for the token endpoint + // (e.g., "client_secret_basic", "none"). Persisted for RFC 7592 updates. + CachedTokenEndpointAuthMethod string `json:"cached_token_auth_method,omitempty" yaml:"cached_token_auth_method,omitempty"` } // BearerTokenEnvVarName is the environment variable name used for bearer token authentication. @@ -190,6 +193,7 @@ func (c *Config) ClearCachedClientCredentials() { c.CachedSecretExpiry = time.Time{} c.CachedRegTokenRef = "" c.CachedRegClientURI = "" + c.CachedTokenEndpointAuthMethod = "" } // LogContext returns the upstream issuer and resolved client_id for use as diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index dd63792c9d..c16a2459e5 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -214,6 +214,7 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. result.SecretExpiry, result.RegistrationAccessToken, result.RegistrationClientURI, + result.TokenEndpointAuthMethod, ); err != nil { slog.Warn("Failed to persist DCR client credentials", "error", err) } else { diff --git a/pkg/auth/remote/persisting_token_source.go b/pkg/auth/remote/persisting_token_source.go index 5777ba6edf..99f810edf2 100644 --- a/pkg/auth/remote/persisting_token_source.go +++ b/pkg/auth/remote/persisting_token_source.go @@ -30,12 +30,14 @@ type TokenPersister func(refreshToken string, expiry time.Time) error // - secretExpiry: when the client secret expires; zero value means it never expires // - registrationAccessToken: bearer token for RFC 7592 management operations (sensitive) // - registrationClientURI: endpoint for RFC 7592 client update/read operations (plain text) +// - tokenEndpointAuthMethod: the auth method used for the token endpoint (e.g., "client_secret_basic", "none") type ClientCredentialsPersister func( clientID string, clientSecret string, secretExpiry time.Time, registrationAccessToken string, registrationClientURI string, + tokenEndpointAuthMethod string, ) error // PersistingTokenSource wraps an oauth2.TokenSource and persists tokens diff --git a/pkg/auth/remote/secret_renewal.go b/pkg/auth/remote/secret_renewal.go index 01c1a58929..fa2f167dee 100644 --- a/pkg/auth/remote/secret_renewal.go +++ b/pkg/auth/remote/secret_renewal.go @@ -98,7 +98,7 @@ func (h *Handler) renewClientSecret(ctx context.Context) error { RedirectURIs: []string{fmt.Sprintf("http://localhost:%d/callback", h.config.CallbackPort)}, GrantTypes: []string{"authorization_code", "refresh_token"}, ResponseTypes: []string{"code"}, - TokenEndpointAuthMethod: "none", + TokenEndpointAuthMethod: h.config.CachedTokenEndpointAuthMethod, } reqBody, err := json.Marshal(updateReq) @@ -198,6 +198,7 @@ func (h *Handler) persistRenewedSecret(updateResp clientUpdateResponse) error { newExpiry, newRegToken, newRegURI, + h.config.CachedTokenEndpointAuthMethod, ); err != nil { return fmt.Errorf("failed to persist renewed client secret: %w", err) } diff --git a/pkg/auth/remote/secret_renewal_test.go b/pkg/auth/remote/secret_renewal_test.go index 04ae124764..0848cf835d 100644 --- a/pkg/auth/remote/secret_renewal_test.go +++ b/pkg/auth/remote/secret_renewal_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -263,7 +264,7 @@ func TestRenewClientSecret_Success(t *testing.T) { clientCredentialsPersister: func( clientID, secret string, expiry time.Time, - regToken, regURI string, + regToken, regURI, _ string, ) error { persistedClientID = clientID persistedSecret = secret @@ -303,7 +304,7 @@ func TestRenewClientSecret_ServerError(t *testing.T) { secretProvider: newMockSecretProvider(map[string]string{ "reg-token-ref": "bad-token", }), - clientCredentialsPersister: func(_, _ string, _ time.Time, _, _ string) error { + clientCredentialsPersister: func(_, _ string, _ time.Time, _, _, _ string) error { return nil }, } @@ -371,7 +372,7 @@ func TestRenewClientSecret_ZeroExpiryInResponse(t *testing.T) { secretProvider: newMockSecretProvider(map[string]string{ "reg-token-ref": "some-token", }), - clientCredentialsPersister: func(_, _ string, expiry time.Time, _, _ string) error { + clientCredentialsPersister: func(_, _ string, expiry time.Time, _, _, _ string) error { capturedExpiry = expiry return nil }, @@ -381,3 +382,223 @@ func TestRenewClientSecret_ZeroExpiryInResponse(t *testing.T) { require.NoError(t, err) assert.True(t, capturedExpiry.IsZero(), "zero client_secret_expires_at must produce zero time.Time") } + +func TestRenewClientSecret_MalformedJSON(t *testing.T) { + t.Parallel() + + svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{invalid-json`)) + })) + defer svc.Close() + + h := &Handler{ + config: &Config{ + CachedRegClientURI: svc.URL, + CachedRegTokenRef: "rat-ref", + }, + secretProvider: &mockSecretProvider{ + secrets: map[string]string{"rat-ref": "rat-token"}, + }, + } + + err := h.renewClientSecret(context.Background()) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to decode client update response") { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestRenewClientSecret_MissingFields(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + response string + wantErr string + }{ + { + name: "missing_client_id", + response: `{"client_secret": "new-secret"}`, + wantErr: "client update response missing client_id", + }, + { + name: "missing_client_secret", + response: `{"client_id": "test-client-id"}`, + wantErr: "client update response missing client_secret", + }, + } + + for _, tt := range tests { + tt := tt // capture loop variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(tt.response)) + })) + defer svc.Close() + + h := &Handler{ + config: &Config{ + CachedRegClientURI: svc.URL, + CachedRegTokenRef: "rat-ref", + }, + secretProvider: &mockSecretProvider{ + secrets: map[string]string{"rat-ref": "rat-token"}, + }, + } + + err := h.renewClientSecret(context.Background()) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("unexpected error message: %v", err) + } + }) + } +} + +func TestValidateRegistrationClientURI_Internal(t *testing.T) { + t.Parallel() + tests := []struct { + name string + uri string + wantErr bool + }{ + {"empty", "", true}, + {"malformed", "://foo", true}, + {"http_external", "http://example.com/reg", true}, + {"https_external", "https://example.com/reg", false}, + {"http_localhost", "http://localhost:8080/reg", false}, + {"http_127_0_0_1", "http://127.0.0.1:8080/reg", false}, + } + + for _, tt := range tests { + tt := tt // capture loop variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateRegistrationClientURI(tt.uri) + if (err != nil) != tt.wantErr { + t.Errorf("validateRegistrationClientURI() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestHandler_Restore_RenewSuccess(t *testing.T) { + t.Parallel() + + // Initial setup: secret expiring in 1 hour + expiry := time.Now().Add(1 * time.Hour) + svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"client_id": "test-client", "client_secret": "new-secret"}`)) + })) + defer svc.Close() + + var persistedID, persistedSecret string + h := &Handler{ + config: &Config{ + CachedClientID: "test-client", + CachedSecretExpiry: expiry, + CachedRegClientURI: svc.URL, + CachedRegTokenRef: "rat-ref", + CachedRefreshTokenRef: "refresh-token-ref", + }, + secretProvider: &mockSecretProvider{ + secrets: map[string]string{ + "rat-ref": "rat-token", + "refresh-token-ref": "some-refresh-token", + }, + }, + clientCredentialsPersister: func(id, secret string, _ time.Time, _, _, _ string) error { + persistedID = id + persistedSecret = secret + return nil + }, + } + + // Calling tryRestoreFromCachedTokens should trigger renewal because of the 1h expiry. + // We expect an error because it will try to refresh the token and fail (no token endpoint). + _, err := h.tryRestoreFromCachedTokens(context.Background(), "http://issuer", nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "cached tokens are invalid or expired") + + // But renewal DID happen + assert.Equal(t, "test-client", persistedID) + assert.Equal(t, "new-secret", persistedSecret) +} + +func TestHandler_Restore_RenewFail_Soft(t *testing.T) { + t.Parallel() + + // Initial setup: secret expiring in 1 hour + expiry := time.Now().Add(1 * time.Hour) + svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer svc.Close() + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client", + CachedSecretExpiry: expiry, + CachedRegClientURI: svc.URL, + CachedRegTokenRef: "rat-ref", + CachedRefreshTokenRef: "refresh-token-ref", + }, + secretProvider: &mockSecretProvider{ + secrets: map[string]string{ + "rat-ref": "rat-token", + "refresh-token-ref": "some-refresh-token", + }, + }, + clientCredentialsPersister: func(_, _ string, _ time.Time, _, _, _ string) error { return nil }, + } + + // Renewal fails, but since it's only "expiring soon", it should continue (and then fail on token refresh) + _, err := h.tryRestoreFromCachedTokens(context.Background(), "http://issuer", nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "cached tokens are invalid or expired") +} + +func TestHandler_Restore_RenewFail_Hard(t *testing.T) { + t.Parallel() + + // Initial setup: secret already expired + expiry := time.Now().Add(-1 * time.Hour) + svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer svc.Close() + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client", + CachedSecretExpiry: expiry, + CachedRegClientURI: svc.URL, + CachedRegTokenRef: "rat-ref", + CachedRefreshTokenRef: "refresh-token-ref", + }, + secretProvider: &mockSecretProvider{ + secrets: map[string]string{ + "rat-ref": "rat-token", + "refresh-token-ref": "some-refresh-token", + }, + }, + clientCredentialsPersister: func(string, string, time.Time, string, string, string) error { return nil }, + } + + // Renewal fails and it's fully expired -> fatal error + _, err := h.tryRestoreFromCachedTokens(context.Background(), "http://issuer", nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "client secret expired at") + assert.Contains(t, err.Error(), "and renewal failed") +} diff --git a/pkg/oauth/constants.go b/pkg/oauth/constants.go new file mode 100644 index 0000000000..9b1ee98763 --- /dev/null +++ b/pkg/oauth/constants.go @@ -0,0 +1,70 @@ +// Copyright 2025 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package oauth provides RFC-defined types and constants for OAuth 2.0 and OpenID Connect. +// This package contains ONLY protocol-level definitions with no business logic. +// It serves as a shared foundation for both OAuth clients (consumers) and servers (producers). +package oauth + +// Well-known endpoint paths as defined by RFC 8414, OpenID Connect Discovery 1.0, and RFC 9728. +const ( + // WellKnownOIDCPath is the standard OIDC discovery endpoint path + // per OpenID Connect Discovery 1.0 specification. + WellKnownOIDCPath = "/.well-known/openid-configuration" + + // WellKnownOAuthServerPath is the standard OAuth authorization server metadata endpoint path + // per RFC 8414 (OAuth 2.0 Authorization Server Metadata). + WellKnownOAuthServerPath = "/.well-known/oauth-authorization-server" + + // WellKnownOAuthResourcePath is the RFC 9728 standard path for OAuth Protected Resource metadata. + // Per RFC 9728 Section 3, this endpoint and any subpaths under it should be accessible + // without authentication to enable OIDC/OAuth discovery. + WellKnownOAuthResourcePath = "/.well-known/oauth-protected-resource" +) + +// Grant types as defined by RFC 6749. +const ( + // GrantTypeAuthorizationCode is the authorization code grant type (RFC 6749 Section 4.1). + GrantTypeAuthorizationCode = "authorization_code" + + // GrantTypeRefreshToken is the refresh token grant type (RFC 6749 Section 6). + GrantTypeRefreshToken = "refresh_token" +) + +// Response types as defined by RFC 6749. +const ( + // ResponseTypeCode is the authorization code response type (RFC 6749 Section 4.1.1). + ResponseTypeCode = "code" +) + +// Token endpoint authentication methods as defined by RFC 7591. +const ( + // TokenEndpointAuthMethodNone indicates no client authentication (public clients). + // Typically used with PKCE for native/mobile applications. + TokenEndpointAuthMethodNone = "none" + + // TokenEndpointAuthMethodClientSecretPost indicates client authentication via + // client_id and client_secret in the request body. + TokenEndpointAuthMethodClientSecretPost = "client_secret_post" + + // TokenEndpointAuthMethodClientSecretBasic indicates client authentication via + // HTTP Basic Authentication. + TokenEndpointAuthMethodClientSecretBasic = "client_secret_basic" +) + +// PKCE (Proof Key for Code Exchange) methods as defined by RFC 7636. +const ( + // PKCEMethodS256 uses SHA-256 hash of the code verifier (recommended). + PKCEMethodS256 = "S256" +) diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 01c7440dcb..b7b140bf82 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -759,8 +759,11 @@ func (r *Runner) handleRemoteAuthentication(ctx context.Context) (oauth2.TokenSo clientID, clientSecret string, secretExpiry time.Time, regAccessToken, regClientURI string, + tokenEndpointAuthMethod string, ) error { - return r.persistClientCredentials(ctx, secretManager, clientID, clientSecret, secretExpiry, regAccessToken, regClientURI) + return r.persistClientCredentials( + ctx, secretManager, clientID, clientSecret, + secretExpiry, regAccessToken, regClientURI, tokenEndpointAuthMethod) }) } @@ -809,6 +812,7 @@ func (r *Runner) persistClientCredentials( clientID, clientSecret string, secretExpiry time.Time, regAccessToken, regClientURI string, + tokenEndpointAuthMethod string, ) error { r.Config.RemoteAuthConfig.CachedClientID = clientID @@ -844,6 +848,7 @@ func (r *Runner) persistClientCredentials( } r.Config.RemoteAuthConfig.CachedRegClientURI = regClientURI + r.Config.RemoteAuthConfig.CachedTokenEndpointAuthMethod = tokenEndpointAuthMethod if err := r.Config.SaveState(ctx); err != nil { return fmt.Errorf("failed to save config with client credentials: %w", err) From 6447fd8d2a97fee466919fed11233ce9c338cb34 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Sun, 22 Mar 2026 15:12:26 +0530 Subject: [PATCH 6/8] fix: CI coverage Signed-off-by: Sanskarzz --- pkg/state/runconfig.go | 43 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/pkg/state/runconfig.go b/pkg/state/runconfig.go index ea1fc91c7d..d94cc419aa 100644 --- a/pkg/state/runconfig.go +++ b/pkg/state/runconfig.go @@ -6,11 +6,12 @@ package state import ( "context" "encoding/json" + "errors" "fmt" "io" "log/slog" - "github.com/stacklok/toolhive/pkg/workloads/types/errors" + werr "github.com/stacklok/toolhive/pkg/workloads/types/errors" ) // LoadRunConfigJSON loads a run configuration from the state store and returns the raw reader @@ -27,7 +28,7 @@ func LoadRunConfigJSON(ctx context.Context, name string) (io.ReadCloser, error) return nil, fmt.Errorf("failed to check if run configuration exists: %w", err) } if !exists { - return nil, fmt.Errorf("%w: %s", errors.ErrRunConfigNotFound, name) + return nil, fmt.Errorf("%w: %s", werr.ErrRunConfigNotFound, name) } // Get a reader for the state @@ -121,6 +122,44 @@ func LoadRunConfig[T any](ctx context.Context, name string, readJSONFunc ReadJSO return readJSONFunc(reader) } +// ReadRunConfigJSON deserializes a run configuration from JSON read from the provided reader +// This is a generic JSON deserializer for any type that can be unmarshalled from JSON +func ReadRunConfigJSON[T any](r io.Reader) (*T, error) { + var config T + decoder := json.NewDecoder(r) + if err := decoder.Decode(&config); err != nil { + if errors.Is(err, io.EOF) { + return &config, nil + } + return nil, err + } + return &config, nil +} + +// LoadRunConfigOfType loads a run configuration of a specific type T from the state store +func LoadRunConfigOfType[T any](ctx context.Context, name string) (*T, error) { + return LoadRunConfig(ctx, name, ReadRunConfigJSON[T]) +} + +// RunConfigReadJSONFunc defines the function signature for reading a RunConfig from JSON +// This allows us to accept the runner.ReadJSON function without creating a circular dependency +type RunConfigReadJSONFunc func(r io.Reader) (interface{}, error) + +// LoadRunConfigWithFunc loads a run configuration using a provided read function +func LoadRunConfigWithFunc(ctx context.Context, name string, readFunc RunConfigReadJSONFunc) (interface{}, error) { + reader, err := LoadRunConfigJSON(ctx, name) + if err != nil { + return nil, err + } + defer func() { + if err := reader.Close(); err != nil { + slog.Warn("Failed to close reader", "error", err) + } + }() + + return readFunc(reader) +} + // ReadJSON deserializes JSON from the provided reader into a generic interface // This function is moved from the runner package to avoid circular dependencies func ReadJSON(r io.Reader, target interface{}) error { From 64558145e2ea1180cad24440805773879fc40b38 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Sat, 23 May 2026 14:39:03 +0530 Subject: [PATCH 7/8] revert to review comments Signed-off-by: Sanskarzz --- docs/server/docs.go | 232 ++++++------------ docs/server/swagger.json | 267 ++++++--------------- docs/server/swagger.yaml | 171 ++++++------- pkg/auth/discovery/discovery.go | 52 ++-- pkg/auth/remote/config.go | 5 + pkg/auth/remote/handler.go | 12 +- pkg/auth/remote/persisting_token_source.go | 2 + pkg/auth/remote/secret_renewal.go | 50 +++- pkg/auth/remote/secret_renewal_test.go | 195 ++++++--------- pkg/runner/runner.go | 55 +++-- 10 files changed, 409 insertions(+), 632 deletions(-) diff --git a/docs/server/docs.go b/docs/server/docs.go index 57d72de1ff..c4a611c28f 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -8,63 +8,6 @@ const docTemplate = `{ "schemes": {{ marshal .Schemes }}, "components": { "schemas": { - "auth.TokenValidatorConfig": { - "description": "DEPRECATED: Middleware configuration.\nOIDCConfig contains OIDC configuration", - "properties": { - "allowPrivateIP": { - "description": "AllowPrivateIP allows JWKS/OIDC endpoints on private IP addresses", - "type": "boolean" - }, - "audience": { - "description": "Audience is the expected audience for the token", - "type": "string" - }, - "authTokenFile": { - "description": "AuthTokenFile is the path to file containing bearer token for authentication", - "type": "string" - }, - "cacertPath": { - "description": "CACertPath is the path to the CA certificate bundle for HTTPS requests", - "type": "string" - }, - "clientID": { - "description": "ClientID is the OIDC client ID", - "type": "string" - }, - "clientSecret": { - "description": "ClientSecret is the optional OIDC client secret for introspection", - "type": "string" - }, - "insecureAllowHTTP": { - "description": "InsecureAllowHTTP allows HTTP (non-HTTPS) OIDC issuers for development/testing\nWARNING: This is insecure and should NEVER be used in production", - "type": "boolean" - }, - "introspectionURL": { - "description": "IntrospectionURL is the optional introspection endpoint for validating tokens", - "type": "string" - }, - "issuer": { - "description": "Issuer is the OIDC issuer URL (e.g., https://accounts.google.com)", - "type": "string" - }, - "jwksurl": { - "description": "JWKSURL is the URL to fetch the JWKS from", - "type": "string" - }, - "resourceURL": { - "description": "ResourceURL is the explicit resource URL for OAuth discovery (RFC 9728)", - "type": "string" - }, - "scopes": { - "description": "Scopes is the list of OAuth scopes to advertise in the well-known endpoint (RFC 9728)\nIf empty, defaults to [\"openid\"]", - "items": { - "type": "string" - }, - "type": "array" - } - }, - "type": "object" - }, "github_com_stacklok_toolhive-core_registry_types.Registry": { "description": "Full registry data", "properties": { @@ -171,6 +114,63 @@ const docTemplate = `{ }, "type": "object" }, + "github_com_stacklok_toolhive_pkg_auth.TokenValidatorConfig": { + "description": "DEPRECATED: Middleware configuration.\nOIDCConfig contains OIDC configuration", + "properties": { + "allowPrivateIP": { + "description": "AllowPrivateIP allows JWKS/OIDC endpoints on private IP addresses", + "type": "boolean" + }, + "audience": { + "description": "Audience is the expected audience for the token", + "type": "string" + }, + "authTokenFile": { + "description": "AuthTokenFile is the path to file containing bearer token for authentication", + "type": "string" + }, + "cacertPath": { + "description": "CACertPath is the path to the CA certificate bundle for HTTPS requests", + "type": "string" + }, + "clientID": { + "description": "ClientID is the OIDC client ID", + "type": "string" + }, + "clientSecret": { + "description": "ClientSecret is the optional OIDC client secret for introspection", + "type": "string" + }, + "insecureAllowHTTP": { + "description": "InsecureAllowHTTP allows HTTP (non-HTTPS) OIDC issuers for development/testing\nWARNING: This is insecure and should NEVER be used in production", + "type": "boolean" + }, + "introspectionURL": { + "description": "IntrospectionURL is the optional introspection endpoint for validating tokens", + "type": "string" + }, + "issuer": { + "description": "Issuer is the OIDC issuer URL (e.g., https://accounts.google.com)", + "type": "string" + }, + "jwksurl": { + "description": "JWKSURL is the URL to fetch the JWKS from", + "type": "string" + }, + "resourceURL": { + "description": "ResourceURL is the explicit resource URL for OAuth discovery (RFC 9728)", + "type": "string" + }, + "scopes": { + "description": "Scopes is the list of OAuth scopes to advertise in the well-known endpoint (RFC 9728)\nIf empty, defaults to [\"openid\"]", + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + }, "github_com_stacklok_toolhive_pkg_auth_awssts.Config": { "description": "AWSStsConfig contains AWS STS token exchange configuration for accessing AWS services", "properties": { @@ -258,6 +258,10 @@ const docTemplate = `{ "cached_client_secret_ref": { "type": "string" }, + "cached_dcr_callback_port": { + "description": "CachedDCRCallbackPort is the callback port that was actually registered\nduring DCR. It may differ from CallbackPort when the requested port was\nunavailable and a fallback port was selected.", + "type": "integer" + }, "cached_refresh_token_ref": { "description": "Cached OAuth token reference for persistence across restarts.\nThe refresh token is stored securely in the secret manager, and this field\ncontains the reference to retrieve it (e.g., \"OAUTH_REFRESH_TOKEN_workload\").\nThis enables session restoration without requiring a new browser-based login.", "type": "string" @@ -553,7 +557,7 @@ const docTemplate = `{ "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_authserver.SigningKeyRunConfig" }, "storage": { - "$ref": "#/components/schemas/storage.RunConfig" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_authserver_storage.RunConfig" }, "token_lifespans": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_authserver.TokenLifespanRunConfig" @@ -1045,19 +1049,7 @@ const docTemplate = `{ "type": "string" }, "status": { - "description": "Status is the current status of the workload.", - "enum": [ - "running", - "stopped", - "error", - "starting", - "stopping", - "unhealthy", - "removing", - "unknown", - "unauthenticated" - ], - "type": "string" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_container_runtime.WorkloadStatus" }, "status_context": { "description": "StatusContext provides additional context about the workload's status.\nThe exact meaning is determined by the status and the underlying runtime.", @@ -1072,14 +1064,7 @@ const docTemplate = `{ "uniqueItems": false }, "transport_type": { - "description": "TransportType is the type of transport used for this workload.", - "enum": [ - "stdio", - "sse", - "streamable-http", - "inspector" - ], - "type": "string" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.TransportType" }, "url": { "description": "URL is the URL of the workload exposed by the ToolHive proxy.", @@ -1316,7 +1301,7 @@ const docTemplate = `{ "type": "string" }, "ignore_config": { - "$ref": "#/components/schemas/ignore.Config" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_ignore.Config" }, "image": { "description": "Image is the Docker image to run", @@ -1341,7 +1326,7 @@ const docTemplate = `{ "middleware_configs": { "description": "MiddlewareConfigs contains the list of middleware to apply to the transport\nand the configuration for each middleware.", "items": { - "$ref": "#/components/schemas/types.MiddlewareConfig" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.MiddlewareConfig" }, "type": "array", "uniqueItems": false @@ -1359,7 +1344,7 @@ const docTemplate = `{ "type": "string" }, "oidc_config": { - "$ref": "#/components/schemas/auth.TokenValidatorConfig" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_auth.TokenValidatorConfig" }, "permission_profile_name_or_path": { "description": "PermissionProfileNameOrPath is the name or path of the permission profile", @@ -1370,12 +1355,7 @@ const docTemplate = `{ "type": "integer" }, "proxy_mode": { - "description": "ProxyMode is the proxy mode for stdio transport (\"sse\" or \"streamable-http\")\nNote: \"sse\" is deprecated; use \"streamable-http\" instead.", - "enum": [ - "sse", - "streamable-http" - ], - "type": "string" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.ProxyMode" }, "publish": { "description": "Publish lists ports to publish to the host in format \"hostPort:containerPort\"", @@ -1472,14 +1452,7 @@ const docTemplate = `{ "type": "object" }, "transport": { - "description": "Transport is the transport mode (stdio, sse, or streamable-http)", - "enum": [ - "stdio", - "sse", - "streamable-http", - "inspector" - ], - "type": "string" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.TransportType" }, "trust_proxy_headers": { "description": "TrustProxyHeaders indicates whether to trust X-Forwarded-* headers from reverse proxies", @@ -1871,16 +1844,15 @@ const docTemplate = `{ }, "type": "object" }, - "ignore.Config": { - "description": "IgnoreConfig contains configuration for ignore processing", + "github_com_stacklok_toolhive_pkg_transport_types.MiddlewareConfig": { "properties": { - "loadGlobal": { - "description": "Whether to load global ignore patterns", - "type": "boolean" + "parameters": { + "description": "Parameters is a JSON object containing the middleware parameters.\nIt is stored as a raw message to allow flexible parameter types.", + "type": "object" }, - "printOverlays": { - "description": "Whether to print resolved overlay paths for debugging", - "type": "boolean" + "type": { + "description": "Type is a string representing the middleware type.", + "type": "string" } }, "type": "object" @@ -2282,44 +2254,6 @@ const docTemplate = `{ }, "type": "object" }, - "github_com_stacklok_toolhive_pkg_transport_types.ProxyMode": { - "description": "ProxyMode is the proxy mode for stdio transport (\"sse\" or \"streamable-http\")\nNote: \"sse\" is deprecated; use \"streamable-http\" instead.", - "enum": [ - "sse", - "streamable-http", - "sse", - "streamable-http" - ], - "type": "string", - "x-enum-varnames": [ - "ProxyModeSSE", - "ProxyModeStreamableHTTP" - ] - }, - "github_com_stacklok_toolhive_pkg_transport_types.TransportType": { - "description": "Transport is the transport mode (stdio, sse, or streamable-http)", - "enum": [ - "stdio", - "sse", - "streamable-http", - "inspector", - "stdio", - "sse", - "streamable-http", - "inspector", - "stdio", - "sse", - "streamable-http", - "inspector" - ], - "type": "string", - "x-enum-varnames": [ - "TransportTypeStdio", - "TransportTypeSSE", - "TransportTypeStreamableHTTP", - "TransportTypeInspector" - ] - }, "permissions.InboundNetworkPermissions": { "description": "Inbound defines inbound network permissions", "properties": { @@ -3478,19 +3412,7 @@ const docTemplate = `{ "description": "Response containing workload status information", "properties": { "status": { - "description": "Current status of the workload", - "enum": [ - "running", - "stopped", - "error", - "starting", - "stopping", - "unhealthy", - "removing", - "unknown", - "unauthenticated" - ], - "type": "string" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_container_runtime.WorkloadStatus" } }, "type": "object" diff --git a/docs/server/swagger.json b/docs/server/swagger.json index b6cff39120..388b8a9a75 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -1,63 +1,6 @@ { "components": { "schemas": { - "auth.TokenValidatorConfig": { - "description": "DEPRECATED: Middleware configuration.\nOIDCConfig contains OIDC configuration", - "properties": { - "allowPrivateIP": { - "description": "AllowPrivateIP allows JWKS/OIDC endpoints on private IP addresses", - "type": "boolean" - }, - "audience": { - "description": "Audience is the expected audience for the token", - "type": "string" - }, - "authTokenFile": { - "description": "AuthTokenFile is the path to file containing bearer token for authentication", - "type": "string" - }, - "cacertPath": { - "description": "CACertPath is the path to the CA certificate bundle for HTTPS requests", - "type": "string" - }, - "clientID": { - "description": "ClientID is the OIDC client ID", - "type": "string" - }, - "clientSecret": { - "description": "ClientSecret is the optional OIDC client secret for introspection", - "type": "string" - }, - "insecureAllowHTTP": { - "description": "InsecureAllowHTTP allows HTTP (non-HTTPS) OIDC issuers for development/testing\nWARNING: This is insecure and should NEVER be used in production", - "type": "boolean" - }, - "introspectionURL": { - "description": "IntrospectionURL is the optional introspection endpoint for validating tokens", - "type": "string" - }, - "issuer": { - "description": "Issuer is the OIDC issuer URL (e.g., https://accounts.google.com)", - "type": "string" - }, - "jwksurl": { - "description": "JWKSURL is the URL to fetch the JWKS from", - "type": "string" - }, - "resourceURL": { - "description": "ResourceURL is the explicit resource URL for OAuth discovery (RFC 9728)", - "type": "string" - }, - "scopes": { - "description": "Scopes is the list of OAuth scopes to advertise in the well-known endpoint (RFC 9728)\nIf empty, defaults to [\"openid\"]", - "items": { - "type": "string" - }, - "type": "array" - } - }, - "type": "object" - }, "github_com_stacklok_toolhive-core_registry_types.Registry": { "description": "Full registry data", "properties": { @@ -164,6 +107,63 @@ }, "type": "object" }, + "github_com_stacklok_toolhive_pkg_auth.TokenValidatorConfig": { + "description": "DEPRECATED: Middleware configuration.\nOIDCConfig contains OIDC configuration", + "properties": { + "allowPrivateIP": { + "description": "AllowPrivateIP allows JWKS/OIDC endpoints on private IP addresses", + "type": "boolean" + }, + "audience": { + "description": "Audience is the expected audience for the token", + "type": "string" + }, + "authTokenFile": { + "description": "AuthTokenFile is the path to file containing bearer token for authentication", + "type": "string" + }, + "cacertPath": { + "description": "CACertPath is the path to the CA certificate bundle for HTTPS requests", + "type": "string" + }, + "clientID": { + "description": "ClientID is the OIDC client ID", + "type": "string" + }, + "clientSecret": { + "description": "ClientSecret is the optional OIDC client secret for introspection", + "type": "string" + }, + "insecureAllowHTTP": { + "description": "InsecureAllowHTTP allows HTTP (non-HTTPS) OIDC issuers for development/testing\nWARNING: This is insecure and should NEVER be used in production", + "type": "boolean" + }, + "introspectionURL": { + "description": "IntrospectionURL is the optional introspection endpoint for validating tokens", + "type": "string" + }, + "issuer": { + "description": "Issuer is the OIDC issuer URL (e.g., https://accounts.google.com)", + "type": "string" + }, + "jwksurl": { + "description": "JWKSURL is the URL to fetch the JWKS from", + "type": "string" + }, + "resourceURL": { + "description": "ResourceURL is the explicit resource URL for OAuth discovery (RFC 9728)", + "type": "string" + }, + "scopes": { + "description": "Scopes is the list of OAuth scopes to advertise in the well-known endpoint (RFC 9728)\nIf empty, defaults to [\"openid\"]", + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + }, "github_com_stacklok_toolhive_pkg_auth_awssts.Config": { "description": "AWSStsConfig contains AWS STS token exchange configuration for accessing AWS services", "properties": { @@ -251,6 +251,10 @@ "cached_client_secret_ref": { "type": "string" }, + "cached_dcr_callback_port": { + "description": "CachedDCRCallbackPort is the callback port that was actually registered\nduring DCR. It may differ from CallbackPort when the requested port was\nunavailable and a fallback port was selected.", + "type": "integer" + }, "cached_refresh_token_ref": { "description": "Cached OAuth token reference for persistence across restarts.\nThe refresh token is stored securely in the secret manager, and this field\ncontains the reference to retrieve it (e.g., \"OAUTH_REFRESH_TOKEN_workload\").\nThis enables session restoration without requiring a new browser-based login.", "type": "string" @@ -546,7 +550,7 @@ "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_authserver.SigningKeyRunConfig" }, "storage": { - "$ref": "#/components/schemas/storage.RunConfig" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_authserver_storage.RunConfig" }, "token_lifespans": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_authserver.TokenLifespanRunConfig" @@ -1038,19 +1042,7 @@ "type": "string" }, "status": { - "description": "Status is the current status of the workload.", - "enum": [ - "running", - "stopped", - "error", - "starting", - "stopping", - "unhealthy", - "removing", - "unknown", - "unauthenticated" - ], - "type": "string" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_container_runtime.WorkloadStatus" }, "status_context": { "description": "StatusContext provides additional context about the workload's status.\nThe exact meaning is determined by the status and the underlying runtime.", @@ -1065,14 +1057,7 @@ "uniqueItems": false }, "transport_type": { - "description": "TransportType is the type of transport used for this workload.", - "enum": [ - "stdio", - "sse", - "streamable-http", - "inspector" - ], - "type": "string" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.TransportType" }, "url": { "description": "URL is the URL of the workload exposed by the ToolHive proxy.", @@ -1309,7 +1294,7 @@ "type": "string" }, "ignore_config": { - "$ref": "#/components/schemas/ignore.Config" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_ignore.Config" }, "image": { "description": "Image is the Docker image to run", @@ -1334,15 +1319,7 @@ "middleware_configs": { "description": "MiddlewareConfigs contains the list of middleware to apply to the transport\nand the configuration for each middleware.", "items": { - "$ref": "#/components/schemas/types.MiddlewareConfig" - }, - "type": "array", - "uniqueItems": false - }, - "mutating_webhooks": { - "description": "MutatingWebhooks contains the configuration for mutating webhook middleware.\nMutating webhooks run before validating webhooks, per RFC THV-0017 ordering.", - "items": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.Config" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.MiddlewareConfig" }, "type": "array", "uniqueItems": false @@ -1360,7 +1337,7 @@ "type": "string" }, "oidc_config": { - "$ref": "#/components/schemas/auth.TokenValidatorConfig" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_auth.TokenValidatorConfig" }, "permission_profile_name_or_path": { "description": "PermissionProfileNameOrPath is the name or path of the permission profile", @@ -1371,39 +1348,7 @@ "type": "integer" }, "proxy_mode": { - "description": "ProxyMode is the proxy mode for stdio transport (\"sse\" or \"streamable-http\")\nNote: \"sse\" is deprecated; use \"streamable-http\" instead.", - "enum": [ - "sse", - "streamable-http" - ], - "type": "string" - }, - "publish": { - "description": "Publish lists ports to publish to the host in format \"hostPort:containerPort\"", - "items": { - "type": "string" - }, - "type": "array", - "uniqueItems": false - }, - "rate_limit_config": { - "$ref": "#/components/schemas/github_com_stacklok_toolhive_cmd_thv-operator_api_v1beta1.RateLimitConfig" - }, - "rate_limit_namespace": { - "description": "RateLimitNamespace is the Kubernetes namespace for Redis key derivation.", - "type": "string" - }, - "registry_api_url": { - "description": "RegistryAPIURL is the registry API URL that served this server's metadata.\nEmpty when the server was not discovered via registry lookup.", - "type": "string" - }, - "registry_server_name": { - "description": "RegistryServerName is the registry entry name used to look up this server's metadata.\nEmpty when the server was not discovered via registry lookup.", - "type": "string" - }, - "registry_url": { - "description": "RegistryURL is the registry URL that served this server's metadata.\nEmpty when the server was not discovered via registry lookup.", - "type": "string" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.ProxyMode" }, "publish": { "description": "Publish lists ports to publish to the host in format \"hostPort:containerPort\"", @@ -1500,14 +1445,7 @@ "type": "object" }, "transport": { - "description": "Transport is the transport mode (stdio, sse, or streamable-http)", - "enum": [ - "stdio", - "sse", - "streamable-http", - "inspector" - ], - "type": "string" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.TransportType" }, "trust_proxy_headers": { "description": "TrustProxyHeaders indicates whether to trust X-Forwarded-* headers from reverse proxies", @@ -1899,16 +1837,15 @@ }, "type": "object" }, - "ignore.Config": { - "description": "IgnoreConfig contains configuration for ignore processing", + "github_com_stacklok_toolhive_pkg_transport_types.MiddlewareConfig": { "properties": { - "loadGlobal": { - "description": "Whether to load global ignore patterns", - "type": "boolean" + "parameters": { + "description": "Parameters is a JSON object containing the middleware parameters.\nIt is stored as a raw message to allow flexible parameter types.", + "type": "object" }, - "printOverlays": { - "description": "Whether to print resolved overlay paths for debugging", - "type": "boolean" + "type": { + "description": "Type is a string representing the middleware type.", + "type": "string" } }, "type": "object" @@ -2310,44 +2247,6 @@ }, "type": "object" }, - "github_com_stacklok_toolhive_pkg_transport_types.ProxyMode": { - "description": "ProxyMode is the proxy mode for stdio transport (\"sse\" or \"streamable-http\")\nNote: \"sse\" is deprecated; use \"streamable-http\" instead.", - "enum": [ - "sse", - "streamable-http", - "sse", - "streamable-http" - ], - "type": "string", - "x-enum-varnames": [ - "ProxyModeSSE", - "ProxyModeStreamableHTTP" - ] - }, - "github_com_stacklok_toolhive_pkg_transport_types.TransportType": { - "description": "Transport is the transport mode (stdio, sse, or streamable-http)", - "enum": [ - "stdio", - "sse", - "streamable-http", - "inspector", - "stdio", - "sse", - "streamable-http", - "inspector", - "stdio", - "sse", - "streamable-http", - "inspector" - ], - "type": "string", - "x-enum-varnames": [ - "TransportTypeStdio", - "TransportTypeSSE", - "TransportTypeStreamableHTTP", - "TransportTypeInspector" - ] - }, "permissions.InboundNetworkPermissions": { "description": "Inbound defines inbound network permissions", "properties": { @@ -3506,19 +3405,7 @@ "description": "Response containing workload status information", "properties": { "status": { - "description": "Current status of the workload", - "enum": [ - "running", - "stopped", - "error", - "starting", - "stopping", - "unhealthy", - "removing", - "unknown", - "unauthenticated" - ], - "type": "string" + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_container_runtime.WorkloadStatus" } }, "type": "object" diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 72d029647c..fab7757d0e 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -1,57 +1,5 @@ components: schemas: - auth.TokenValidatorConfig: - description: |- - DEPRECATED: Middleware configuration. - OIDCConfig contains OIDC configuration - properties: - allowPrivateIP: - description: AllowPrivateIP allows JWKS/OIDC endpoints on private IP addresses - type: boolean - audience: - description: Audience is the expected audience for the token - type: string - authTokenFile: - description: AuthTokenFile is the path to file containing bearer token for - authentication - type: string - cacertPath: - description: CACertPath is the path to the CA certificate bundle for HTTPS - requests - type: string - clientID: - description: ClientID is the OIDC client ID - type: string - clientSecret: - description: ClientSecret is the optional OIDC client secret for introspection - type: string - insecureAllowHTTP: - description: |- - InsecureAllowHTTP allows HTTP (non-HTTPS) OIDC issuers for development/testing - WARNING: This is insecure and should NEVER be used in production - type: boolean - introspectionURL: - description: IntrospectionURL is the optional introspection endpoint for - validating tokens - type: string - issuer: - description: Issuer is the OIDC issuer URL (e.g., https://accounts.google.com) - type: string - jwksurl: - description: JWKSURL is the URL to fetch the JWKS from - type: string - resourceURL: - description: ResourceURL is the explicit resource URL for OAuth discovery - (RFC 9728) - type: string - scopes: - description: |- - Scopes is the list of OAuth scopes to advertise in the well-known endpoint (RFC 9728) - If empty, defaults to ["openid"] - items: - type: string - type: array - type: object github_com_stacklok_toolhive-core_registry_types.Registry: description: Full registry data properties: @@ -173,6 +121,58 @@ components: +optional type: integer type: object + github_com_stacklok_toolhive_pkg_auth.TokenValidatorConfig: + description: |- + DEPRECATED: Middleware configuration. + OIDCConfig contains OIDC configuration + properties: + allowPrivateIP: + description: AllowPrivateIP allows JWKS/OIDC endpoints on private IP addresses + type: boolean + audience: + description: Audience is the expected audience for the token + type: string + authTokenFile: + description: AuthTokenFile is the path to file containing bearer token for + authentication + type: string + cacertPath: + description: CACertPath is the path to the CA certificate bundle for HTTPS + requests + type: string + clientID: + description: ClientID is the OIDC client ID + type: string + clientSecret: + description: ClientSecret is the optional OIDC client secret for introspection + type: string + insecureAllowHTTP: + description: |- + InsecureAllowHTTP allows HTTP (non-HTTPS) OIDC issuers for development/testing + WARNING: This is insecure and should NEVER be used in production + type: boolean + introspectionURL: + description: IntrospectionURL is the optional introspection endpoint for + validating tokens + type: string + issuer: + description: Issuer is the OIDC issuer URL (e.g., https://accounts.google.com) + type: string + jwksurl: + description: JWKSURL is the URL to fetch the JWKS from + type: string + resourceURL: + description: ResourceURL is the explicit resource URL for OAuth discovery + (RFC 9728) + type: string + scopes: + description: |- + Scopes is the list of OAuth scopes to advertise in the well-known endpoint (RFC 9728) + If empty, defaults to ["openid"] + items: + type: string + type: array + type: object github_com_stacklok_toolhive_pkg_auth_awssts.Config: description: AWSStsConfig contains AWS STS token exchange configuration for accessing AWS services @@ -267,6 +267,12 @@ components: type: string cached_client_secret_ref: type: string + cached_dcr_callback_port: + description: |- + CachedDCRCallbackPort is the callback port that was actually registered + during DCR. It may differ from CallbackPort when the requested port was + unavailable and a fallback port was selected. + type: integer cached_refresh_token_ref: description: |- Cached OAuth token reference for persistence across restarts. @@ -616,7 +622,7 @@ components: signing_key_config: $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_authserver.SigningKeyRunConfig' storage: - $ref: '#/components/schemas/storage.RunConfig' + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_authserver_storage.RunConfig' token_lifespans: $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_authserver.TokenLifespanRunConfig' upstreams: @@ -1083,18 +1089,7 @@ components: restart) type: string status: - description: Status is the current status of the workload. - enum: - - running - - stopped - - error - - starting - - stopping - - unhealthy - - removing - - unknown - - unauthenticated - type: string + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_container_runtime.WorkloadStatus' status_context: description: |- StatusContext provides additional context about the workload's status. @@ -1107,13 +1102,7 @@ components: type: array uniqueItems: false transport_type: - description: TransportType is the type of transport used for this workload. - enum: - - stdio - - sse - - streamable-http - - inspector - type: string + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.TransportType' url: description: URL is the URL of the workload exposed by the ToolHive proxy. type: string @@ -1321,7 +1310,7 @@ components: description: Host is the host for the HTTP proxy type: string ignore_config: - $ref: '#/components/schemas/ignore.Config' + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_ignore.Config' image: description: Image is the Docker image to run type: string @@ -1351,15 +1340,7 @@ components: MiddlewareConfigs contains the list of middleware to apply to the transport and the configuration for each middleware. items: - $ref: '#/components/schemas/types.MiddlewareConfig' - type: array - uniqueItems: false - mutating_webhooks: - description: |- - MutatingWebhooks contains the configuration for mutating webhook middleware. - Mutating webhooks run before validating webhooks, per RFC THV-0017 ordering. - items: - $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_webhook.Config' + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.MiddlewareConfig' type: array uniqueItems: false mutating_webhooks: @@ -1374,7 +1355,7 @@ components: description: Name is the name of the MCP server type: string oidc_config: - $ref: '#/components/schemas/auth.TokenValidatorConfig' + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_auth.TokenValidatorConfig' permission_profile_name_or_path: description: PermissionProfileNameOrPath is the name or path of the permission profile @@ -1481,13 +1462,7 @@ components: ToolsOverride is a map from an actual tool to its overridden name and/or description type: object transport: - description: Transport is the transport mode (stdio, sse, or streamable-http) - enum: - - stdio - - sse - - streamable-http - - inspector - type: string + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_transport_types.TransportType' trust_proxy_headers: description: TrustProxyHeaders indicates whether to trust X-Forwarded-* headers from reverse proxies @@ -1844,8 +1819,7 @@ components: +optional type: boolean type: object - ignore.Config: - description: IgnoreConfig contains configuration for ignore processing + github_com_stacklok_toolhive_pkg_transport_types.MiddlewareConfig: properties: parameters: description: |- @@ -3043,18 +3017,7 @@ components: description: Response containing workload status information properties: status: - description: Current status of the workload - enum: - - running - - stopped - - error - - starting - - stopping - - unhealthy - - removing - - unknown - - unauthenticated - type: string + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_container_runtime.WorkloadStatus' type: object registry.EnvVar: properties: diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index 32acddce02..31561228a2 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -519,6 +519,7 @@ type OAuthFlowConfig struct { RegistrationAccessToken string //nolint:gosec // G117: field legitimately holds sensitive data RegistrationClientURI string TokenEndpointAuthMethod string + RegisteredCallbackPort int } // OAuthFlowResult contains the result of an OAuth flow @@ -543,6 +544,7 @@ type OAuthFlowResult struct { RegistrationAccessToken string //nolint:gosec // G117: field legitimately holds sensitive data RegistrationClientURI string TokenEndpointAuthMethod string + RegisteredCallbackPort int } func shouldDynamicallyRegisterClient(config *OAuthFlowConfig) bool { @@ -615,20 +617,12 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi // exists. // // One consequence of option (b) is that the resolver's RFC 7591 §3.2.1 -// expiry-driven refetch does NOT participate in the CLI's cross- -// invocation persistence loop: each PerformOAuthFlow call builds a fresh -// in-memory store, so a "cached but expired" entry from the previous -// invocation never reaches the resolver. Cross-invocation expiry is also -// NOT enforced by the remote handler's gate today — -// HasCachedClientCredentials only checks CachedClientID != "" and does -// not consult CachedSecretExpiry, so an expired-but-still-cached client -// gets reused on the next invocation and surfaces as a token-endpoint -// failure rather than a clean DCR re-registration. Tightening the gate -// to also check CachedSecretExpiry is open follow-up work; the -// behaviour today is "cross-invocation expiry is unhandled". Within a -// single invocation, the resolver's expiry check is still in the loop -// and would fire if the same call site somehow registered, persisted, -// and re-queried the in-memory store — but the CLI never does this today. +// expiry-driven refetch does NOT participate in the CLI's cross-invocation +// persistence loop: each PerformOAuthFlow call builds a fresh in-memory store, +// so a cached entry from a previous invocation never reaches the resolver. +// Cross-invocation client-secret expiry is handled instead by the remote +// handler, which consults CachedSecretExpiry and renews through RFC 7592 before +// cached credentials are used. // // Wrapping the remote handler's secretProvider into a dcr.CredentialStore // adapter (option (a)) would close that loop and is the natural follow-up; @@ -677,27 +671,14 @@ func handleDynamicRegistration(ctx context.Context, issuer string, config *OAuth } // Store DCR renewal metadata for RFC 7592 operations. - // client_secret_expires_at == 0 means the secret never expires (RFC 7591 §3.2.1). - if registrationResponse.ClientSecretExpiresAt > 0 { - config.SecretExpiry = time.Unix(registrationResponse.ClientSecretExpiresAt, 0) - } - config.RegistrationAccessToken = registrationResponse.RegistrationAccessToken - config.RegistrationClientURI = registrationResponse.RegistrationClientURI - - if registrationResponse.RegistrationAccessToken != "" { - slog.Debug("DCR response includes registration access token for RFC 7592 operations") - } - - // Store DCR renewal metadata for RFC 7592 operations. - // client_secret_expires_at == 0 means the secret never expires (RFC 7591 §3.2.1). - if registrationResponse.ClientSecretExpiresAt > 0 { - config.SecretExpiry = time.Unix(registrationResponse.ClientSecretExpiresAt, 0) - } - config.RegistrationAccessToken = registrationResponse.RegistrationAccessToken - config.RegistrationClientURI = registrationResponse.RegistrationClientURI - config.TokenEndpointAuthMethod = registrationResponse.TokenEndpointAuthMethod - - if registrationResponse.RegistrationAccessToken != "" { + // A zero ClientSecretExpiresAt means the secret never expires (RFC 7591 §3.2.1). + config.SecretExpiry = resolution.ClientSecretExpiresAt + config.RegistrationAccessToken = resolution.RegistrationAccessToken + config.RegistrationClientURI = resolution.RegistrationClientURI + config.TokenEndpointAuthMethod = resolution.TokenEndpointAuthMethod + config.RegisteredCallbackPort = config.CallbackPort + + if resolution.RegistrationAccessToken != "" { slog.Debug("DCR response includes registration access token for RFC 7592 operations") } @@ -938,6 +919,7 @@ func newOAuthFlow(ctx context.Context, oauthConfig *oauth.Config, config *OAuthF RegistrationAccessToken: config.RegistrationAccessToken, RegistrationClientURI: config.RegistrationClientURI, TokenEndpointAuthMethod: config.TokenEndpointAuthMethod, + RegisteredCallbackPort: config.RegisteredCallbackPort, }, nil } diff --git a/pkg/auth/remote/config.go b/pkg/auth/remote/config.go index b88e93566a..733e301024 100644 --- a/pkg/auth/remote/config.go +++ b/pkg/auth/remote/config.go @@ -86,6 +86,10 @@ type Config struct { // CachedTokenEndpointAuthMethod is the auth method used for the token endpoint // (e.g., "client_secret_basic", "none"). Persisted for RFC 7592 updates. CachedTokenEndpointAuthMethod string `json:"cached_token_auth_method,omitempty" yaml:"cached_token_auth_method,omitempty"` + // CachedDCRCallbackPort is the callback port that was actually registered + // during DCR. It may differ from CallbackPort when the requested port was + // unavailable and a fallback port was selected. + CachedDCRCallbackPort int `json:"cached_dcr_callback_port,omitempty" yaml:"cached_dcr_callback_port,omitempty"` } // BearerTokenEnvVarName is the environment variable name used for bearer token authentication. @@ -194,6 +198,7 @@ func (c *Config) ClearCachedClientCredentials() { c.CachedRegTokenRef = "" c.CachedRegClientURI = "" c.CachedTokenEndpointAuthMethod = "" + c.CachedDCRCallbackPort = 0 } // LogContext returns the upstream issuer and resolved client_id for use as diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index c16a2459e5..608415a57b 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -14,6 +14,7 @@ import ( "golang.org/x/oauth2" "github.com/stacklok/toolhive/pkg/auth/discovery" + "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/oauthproto" "github.com/stacklok/toolhive/pkg/secrets" ) @@ -25,6 +26,7 @@ type Handler struct { tokenPersister TokenPersister clientCredentialsPersister ClientCredentialsPersister secretProvider secrets.Provider + httpClient networking.HTTPClient } // NewHandler creates a new remote authentication handler @@ -51,6 +53,11 @@ func (h *Handler) SetClientCredentialsPersister(persister ClientCredentialsPersi h.clientCredentialsPersister = persister } +// SetHTTPClient sets the HTTP client used for RFC 7592 registration management requests. +func (h *Handler) SetHTTPClient(client networking.HTTPClient) { + h.httpClient = client +} + // Authenticate is the main entry point for remote MCP server authentication func (h *Handler) Authenticate(ctx context.Context, remoteURL string) (oauth2.TokenSource, error) { // Priority 1: Bearer token authentication (if configured) @@ -215,6 +222,7 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. result.RegistrationAccessToken, result.RegistrationClientURI, result.TokenEndpointAuthMethod, + result.RegisteredCallbackPort, ); err != nil { slog.Warn("Failed to persist DCR client credentials", "error", err) } else { @@ -264,7 +272,7 @@ func (h *Handler) resolveClientCredentials(ctx context.Context) (clientID, clien // Proactively renew the client secret if it is expiring soon (RFC 7592) if h.isSecretExpiredOrExpiringSoon() { - slog.Info("Cached client secret is expiring soon, attempting renewal", + slog.Debug("Cached client secret is expiring soon, attempting renewal", "expiry", h.config.CachedSecretExpiry) if renewErr := h.renewClientSecret(ctx); renewErr != nil { slog.Warn("Failed to proactively renew client secret; continuing with existing secret", @@ -303,7 +311,7 @@ func (h *Handler) tryRestoreFromCachedTokens( // Check if the cached client secret is expired before attempting token refresh. // If it has fully expired and renewal also fails we must force a fresh OAuth flow. if h.isSecretExpiredOrExpiringSoon() { - slog.Info("Cached client secret is expiring or expired; attempting renewal before token restore", + slog.Debug("Cached client secret is expiring or expired; attempting renewal before token restore", "expiry", h.config.CachedSecretExpiry) if renewErr := h.renewClientSecret(ctx); renewErr != nil { slog.Warn("Client secret renewal failed", "error", renewErr) diff --git a/pkg/auth/remote/persisting_token_source.go b/pkg/auth/remote/persisting_token_source.go index 99f810edf2..11e92081e9 100644 --- a/pkg/auth/remote/persisting_token_source.go +++ b/pkg/auth/remote/persisting_token_source.go @@ -31,6 +31,7 @@ type TokenPersister func(refreshToken string, expiry time.Time) error // - registrationAccessToken: bearer token for RFC 7592 management operations (sensitive) // - registrationClientURI: endpoint for RFC 7592 client update/read operations (plain text) // - tokenEndpointAuthMethod: the auth method used for the token endpoint (e.g., "client_secret_basic", "none") +// - registeredCallbackPort: the callback port used in the original DCR redirect URI type ClientCredentialsPersister func( clientID string, clientSecret string, @@ -38,6 +39,7 @@ type ClientCredentialsPersister func( registrationAccessToken string, registrationClientURI string, tokenEndpointAuthMethod string, + registeredCallbackPort int, ) error // PersistingTokenSource wraps an oauth2.TokenSource and persists tokens diff --git a/pkg/auth/remote/secret_renewal.go b/pkg/auth/remote/secret_renewal.go index fa2f167dee..fb6b479155 100644 --- a/pkg/auth/remote/secret_renewal.go +++ b/pkg/auth/remote/secret_renewal.go @@ -4,6 +4,7 @@ package remote import ( + "bytes" "context" "encoding/json" "fmt" @@ -11,10 +12,10 @@ import ( "log/slog" "net/http" "net/url" - "strings" "time" "github.com/stacklok/toolhive/pkg/networking" + "github.com/stacklok/toolhive/pkg/oauthproto" ) // secretExpiryBuffer is the lead time before expiry at which we proactively @@ -22,6 +23,14 @@ import ( // expires within this window, not only after expiry. const secretExpiryBuffer = 24 * time.Hour +var defaultRenewalHTTPClient = &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + }, +} + // clientUpdateRequest is the body sent in a RFC 7592 §2.2 PUT request. // Per the spec, all client metadata fields that were provided during // registration must be included in the update request body. @@ -94,8 +103,8 @@ func (h *Handler) renewClientSecret(ctx context.Context) error { // that were provided during the initial registration. updateReq := clientUpdateRequest{ ClientID: h.config.CachedClientID, - ClientName: "ToolHive MCP Client", - RedirectURIs: []string{fmt.Sprintf("http://localhost:%d/callback", h.config.CallbackPort)}, + ClientName: oauthproto.ToolHiveMCPClientName, + RedirectURIs: []string{fmt.Sprintf("http://localhost:%d/callback", h.registeredDCRCallbackPort())}, GrantTypes: []string{"authorization_code", "refresh_token"}, ResponseTypes: []string{"code"}, TokenEndpointAuthMethod: h.config.CachedTokenEndpointAuthMethod, @@ -111,7 +120,7 @@ func (h *Handler) renewClientSecret(ctx context.Context) error { ctx, http.MethodPut, h.config.CachedRegClientURI, - strings.NewReader(string(reqBody)), + bytes.NewReader(reqBody), ) if err != nil { return fmt.Errorf("failed to create client update request: %w", err) @@ -119,15 +128,11 @@ func (h *Handler) renewClientSecret(ctx context.Context) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", "Bearer "+regAccessToken) //nolint:gosec // G117 + req.Header.Set("Authorization", "Bearer "+regAccessToken) - // Execute the request - httpClient := &http.Client{ - Timeout: 30 * time.Second, - Transport: &http.Transport{ - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: 10 * time.Second, - }, + httpClient := h.httpClient + if httpClient == nil { + httpClient = defaultRenewalHTTPClient } resp, err := httpClient.Do(req) @@ -142,6 +147,7 @@ func (h *Handler) renewClientSecret(ctx context.Context) error { if resp.StatusCode != http.StatusOK { errorBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + _, _ = io.Copy(io.Discard, resp.Body) return fmt.Errorf("client update request returned HTTP %d: %s", resp.StatusCode, string(errorBody)) } @@ -175,6 +181,13 @@ func (h *Handler) validateRenewalPrerequisites() error { return nil } +func (h *Handler) registeredDCRCallbackPort() int { + if h.config.CachedDCRCallbackPort != 0 { + return h.config.CachedDCRCallbackPort + } + return h.config.CallbackPort +} + func (h *Handler) persistRenewedSecret(updateResp clientUpdateResponse) error { if h.clientCredentialsPersister == nil { return fmt.Errorf("client credentials persister not configured; cannot save renewed secret") @@ -199,11 +212,19 @@ func (h *Handler) persistRenewedSecret(updateResp clientUpdateResponse) error { newRegToken, newRegURI, h.config.CachedTokenEndpointAuthMethod, + h.registeredDCRCallbackPort(), ); err != nil { return fmt.Errorf("failed to persist renewed client secret: %w", err) } - slog.Info("Successfully renewed client secret via RFC 7592", + h.config.CachedClientID = updateResp.ClientID + h.config.CachedSecretExpiry = newExpiry + h.config.CachedRegClientURI = newRegURI + if h.config.CachedDCRCallbackPort == 0 { + h.config.CachedDCRCallbackPort = h.config.CallbackPort + } + + slog.Debug("Successfully renewed client secret via RFC 7592", "client_id", updateResp.ClientID, "new_expiry_zero", newExpiry.IsZero(), "reg_token_rotated", newRegToken != "") @@ -226,6 +247,9 @@ func validateRegistrationClientURI(registrationClientURI string) error { if parsedURL.Scheme != "https" && !networking.IsLocalhost(parsedURL.Host) { return fmt.Errorf("registration_client_uri must use HTTPS: %s", registrationClientURI) } + if parsedURL.Path == "" || parsedURL.Path == "/" { + return fmt.Errorf("registration_client_uri must include a non-root path: %s", registrationClientURI) + } return nil } diff --git a/pkg/auth/remote/secret_renewal_test.go b/pkg/auth/remote/secret_renewal_test.go index 0848cf835d..c574d1761d 100644 --- a/pkg/auth/remote/secret_renewal_test.go +++ b/pkg/auth/remote/secret_renewal_test.go @@ -9,64 +9,31 @@ import ( "fmt" "net/http" "net/http/httptest" - "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" - "github.com/stacklok/toolhive/pkg/secrets" + secretmocks "github.com/stacklok/toolhive/pkg/secrets/mocks" ) -// mockSecretProvider is a simple in-memory secret store for tests. -// It implements the full secrets.Provider interface. -type mockSecretProvider struct { - secrets map[string]string -} - -func newMockSecretProvider(initial map[string]string) *mockSecretProvider { - if initial == nil { - initial = make(map[string]string) - } - return &mockSecretProvider{secrets: initial} -} - -func (m *mockSecretProvider) GetSecret(_ context.Context, name string) (string, error) { - v, ok := m.secrets[name] - if !ok { - return "", fmt.Errorf("secret %q not found", name) - } - return v, nil -} - -func (m *mockSecretProvider) SetSecret(_ context.Context, name, value string) error { - m.secrets[name] = value - return nil -} - -func (m *mockSecretProvider) DeleteSecret(_ context.Context, name string) error { - delete(m.secrets, name) - return nil -} +func newTestSecretProvider(t *testing.T, values map[string]string) *secretmocks.MockProvider { + t.Helper() -func (m *mockSecretProvider) ListSecrets(_ context.Context) ([]secrets.SecretDescription, error) { - result := make([]secrets.SecretDescription, 0, len(m.secrets)) - for k := range m.secrets { - result = append(result, secrets.SecretDescription{Key: k}) - } - return result, nil -} - -func (*mockSecretProvider) Cleanup() error { return nil } - -func (*mockSecretProvider) Capabilities() secrets.ProviderCapabilities { - return secrets.ProviderCapabilities{ - CanRead: true, - CanWrite: true, - CanDelete: true, - CanList: true, - } + ctrl := gomock.NewController(t) + provider := secretmocks.NewMockProvider(ctrl) + provider.EXPECT().GetSecret(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, name string) (string, error) { + value, ok := values[name] + if !ok { + return "", fmt.Errorf("secret %q not found", name) + } + return value, nil + }, + ).AnyTimes() + return provider } // TestIsSecretExpiredOrExpiringSoon tests the expiry helper on various time scenarios. @@ -158,6 +125,11 @@ func TestValidateRegistrationClientURI(t *testing.T) { uri: "://bad-url", wantErr: true, }, + { + name: "root path URI is rejected", + uri: "https://example.com/", + wantErr: true, + }, } for _, tt := range tests { @@ -235,6 +207,10 @@ func TestRenewClientSecret_Success(t *testing.T) { assert.Contains(t, r.Header.Get("Authorization"), "Bearer reg-access-token") assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + var updateReq clientUpdateRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&updateReq)) + assert.Equal(t, []string{"http://localhost:9876/callback"}, updateReq.RedirectURIs) + // Return the updated registration response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) @@ -246,7 +222,7 @@ func TestRenewClientSecret_Success(t *testing.T) { "registration_client_uri": "http://" + r.Host + r.URL.Path, }) })) - defer server.Close() + t.Cleanup(server.Close) // Set up persister capture var persistedClientID, persistedSecret, persistedRegToken, persistedRegURI string @@ -254,23 +230,27 @@ func TestRenewClientSecret_Success(t *testing.T) { h := &Handler{ config: &Config{ - CachedClientID: "test-client-id", - CachedRegClientURI: server.URL + "/register/test-client-id", - CachedRegTokenRef: "reg-token-secret-ref", + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-secret-ref", + CallbackPort: 8666, + CachedDCRCallbackPort: 9876, }, - secretProvider: newMockSecretProvider(map[string]string{ + secretProvider: newTestSecretProvider(t, map[string]string{ "reg-token-secret-ref": "reg-access-token", }), clientCredentialsPersister: func( clientID, secret string, expiry time.Time, regToken, regURI, _ string, + callbackPort int, ) error { persistedClientID = clientID persistedSecret = secret persistedExpiry = expiry persistedRegToken = regToken persistedRegURI = regURI + assert.Equal(t, 9876, callbackPort) return nil }, } @@ -293,7 +273,7 @@ func TestRenewClientSecret_ServerError(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) })) - defer server.Close() + t.Cleanup(server.Close) h := &Handler{ config: &Config{ @@ -301,10 +281,10 @@ func TestRenewClientSecret_ServerError(t *testing.T) { CachedRegClientURI: server.URL + "/register/test-client-id", CachedRegTokenRef: "reg-token-ref", }, - secretProvider: newMockSecretProvider(map[string]string{ + secretProvider: newTestSecretProvider(t, map[string]string{ "reg-token-ref": "bad-token", }), - clientCredentialsPersister: func(_, _ string, _ time.Time, _, _, _ string) error { + clientCredentialsPersister: func(_, _ string, _ time.Time, _, _, _ string, _ int) error { return nil }, } @@ -326,7 +306,7 @@ func TestRenewClientSecret_NoPersister(t *testing.T) { "client_secret": "new-secret", }) })) - defer server.Close() + t.Cleanup(server.Close) h := &Handler{ config: &Config{ @@ -334,7 +314,7 @@ func TestRenewClientSecret_NoPersister(t *testing.T) { CachedRegClientURI: server.URL + "/register/test-client-id", CachedRegTokenRef: "reg-token-ref", }, - secretProvider: newMockSecretProvider(map[string]string{ + secretProvider: newTestSecretProvider(t, map[string]string{ "reg-token-ref": "some-token", }), clientCredentialsPersister: nil, // no persister @@ -359,7 +339,7 @@ func TestRenewClientSecret_ZeroExpiryInResponse(t *testing.T) { "client_secret_expires_at": 0, // never expires }) })) - defer server.Close() + t.Cleanup(server.Close) var capturedExpiry time.Time @@ -369,10 +349,10 @@ func TestRenewClientSecret_ZeroExpiryInResponse(t *testing.T) { CachedRegClientURI: server.URL + "/register/test-client-id", CachedRegTokenRef: "reg-token-ref", }, - secretProvider: newMockSecretProvider(map[string]string{ + secretProvider: newTestSecretProvider(t, map[string]string{ "reg-token-ref": "some-token", }), - clientCredentialsPersister: func(_, _ string, expiry time.Time, _, _, _ string) error { + clientCredentialsPersister: func(_, _ string, expiry time.Time, _, _, _ string, _ int) error { capturedExpiry = expiry return nil }, @@ -391,25 +371,19 @@ func TestRenewClientSecret_MalformedJSON(t *testing.T) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{invalid-json`)) })) - defer svc.Close() + t.Cleanup(svc.Close) h := &Handler{ config: &Config{ - CachedRegClientURI: svc.URL, + CachedRegClientURI: svc.URL + "/register/test-client-id", CachedRegTokenRef: "rat-ref", }, - secretProvider: &mockSecretProvider{ - secrets: map[string]string{"rat-ref": "rat-token"}, - }, + secretProvider: newTestSecretProvider(t, map[string]string{"rat-ref": "rat-token"}), } err := h.renewClientSecret(context.Background()) - if err == nil { - t.Fatal("expected error, got nil") - } - if !strings.Contains(err.Error(), "failed to decode client update response") { - t.Errorf("unexpected error message: %v", err) - } + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode client update response") } func TestRenewClientSecret_MissingFields(t *testing.T) { @@ -441,25 +415,19 @@ func TestRenewClientSecret_MissingFields(t *testing.T) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(tt.response)) })) - defer svc.Close() + t.Cleanup(svc.Close) h := &Handler{ config: &Config{ - CachedRegClientURI: svc.URL, + CachedRegClientURI: svc.URL + "/register/test-client-id", CachedRegTokenRef: "rat-ref", }, - secretProvider: &mockSecretProvider{ - secrets: map[string]string{"rat-ref": "rat-token"}, - }, + secretProvider: newTestSecretProvider(t, map[string]string{"rat-ref": "rat-token"}), } err := h.renewClientSecret(context.Background()) - if err == nil { - t.Fatal("expected error, got nil") - } - if !strings.Contains(err.Error(), tt.wantErr) { - t.Errorf("unexpected error message: %v", err) - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) }) } } @@ -475,6 +443,7 @@ func TestValidateRegistrationClientURI_Internal(t *testing.T) { {"malformed", "://foo", true}, {"http_external", "http://example.com/reg", true}, {"https_external", "https://example.com/reg", false}, + {"https_root_path", "https://example.com/", true}, {"http_localhost", "http://localhost:8080/reg", false}, {"http_127_0_0_1", "http://127.0.0.1:8080/reg", false}, } @@ -484,9 +453,7 @@ func TestValidateRegistrationClientURI_Internal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() err := validateRegistrationClientURI(tt.uri) - if (err != nil) != tt.wantErr { - t.Errorf("validateRegistrationClientURI() error = %v, wantErr %v", err, tt.wantErr) - } + assert.Equal(t, tt.wantErr, err != nil) }) } } @@ -499,28 +466,28 @@ func TestHandler_Restore_RenewSuccess(t *testing.T) { svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"client_id": "test-client", "client_secret": "new-secret"}`)) + _, _ = w.Write([]byte(`{"client_id": "test-client", "client_secret": "new-secret", "client_secret_expires_at": 4102444800}`)) })) - defer svc.Close() + t.Cleanup(svc.Close) var persistedID, persistedSecret string + renewalRequests := 0 h := &Handler{ config: &Config{ CachedClientID: "test-client", CachedSecretExpiry: expiry, - CachedRegClientURI: svc.URL, + CachedRegClientURI: svc.URL + "/register/test-client", CachedRegTokenRef: "rat-ref", CachedRefreshTokenRef: "refresh-token-ref", }, - secretProvider: &mockSecretProvider{ - secrets: map[string]string{ - "rat-ref": "rat-token", - "refresh-token-ref": "some-refresh-token", - }, - }, - clientCredentialsPersister: func(id, secret string, _ time.Time, _, _, _ string) error { + secretProvider: newTestSecretProvider(t, map[string]string{ + "rat-ref": "rat-token", + "refresh-token-ref": "some-refresh-token", + }), + clientCredentialsPersister: func(id, secret string, _ time.Time, _, _, _ string, _ int) error { persistedID = id persistedSecret = secret + renewalRequests++ return nil }, } @@ -534,6 +501,8 @@ func TestHandler_Restore_RenewSuccess(t *testing.T) { // But renewal DID happen assert.Equal(t, "test-client", persistedID) assert.Equal(t, "new-secret", persistedSecret) + assert.Equal(t, 1, renewalRequests) + assert.False(t, h.isSecretExpiredOrExpiringSoon()) } func TestHandler_Restore_RenewFail_Soft(t *testing.T) { @@ -544,23 +513,21 @@ func TestHandler_Restore_RenewFail_Soft(t *testing.T) { svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) - defer svc.Close() + t.Cleanup(svc.Close) h := &Handler{ config: &Config{ CachedClientID: "test-client", CachedSecretExpiry: expiry, - CachedRegClientURI: svc.URL, + CachedRegClientURI: svc.URL + "/register/test-client", CachedRegTokenRef: "rat-ref", CachedRefreshTokenRef: "refresh-token-ref", }, - secretProvider: &mockSecretProvider{ - secrets: map[string]string{ - "rat-ref": "rat-token", - "refresh-token-ref": "some-refresh-token", - }, - }, - clientCredentialsPersister: func(_, _ string, _ time.Time, _, _, _ string) error { return nil }, + secretProvider: newTestSecretProvider(t, map[string]string{ + "rat-ref": "rat-token", + "refresh-token-ref": "some-refresh-token", + }), + clientCredentialsPersister: func(_, _ string, _ time.Time, _, _, _ string, _ int) error { return nil }, } // Renewal fails, but since it's only "expiring soon", it should continue (and then fail on token refresh) @@ -577,23 +544,21 @@ func TestHandler_Restore_RenewFail_Hard(t *testing.T) { svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) - defer svc.Close() + t.Cleanup(svc.Close) h := &Handler{ config: &Config{ CachedClientID: "test-client", CachedSecretExpiry: expiry, - CachedRegClientURI: svc.URL, + CachedRegClientURI: svc.URL + "/register/test-client", CachedRegTokenRef: "rat-ref", CachedRefreshTokenRef: "refresh-token-ref", }, - secretProvider: &mockSecretProvider{ - secrets: map[string]string{ - "rat-ref": "rat-token", - "refresh-token-ref": "some-refresh-token", - }, - }, - clientCredentialsPersister: func(string, string, time.Time, string, string, string) error { return nil }, + secretProvider: newTestSecretProvider(t, map[string]string{ + "rat-ref": "rat-token", + "refresh-token-ref": "some-refresh-token", + }), + clientCredentialsPersister: func(string, string, time.Time, string, string, string, int) error { return nil }, } // Renewal fails and it's fully expired -> fatal error diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index b7b140bf82..8d5dd93019 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -760,10 +760,11 @@ func (r *Runner) handleRemoteAuthentication(ctx context.Context) (oauth2.TokenSo secretExpiry time.Time, regAccessToken, regClientURI string, tokenEndpointAuthMethod string, + registeredCallbackPort int, ) error { return r.persistClientCredentials( ctx, secretManager, clientID, clientSecret, - secretExpiry, regAccessToken, regClientURI, tokenEndpointAuthMethod) + secretExpiry, regAccessToken, regClientURI, tokenEndpointAuthMethod, registeredCallbackPort) }) } @@ -813,47 +814,65 @@ func (r *Runner) persistClientCredentials( secretExpiry time.Time, regAccessToken, regClientURI string, tokenEndpointAuthMethod string, + registeredCallbackPort int, ) error { - r.Config.RemoteAuthConfig.CachedClientID = clientID + updatedConfig := *r.Config + updatedRemoteAuthConfig := remote.Config{} + if r.Config.RemoteAuthConfig != nil { + updatedRemoteAuthConfig = *r.Config.RemoteAuthConfig + } + updatedConfig.RemoteAuthConfig = &updatedRemoteAuthConfig + updatedRemoteAuthConfig.CachedClientID = clientID if clientSecret != "" { - clientSecretSecretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( - r.Config.Name, - "OAUTH_CLIENT_SECRET_", - secretManager, - ) - if err != nil { - return fmt.Errorf("failed to generate client secret secret name: %w", err) + clientSecretSecretName := updatedRemoteAuthConfig.CachedClientSecretRef + if clientSecretSecretName == "" { + var err error + clientSecretSecretName, err = authsecrets.GenerateUniqueSecretNameWithPrefix( + r.Config.Name, + "OAUTH_CLIENT_SECRET_", + secretManager, + ) + if err != nil { + return fmt.Errorf("failed to generate client secret secret name: %w", err) + } } if err := authsecrets.StoreSecretInManagerWithProvider(ctx, clientSecretSecretName, clientSecret, secretManager); err != nil { return fmt.Errorf("failed to store client secret: %w", err) } - r.Config.RemoteAuthConfig.CachedClientSecretRef = clientSecretSecretName + updatedRemoteAuthConfig.CachedClientSecretRef = clientSecretSecretName } - r.Config.RemoteAuthConfig.CachedSecretExpiry = secretExpiry + updatedRemoteAuthConfig.CachedSecretExpiry = secretExpiry if regAccessToken != "" { - regTokenSecretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix(r.Config.Name, "OAUTH_REG_TOKEN_", secretManager) - if err != nil { - return fmt.Errorf("failed to generate registration token secret name: %w", err) + regTokenSecretName := updatedRemoteAuthConfig.CachedRegTokenRef + if regTokenSecretName == "" { + var err error + regTokenSecretName, err = authsecrets.GenerateUniqueSecretNameWithPrefix(r.Config.Name, "OAUTH_REG_TOKEN_", secretManager) + if err != nil { + return fmt.Errorf("failed to generate registration token secret name: %w", err) + } } if err := authsecrets.StoreSecretInManagerWithProvider(ctx, regTokenSecretName, regAccessToken, secretManager); err != nil { return fmt.Errorf("failed to store registration access token: %w", err) } - r.Config.RemoteAuthConfig.CachedRegTokenRef = regTokenSecretName + updatedRemoteAuthConfig.CachedRegTokenRef = regTokenSecretName slog.Debug("Stored DCR registration access token for RFC 7592 operations") } - r.Config.RemoteAuthConfig.CachedRegClientURI = regClientURI - r.Config.RemoteAuthConfig.CachedTokenEndpointAuthMethod = tokenEndpointAuthMethod + updatedRemoteAuthConfig.CachedRegClientURI = regClientURI + updatedRemoteAuthConfig.CachedTokenEndpointAuthMethod = tokenEndpointAuthMethod + updatedRemoteAuthConfig.CachedDCRCallbackPort = registeredCallbackPort - if err := r.Config.SaveState(ctx); err != nil { + if err := updatedConfig.SaveState(ctx); err != nil { return fmt.Errorf("failed to save config with client credentials: %w", err) } + r.Config.RemoteAuthConfig = &updatedRemoteAuthConfig + slog.Debug("Stored DCR client credentials", "client_id", clientID, "has_expiry", !secretExpiry.IsZero(), "has_reg_token", regAccessToken != "", From 305268efb87cdd087a0dc936e950950394eabfc8 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Sun, 24 May 2026 01:18:39 +0530 Subject: [PATCH 8/8] remove unused constants Signed-off-by: Sanskarzz --- pkg/auth/oauth/dynamic_registration_test.go | 748 -------------------- pkg/oauth/constants.go | 70 -- 2 files changed, 818 deletions(-) delete mode 100644 pkg/auth/oauth/dynamic_registration_test.go delete mode 100644 pkg/oauth/constants.go diff --git a/pkg/auth/oauth/dynamic_registration_test.go b/pkg/auth/oauth/dynamic_registration_test.go deleted file mode 100644 index db195fe764..0000000000 --- a/pkg/auth/oauth/dynamic_registration_test.go +++ /dev/null @@ -1,748 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package oauth - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - oauthproto "github.com/stacklok/toolhive/pkg/oauth" -) - -func TestDiscoverOIDCEndpointsWithRegistration(t *testing.T) { - t.Parallel() - tests := []struct { - name string - issuer string - response string - expectedError bool - expectedResult *oauthproto.OIDCDiscoveryDocument - }{ - { - name: "valid OIDC discovery with registration endpoint", - issuer: "https://example.com", - response: `{ - "issuer": "{{SERVER_URL}}", - "authorization_endpoint": "{{SERVER_URL}}/oauth/authorize", - "token_endpoint": "{{SERVER_URL}}/oauth/token", - "userinfo_endpoint": "{{SERVER_URL}}/oauth/userinfo", - "jwks_uri": "{{SERVER_URL}}/oauth/jwks", - "registration_endpoint": "{{SERVER_URL}}/oauth/register" - }`, - expectedError: false, - expectedResult: &oauthproto.OIDCDiscoveryDocument{ - AuthorizationServerMetadata: oauthproto.AuthorizationServerMetadata{ - Issuer: "https://example.com", - AuthorizationEndpoint: "https://example.com/oauth/authorize", - TokenEndpoint: "https://example.com/oauth/token", - UserinfoEndpoint: "https://example.com/oauth/userinfo", - JWKSURI: "https://example.com/oauth/jwks", - RegistrationEndpoint: "https://example.com/oauth/register", - }, - }, - }, - { - name: "valid OIDC discovery without registration endpoint", - issuer: "https://example.com", - response: `{ - "issuer": "{{SERVER_URL}}", - "authorization_endpoint": "{{SERVER_URL}}/oauth/authorize", - "token_endpoint": "{{SERVER_URL}}/oauth/token", - "userinfo_endpoint": "{{SERVER_URL}}/oauth/userinfo", - "jwks_uri": "{{SERVER_URL}}/oauth/jwks" - }`, - expectedError: false, - expectedResult: &oauthproto.OIDCDiscoveryDocument{ - AuthorizationServerMetadata: oauthproto.AuthorizationServerMetadata{ - Issuer: "https://example.com", - AuthorizationEndpoint: "https://example.com/oauth/authorize", - TokenEndpoint: "https://example.com/oauth/token", - UserinfoEndpoint: "https://example.com/oauth/userinfo", - JWKSURI: "https://example.com/oauth/jwks", - RegistrationEndpoint: "", - }, - }, - }, - { - name: "invalid issuer URL", - issuer: "not-a-url", - expectedError: true, - }, - { - name: "non-HTTPS issuer", - issuer: "http://example.com", - expectedError: true, - }, - { - name: "localhost HTTP allowed for development", - issuer: "http://localhost:8080", - response: `{ - "issuer": "{{SERVER_URL}}", - "authorization_endpoint": "{{SERVER_URL}}/oauth/authorize", - "token_endpoint": "{{SERVER_URL}}/oauth/token", - "userinfo_endpoint": "{{SERVER_URL}}/oauth/userinfo", - "jwks_uri": "{{SERVER_URL}}/oauth/jwks", - "registration_endpoint": "{{SERVER_URL}}/oauth/register" - }`, - expectedError: false, - expectedResult: &oauthproto.OIDCDiscoveryDocument{ - AuthorizationServerMetadata: oauthproto.AuthorizationServerMetadata{ - Issuer: "http://localhost:8080", - AuthorizationEndpoint: "http://localhost:8080/oauth/authorize", - TokenEndpoint: "http://localhost:8080/oauth/token", - UserinfoEndpoint: "http://localhost:8080/oauth/userinfo", - JWKSURI: "http://localhost:8080/oauth/jwks", - RegistrationEndpoint: "http://localhost:8080/oauth/register", - }, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - var server *httptest.Server - var responseTemplate string - - if tt.response != "" { - responseTemplate = tt.response - server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Handle both OIDC and OAuth discovery endpoints - if r.URL.Path == oauthproto.WellKnownOIDCPath || - r.URL.Path == oauthproto.WellKnownOAuthServerPath { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - // Replace placeholder with actual server URL - response := strings.ReplaceAll(responseTemplate, "{{SERVER_URL}}", server.URL) - w.Write([]byte(response)) - } else { - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() - } - - issuer := tt.issuer - if server != nil { - // For test server, use the actual server URL - issuer = server.URL - } - - result, err := DiscoverOIDCEndpoints(context.Background(), issuer) - - if tt.expectedError { - assert.Error(t, err) - assert.Nil(t, result) - } else { - assert.NoError(t, err) - assert.NotNil(t, result) - if server != nil { - // For test server, we can't predict the exact URLs, so just check structure - assert.NotEmpty(t, result.Issuer) - assert.NotEmpty(t, result.AuthorizationEndpoint) - assert.NotEmpty(t, result.TokenEndpoint) - if tt.expectedResult.RegistrationEndpoint != "" { - assert.NotEmpty(t, result.RegistrationEndpoint) - } - } else { - // For static tests, check exact values - assert.Equal(t, tt.expectedResult.Issuer, result.Issuer) - assert.Equal(t, tt.expectedResult.AuthorizationEndpoint, result.AuthorizationEndpoint) - assert.Equal(t, tt.expectedResult.TokenEndpoint, result.TokenEndpoint) - assert.Equal(t, tt.expectedResult.RegistrationEndpoint, result.RegistrationEndpoint) - } - } - }) - } -} - -func TestNewDynamicClientRegistrationRequest(t *testing.T) { - t.Parallel() - tests := []struct { - name string - scopes []string - callbackPort int - expected *DynamicClientRegistrationRequest - }{ - { - name: "basic request", - scopes: []string{"openid", "profile"}, - callbackPort: 8080, - expected: &DynamicClientRegistrationRequest{ - ClientName: "ToolHive MCP Client", - RedirectURIs: []string{"http://localhost:8080/callback"}, - TokenEndpointAuthMethod: "client_secret_post", - GrantTypes: []string{"authorization_code", "refresh_token"}, - ResponseTypes: []string{"code"}, - Scopes: []string{"openid", "profile"}, - }, - }, - { - name: "empty scopes", - scopes: []string{}, - callbackPort: 8666, - expected: &DynamicClientRegistrationRequest{ - ClientName: "ToolHive MCP Client", - RedirectURIs: []string{"http://localhost:8666/callback"}, - TokenEndpointAuthMethod: "client_secret_post", - GrantTypes: []string{"authorization_code", "refresh_token"}, - ResponseTypes: []string{"code"}, - Scopes: []string{}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - result := NewDynamicClientRegistrationRequest(tt.scopes, tt.callbackPort) - - assert.Equal(t, tt.expected.ClientName, result.ClientName) - assert.Equal(t, tt.expected.RedirectURIs, result.RedirectURIs) - assert.Equal(t, tt.expected.TokenEndpointAuthMethod, result.TokenEndpointAuthMethod) - assert.Equal(t, tt.expected.GrantTypes, result.GrantTypes) - assert.Equal(t, tt.expected.ResponseTypes, result.ResponseTypes) - assert.Equal(t, tt.expected.Scopes, result.Scopes) - }) - } -} - -func TestDynamicClientRegistrationRequest_ScopeSerialization(t *testing.T) { - t.Parallel() - - // This test verifies RFC 7591 Section 2 compliance for scope serialization. - // Per the spec, scopes MUST be serialized as a space-delimited string, not a JSON array. - // Empty/nil scopes should result in the scope field being omitted entirely (omitempty), - // which is RFC 7591 compliant since the scope parameter is optional. - - tests := []struct { - name string - scopes []string - shouldOmitScope bool - expectedScopeJSON string // Expected scope field in JSON, empty if omitted - }{ - { - name: "nil scopes should omit scope field entirely", - scopes: nil, - shouldOmitScope: true, - }, - { - name: "empty slice scopes should omit scope field entirely", - scopes: []string{}, - shouldOmitScope: true, - }, - { - name: "single scope should be space-delimited string per RFC 7591", - scopes: []string{"openid"}, - shouldOmitScope: false, - expectedScopeJSON: `"scope":"openid"`, - }, - { - name: "multiple scopes should be space-delimited string per RFC 7591", - scopes: []string{"openid", "profile"}, - shouldOmitScope: false, - expectedScopeJSON: `"scope":"openid profile"`, - }, - { - name: "three scopes should be space-delimited string per RFC 7591", - scopes: []string{"openid", "profile", "email"}, - shouldOmitScope: false, - expectedScopeJSON: `"scope":"openid profile email"`, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - // Create request with specified scopes - request := NewDynamicClientRegistrationRequest(tt.scopes, 8666) - - // Marshal to JSON - jsonBytes, err := json.Marshal(request) - require.NoError(t, err, "JSON marshaling should succeed") - - jsonStr := string(jsonBytes) - - // Verify scope field behavior - if tt.shouldOmitScope { - assert.NotContains(t, jsonStr, `"scope"`, - "JSON should NOT contain scope field when scopes are empty/nil (omitempty behavior)") - } else { - assert.Contains(t, jsonStr, tt.expectedScopeJSON, - "JSON should contain expected scope field") - } - - // Verify other required fields are always present - assert.Contains(t, jsonStr, `"redirect_uris"`, "redirect_uris should be present") - assert.Contains(t, jsonStr, `"client_name"`, "client_name should be present") - assert.Contains(t, jsonStr, `"grant_types"`, "grant_types should be present") - }) - } -} - -func TestRegisterClientDynamically(t *testing.T) { - t.Parallel() - tests := []struct { - name string - request *DynamicClientRegistrationRequest - response string - responseStatus int - expectedError bool - expectedResult *DynamicClientRegistrationResponse - }{ - { - name: "successful registration", - request: &DynamicClientRegistrationRequest{ - ClientName: "Test Client", - RedirectURIs: []string{"http://localhost:8080/callback"}, - TokenEndpointAuthMethod: "none", - GrantTypes: []string{"authorization_code"}, - ResponseTypes: []string{"code"}, - Scopes: []string{"openid", "profile"}, - }, - response: `{ - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "client_id_issued_at": 1234567890, - "client_secret_expires_at": 0, - "registration_access_token": "reg-token", - "registration_client_uri": "https://example.com/oauth/register/test-client-id" - }`, - responseStatus: http.StatusCreated, - expectedError: false, - expectedResult: &DynamicClientRegistrationResponse{ - ClientID: "test-client-id", - ClientSecret: "test-client-secret", - ClientIDIssuedAt: 1234567890, - ClientSecretExpiresAt: 0, - RegistrationAccessToken: "reg-token", - RegistrationClientURI: "https://example.com/oauth/register/test-client-id", - }, - }, - { - name: "registration without client secret (PKCE flow)", - request: &DynamicClientRegistrationRequest{ - ClientName: "Test Client", - RedirectURIs: []string{"http://localhost:8080/callback"}, - TokenEndpointAuthMethod: "none", - GrantTypes: []string{"authorization_code"}, - ResponseTypes: []string{"code"}, - }, - response: `{ - "client_id": "test-client-id", - "client_id_issued_at": 1234567890 - }`, - responseStatus: http.StatusCreated, - expectedError: false, - expectedResult: &DynamicClientRegistrationResponse{ - ClientID: "test-client-id", - ClientIDIssuedAt: 1234567890, - }, - }, - { - name: "server error", - request: &DynamicClientRegistrationRequest{ - ClientName: "Test Client", - RedirectURIs: []string{"http://localhost:8080/callback"}, - }, - response: `{"error": "invalid_request", "error_description": "Invalid request"}`, - responseStatus: http.StatusBadRequest, - expectedError: true, - }, - { - name: "DCR not supported - 404 Not Found", - request: &DynamicClientRegistrationRequest{ - ClientName: "Test Client", - RedirectURIs: []string{"http://localhost:8080/callback"}, - }, - response: `{"error": "not_found"}`, - responseStatus: http.StatusNotFound, - expectedError: true, - }, - { - name: "DCR not supported - 405 Method Not Allowed", - request: &DynamicClientRegistrationRequest{ - ClientName: "Test Client", - RedirectURIs: []string{"http://localhost:8080/callback"}, - }, - response: `{"error": "method_not_allowed"}`, - responseStatus: http.StatusMethodNotAllowed, - expectedError: true, - }, - { - name: "DCR not supported - 501 Not Implemented", - request: &DynamicClientRegistrationRequest{ - ClientName: "Test Client", - RedirectURIs: []string{"http://localhost:8080/callback"}, - }, - response: `{"error": "not_implemented", "error_description": "Dynamic Client Registration is not supported"}`, - responseStatus: http.StatusNotImplemented, - expectedError: true, - }, - { - name: "invalid request - no redirect URIs", - request: &DynamicClientRegistrationRequest{ - ClientName: "Test Client", - }, - expectedError: true, - }, - { - name: "invalid request - scope with spaces", - request: &DynamicClientRegistrationRequest{ - ClientName: "Test Client", - RedirectURIs: []string{"http://localhost:8080/callback"}, - Scopes: []string{"openid", "profile email", "another"}, - }, - expectedError: true, - }, - { - name: "invalid request - scope with leading space", - request: &DynamicClientRegistrationRequest{ - ClientName: "Test Client", - RedirectURIs: []string{"http://localhost:8080/callback"}, - Scopes: []string{" openid"}, - }, - expectedError: true, - }, - { - name: "invalid request - scope with trailing space", - request: &DynamicClientRegistrationRequest{ - ClientName: "Test Client", - RedirectURIs: []string{"http://localhost:8080/callback"}, - Scopes: []string{"openid "}, - }, - expectedError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - var server *httptest.Server - if tt.response != "" { - server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "POST", r.Method) - assert.Equal(t, "application/json", r.Header.Get("Content-Type")) - assert.Equal(t, "application/json", r.Header.Get("Accept")) - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(tt.responseStatus) - w.Write([]byte(tt.response)) - })) - defer server.Close() - } - - var registrationEndpoint string - if server != nil { - registrationEndpoint = server.URL - } else { - registrationEndpoint = "https://example.com/oauth/register" - } - - result, err := RegisterClientDynamically(context.Background(), registrationEndpoint, tt.request) - - if tt.expectedError { - assert.Error(t, err) - assert.Nil(t, result) - } else { - assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, tt.expectedResult.ClientID, result.ClientID) - assert.Equal(t, tt.expectedResult.ClientSecret, result.ClientSecret) - assert.Equal(t, tt.expectedResult.ClientIDIssuedAt, result.ClientIDIssuedAt) - assert.Equal(t, tt.expectedResult.RegistrationAccessToken, result.RegistrationAccessToken) - assert.Equal(t, tt.expectedResult.RegistrationClientURI, result.RegistrationClientURI) - } - }) - } -} - -func TestDynamicClientRegistrationRequest_Defaults(t *testing.T) { - t.Parallel() - // Test that default values are set correctly - request := &DynamicClientRegistrationRequest{ - RedirectURIs: []string{"http://localhost:8080/callback"}, - } - - // Serialize to JSON to verify defaults - data, err := json.Marshal(request) - require.NoError(t, err) - - var result map[string]interface{} - err = json.Unmarshal(data, &result) - require.NoError(t, err) - - // Verify that required fields are present - assert.Contains(t, result, "redirect_uris") - assert.Equal(t, []interface{}{"http://localhost:8080/callback"}, result["redirect_uris"]) -} - -// TestDynamicClientRegistrationResponse_Validation tests that the response validation works correctly -func TestDynamicClientRegistrationResponse_Validation(t *testing.T) { - t.Parallel() - // Test that response validation works correctly - validResponse := &DynamicClientRegistrationResponse{ - ClientID: "test-client-id", - } - - // Serialize to JSON - data, err := json.Marshal(validResponse) - require.NoError(t, err) - - var result DynamicClientRegistrationResponse - err = json.Unmarshal(data, &result) - require.NoError(t, err) - - assert.Equal(t, "test-client-id", result.ClientID) -} - -func TestDiscoverOIDCEndpointsWithRegistrationFallback(t *testing.T) { - t.Parallel() - - // Test case: OIDC well-known succeeds but lacks registration_endpoint, - // OAuth authorization server well-known has it - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - baseURL := "http://" + r.Host - switch r.URL.Path { - case oauthproto.WellKnownOIDCPath: - // OIDC discovery - no registration_endpoint - response := `{ - "issuer": "` + baseURL + `", - "authorization_endpoint": "` + baseURL + `/oauth/authorize", - "token_endpoint": "` + baseURL + `/oauth/token", - "userinfo_endpoint": "` + baseURL + `/oauth/userinfo", - "jwks_uri": "` + baseURL + `/oauth/jwks" - }` - w.WriteHeader(http.StatusOK) - w.Write([]byte(response)) - case oauthproto.WellKnownOAuthServerPath: - // OAuth authorization server - has registration_endpoint - response := `{ - "issuer": "` + baseURL + `", - "authorization_endpoint": "` + baseURL + `/oauth/authorize", - "token_endpoint": "` + baseURL + `/oauth/token", - "registration_endpoint": "` + baseURL + `/oauth/register" - }` - w.WriteHeader(http.StatusOK) - w.Write([]byte(response)) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() - - result, err := DiscoverOIDCEndpoints(context.Background(), server.URL) - - require.NoError(t, err) - require.NotNil(t, result) - assert.Equal(t, server.URL, result.Issuer) - assert.NotEmpty(t, result.AuthorizationEndpoint) - assert.NotEmpty(t, result.TokenEndpoint) - // Registration endpoint should be found from OAuth authorization server well-known - assert.NotEmpty(t, result.RegistrationEndpoint, "registration_endpoint should be found via OAuth authorization server fallback") - assert.Equal(t, server.URL+"/oauth/register", result.RegistrationEndpoint) -} - -func TestDiscoverOIDCEndpointsWithRegistrationFallbackIssuerMismatch(t *testing.T) { - t.Parallel() - - // Test case: OIDC and OAuth have different issuers - should not merge - // Use DiscoverActualIssuer which doesn't validate issuer, allowing us to test the merge logic - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - baseURL := "http://" + r.Host - switch r.URL.Path { - case oauthproto.WellKnownOIDCPath: - // OIDC discovery - no registration_endpoint, different issuer - response := `{ - "issuer": "https://oidc.example.com", - "authorization_endpoint": "` + baseURL + `/oauth/authorize", - "token_endpoint": "` + baseURL + `/oauth/token", - "userinfo_endpoint": "` + baseURL + `/oauth/userinfo", - "jwks_uri": "` + baseURL + `/oauth/jwks" - }` - w.WriteHeader(http.StatusOK) - w.Write([]byte(response)) - case oauthproto.WellKnownOAuthServerPath: - // OAuth authorization server - has registration_endpoint but different issuer - response := `{ - "issuer": "https://oauth.example.com", - "authorization_endpoint": "` + baseURL + `/oauth/authorize", - "token_endpoint": "` + baseURL + `/oauth/token", - "registration_endpoint": "` + baseURL + `/oauth/register" - }` - w.WriteHeader(http.StatusOK) - w.Write([]byte(response)) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() - - // Use DiscoverActualIssuer which doesn't validate issuer, allowing us to test merge logic - result, err := DiscoverActualIssuer(context.Background(), server.URL) - - require.NoError(t, err) - require.NotNil(t, result) - // Registration endpoint should NOT be merged due to issuer mismatch - assert.Empty(t, result.RegistrationEndpoint, "registration_endpoint should not be merged when issuers don't match") -} - -// TestIsLocalhost is already defined in oidc_test.go - -// TestScopeList_MarshalJSON tests that the ScopeList marshaling works correctly -// and produces RFC 7591 compliant space-delimited strings. -func TestScopeList_MarshalJSON(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - scopes ScopeList - wantJSON string - wantOmit bool // If true, expect omitempty to hide the field - }{ - { - name: "nil scopes => empty string (omitempty will hide at struct level)", - scopes: nil, - wantJSON: `""`, - wantOmit: true, - }, - { - name: "empty slice => empty string (omitempty will hide at struct level)", - scopes: ScopeList{}, - wantJSON: `""`, - wantOmit: true, - }, - { - name: "single scope => string", - scopes: ScopeList{"openid"}, - wantJSON: `"openid"`, - }, - { - name: "two scopes => space-delimited string", - scopes: ScopeList{"openid", "profile"}, - wantJSON: `"openid profile"`, - }, - { - name: "three scopes => space-delimited string", - scopes: ScopeList{"openid", "profile", "email"}, - wantJSON: `"openid profile email"`, - }, - { - name: "scopes with special characters", - scopes: ScopeList{"read:user", "write:repo"}, - wantJSON: `"read:user write:repo"`, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - jsonBytes, err := json.Marshal(tt.scopes) - require.NoError(t, err, "marshaling should succeed") - - jsonStr := string(jsonBytes) - assert.Equal(t, tt.wantJSON, jsonStr, "marshaled JSON should match expected format") - - // Verify omitempty behavior in a struct - // Note: omitempty checks the Go value (empty slice) before calling MarshalJSON, - // so empty slices are omitted regardless of what MarshalJSON returns. - if tt.wantOmit { - type testStruct struct { - Scope ScopeList `json:"scope,omitempty"` - } - s := testStruct{Scope: tt.scopes} - structJSON, err := json.Marshal(s) - require.NoError(t, err) - assert.Equal(t, "{}", string(structJSON), "omitempty should hide empty scope field") - } - }) - } -} - -// TestScopeList_UnmarshalJSON tests that the ScopeList unmarshaling works correctly. -func TestScopeList_UnmarshalJSON(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - jsonIn string - want []string - wantErr bool - }{ - { - name: "space-delimited string", - jsonIn: `"openid profile email"`, - want: []string{"openid", "profile", "email"}, - }, - { - name: "empty string => nil", - jsonIn: `""`, - want: nil, - }, - { - name: "string with extra spaces", - jsonIn: `" openid profile "`, - want: []string{"openid", "profile"}, - }, - { - name: "normal array", - jsonIn: `["openid","profile","email"]`, - want: []string{"openid", "profile", "email"}, - }, - { - name: "array with whitespace and empties", - jsonIn: `[" openid ",""," profile "]`, - want: []string{"openid", "profile"}, - }, - { - name: "all-empty array => nil", - jsonIn: `[""," "]`, - want: nil, - }, - { - name: "explicit null => nil", - jsonIn: `null`, - want: nil, - }, - { - name: "invalid type (number)", - jsonIn: `123`, - wantErr: true, - }, - { - name: "invalid type (object)", - jsonIn: `{"not":"valid"}`, - wantErr: true, - }, - } - - for _, tt := range tests { - tt := tt // capture loop variable - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - var s ScopeList - err := json.Unmarshal([]byte(tt.jsonIn), &s) - - if tt.wantErr { - assert.Error(t, err, "expected error but got none") - return - } - - assert.NoError(t, err, "unexpected unmarshal error") - assert.Equal(t, tt.want, []string(s)) - }) - } -} diff --git a/pkg/oauth/constants.go b/pkg/oauth/constants.go deleted file mode 100644 index 9b1ee98763..0000000000 --- a/pkg/oauth/constants.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2025 Stacklok, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package oauth provides RFC-defined types and constants for OAuth 2.0 and OpenID Connect. -// This package contains ONLY protocol-level definitions with no business logic. -// It serves as a shared foundation for both OAuth clients (consumers) and servers (producers). -package oauth - -// Well-known endpoint paths as defined by RFC 8414, OpenID Connect Discovery 1.0, and RFC 9728. -const ( - // WellKnownOIDCPath is the standard OIDC discovery endpoint path - // per OpenID Connect Discovery 1.0 specification. - WellKnownOIDCPath = "/.well-known/openid-configuration" - - // WellKnownOAuthServerPath is the standard OAuth authorization server metadata endpoint path - // per RFC 8414 (OAuth 2.0 Authorization Server Metadata). - WellKnownOAuthServerPath = "/.well-known/oauth-authorization-server" - - // WellKnownOAuthResourcePath is the RFC 9728 standard path for OAuth Protected Resource metadata. - // Per RFC 9728 Section 3, this endpoint and any subpaths under it should be accessible - // without authentication to enable OIDC/OAuth discovery. - WellKnownOAuthResourcePath = "/.well-known/oauth-protected-resource" -) - -// Grant types as defined by RFC 6749. -const ( - // GrantTypeAuthorizationCode is the authorization code grant type (RFC 6749 Section 4.1). - GrantTypeAuthorizationCode = "authorization_code" - - // GrantTypeRefreshToken is the refresh token grant type (RFC 6749 Section 6). - GrantTypeRefreshToken = "refresh_token" -) - -// Response types as defined by RFC 6749. -const ( - // ResponseTypeCode is the authorization code response type (RFC 6749 Section 4.1.1). - ResponseTypeCode = "code" -) - -// Token endpoint authentication methods as defined by RFC 7591. -const ( - // TokenEndpointAuthMethodNone indicates no client authentication (public clients). - // Typically used with PKCE for native/mobile applications. - TokenEndpointAuthMethodNone = "none" - - // TokenEndpointAuthMethodClientSecretPost indicates client authentication via - // client_id and client_secret in the request body. - TokenEndpointAuthMethodClientSecretPost = "client_secret_post" - - // TokenEndpointAuthMethodClientSecretBasic indicates client authentication via - // HTTP Basic Authentication. - TokenEndpointAuthMethodClientSecretBasic = "client_secret_basic" -) - -// PKCE (Proof Key for Code Exchange) methods as defined by RFC 7636. -const ( - // PKCEMethodS256 uses SHA-256 hash of the code verifier (recommended). - PKCEMethodS256 = "S256" -)