diff --git a/docs/server/docs.go b/docs/server/docs.go index fa76390328..c4a611c28f 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -258,18 +258,30 @@ const docTemplate = `{ "cached_client_secret_ref": { "type": "string" }, + "cached_dcr_callback_port": { + "description": "CachedDCRCallbackPort is the callback port that was actually registered\nduring DCR. It may differ from CallbackPort when the requested port was\nunavailable and a fallback port was selected.", + "type": "integer" + }, "cached_refresh_token_ref": { "description": "Cached OAuth token reference for persistence across restarts.\nThe refresh token is stored securely in the secret manager, and this field\ncontains the reference to retrieve it (e.g., \"OAUTH_REFRESH_TOKEN_workload\").\nThis enables session restoration without requiring a new browser-based login.", "type": "string" }, + "cached_reg_client_uri": { + "description": "CachedRegClientURI is the registration_client_uri from the DCR response.\nThis is the endpoint used for RFC 7592 client read/update/delete operations.\nStored as plain text since it is not sensitive.", + "type": "string" + }, "cached_reg_token_ref": { - "description": "RegistrationAccessToken is used to update/delete the client registration.\nStored as a secret reference since it's sensitive.", + "description": "CachedRegTokenRef is a secret manager reference to the registration_access_token\nreturned in the DCR response. Used for RFC 7592 client update operations.\nStored as a secret reference since it's sensitive.", "type": "string" }, "cached_secret_expiry": { "description": "ClientSecretExpiresAt indicates when the client secret expires (if provided by the DCR server).\nA zero value means the secret does not expire.", "type": "string" }, + "cached_token_auth_method": { + "description": "CachedTokenEndpointAuthMethod is the auth method used for the token endpoint\n(e.g., \"client_secret_basic\", \"none\"). Persisted for RFC 7592 updates.", + "type": "string" + }, "cached_token_expiry": { "type": "string" }, diff --git a/docs/server/swagger.json b/docs/server/swagger.json index eb71c3efec..388b8a9a75 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -251,18 +251,30 @@ "cached_client_secret_ref": { "type": "string" }, + "cached_dcr_callback_port": { + "description": "CachedDCRCallbackPort is the callback port that was actually registered\nduring DCR. It may differ from CallbackPort when the requested port was\nunavailable and a fallback port was selected.", + "type": "integer" + }, "cached_refresh_token_ref": { "description": "Cached OAuth token reference for persistence across restarts.\nThe refresh token is stored securely in the secret manager, and this field\ncontains the reference to retrieve it (e.g., \"OAUTH_REFRESH_TOKEN_workload\").\nThis enables session restoration without requiring a new browser-based login.", "type": "string" }, + "cached_reg_client_uri": { + "description": "CachedRegClientURI is the registration_client_uri from the DCR response.\nThis is the endpoint used for RFC 7592 client read/update/delete operations.\nStored as plain text since it is not sensitive.", + "type": "string" + }, "cached_reg_token_ref": { - "description": "RegistrationAccessToken is used to update/delete the client registration.\nStored as a secret reference since it's sensitive.", + "description": "CachedRegTokenRef is a secret manager reference to the registration_access_token\nreturned in the DCR response. Used for RFC 7592 client update operations.\nStored as a secret reference since it's sensitive.", "type": "string" }, "cached_secret_expiry": { "description": "ClientSecretExpiresAt indicates when the client secret expires (if provided by the DCR server).\nA zero value means the secret does not expire.", "type": "string" }, + "cached_token_auth_method": { + "description": "CachedTokenEndpointAuthMethod is the auth method used for the token endpoint\n(e.g., \"client_secret_basic\", \"none\"). Persisted for RFC 7592 updates.", + "type": "string" + }, "cached_token_expiry": { "type": "string" }, diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 6e9d446e50..fab7757d0e 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -267,6 +267,12 @@ components: type: string cached_client_secret_ref: type: string + cached_dcr_callback_port: + description: |- + CachedDCRCallbackPort is the callback port that was actually registered + during DCR. It may differ from CallbackPort when the requested port was + unavailable and a fallback port was selected. + type: integer cached_refresh_token_ref: description: |- Cached OAuth token reference for persistence across restarts. @@ -274,9 +280,16 @@ components: contains the reference to retrieve it (e.g., "OAUTH_REFRESH_TOKEN_workload"). This enables session restoration without requiring a new browser-based login. type: string + cached_reg_client_uri: + description: |- + CachedRegClientURI is the registration_client_uri from the DCR response. + This is the endpoint used for RFC 7592 client read/update/delete operations. + Stored as plain text since it is not sensitive. + type: string cached_reg_token_ref: description: |- - RegistrationAccessToken is used to update/delete the client registration. + CachedRegTokenRef is a secret manager reference to the registration_access_token + returned in the DCR response. Used for RFC 7592 client update operations. Stored as a secret reference since it's sensitive. type: string cached_secret_expiry: @@ -284,6 +297,11 @@ components: ClientSecretExpiresAt indicates when the client secret expires (if provided by the DCR server). A zero value means the secret does not expire. type: string + cached_token_auth_method: + description: |- + CachedTokenEndpointAuthMethod is the auth method used for the token endpoint + (e.g., "client_secret_basic", "none"). Persisted for RFC 7592 updates. + type: string cached_token_expiry: type: string callback_port: diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index ee28036c5d..31561228a2 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -512,6 +512,14 @@ type OAuthFlowConfig struct { Resource string // RFC 8707 resource indicator (optional) OAuthParams map[string]string ScopeParamName string // Override scope query parameter name (e.g., "user_scope" for Slack) + + // DCR renewal metadata — populated by handleDynamicRegistration and threaded + // into OAuthFlowResult so callers can persist the data for RFC 7592 operations. + SecretExpiry time.Time // zero means the secret never expires + RegistrationAccessToken string //nolint:gosec // G117: field legitimately holds sensitive data + RegistrationClientURI string + TokenEndpointAuthMethod string + RegisteredCallbackPort int } // OAuthFlowResult contains the result of an OAuth flow @@ -527,6 +535,16 @@ type OAuthFlowResult struct { // DCR client credentials for persistence (obtained during Dynamic Client Registration) ClientID string ClientSecret string //nolint:gosec // G117: field legitimately holds sensitive data + + // DCR renewal metadata (RFC 7591 §3.2.1 / RFC 7592). + // SecretExpiry is zero when the provider did not issue an expiring secret. + // RegistrationAccessToken and RegistrationClientURI are empty when the + // provider does not support RFC 7592 management operations. + SecretExpiry time.Time + RegistrationAccessToken string //nolint:gosec // G117: field legitimately holds sensitive data + RegistrationClientURI string + TokenEndpointAuthMethod string + RegisteredCallbackPort int } func shouldDynamicallyRegisterClient(config *OAuthFlowConfig) bool { @@ -599,20 +617,12 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi // exists. // // One consequence of option (b) is that the resolver's RFC 7591 §3.2.1 -// expiry-driven refetch does NOT participate in the CLI's cross- -// invocation persistence loop: each PerformOAuthFlow call builds a fresh -// in-memory store, so a "cached but expired" entry from the previous -// invocation never reaches the resolver. Cross-invocation expiry is also -// NOT enforced by the remote handler's gate today — -// HasCachedClientCredentials only checks CachedClientID != "" and does -// not consult CachedSecretExpiry, so an expired-but-still-cached client -// gets reused on the next invocation and surfaces as a token-endpoint -// failure rather than a clean DCR re-registration. Tightening the gate -// to also check CachedSecretExpiry is open follow-up work; the -// behaviour today is "cross-invocation expiry is unhandled". Within a -// single invocation, the resolver's expiry check is still in the loop -// and would fire if the same call site somehow registered, persisted, -// and re-queried the in-memory store — but the CLI never does this today. +// expiry-driven refetch does NOT participate in the CLI's cross-invocation +// persistence loop: each PerformOAuthFlow call builds a fresh in-memory store, +// so a cached entry from a previous invocation never reaches the resolver. +// Cross-invocation client-secret expiry is handled instead by the remote +// handler, which consults CachedSecretExpiry and renews through RFC 7592 before +// cached credentials are used. // // Wrapping the remote handler's secretProvider into a dcr.CredentialStore // adapter (option (a)) would close that loop and is the natural follow-up; @@ -660,6 +670,18 @@ func handleDynamicRegistration(ctx context.Context, issuer string, config *OAuth config.TokenURL = resolution.TokenEndpoint } + // Store DCR renewal metadata for RFC 7592 operations. + // A zero ClientSecretExpiresAt means the secret never expires (RFC 7591 §3.2.1). + config.SecretExpiry = resolution.ClientSecretExpiresAt + config.RegistrationAccessToken = resolution.RegistrationAccessToken + config.RegistrationClientURI = resolution.RegistrationClientURI + config.TokenEndpointAuthMethod = resolution.TokenEndpointAuthMethod + config.RegisteredCallbackPort = config.CallbackPort + + if resolution.RegistrationAccessToken != "" { + slog.Debug("DCR response includes registration access token for RFC 7592 operations") + } + return nil } @@ -892,6 +914,12 @@ func newOAuthFlow(ctx context.Context, oauthConfig *oauth.Config, config *OAuthF Expiry: tokenResult.Expiry, ClientID: oauthConfig.ClientID, ClientSecret: oauthConfig.ClientSecret, + // DCR renewal metadata — populated only when dynamic registration was performed. + SecretExpiry: config.SecretExpiry, + RegistrationAccessToken: config.RegistrationAccessToken, + RegistrationClientURI: config.RegistrationClientURI, + TokenEndpointAuthMethod: config.TokenEndpointAuthMethod, + RegisteredCallbackPort: config.RegisteredCallbackPort, }, nil } diff --git a/pkg/auth/remote/config.go b/pkg/auth/remote/config.go index 8675036dda..733e301024 100644 --- a/pkg/auth/remote/config.go +++ b/pkg/auth/remote/config.go @@ -11,7 +11,7 @@ import ( "strings" "time" - "github.com/stacklok/toolhive-core/registry/types" + registry "github.com/stacklok/toolhive-core/registry/types" httpval "github.com/stacklok/toolhive-core/validation/http" ) @@ -68,7 +68,8 @@ type Config struct { // ClientSecretExpiresAt indicates when the client secret expires (if provided by the DCR server). // A zero value means the secret does not expire. CachedSecretExpiry time.Time `json:"cached_secret_expiry,omitempty" yaml:"cached_secret_expiry,omitempty"` - // RegistrationAccessToken is used to update/delete the client registration. + // CachedRegTokenRef is a secret manager reference to the registration_access_token + // returned in the DCR response. Used for RFC 7592 client update operations. // Stored as a secret reference since it's sensitive. CachedRegTokenRef string `json:"cached_reg_token_ref,omitempty" yaml:"cached_reg_token_ref,omitempty"` @@ -78,6 +79,17 @@ type Config struct { // rotation clears CachedClientID without touching the stable CIMD URL. // Read by resolveClientCredentials to send the correct client_id on token refresh. CachedCIMDClientID string `json:"cached_cimd_client_id,omitempty" yaml:"cached_cimd_client_id,omitempty"` + // CachedRegClientURI is the registration_client_uri from the DCR response. + // This is the endpoint used for RFC 7592 client read/update/delete operations. + // Stored as plain text since it is not sensitive. + CachedRegClientURI string `json:"cached_reg_client_uri,omitempty" yaml:"cached_reg_client_uri,omitempty"` + // CachedTokenEndpointAuthMethod is the auth method used for the token endpoint + // (e.g., "client_secret_basic", "none"). Persisted for RFC 7592 updates. + CachedTokenEndpointAuthMethod string `json:"cached_token_auth_method,omitempty" yaml:"cached_token_auth_method,omitempty"` + // CachedDCRCallbackPort is the callback port that was actually registered + // during DCR. It may differ from CallbackPort when the requested port was + // unavailable and a fallback port was selected. + CachedDCRCallbackPort int `json:"cached_dcr_callback_port,omitempty" yaml:"cached_dcr_callback_port,omitempty"` } // BearerTokenEnvVarName is the environment variable name used for bearer token authentication. @@ -184,6 +196,9 @@ func (c *Config) ClearCachedClientCredentials() { c.CachedClientSecretRef = "" c.CachedSecretExpiry = time.Time{} c.CachedRegTokenRef = "" + c.CachedRegClientURI = "" + c.CachedTokenEndpointAuthMethod = "" + c.CachedDCRCallbackPort = 0 } // LogContext returns the upstream issuer and resolved client_id for use as diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index ce7d271288..608415a57b 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -9,10 +9,12 @@ import ( "fmt" "log/slog" "strings" + "time" "golang.org/x/oauth2" "github.com/stacklok/toolhive/pkg/auth/discovery" + "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/oauthproto" "github.com/stacklok/toolhive/pkg/secrets" ) @@ -24,6 +26,7 @@ type Handler struct { tokenPersister TokenPersister clientCredentialsPersister ClientCredentialsPersister secretProvider secrets.Provider + httpClient networking.HTTPClient } // NewHandler creates a new remote authentication handler @@ -50,6 +53,11 @@ func (h *Handler) SetClientCredentialsPersister(persister ClientCredentialsPersi h.clientCredentialsPersister = persister } +// SetHTTPClient sets the HTTP client used for RFC 7592 registration management requests. +func (h *Handler) SetHTTPClient(client networking.HTTPClient) { + h.httpClient = client +} + // Authenticate is the main entry point for remote MCP server authentication func (h *Handler) Authenticate(ctx context.Context, remoteURL string) (oauth2.TokenSource, error) { // Priority 1: Bearer token authentication (if configured) @@ -207,7 +215,15 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. // CIMD client IDs (HTTPS URLs) are stable constants and are stored separately below. if h.clientCredentialsPersister != nil && result.ClientID != "" && !oauthproto.IsClientIDMetadataDocumentURL(result.ClientID) { - if err := h.clientCredentialsPersister(result.ClientID, result.ClientSecret); err != nil { + if err := h.clientCredentialsPersister( + result.ClientID, + result.ClientSecret, + result.SecretExpiry, + result.RegistrationAccessToken, + result.RegistrationClientURI, + result.TokenEndpointAuthMethod, + result.RegisteredCallbackPort, + ); err != nil { slog.Warn("Failed to persist DCR client credentials", "error", err) } else { slog.Debug("Successfully persisted DCR client credentials for future restarts") @@ -232,6 +248,8 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. // resolveClientCredentials returns the client ID and secret to use, preferring // cached DCR credentials over statically configured ones. +// If the cached client secret is expiring soon, it attempts renewal via RFC 7592 +// before returning the credentials. func (h *Handler) resolveClientCredentials(ctx context.Context) (clientID, clientSecret string) { // First try to use statically configured credentials clientID = h.config.ClientID @@ -252,6 +270,18 @@ func (h *Handler) resolveClientCredentials(ctx context.Context) (clientID, clien clientID = h.config.CachedClientID slog.Debug("Using cached DCR client credentials", "client_id", clientID) + // Proactively renew the client secret if it is expiring soon (RFC 7592) + if h.isSecretExpiredOrExpiringSoon() { + slog.Debug("Cached client secret is expiring soon, attempting renewal", + "expiry", h.config.CachedSecretExpiry) + if renewErr := h.renewClientSecret(ctx); renewErr != nil { + slog.Warn("Failed to proactively renew client secret; continuing with existing secret", + "error", renewErr) + } else { + slog.Debug("Successfully renewed client secret ahead of expiry") + } + } + // Client secret is stored securely and may be empty for PKCE flows if h.config.CachedClientSecretRef != "" && h.secretProvider != nil { cachedClientSecret, err := h.secretProvider.GetSecret(ctx, h.config.CachedClientSecretRef) @@ -278,6 +308,27 @@ func (h *Handler) tryRestoreFromCachedTokens( return nil, fmt.Errorf("secret provider not configured, cannot restore cached tokens") } + // Check if the cached client secret is expired before attempting token refresh. + // If it has fully expired and renewal also fails we must force a fresh OAuth flow. + if h.isSecretExpiredOrExpiringSoon() { + slog.Debug("Cached client secret is expiring or expired; attempting renewal before token restore", + "expiry", h.config.CachedSecretExpiry) + if renewErr := h.renewClientSecret(ctx); renewErr != nil { + slog.Warn("Client secret renewal failed", "error", renewErr) + // Hard-fail only when the secret is already past its expiry. + // If we are still in the buffer window the existing secret may work. + if !h.config.CachedSecretExpiry.IsZero() && time.Now().After(h.config.CachedSecretExpiry) { + return nil, fmt.Errorf( + "client secret expired at %v and renewal failed: %w", + h.config.CachedSecretExpiry, renewErr) + } + // Still within buffer — log and continue with the existing (still-valid) secret + slog.Warn("Proceeding with expiring client secret after failed renewal attempt") + } else { + slog.Debug("Successfully renewed client secret before token restore") + } + } + refreshToken, err := h.secretProvider.GetSecret(ctx, h.config.CachedRefreshTokenRef) if err != nil { return nil, fmt.Errorf("failed to retrieve cached refresh token: %w", err) diff --git a/pkg/auth/remote/persisting_token_source.go b/pkg/auth/remote/persisting_token_source.go index 3ba90b6253..11e92081e9 100644 --- a/pkg/auth/remote/persisting_token_source.go +++ b/pkg/auth/remote/persisting_token_source.go @@ -21,8 +21,26 @@ import ( type TokenPersister func(refreshToken string, expiry time.Time) error // ClientCredentialsPersister is called when DCR client credentials need to be persisted. -// This is used to store client_id and client_secret obtained during Dynamic Client Registration. -type ClientCredentialsPersister func(clientID, clientSecret string) error +// This is used to store client_id, client_secret, and renewal metadata obtained during +// Dynamic Client Registration (RFC 7591) and needed for secret renewal (RFC 7592). +// +// Parameters: +// - clientID: the registered client ID (public, stored as plain text) +// - clientSecret: the registered client secret (sensitive, stored via secret manager) +// - secretExpiry: when the client secret expires; zero value means it never expires +// - registrationAccessToken: bearer token for RFC 7592 management operations (sensitive) +// - registrationClientURI: endpoint for RFC 7592 client update/read operations (plain text) +// - tokenEndpointAuthMethod: the auth method used for the token endpoint (e.g., "client_secret_basic", "none") +// - registeredCallbackPort: the callback port used in the original DCR redirect URI +type ClientCredentialsPersister func( + clientID string, + clientSecret string, + secretExpiry time.Time, + registrationAccessToken string, + registrationClientURI string, + tokenEndpointAuthMethod string, + registeredCallbackPort int, +) error // PersistingTokenSource wraps an oauth2.TokenSource and persists tokens // whenever they are refreshed. This enables session restoration across diff --git a/pkg/auth/remote/secret_renewal.go b/pkg/auth/remote/secret_renewal.go new file mode 100644 index 0000000000..fb6b479155 --- /dev/null +++ b/pkg/auth/remote/secret_renewal.go @@ -0,0 +1,255 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package remote + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "time" + + "github.com/stacklok/toolhive/pkg/networking" + "github.com/stacklok/toolhive/pkg/oauthproto" +) + +// secretExpiryBuffer is the lead time before expiry at which we proactively +// renew the client secret (RFC 7592). Renewal is attempted when the secret +// expires within this window, not only after expiry. +const secretExpiryBuffer = 24 * time.Hour + +var defaultRenewalHTTPClient = &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + }, +} + +// clientUpdateRequest is the body sent in a RFC 7592 §2.2 PUT request. +// Per the spec, all client metadata fields that were provided during +// registration must be included in the update request body. +type clientUpdateRequest struct { + ClientID string `json:"client_id"` + ClientName string `json:"client_name,omitempty"` + RedirectURIs []string `json:"redirect_uris,omitempty"` + GrantTypes []string `json:"grant_types,omitempty"` + ResponseTypes []string `json:"response_types,omitempty"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` +} + +// clientUpdateResponse is the body returned by a RFC 7592 §2.1 response. +// The provider may rotate the registration_access_token; if present we must +// replace the stored one. +type clientUpdateResponse struct { + // Required fields mirrored from registration response + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` //nolint:gosec // G117: field holds sensitive data + + // Expiry fields + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + + // Management fields — registration_access_token may be rotated + RegistrationAccessToken string `json:"registration_access_token,omitempty"` //nolint:gosec + RegistrationClientURI string `json:"registration_client_uri,omitempty"` +} + +// isSecretExpiredOrExpiringSoon returns true when the cached client secret is +// already expired or will expire within secretExpiryBuffer. +// A zero CachedSecretExpiry means the secret never expires, so this returns false. +func (h *Handler) isSecretExpiredOrExpiringSoon() bool { + expiry := h.config.CachedSecretExpiry + if expiry.IsZero() { + return false // Non-expiring secret + } + return time.Now().After(expiry.Add(-secretExpiryBuffer)) +} + +// renewClientSecret attempts to renew the client secret using RFC 7592 §2.2. +// It retrieves the stored registration_access_token and sends a PUT request +// to the registration_client_uri with the current client metadata. +// +// On success the handler's config is updated with the new secret, expiry, and +// (if rotated) the new registration_access_token. +// +// Callers should log a warning and continue if renewal fails — the existing +// secret may still be valid for some time, or the provider may not support renewal. +func (h *Handler) renewClientSecret(ctx context.Context) error { + if err := h.validateRenewalPrerequisites(); err != nil { + return err + } + + // Retrieve the registration access token from the secret manager + regAccessToken, err := h.secretProvider.GetSecret(ctx, h.config.CachedRegTokenRef) + if err != nil { + return fmt.Errorf("failed to retrieve registration access token: %w", err) + } + + slog.Debug("Attempting RFC 7592 client secret renewal", + "registration_client_uri", h.config.CachedRegClientURI) + + // Validate the registration_client_uri before using it + if err := validateRegistrationClientURI(h.config.CachedRegClientURI); err != nil { + return fmt.Errorf("invalid registration_client_uri: %w", err) + } + + // Build the update request body with the current client metadata. + // Per RFC 7592 §2.2, the request MUST include all client metadata fields + // that were provided during the initial registration. + updateReq := clientUpdateRequest{ + ClientID: h.config.CachedClientID, + ClientName: oauthproto.ToolHiveMCPClientName, + RedirectURIs: []string{fmt.Sprintf("http://localhost:%d/callback", h.registeredDCRCallbackPort())}, + GrantTypes: []string{"authorization_code", "refresh_token"}, + ResponseTypes: []string{"code"}, + TokenEndpointAuthMethod: h.config.CachedTokenEndpointAuthMethod, + } + + reqBody, err := json.Marshal(updateReq) + if err != nil { + return fmt.Errorf("failed to marshal client update request: %w", err) + } + + // Create PUT request per RFC 7592 §2.2 + req, err := http.NewRequestWithContext( + ctx, + http.MethodPut, + h.config.CachedRegClientURI, + bytes.NewReader(reqBody), + ) + if err != nil { + return fmt.Errorf("failed to create client update request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+regAccessToken) + + httpClient := h.httpClient + if httpClient == nil { + httpClient = defaultRenewalHTTPClient + } + + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("client update request failed: %w", err) + } + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + slog.Debug("Failed to close renewal response body", "error", closeErr) + } + }() + + if resp.StatusCode != http.StatusOK { + errorBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + _, _ = io.Copy(io.Discard, resp.Body) + return fmt.Errorf("client update request returned HTTP %d: %s", resp.StatusCode, string(errorBody)) + } + + // Parse the renewal response + const maxResponseSize = 1024 * 1024 // 1 MB + var updateResp clientUpdateResponse + if err := json.NewDecoder(io.LimitReader(resp.Body, maxResponseSize)).Decode(&updateResp); err != nil { + return fmt.Errorf("failed to decode client update response: %w", err) + } + + if updateResp.ClientID == "" { + return fmt.Errorf("client update response missing client_id") + } + if updateResp.ClientSecret == "" { + return fmt.Errorf("client update response missing client_secret") + } + + return h.persistRenewedSecret(updateResp) +} + +func (h *Handler) validateRenewalPrerequisites() error { + if h.config.CachedRegClientURI == "" { + return fmt.Errorf("registration_client_uri missing; cannot renew secret (RFC 7592 unsupported)") + } + if h.config.CachedRegTokenRef == "" { + return fmt.Errorf("registration_access_token missing; cannot renew secret (RFC 7592 unsupported)") + } + if h.secretProvider == nil { + return fmt.Errorf("secret provider not configured; cannot retrieve registration access token") + } + return nil +} + +func (h *Handler) registeredDCRCallbackPort() int { + if h.config.CachedDCRCallbackPort != 0 { + return h.config.CachedDCRCallbackPort + } + return h.config.CallbackPort +} + +func (h *Handler) persistRenewedSecret(updateResp clientUpdateResponse) error { + if h.clientCredentialsPersister == nil { + return fmt.Errorf("client credentials persister not configured; cannot save renewed secret") + } + + var newExpiry time.Time + if updateResp.ClientSecretExpiresAt > 0 { + newExpiry = time.Unix(updateResp.ClientSecretExpiresAt, 0) + } + + // Use the rotated registration_access_token if provided; fall back to existing. + newRegToken := updateResp.RegistrationAccessToken + newRegURI := updateResp.RegistrationClientURI + if newRegURI == "" { + newRegURI = h.config.CachedRegClientURI + } + + if err := h.clientCredentialsPersister( + updateResp.ClientID, + updateResp.ClientSecret, + newExpiry, + newRegToken, + newRegURI, + h.config.CachedTokenEndpointAuthMethod, + h.registeredDCRCallbackPort(), + ); err != nil { + return fmt.Errorf("failed to persist renewed client secret: %w", err) + } + + h.config.CachedClientID = updateResp.ClientID + h.config.CachedSecretExpiry = newExpiry + h.config.CachedRegClientURI = newRegURI + if h.config.CachedDCRCallbackPort == 0 { + h.config.CachedDCRCallbackPort = h.config.CallbackPort + } + + slog.Debug("Successfully renewed client secret via RFC 7592", + "client_id", updateResp.ClientID, + "new_expiry_zero", newExpiry.IsZero(), + "reg_token_rotated", newRegToken != "") + + return nil +} + +// validateRegistrationClientURI validates that the registration_client_uri is +// a valid HTTPS URL (or localhost for development). +func validateRegistrationClientURI(registrationClientURI string) error { + if registrationClientURI == "" { + return fmt.Errorf("registration_client_uri is empty") + } + + parsedURL, err := url.Parse(registrationClientURI) + if err != nil { + return fmt.Errorf("invalid registration_client_uri URL: %w", err) + } + + if parsedURL.Scheme != "https" && !networking.IsLocalhost(parsedURL.Host) { + return fmt.Errorf("registration_client_uri must use HTTPS: %s", registrationClientURI) + } + if parsedURL.Path == "" || parsedURL.Path == "/" { + return fmt.Errorf("registration_client_uri must include a non-root path: %s", registrationClientURI) + } + + return nil +} diff --git a/pkg/auth/remote/secret_renewal_test.go b/pkg/auth/remote/secret_renewal_test.go new file mode 100644 index 0000000000..c574d1761d --- /dev/null +++ b/pkg/auth/remote/secret_renewal_test.go @@ -0,0 +1,569 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package remote + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + secretmocks "github.com/stacklok/toolhive/pkg/secrets/mocks" +) + +func newTestSecretProvider(t *testing.T, values map[string]string) *secretmocks.MockProvider { + t.Helper() + + ctrl := gomock.NewController(t) + provider := secretmocks.NewMockProvider(ctrl) + provider.EXPECT().GetSecret(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, name string) (string, error) { + value, ok := values[name] + if !ok { + return "", fmt.Errorf("secret %q not found", name) + } + return value, nil + }, + ).AnyTimes() + return provider +} + +// TestIsSecretExpiredOrExpiringSoon tests the expiry helper on various time scenarios. +func TestIsSecretExpiredOrExpiringSoon(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + expiry time.Time + wantExpired bool + }{ + { + name: "zero expiry means never expires", + expiry: time.Time{}, + wantExpired: false, + }, + { + name: "expiry far in the future — not expiring", + expiry: time.Now().Add(48 * time.Hour), + wantExpired: false, + }, + { + name: "expiry within 24h buffer — expiring soon", + expiry: time.Now().Add(12 * time.Hour), + wantExpired: true, + }, + { + name: "expiry in the past — already expired", + expiry: time.Now().Add(-1 * time.Hour), + wantExpired: true, + }, + { + name: "expiry exactly at buffer boundary — expiring soon", + expiry: time.Now().Add(secretExpiryBuffer - time.Minute), + wantExpired: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h := &Handler{ + config: &Config{ + CachedSecretExpiry: tt.expiry, + }, + } + assert.Equal(t, tt.wantExpired, h.isSecretExpiredOrExpiringSoon()) + }) + } +} + +// TestValidateRegistrationClientURI tests URI validation. +func TestValidateRegistrationClientURI(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + uri string + wantErr bool + }{ + { + name: "empty URI", + uri: "", + wantErr: true, + }, + { + name: "valid HTTPS URI", + uri: "https://example.com/oauth/register/client-id", + wantErr: false, + }, + { + name: "HTTP URI for non-localhost is rejected", + uri: "http://example.com/oauth/register/client-id", + wantErr: true, + }, + { + name: "localhost HTTP is allowed (development)", + uri: "http://localhost:8080/oauth/register/client-id", + wantErr: false, + }, + { + name: "127.0.0.1 HTTP is allowed (development)", + uri: "http://127.0.0.1:8080/oauth/register/client-id", + wantErr: false, + }, + { + name: "invalid URL", + uri: "://bad-url", + wantErr: true, + }, + { + name: "root path URI is rejected", + uri: "https://example.com/", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateRegistrationClientURI(tt.uri) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestRenewClientSecret_MissingConfig tests early-exit conditions. +func TestRenewClientSecret_MissingConfig(t *testing.T) { + t.Parallel() + + t.Run("missing registration_client_uri", func(t *testing.T) { + t.Parallel() + + h := &Handler{ + config: &Config{ + CachedRegClientURI: "", + CachedRegTokenRef: "some-ref", + }, + } + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "registration_client_uri missing") + }) + + t.Run("missing registration_token_ref", func(t *testing.T) { + t.Parallel() + + h := &Handler{ + config: &Config{ + CachedRegClientURI: "https://example.com/register/client-id", + CachedRegTokenRef: "", + }, + } + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "registration_access_token missing") + }) + + t.Run("missing secret provider", func(t *testing.T) { + t.Parallel() + + h := &Handler{ + config: &Config{ + CachedRegClientURI: "https://example.com/register/client-id", + CachedRegTokenRef: "some-ref", + }, + secretProvider: nil, // no provider + } + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "secret provider not configured") + }) +} + +// TestRenewClientSecret_Success tests the happy path with a mock RFC 7592 server. +func TestRenewClientSecret_Success(t *testing.T) { + t.Parallel() + + newSecret := "new-client-secret-xyz" + newExpiry := time.Now().Add(24 * time.Hour * 30).Unix() + newRegToken := "new-registration-access-token" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // RFC 7592 §2.2: must be PUT with Bearer auth + assert.Equal(t, http.MethodPut, r.Method) + assert.Contains(t, r.Header.Get("Authorization"), "Bearer reg-access-token") + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + + var updateReq clientUpdateRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&updateReq)) + assert.Equal(t, []string{"http://localhost:9876/callback"}, updateReq.RedirectURIs) + + // Return the updated registration response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "client_id": "test-client-id", + "client_secret": newSecret, + "client_secret_expires_at": newExpiry, + "registration_access_token": newRegToken, + "registration_client_uri": "http://" + r.Host + r.URL.Path, + }) + })) + t.Cleanup(server.Close) + + // Set up persister capture + var persistedClientID, persistedSecret, persistedRegToken, persistedRegURI string + var persistedExpiry time.Time + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-secret-ref", + CallbackPort: 8666, + CachedDCRCallbackPort: 9876, + }, + secretProvider: newTestSecretProvider(t, map[string]string{ + "reg-token-secret-ref": "reg-access-token", + }), + clientCredentialsPersister: func( + clientID, secret string, + expiry time.Time, + regToken, regURI, _ string, + callbackPort int, + ) error { + persistedClientID = clientID + persistedSecret = secret + persistedExpiry = expiry + persistedRegToken = regToken + persistedRegURI = regURI + assert.Equal(t, 9876, callbackPort) + return nil + }, + } + + err := h.renewClientSecret(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "test-client-id", persistedClientID) + assert.Equal(t, newSecret, persistedSecret) + assert.Equal(t, newRegToken, persistedRegToken) + assert.False(t, persistedExpiry.IsZero(), "expiry should be set") + assert.NotEmpty(t, persistedRegURI) +} + +// TestRenewClientSecret_ServerError tests error propagation when the server returns non-200. +func TestRenewClientSecret_ServerError(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + })) + t.Cleanup(server.Close) + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-ref", + }, + secretProvider: newTestSecretProvider(t, map[string]string{ + "reg-token-ref": "bad-token", + }), + clientCredentialsPersister: func(_, _ string, _ time.Time, _, _, _ string, _ int) error { + return nil + }, + } + + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "401") +} + +// TestRenewClientSecret_NoPersister tests failure when persister is not set. +func TestRenewClientSecret_NoPersister(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "client_id": "test-client-id", + "client_secret": "new-secret", + }) + })) + t.Cleanup(server.Close) + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-ref", + }, + secretProvider: newTestSecretProvider(t, map[string]string{ + "reg-token-ref": "some-token", + }), + clientCredentialsPersister: nil, // no persister + } + + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "client credentials persister not configured") +} + +// TestRenewClientSecret_ZeroExpiryInResponse tests that a zero client_secret_expires_at +// is correctly interpreted as a non-expiring secret. +func TestRenewClientSecret_ZeroExpiryInResponse(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "client_id": "test-client-id", + "client_secret": "new-secret", + "client_secret_expires_at": 0, // never expires + }) + })) + t.Cleanup(server.Close) + + var capturedExpiry time.Time + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-ref", + }, + secretProvider: newTestSecretProvider(t, map[string]string{ + "reg-token-ref": "some-token", + }), + clientCredentialsPersister: func(_, _ string, expiry time.Time, _, _, _ string, _ int) error { + capturedExpiry = expiry + return nil + }, + } + + err := h.renewClientSecret(context.Background()) + require.NoError(t, err) + assert.True(t, capturedExpiry.IsZero(), "zero client_secret_expires_at must produce zero time.Time") +} + +func TestRenewClientSecret_MalformedJSON(t *testing.T) { + t.Parallel() + + svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{invalid-json`)) + })) + t.Cleanup(svc.Close) + + h := &Handler{ + config: &Config{ + CachedRegClientURI: svc.URL + "/register/test-client-id", + CachedRegTokenRef: "rat-ref", + }, + secretProvider: newTestSecretProvider(t, map[string]string{"rat-ref": "rat-token"}), + } + + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode client update response") +} + +func TestRenewClientSecret_MissingFields(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + response string + wantErr string + }{ + { + name: "missing_client_id", + response: `{"client_secret": "new-secret"}`, + wantErr: "client update response missing client_id", + }, + { + name: "missing_client_secret", + response: `{"client_id": "test-client-id"}`, + wantErr: "client update response missing client_secret", + }, + } + + for _, tt := range tests { + tt := tt // capture loop variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(tt.response)) + })) + t.Cleanup(svc.Close) + + h := &Handler{ + config: &Config{ + CachedRegClientURI: svc.URL + "/register/test-client-id", + CachedRegTokenRef: "rat-ref", + }, + secretProvider: newTestSecretProvider(t, map[string]string{"rat-ref": "rat-token"}), + } + + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} + +func TestValidateRegistrationClientURI_Internal(t *testing.T) { + t.Parallel() + tests := []struct { + name string + uri string + wantErr bool + }{ + {"empty", "", true}, + {"malformed", "://foo", true}, + {"http_external", "http://example.com/reg", true}, + {"https_external", "https://example.com/reg", false}, + {"https_root_path", "https://example.com/", true}, + {"http_localhost", "http://localhost:8080/reg", false}, + {"http_127_0_0_1", "http://127.0.0.1:8080/reg", false}, + } + + for _, tt := range tests { + tt := tt // capture loop variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateRegistrationClientURI(tt.uri) + assert.Equal(t, tt.wantErr, err != nil) + }) + } +} + +func TestHandler_Restore_RenewSuccess(t *testing.T) { + t.Parallel() + + // Initial setup: secret expiring in 1 hour + expiry := time.Now().Add(1 * time.Hour) + svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"client_id": "test-client", "client_secret": "new-secret", "client_secret_expires_at": 4102444800}`)) + })) + t.Cleanup(svc.Close) + + var persistedID, persistedSecret string + renewalRequests := 0 + h := &Handler{ + config: &Config{ + CachedClientID: "test-client", + CachedSecretExpiry: expiry, + CachedRegClientURI: svc.URL + "/register/test-client", + CachedRegTokenRef: "rat-ref", + CachedRefreshTokenRef: "refresh-token-ref", + }, + secretProvider: newTestSecretProvider(t, map[string]string{ + "rat-ref": "rat-token", + "refresh-token-ref": "some-refresh-token", + }), + clientCredentialsPersister: func(id, secret string, _ time.Time, _, _, _ string, _ int) error { + persistedID = id + persistedSecret = secret + renewalRequests++ + return nil + }, + } + + // Calling tryRestoreFromCachedTokens should trigger renewal because of the 1h expiry. + // We expect an error because it will try to refresh the token and fail (no token endpoint). + _, err := h.tryRestoreFromCachedTokens(context.Background(), "http://issuer", nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "cached tokens are invalid or expired") + + // But renewal DID happen + assert.Equal(t, "test-client", persistedID) + assert.Equal(t, "new-secret", persistedSecret) + assert.Equal(t, 1, renewalRequests) + assert.False(t, h.isSecretExpiredOrExpiringSoon()) +} + +func TestHandler_Restore_RenewFail_Soft(t *testing.T) { + t.Parallel() + + // Initial setup: secret expiring in 1 hour + expiry := time.Now().Add(1 * time.Hour) + svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + t.Cleanup(svc.Close) + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client", + CachedSecretExpiry: expiry, + CachedRegClientURI: svc.URL + "/register/test-client", + CachedRegTokenRef: "rat-ref", + CachedRefreshTokenRef: "refresh-token-ref", + }, + secretProvider: newTestSecretProvider(t, map[string]string{ + "rat-ref": "rat-token", + "refresh-token-ref": "some-refresh-token", + }), + clientCredentialsPersister: func(_, _ string, _ time.Time, _, _, _ string, _ int) error { return nil }, + } + + // Renewal fails, but since it's only "expiring soon", it should continue (and then fail on token refresh) + _, err := h.tryRestoreFromCachedTokens(context.Background(), "http://issuer", nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "cached tokens are invalid or expired") +} + +func TestHandler_Restore_RenewFail_Hard(t *testing.T) { + t.Parallel() + + // Initial setup: secret already expired + expiry := time.Now().Add(-1 * time.Hour) + svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + t.Cleanup(svc.Close) + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client", + CachedSecretExpiry: expiry, + CachedRegClientURI: svc.URL + "/register/test-client", + CachedRegTokenRef: "rat-ref", + CachedRefreshTokenRef: "refresh-token-ref", + }, + secretProvider: newTestSecretProvider(t, map[string]string{ + "rat-ref": "rat-token", + "refresh-token-ref": "some-refresh-token", + }), + clientCredentialsPersister: func(string, string, time.Time, string, string, string, int) error { return nil }, + } + + // Renewal fails and it's fully expired -> fatal error + _, err := h.tryRestoreFromCachedTokens(context.Background(), "http://issuer", nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "client secret expired at") + assert.Contains(t, err.Error(), "and renewal failed") +} diff --git a/pkg/runner/config_test.go b/pkg/runner/config_test.go index ad66252c9a..f933741236 100644 --- a/pkg/runner/config_test.go +++ b/pkg/runner/config_test.go @@ -174,6 +174,22 @@ func TestRunConfig_NormalizeProxyMode(t *testing.T) { // Note: This test uses actual port finding logic, so it may fail if ports are in use func TestRunConfig_WithPorts(t *testing.T) { t.Parallel() + // Find available ports dynamically to avoid flaky failures + port1 := networking.FindAvailable() + require.NotZero(t, port1, "should find an available proxy port for SSE") + targetPort1 := networking.FindAvailable() + require.NotZero(t, targetPort1, "should find an available target port for SSE") + + port2 := networking.FindAvailable() + require.NotZero(t, port2, "should find an available proxy port for HTTP") + targetPort2 := networking.FindAvailable() + require.NotZero(t, targetPort2, "should find an available target port for HTTP") + + port3 := networking.FindAvailable() + require.NotZero(t, port3, "should find an available proxy port for Stdio") + targetPort3 := networking.FindAvailable() + require.NotZero(t, targetPort3, "should find an available target port for Stdio") + testCases := []struct { name string config *RunConfig @@ -184,8 +200,8 @@ func TestRunConfig_WithPorts(t *testing.T) { { name: "SSE transport with specific ports", config: &RunConfig{Transport: types.TransportTypeSSE}, - port: 8001, - targetPort: 9001, + port: port1, + targetPort: targetPort1, expectError: false, }, { @@ -198,15 +214,15 @@ func TestRunConfig_WithPorts(t *testing.T) { { name: "Streamable HTTP transport with specific ports", config: &RunConfig{Transport: types.TransportTypeStreamableHTTP}, - port: 8002, - targetPort: 9002, + port: port2, + targetPort: targetPort2, expectError: false, }, { name: "Stdio transport with specific port", config: &RunConfig{Transport: types.TransportTypeStdio}, - port: 8003, - targetPort: 9003, // This should be ignored for stdio + port: port3, + targetPort: targetPort3, // This should be ignored for stdio expectError: false, }, } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 079a7d88db..8d5dd93019 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -751,73 +751,133 @@ func (r *Runner) handleRemoteAuthentication(ctx context.Context) (oauth2.TokenSo // Set up token persister to save tokens across restarts if secretManager != nil { authHandler.SetTokenPersister(func(refreshToken string, expiry time.Time) error { - // Generate a unique secret name for this workload's refresh token - secretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( + return r.persistRefreshToken(ctx, secretManager, refreshToken, expiry) + }) + + // Set up client credentials persister for DCR (Dynamic Client Registration) + authHandler.SetClientCredentialsPersister(func( + clientID, clientSecret string, + secretExpiry time.Time, + regAccessToken, regClientURI string, + tokenEndpointAuthMethod string, + registeredCallbackPort int, + ) error { + return r.persistClientCredentials( + ctx, secretManager, clientID, clientSecret, + secretExpiry, regAccessToken, regClientURI, tokenEndpointAuthMethod, registeredCallbackPort) + }) + } + + // Perform authentication + tokenSource, err := authHandler.Authenticate(ctx, r.Config.RemoteURL) + if err != nil { + return nil, fmt.Errorf("remote authentication failed: %w", err) + } + + return tokenSource, nil +} + +func (r *Runner) persistRefreshToken( + ctx context.Context, + secretManager secrets.Provider, + refreshToken string, + expiry time.Time, +) error { + secretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( + r.Config.Name, + "OAUTH_REFRESH_TOKEN_", + secretManager, + ) + if err != nil { + return fmt.Errorf("failed to generate secret name: %w", err) + } + + if err := authsecrets.StoreSecretInManagerWithProvider(ctx, secretName, refreshToken, secretManager); err != nil { + return fmt.Errorf("failed to store refresh token: %w", err) + } + + r.Config.RemoteAuthConfig.CachedRefreshTokenRef = secretName + r.Config.RemoteAuthConfig.CachedTokenExpiry = expiry + + if err := r.Config.SaveState(ctx); err != nil { + return fmt.Errorf("failed to save config with token reference: %w", err) + } + + slog.Debug("Stored OAuth refresh token in secret manager", "secret_name", secretName) + return nil +} + +func (r *Runner) persistClientCredentials( + ctx context.Context, + secretManager secrets.Provider, + clientID, clientSecret string, + secretExpiry time.Time, + regAccessToken, regClientURI string, + tokenEndpointAuthMethod string, + registeredCallbackPort int, +) error { + updatedConfig := *r.Config + updatedRemoteAuthConfig := remote.Config{} + if r.Config.RemoteAuthConfig != nil { + updatedRemoteAuthConfig = *r.Config.RemoteAuthConfig + } + updatedConfig.RemoteAuthConfig = &updatedRemoteAuthConfig + updatedRemoteAuthConfig.CachedClientID = clientID + + if clientSecret != "" { + clientSecretSecretName := updatedRemoteAuthConfig.CachedClientSecretRef + if clientSecretSecretName == "" { + var err error + clientSecretSecretName, err = authsecrets.GenerateUniqueSecretNameWithPrefix( r.Config.Name, - "OAUTH_REFRESH_TOKEN_", + "OAUTH_CLIENT_SECRET_", secretManager, ) if err != nil { - return fmt.Errorf("failed to generate secret name: %w", err) + return fmt.Errorf("failed to generate client secret secret name: %w", err) } + } - // Store the refresh token in the secret manager - if err := authsecrets.StoreSecretInManagerWithProvider(ctx, secretName, refreshToken, secretManager); err != nil { - return fmt.Errorf("failed to store refresh token: %w", err) - } + if err := authsecrets.StoreSecretInManagerWithProvider(ctx, clientSecretSecretName, clientSecret, secretManager); err != nil { + return fmt.Errorf("failed to store client secret: %w", err) + } + updatedRemoteAuthConfig.CachedClientSecretRef = clientSecretSecretName + } - // Store the secret reference (not the actual token) in the config - r.Config.RemoteAuthConfig.CachedRefreshTokenRef = secretName - r.Config.RemoteAuthConfig.CachedTokenExpiry = expiry + updatedRemoteAuthConfig.CachedSecretExpiry = secretExpiry - // Save the updated config to persist the reference - if err := r.Config.SaveState(ctx); err != nil { - return fmt.Errorf("failed to save config with token reference: %w", err) + if regAccessToken != "" { + regTokenSecretName := updatedRemoteAuthConfig.CachedRegTokenRef + if regTokenSecretName == "" { + var err error + regTokenSecretName, err = authsecrets.GenerateUniqueSecretNameWithPrefix(r.Config.Name, "OAUTH_REG_TOKEN_", secretManager) + if err != nil { + return fmt.Errorf("failed to generate registration token secret name: %w", err) } + } - slog.Debug("Stored OAuth refresh token in secret manager", "secret_name", secretName) - return nil - }) - - // Set up client credentials persister for DCR (Dynamic Client Registration) - authHandler.SetClientCredentialsPersister(func(clientID, clientSecret string) error { - // Store client ID directly (it's public information) - r.Config.RemoteAuthConfig.CachedClientID = clientID - - // Only store client secret if it's non-empty (PKCE flows may not have one) - if clientSecret != "" { - clientSecretSecretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( - r.Config.Name, - "OAUTH_CLIENT_SECRET_", - secretManager, - ) - if err != nil { - return fmt.Errorf("failed to generate client secret secret name: %w", err) - } - - if err := authsecrets.StoreSecretInManagerWithProvider(ctx, clientSecretSecretName, clientSecret, secretManager); err != nil { - return fmt.Errorf("failed to store client secret: %w", err) - } - r.Config.RemoteAuthConfig.CachedClientSecretRef = clientSecretSecretName - } + if err := authsecrets.StoreSecretInManagerWithProvider(ctx, regTokenSecretName, regAccessToken, secretManager); err != nil { + return fmt.Errorf("failed to store registration access token: %w", err) + } + updatedRemoteAuthConfig.CachedRegTokenRef = regTokenSecretName + slog.Debug("Stored DCR registration access token for RFC 7592 operations") + } - // Save the updated config to persist the credentials - if err := r.Config.SaveState(ctx); err != nil { - return fmt.Errorf("failed to save config with client credentials: %w", err) - } + updatedRemoteAuthConfig.CachedRegClientURI = regClientURI + updatedRemoteAuthConfig.CachedTokenEndpointAuthMethod = tokenEndpointAuthMethod + updatedRemoteAuthConfig.CachedDCRCallbackPort = registeredCallbackPort - slog.Debug("Stored DCR client credentials", "client_id", clientID) - return nil - }) + if err := updatedConfig.SaveState(ctx); err != nil { + return fmt.Errorf("failed to save config with client credentials: %w", err) } - // Perform authentication - tokenSource, err := authHandler.Authenticate(ctx, r.Config.RemoteURL) - if err != nil { - return nil, fmt.Errorf("remote authentication failed: %w", err) - } + r.Config.RemoteAuthConfig = &updatedRemoteAuthConfig - return tokenSource, nil + slog.Debug("Stored DCR client credentials", "client_id", clientID, + "has_expiry", !secretExpiry.IsZero(), + "has_reg_token", regAccessToken != "", + "has_reg_uri", regClientURI != "") + return nil } // Cleanup performs cleanup operations for the runner, including shutting down all middleware. diff --git a/pkg/state/runconfig.go b/pkg/state/runconfig.go index ea1fc91c7d..d94cc419aa 100644 --- a/pkg/state/runconfig.go +++ b/pkg/state/runconfig.go @@ -6,11 +6,12 @@ package state import ( "context" "encoding/json" + "errors" "fmt" "io" "log/slog" - "github.com/stacklok/toolhive/pkg/workloads/types/errors" + werr "github.com/stacklok/toolhive/pkg/workloads/types/errors" ) // LoadRunConfigJSON loads a run configuration from the state store and returns the raw reader @@ -27,7 +28,7 @@ func LoadRunConfigJSON(ctx context.Context, name string) (io.ReadCloser, error) return nil, fmt.Errorf("failed to check if run configuration exists: %w", err) } if !exists { - return nil, fmt.Errorf("%w: %s", errors.ErrRunConfigNotFound, name) + return nil, fmt.Errorf("%w: %s", werr.ErrRunConfigNotFound, name) } // Get a reader for the state @@ -121,6 +122,44 @@ func LoadRunConfig[T any](ctx context.Context, name string, readJSONFunc ReadJSO return readJSONFunc(reader) } +// ReadRunConfigJSON deserializes a run configuration from JSON read from the provided reader +// This is a generic JSON deserializer for any type that can be unmarshalled from JSON +func ReadRunConfigJSON[T any](r io.Reader) (*T, error) { + var config T + decoder := json.NewDecoder(r) + if err := decoder.Decode(&config); err != nil { + if errors.Is(err, io.EOF) { + return &config, nil + } + return nil, err + } + return &config, nil +} + +// LoadRunConfigOfType loads a run configuration of a specific type T from the state store +func LoadRunConfigOfType[T any](ctx context.Context, name string) (*T, error) { + return LoadRunConfig(ctx, name, ReadRunConfigJSON[T]) +} + +// RunConfigReadJSONFunc defines the function signature for reading a RunConfig from JSON +// This allows us to accept the runner.ReadJSON function without creating a circular dependency +type RunConfigReadJSONFunc func(r io.Reader) (interface{}, error) + +// LoadRunConfigWithFunc loads a run configuration using a provided read function +func LoadRunConfigWithFunc(ctx context.Context, name string, readFunc RunConfigReadJSONFunc) (interface{}, error) { + reader, err := LoadRunConfigJSON(ctx, name) + if err != nil { + return nil, err + } + defer func() { + if err := reader.Close(); err != nil { + slog.Warn("Failed to close reader", "error", err) + } + }() + + return readFunc(reader) +} + // ReadJSON deserializes JSON from the provided reader into a generic interface // This function is moved from the runner package to avoid circular dependencies func ReadJSON(r io.Reader, target interface{}) error {