diff --git a/cmd/thv/app/common.go b/cmd/thv/app/common.go index 2cd23028c1..ef46aa7d1b 100644 --- a/cmd/thv/app/common.go +++ b/cmd/thv/app/common.go @@ -157,12 +157,14 @@ func completeLogsArgs(cmd *cobra.Command, args []string, _ string) ([]string, co } // workloadStatusIndicator returns the status string with a visual indicator prepended -// for statuses that warrant user attention (unauthenticated, policy_stopped). +// for statuses that warrant user attention (unauthenticated, auth_retrying, policy_stopped). // All other statuses are returned as plain strings. func workloadStatusIndicator(status runtime.WorkloadStatus) string { switch status { case runtime.WorkloadStatusUnauthenticated: return "โš ๏ธ " + string(status) + case runtime.WorkloadStatusAuthRetrying: + return "๐Ÿ”„ " + string(status) case runtime.WorkloadStatusPolicyStopped: return "๐Ÿšซ " + string(status) case runtime.WorkloadStatusRunning, runtime.WorkloadStatusStopped, runtime.WorkloadStatusError, diff --git a/cmd/thv/app/common_test.go b/cmd/thv/app/common_test.go index 8a74bd20c0..754e8130ef 100644 --- a/cmd/thv/app/common_test.go +++ b/cmd/thv/app/common_test.go @@ -4,9 +4,12 @@ package app import ( + "strings" "testing" "github.com/spf13/cobra" + + "github.com/stacklok/toolhive/pkg/container/runtime" ) func TestAddFormatFlag(t *testing.T) { @@ -266,3 +269,38 @@ func TestIsOIDCEnabled(t *testing.T) { }) } } + +func TestWorkloadStatusIndicator(t *testing.T) { + t.Parallel() + tests := []struct { + name string + status runtime.WorkloadStatus + wantHas string // substring that must appear + wantExact string // if non-empty, must match exactly + }{ + {"unauthenticated has โš ๏ธ prefix", runtime.WorkloadStatusUnauthenticated, "โš ๏ธ", ""}, + {"auth_retrying has ๐Ÿ”„ prefix", runtime.WorkloadStatusAuthRetrying, "๐Ÿ”„", ""}, + {"policy_stopped has ๐Ÿšซ prefix", runtime.WorkloadStatusPolicyStopped, "๐Ÿšซ", ""}, + {"running passes through plain", runtime.WorkloadStatusRunning, "", "running"}, + {"stopped passes through plain", runtime.WorkloadStatusStopped, "", "stopped"}, + {"unhealthy passes through plain", runtime.WorkloadStatusUnhealthy, "", "unhealthy"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := workloadStatusIndicator(tc.status) + if tc.wantExact != "" && got != tc.wantExact { + t.Errorf("workloadStatusIndicator(%q) = %q, want exact %q", + tc.status, got, tc.wantExact) + } + if tc.wantHas != "" && !strings.Contains(got, tc.wantHas) { + t.Errorf("workloadStatusIndicator(%q) = %q, want substring %q", + tc.status, got, tc.wantHas) + } + if !strings.Contains(got, string(tc.status)) { + t.Errorf("workloadStatusIndicator(%q) = %q, must include status name", + tc.status, got) + } + }) + } +} diff --git a/cmd/thv/app/list.go b/cmd/thv/app/list.go index a3943bbd60..8a0d4efc81 100644 --- a/cmd/thv/app/list.go +++ b/cmd/thv/app/list.go @@ -164,7 +164,7 @@ func printTextOutput(workloadList []core.Workload) { // Print workload information for _, c := range workloadList { - // Highlight unauthenticated and policy-stopped workloads with indicators + // Highlight unauthenticated, auth-retrying, and policy-stopped workloads with indicators status := workloadStatusIndicator(c.Status) // Print workload information diff --git a/cmd/thv/app/ui/styles.go b/cmd/thv/app/ui/styles.go index f9293a0abe..008440096b 100644 --- a/cmd/thv/app/ui/styles.go +++ b/cmd/thv/app/ui/styles.go @@ -71,6 +71,9 @@ var ( pillUnauthed = lipgloss.NewStyle(). Background(bgWarning).Foreground(ColorYellow). Padding(0, 1).Render("โš  unauthed") + pillAuthRetrying = lipgloss.NewStyle(). + Background(bgWarning).Foreground(ColorYellow). + Padding(0, 1).Render("๐Ÿ”„ retrying") keyStyle = lipgloss.NewStyle().Foreground(ColorDim2) portStyle = lipgloss.NewStyle().Foreground(ColorCyan).Bold(true) @@ -97,6 +100,8 @@ func RenderStatusDot(status rt.WorkloadStatus) string { return dotWarning case rt.WorkloadStatusUnauthenticated: return dotWarning + case rt.WorkloadStatusAuthRetrying: + return dotWarning case rt.WorkloadStatusRemoving: return dotWarning case rt.WorkloadStatusPolicyStopped: @@ -128,6 +133,8 @@ func RenderStatusPill(status rt.WorkloadStatus) string { return pillUnknown case rt.WorkloadStatusUnauthenticated: return pillUnauthed + case rt.WorkloadStatusAuthRetrying: + return pillAuthRetrying case rt.WorkloadStatusPolicyStopped: return pillStopped default: diff --git a/docs/arch/02-core-concepts.md b/docs/arch/02-core-concepts.md index 84e565ee3f..52259edae4 100644 --- a/docs/arch/02-core-concepts.md +++ b/docs/arch/02-core-concepts.md @@ -38,6 +38,7 @@ A **workload** is the fundamental deployment unit in ToolHive. It represents eve - `error` - Workload encountered an error - `unhealthy` - Workload is running but unhealthy - `unauthenticated` - Remote workload cannot authenticate (expired tokens) +- `auth_retrying` - Remote workload's token refresh is failing transiently; monitor is still retrying until success (โ†’ `running`) or the configured ceiling (โ†’ `unauthenticated`) **Implementation:** - Interface: `pkg/workloads/manager.go` diff --git a/docs/arch/08-workloads-lifecycle.md b/docs/arch/08-workloads-lifecycle.md index 22bd6ddd33..1e24f438fb 100644 --- a/docs/arch/08-workloads-lifecycle.md +++ b/docs/arch/08-workloads-lifecycle.md @@ -22,6 +22,7 @@ stateDiagram-v2 Running --> Stopping: Stop Running --> Unhealthy: Health Failed Running --> Unauthenticated: Auth Failed + Running --> AuthRetrying: Transient Token Refresh Failures Running --> Stopped: Container Exit Stopping --> Stopped: Success @@ -31,6 +32,9 @@ stateDiagram-v2 Unauthenticated --> Starting: Re-authenticate Unauthenticated --> Removing: Delete + AuthRetrying --> Running: Refresh Succeeds + AuthRetrying --> Unauthenticated: Ceiling Exceeded or Permanent Error + Removing --> [*]: Success Error --> Starting: Restart Error --> Removing: Delete @@ -38,7 +42,15 @@ stateDiagram-v2 **States**: `pkg/container/runtime/types.go` - `starting`, `running`, `stopping`, `stopped` -- `removing`, `error`, `unhealthy`, `unauthenticated` +- `removing`, `error`, `unhealthy`, `unauthenticated`, `auth_retrying` + +The `auth_retrying` cadence and ceiling can be tuned via environment +variables on the proxy process: + +- `TOOLHIVE_TOKEN_AUTH_RETRYING_TICK_INTERVAL` (default `10m`): cadence + between background refresh attempts during the AuthRetrying window. +- `TOOLHIVE_TOKEN_AUTH_RETRYING_MAX_ELAPSED` (default `24h`): ceiling + before the workload is finally marked `unauthenticated`. ## Core Operations diff --git a/docs/server/docs.go b/docs/server/docs.go index fa76390328..2937bf0841 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -941,6 +941,7 @@ const docTemplate = `{ "removing", "unknown", "unauthenticated", + "auth_retrying", "policy_stopped", "running", "stopped", @@ -951,6 +952,7 @@ const docTemplate = `{ "removing", "unknown", "unauthenticated", + "auth_retrying", "policy_stopped", "running", "stopped", @@ -961,6 +963,7 @@ const docTemplate = `{ "removing", "unknown", "unauthenticated", + "auth_retrying", "policy_stopped" ], "type": "string", @@ -974,6 +977,7 @@ const docTemplate = `{ "WorkloadStatusRemoving", "WorkloadStatusUnknown", "WorkloadStatusUnauthenticated", + "WorkloadStatusAuthRetrying", "WorkloadStatusPolicyStopped" ] }, diff --git a/docs/server/swagger.json b/docs/server/swagger.json index eb71c3efec..a9f4ac4c3c 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -934,6 +934,7 @@ "removing", "unknown", "unauthenticated", + "auth_retrying", "policy_stopped", "running", "stopped", @@ -944,6 +945,7 @@ "removing", "unknown", "unauthenticated", + "auth_retrying", "policy_stopped", "running", "stopped", @@ -954,6 +956,7 @@ "removing", "unknown", "unauthenticated", + "auth_retrying", "policy_stopped" ], "type": "string", @@ -967,6 +970,7 @@ "WorkloadStatusRemoving", "WorkloadStatusUnknown", "WorkloadStatusUnauthenticated", + "WorkloadStatusAuthRetrying", "WorkloadStatusPolicyStopped" ] }, diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 6e9d446e50..c2d7ce589f 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -972,6 +972,7 @@ components: - removing - unknown - unauthenticated + - auth_retrying - policy_stopped - running - stopped @@ -982,6 +983,7 @@ components: - removing - unknown - unauthenticated + - auth_retrying - policy_stopped - running - stopped @@ -992,6 +994,7 @@ components: - removing - unknown - unauthenticated + - auth_retrying - policy_stopped type: string x-enum-varnames: @@ -1004,6 +1007,7 @@ components: - WorkloadStatusRemoving - WorkloadStatusUnknown - WorkloadStatusUnauthenticated + - WorkloadStatusAuthRetrying - WorkloadStatusPolicyStopped github_com_stacklok_toolhive_pkg_container_templates.RuntimeConfig: description: |- diff --git a/pkg/api/v1/workload_types.go b/pkg/api/v1/workload_types.go index 180e4e5045..ecae512a4a 100644 --- a/pkg/api/v1/workload_types.go +++ b/pkg/api/v1/workload_types.go @@ -32,7 +32,7 @@ type workloadListResponse struct { type workloadStatusResponse struct { // Current status of the workload //nolint:lll // enums tag needed for swagger generation with --parseDependencyLevel - Status runtime.WorkloadStatus `json:"status" enums:"running,stopped,error,starting,stopping,unhealthy,removing,unknown,unauthenticated,policy_stopped"` + Status runtime.WorkloadStatus `json:"status" enums:"running,stopped,error,starting,stopping,unhealthy,removing,unknown,unauthenticated,auth_retrying,policy_stopped"` } // updateRequest represents the request to update an existing workload diff --git a/pkg/auth/monitored_token_source.go b/pkg/auth/monitored_token_source.go index a3f204a489..8384d5b69b 100644 --- a/pkg/auth/monitored_token_source.go +++ b/pkg/auth/monitored_token_source.go @@ -39,6 +39,19 @@ const ( // tokenRefreshMaxElapsedTime is the default maximum elapsed time for all retry attempts. // Override with TOOLHIVE_TOKEN_REFRESH_MAX_ELAPSED_TIME (e.g. "10m"). tokenRefreshMaxElapsedTime = 5 * time.Minute + // authRetryingTickInterval is the default cadence between background refresh + // attempts once the short retry has exhausted on a transient error. The + // short retry inside transientRefresher.retry already performs exponential + // backoff; this cross-tick layer uses a fixed cadence to avoid layering + // exponential on top of exponential and to give operators a predictable + // retry signature when investigating long outages. + // Override with TOOLHIVE_TOKEN_AUTH_RETRYING_TICK_INTERVAL (e.g. "5m", "30m"). + authRetryingTickInterval = 10 * time.Minute + // authRetryingMaxElapsed is the ceiling on how long the monitor will stay + // in the AuthRetrying transient-failure window before giving up and + // marking the workload unauthenticated. + // Override with TOOLHIVE_TOKEN_AUTH_RETRYING_MAX_ELAPSED (e.g. "12h", "48h"). + authRetryingMaxElapsed = 24 * time.Hour ) const ( @@ -50,6 +63,10 @@ const ( tokenRefreshMaxElapsedTimeEnv = "TOOLHIVE_TOKEN_REFRESH_MAX_ELAPSED_TIME" // #nosec G101 โ€” not credentials, just max tries tokenRefreshMaxTriesEnv = "TOOLHIVE_TOKEN_REFRESH_MAX_TRIES" + // #nosec G101 โ€” not credentials, just retry tick interval + authRetryingTickIntervalEnv = "TOOLHIVE_TOKEN_AUTH_RETRYING_TICK_INTERVAL" + // #nosec G101 โ€” not credentials, just max elapsed time + authRetryingMaxElapsedEnv = "TOOLHIVE_TOKEN_AUTH_RETRYING_MAX_ELAPSED" ) // resolveTokenRefreshInitialRetryInterval returns the initial retry interval for @@ -97,6 +114,20 @@ func resolveTokenRefreshMaxElapsedTime() time.Duration { ) } +// resolveAuthRetryingTickInterval returns the cadence between monitor refresh +// attempts during the AuthRetrying transient-failure window, reading from +// TOOLHIVE_TOKEN_AUTH_RETRYING_TICK_INTERVAL if set, otherwise the default. +func resolveAuthRetryingTickInterval() time.Duration { + return resolveDurationEnv(authRetryingTickIntervalEnv, authRetryingTickInterval) +} + +// resolveAuthRetryingMaxElapsed returns the ceiling on how long the monitor +// will stay in AuthRetrying before giving up, reading from +// TOOLHIVE_TOKEN_AUTH_RETRYING_MAX_ELAPSED if set, otherwise the default. +func resolveAuthRetryingMaxElapsed() time.Duration { + return resolveDurationEnv(authRetryingMaxElapsedEnv, authRetryingMaxElapsed) +} + // resolveDurationEnv reads a duration from the given environment variable. // Returns defaultVal if the variable is unset or its value is not a valid // positive duration. @@ -208,14 +239,35 @@ func (r *transientRefresher) retry(ctx context.Context, origErr error) (*oauth2. ) } -// MonitoredTokenSource is a wrapper around an oauth2.TokenSource that monitors authentication -// failures and automatically marks workloads as unauthenticated when tokens expire or fail. +// MonitoredTokenSource is a wrapper around an oauth2.TokenSource that monitors +// authentication failures and surfaces them through workload status transitions. // It provides both per-request token retrieval and background monitoring. // -// When the background monitor encounters a token refresh failure it retries with exponential -// backoff rather than immediately marking the workload as unauthenticated. This handles -// scenarios like overnight VPN disconnects where the token refresh endpoint is temporarily -// unreachable. +// Failure handling has two layers: +// +// - Short retry (inside transientRefresher.retry): exponential backoff up to +// a ~5min total window. Retries transient failures (DNS errors, dropped +// connections, OAuth server 5xx/429, WAF 4xx-with-HTML) without changing +// workload status. Failures that resolve within the window never leave +// Running; failures that outlast the window fall through to AuthRetrying. +// +// - AuthRetrying ceiling (this struct): after the short retry exhausts on a +// still-transient error, the workload transitions to AuthRetrying and the +// monitor keeps running on a longer cadence (default 10min). On the next +// successful refresh, the workload returns to Running. After a ceiling +// (default 24h) the workload is finally marked Unauthenticated and the +// monitor stops. +// +// Hot callers โ€” request-path Token() calls, e.g. from the token injection +// middleware serving a live MCP request โ€” fast-fail with the cached error +// during AuthRetrying so they return 503+Retry-After immediately rather +// than blocking on a singleflight retry that would exhaust again on every +// call. +// +// Hot callers and the background monitor can both drive state transitions +// independently (the monitor via its tick, hot callers via short-retry +// exhaustion in Token). When the upstream alternates between transient +// failures and brief recoveries, the workload status oscillates with it. type MonitoredTokenSource struct { tokenSource oauth2.TokenSource workloadName string @@ -231,6 +283,20 @@ type MonitoredTokenSource struct { stopped chan struct{} timer *time.Timer + + // AuthRetrying state, guarded by mu. A non-zero transientStartedAt + // means the workload is in the AuthRetrying window (short retry + // exhausted, ceiling not yet reached). lastTransientErr is the error + // returned to hot callers via Token() during the window so they fail + // fast instead of joining a singleflight that would block for the + // full short-retry duration on every tick. Written by the monitor + // goroutine and by Token() callers who reach the short-retry + // exhaustion branch; read by Token() callers during the fast-fail + // check. mu is held only across pure field reads/writes โ€” never + // across a statusUpdater call, which does I/O. + mu sync.Mutex + transientStartedAt time.Time + lastTransientErr error } // NewMonitoredTokenSource creates a new MonitoredTokenSource that wraps the provided @@ -287,15 +353,40 @@ func (mts *MonitoredTokenSource) Stopped() <-chan struct{} { return mts.stopped } -// Token retrieves a token, retrying with exponential backoff on transient -// errors and marking the workload as unauthenticated on non-transient errors. -// See isTransientNetworkError for the classification rule. Context -// cancellation (workload removal) stops the retry without marking the -// workload as unauthenticated. +// Token retrieves a token. On a non-transient error, it marks the workload +// unauthenticated and returns immediately. On a transient error, behavior +// depends on monitor state โ€” see the paragraphs below. Context cancellation +// (workload removal) stops any in-flight retry without marking the workload +// unauthenticated. See isTransientNetworkError for the classification rule. +// +// When the monitor is in the AuthRetrying transient-failure window +// (short retry exhausted, ceiling not yet reached), Token() fast-fails +// with the cached error rather than joining a singleflight retry that +// would hang for the full short-retry duration on every call. Hot +// callers see 503+Retry-After until the next monitor tick observes +// upstream recovery and clears the state โ€” only the monitor's onTick +// (which calls the raw token source directly) can do that. // -// Concurrent callers are deduplicated via singleflight so that only one retry -// loop runs at a time during transient failures. +// If a hot caller's own short retry exhausts on a still-transient error, +// the workload transitions to AuthRetrying and the monitor stays alive โ€” +// subsequent hot callers fast-fail until the next monitor tick clears the +// state or extends it. If the monitor has already stopped (a prior +// permanent error closed it), enterAuthRetrying is a no-op; the hot +// caller returns the error without changing workload status. +// +// During the short retry on a transient failure, concurrent callers +// joining at the same time are deduplicated via singleflight so that +// only one retry loop runs at a time. Callers that arrive AFTER the +// leader's singleflight call has returned start their own retry โ€” they +// are not deduplicated against past calls. This is why hot-caller- +// driven AuthRetrying entry (after the retry exhausts) is load-bearing: +// without it, sequential hot callers would each pay the full short- +// retry duration against a broken endpoint. func (mts *MonitoredTokenSource) Token() (*oauth2.Token, error) { + if err := mts.fastFailIfAuthRetrying(); err != nil { + return nil, err + } + tok, err := mts.tokenSource.Token() if err == nil { return tok, nil @@ -309,19 +400,38 @@ func (mts *MonitoredTokenSource) Token() (*oauth2.Token, error) { return nil, err } + // If the monitor has already stopped (prior permanent error or 24h + // ceiling), do not enter the short retry. The workload is terminal; + // making hot callers each pay the full short-retry duration against + // a known-broken endpoint is exactly the pathology AuthRetrying was + // introduced to avoid. + select { + case <-mts.stopMonitoring: + return nil, err + default: + } + // Transient network error โ€” funnel all concurrent callers through a // single retry loop so we don't hammer the token endpoint. tok, err = mts.refresher.Refresh(mts.monitoringCtx, err) - if err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { - mts.markAsUnauthenticated( - fmt.Sprintf("Token refresh failed after retries: %v", err), - isPermanentTokenEndpointError(err), - ) - } + if err == nil { + return tok, nil + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + if !isTransientNetworkError(err) { + mts.markAsUnauthenticated( + fmt.Sprintf("Token refresh failed after retries: %v", err), + isPermanentTokenEndpointError(err), + ) return nil, err } - return tok, nil + // Short retry exhausted on a still-transient error โ†’ enter AuthRetrying + // and keep the monitor alive. Subsequent hot callers fast-fail; the + // next monitor tick will decide whether to recover or extend. + mts.enterAuthRetrying(err) + return nil, err } // StartBackgroundMonitoring starts the background monitoring goroutine that checks @@ -368,22 +478,90 @@ func (mts *MonitoredTokenSource) resetTimer(d time.Duration) { mts.timer.Reset(d) } -// onTick calls Token() to refresh the token and returns the next check delay. -// Token() handles transient error retries and marks the workload as unauthenticated -// on permanent failures. +// onTick performs a token refresh attempt on the background monitor goroutine +// and returns whether to stop monitoring and the next tick delay. +// +// State transitions handled here: +// +// - Success: clear any AuthRetrying state (โ†’ Running) and schedule the +// next tick at the new token's expiry. +// - Permanent error: mark Unauthenticated immediately and stop monitoring +// (existing semantics). +// - Transient error while in AuthRetrying: check the ceiling. If exceeded, +// mark Unauthenticated. Otherwise refresh the cached error and reschedule +// at the AuthRetrying cadence. +// - Transient error while NOT in AuthRetrying: run the existing short retry +// once (singleflight + exponential backoff). If it succeeds, return to +// normal expiry-based scheduling. If it exhausts on a still-transient +// error, enter AuthRetrying. func (mts *MonitoredTokenSource) onTick() (bool, time.Duration) { - tok, err := mts.Token() - if err != nil { + // Call the raw token source directly โ€” NOT Token() โ€” so the monitor + // bypasses the AuthRetrying fast-fail and the singleflight retry. The + // monitor owns those state transitions itself. + tok, err := mts.tokenSource.Token() + + if err == nil { + mts.exitAuthRetrying() + if tok == nil || tok.Expiry.IsZero() { + return true, 0 + } + return false, waitUntilExpiry(tok.Expiry) + } + + if !isTransientNetworkError(err) { + mts.markAsUnauthenticated( + fmt.Sprintf("Token retrieval failed: %v", err), + isPermanentTokenEndpointError(err), + ) return true, 0 } - if tok == nil || tok.Expiry.IsZero() { + + if mts.inAuthRetrying() { + if mts.ceilingExceeded() { + mts.markAsUnauthenticated( + fmt.Sprintf("Token refresh failed transiently for over %s: %v", + resolveAuthRetryingMaxElapsed(), err), + false, // transient by definition โ†’ DCR Warn correctly silent + ) + return true, 0 + } + mts.enterAuthRetrying(err) + return false, resolveAuthRetryingTickInterval() + } + + // First transient failure on this tick: run the existing short retry once + // so we benefit from singleflight + exponential backoff for brief blips. + tok, err = mts.refresher.Refresh(mts.monitoringCtx, err) + if err == nil { + if tok == nil || tok.Expiry.IsZero() { + return true, 0 + } + return false, waitUntilExpiry(tok.Expiry) + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true, 0 + } + if !isTransientNetworkError(err) { + mts.markAsUnauthenticated( + fmt.Sprintf("Token refresh failed after retries: %v", err), + isPermanentTokenEndpointError(err), + ) return true, 0 } - wait := time.Until(tok.Expiry) + + // Short retry exhausted on a still-transient error โ†’ enter AuthRetrying. + mts.enterAuthRetrying(err) + return false, resolveAuthRetryingTickInterval() +} + +// waitUntilExpiry returns the duration until expiry, clamped to a minimum +// of one second so the monitor doesn't spin on a stale or near-past expiry. +func waitUntilExpiry(expiry time.Time) time.Duration { + wait := time.Until(expiry) if wait < time.Second { wait = time.Second } - return false, wait + return wait } // isTransientNetworkError reports whether err represents a transient @@ -541,32 +719,169 @@ func isOAuthParseError(err error) bool { // of stopMonitoring share stopOnce, so a caller (e.g. a tight Token() // loop) cannot spam the record on every call after the workload has // already transitioned to Unauthenticated. +// +// Ordering matters: stopMonitoring is closed first so any concurrent +// enterAuthRetrying call sees the gate closed before it acquires the +// field mutex, eliminating the race where a hot caller could write +// AuthRetrying to the status file just after we've written +// Unauthenticated. After the channel close, AuthRetrying state fields +// are cleared (in case enterAuthRetrying had already set them between +// monitor exit and this call) before the Unauthenticated status emit. func (mts *MonitoredTokenSource) markAsUnauthenticated(reason string, permanent4xx bool) { + firstClose := false + mts.stopOnce.Do(func() { + firstClose = true + close(mts.stopMonitoring) + }) + + mts.mu.Lock() + mts.transientStartedAt = time.Time{} + mts.lastTransientErr = nil + mts.mu.Unlock() + _ = mts.statusUpdater.SetWorkloadStatus( context.Background(), mts.workloadName, runtime.WorkloadStatusUnauthenticated, reason, ) - mts.stopOnce.Do(func() { - // A permanent 4xx from the token endpoint commonly indicates the - // cached client (DCR or CIMD) is no longer recognised โ€” but the - // same branch fires for revoked consent, disabled accounts, and - // statically configured clients, so the message has to be honest - // about the variability. Gating on clientID != "" suppresses the - // log entirely for workloads where no client_id context is - // available; the operator-correlation it provides would be empty. - if permanent4xx && mts.clientID != "" { - //nolint:gosec // G706: client_id is public metadata per RFC 7591. - slog.Warn( - "token endpoint returned a permanent error; if this workload uses "+ - "cached DCR or CIMD credentials they may be stale โ€” delete the "+ - "cached credentials and restart to re-register.", - "workload", mts.workloadName, - "upstream", mts.upstream, - "client_id", mts.clientID, - ) - } - close(mts.stopMonitoring) - }) + + // Emit the DCR/CIMD remediation hint at most once per transition. + // A permanent 4xx from the token endpoint commonly indicates the + // cached client (DCR or CIMD) is no longer recognised โ€” but the same + // branch fires for revoked consent, disabled accounts, and statically + // configured clients, so the message has to be honest about the + // variability. Gating on clientID != "" suppresses the log entirely + // for workloads where no client_id context is available; the + // operator-correlation it provides would be empty. + if firstClose && permanent4xx && mts.clientID != "" { + //nolint:gosec // G706: client_id is public metadata per RFC 7591. + slog.Warn( + "token endpoint returned a permanent error; if this workload uses "+ + "cached DCR or CIMD credentials they may be stale โ€” delete the "+ + "cached credentials and restart to re-register.", + "workload", mts.workloadName, + "upstream", mts.upstream, + "client_id", mts.clientID, + ) + } +} + +// inAuthRetrying reports whether the monitor is currently in the +// AuthRetrying transient-failure window. +func (mts *MonitoredTokenSource) inAuthRetrying() bool { + mts.mu.Lock() + defer mts.mu.Unlock() + return !mts.transientStartedAt.IsZero() +} + +// fastFailIfAuthRetrying returns the cached transient error when the +// monitor is in AuthRetrying so hot callers (e.g. the token injection +// middleware) return 503+Retry-After immediately rather than blocking +// on a singleflight retry that will exhaust again. +func (mts *MonitoredTokenSource) fastFailIfAuthRetrying() error { + mts.mu.Lock() + defer mts.mu.Unlock() + if mts.transientStartedAt.IsZero() { + return nil + } + // enterAuthRetrying sets transientStartedAt and lastTransientErr together + // under mu, so under normal flow lastTransientErr is non-nil here. Guard + // against a future refactor that breaks the invariant; fmt.Errorf("%w", nil) + // returns a non-nil error that errors.Is/Unwrap handle awkwardly. + if mts.lastTransientErr == nil { + return fmt.Errorf("auth retrying since %s", + mts.transientStartedAt.Format(time.RFC3339)) + } + return fmt.Errorf("auth retrying since %s: %w", + mts.transientStartedAt.Format(time.RFC3339), mts.lastTransientErr) +} + +// ceilingExceeded reports whether the workload has been in AuthRetrying +// for longer than the configured ceiling. Returns false if not in +// AuthRetrying. +func (mts *MonitoredTokenSource) ceilingExceeded() bool { + mts.mu.Lock() + defer mts.mu.Unlock() + if mts.transientStartedAt.IsZero() { + return false + } + return time.Since(mts.transientStartedAt) > resolveAuthRetryingMaxElapsed() +} + +// enterAuthRetrying marks the workload as AuthRetrying on the first call +// and refreshes the cached error on subsequent calls. The status +// transition is emitted only on first entry to avoid spamming the +// status file every tick. +// +// The gate check is performed under mts.mu so that markAsUnauthenticated +// (which closes the channel before acquiring mu) cannot interleave to +// re-populate transientStartedAt/lastTransientErr after clearing them. +// A narrower disk-write inversion is still possible (auth_retrying +// written briefly after Unauthenticated); the in-memory state is +// correct, and the disk entry resolves on the next workload restart +// (runner.go resets non-Unauthenticated statuses to Running). +func (mts *MonitoredTokenSource) enterAuthRetrying(err error) { + mts.mu.Lock() + select { + case <-mts.stopMonitoring: + mts.mu.Unlock() + return + default: + } + firstEntry := mts.transientStartedAt.IsZero() + if firstEntry { + mts.transientStartedAt = time.Now() + } + elapsed := time.Since(mts.transientStartedAt) + mts.lastTransientErr = err + mts.mu.Unlock() + + if firstEntry { + slog.Warn("token refresh entering AuthRetrying after short retry exhaustion; will retry on a longer cadence", + "workload", mts.workloadName, + "tick_interval", resolveAuthRetryingTickInterval(), + "max_elapsed", resolveAuthRetryingMaxElapsed(), + "error", err, + ) + _ = mts.statusUpdater.SetWorkloadStatus( + context.Background(), + mts.workloadName, + runtime.WorkloadStatusAuthRetrying, + fmt.Sprintf("Token refresh failing transiently; retrying every %s: %v", + resolveAuthRetryingTickInterval(), err), + ) + return + } + slog.Warn("token refresh still failing transiently", + "workload", mts.workloadName, + "elapsed", elapsed, + "error", err, + ) +} + +// exitAuthRetrying clears the AuthRetrying window and emits the Running +// status transition if the workload was previously in AuthRetrying. No-op +// if not in AuthRetrying. +func (mts *MonitoredTokenSource) exitAuthRetrying() { + mts.mu.Lock() + wasInAuthRetrying := !mts.transientStartedAt.IsZero() + if wasInAuthRetrying { + mts.transientStartedAt = time.Time{} + mts.lastTransientErr = nil + } + mts.mu.Unlock() + + if !wasInAuthRetrying { + return + } + slog.Info("token refresh recovered; exiting AuthRetrying state", + "workload", mts.workloadName, + ) + _ = mts.statusUpdater.SetWorkloadStatus( + context.Background(), + mts.workloadName, + runtime.WorkloadStatusRunning, + "", + ) } diff --git a/pkg/auth/monitored_token_source_test.go b/pkg/auth/monitored_token_source_test.go index c5e3137222..7c85dfee9d 100644 --- a/pkg/auth/monitored_token_source_test.go +++ b/pkg/auth/monitored_token_source_test.go @@ -4,9 +4,11 @@ package auth import ( + "bytes" "context" "errors" "fmt" + "log/slog" "net" "net/http" "net/http/httptest" @@ -23,6 +25,7 @@ import ( "golang.org/x/oauth2" rt "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/oauthproto/oauthtest" statusMocks "github.com/stacklok/toolhive/pkg/workloads/statuses/mocks" ) @@ -648,7 +651,8 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T statusUpdater, _ := newMockStatusUpdater(ctrl) retrying := tokenSource.notifyOnCall(2) - ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", "", "", statusUpdater, fastBackOff) + ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater) + ats.refresher.newBackOff = fastBackOff ats.StartBackgroundMonitoring() <-retrying // Ensure the retry loop has been entered before cancelling. @@ -667,7 +671,8 @@ func TestMonitoredTokenSource_BackgroundMonitor_ErrorClassification(t *testing.T Return(nil). Times(1) - ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", "", "", statusUpdater, fastBackOff) + ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater) + ats.refresher.newBackOff = fastBackOff ats.StartBackgroundMonitoring() <-ats.Stopped() // Monitor stops itself after marking unauthenticated. @@ -856,7 +861,8 @@ func TestMonitoredTokenSource_TransientErrorRetriesAndSucceeds(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", "", "", statusUpdater, fastBackOff) + ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater) + ats.refresher.newBackOff = fastBackOff ats.StartBackgroundMonitoring() // Block until the monitor has successfully recovered, then stop it. @@ -895,7 +901,8 @@ func TestMonitoredTokenSource_TransientErrorContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", "", "", statusUpdater, fastBackOff) + ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater) + ats.refresher.newBackOff = fastBackOff ats.StartBackgroundMonitoring() // Cancel once we know the retry loop is running, then wait for clean exit. @@ -949,10 +956,711 @@ func TestMonitoredTokenSource_TransientThenNonTransientMarksUnauthenticated(t *t ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", "", "", statusUpdater, fastBackOff) + ats := NewMonitoredTokenSource(ctx, tokenSource, "test-workload", "", "", statusUpdater) + ats.refresher.newBackOff = fastBackOff ats.StartBackgroundMonitoring() // Monitor stops itself after the non-transient error; wait for that. <-ats.Stopped() // gomock verifies SetWorkloadStatus was called exactly once. } + +// --- AuthRetrying state machine --- + +// drainShortRetryEnv sets env vars so the short retry inside transientRefresher +// exhausts in tens of milliseconds rather than minutes. Combined with +// fastBackOff on the refresher, the short-retry window becomes a no-op for +// AuthRetrying tests. +// +// Note: t.Setenv disallows the calling test from running with t.Parallel() +// because the env is process-wide. The AuthRetrying tests below therefore +// do not call t.Parallel(). +func drainShortRetryEnv(t *testing.T) { + t.Helper() + t.Setenv(tokenRefreshMaxTriesEnv, "3") + t.Setenv(tokenRefreshMaxElapsedTimeEnv, "200ms") +} + +// TestMonitor_EnterAuthRetryingAfterShortRetryExhausts asserts that once the +// short-retry window inside transientRefresher exhausts on a still-transient +// error, the monitor transitions the workload to AuthRetrying and keeps the +// monitor goroutine alive (i.e. does NOT mark Unauthenticated). +func TestMonitor_EnterAuthRetryingAfterShortRetryExhausts(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + drainShortRetryEnv(t) + t.Setenv(authRetryingTickIntervalEnv, "20ms") + t.Setenv(authRetryingMaxElapsedEnv, "10s") // ample ceiling โ€” must NOT be reached + + statusUpdater, statusManager := newMockStatusUpdater(ctrl) + tokenSource := newMockTokenSource() + + transientErr := &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}} + tokenSource.setTokenFn(func() (*oauth2.Token, error) { + if tokenSource.callCount == 1 { + return &oauth2.Token{ + AccessToken: "initial-token", + Expiry: time.Now().Add(10 * time.Millisecond), + }, nil + } + return nil, transientErr + }) + + authRetryingCalled := make(chan struct{}) + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusAuthRetrying, gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, _ rt.WorkloadStatus, _ string) error { + close(authRetryingCalled) + return nil + }). + Times(1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", "", "", statusUpdater, fastBackOff) + ats.StartBackgroundMonitoring() + + select { + case <-authRetryingCalled: + case <-time.After(2 * time.Second): + t.Fatal("did not transition to AuthRetrying within 2s") + } + + // Monitor must remain alive after entering AuthRetrying. + time.Sleep(80 * time.Millisecond) + select { + case <-ats.Stopped(): + t.Fatal("monitor stopped after AuthRetrying; expected it to stay alive") + default: + } + if !ats.inAuthRetrying() { + t.Fatal("expected inAuthRetrying() to be true") + } + + cancel() + <-ats.Stopped() +} + +// TestMonitor_AuthRetryingRecoversToRunning asserts that when the token +// endpoint recovers during AuthRetrying, the next successful refresh transitions +// the workload back to Running and clears the cached transient state. +func TestMonitor_AuthRetryingRecoversToRunning(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + drainShortRetryEnv(t) + t.Setenv(authRetryingTickIntervalEnv, "20ms") + t.Setenv(authRetryingMaxElapsedEnv, "10s") + + statusUpdater, statusManager := newMockStatusUpdater(ctrl) + tokenSource := newMockTokenSource() + + transientErr := &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}} + recovered := make(chan struct{}) + var once sync.Once + tokenSource.setTokenFn(func() (*oauth2.Token, error) { + if tokenSource.callCount == 1 { + return &oauth2.Token{ + AccessToken: "initial-token", + Expiry: time.Now().Add(10 * time.Millisecond), + }, nil + } + // Recover once the AuthRetrying transition has been observed and a + // post-transition tick fires. + select { + case <-recovered: + return &oauth2.Token{ + AccessToken: "recovered-token", + Expiry: time.Now().Add(time.Hour), + }, nil + default: + return nil, transientErr + } + }) + + runningCalled := make(chan struct{}) + gomock.InOrder( + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusAuthRetrying, gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, _ rt.WorkloadStatus, _ string) error { + once.Do(func() { close(recovered) }) + return nil + }). + Times(1), + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusRunning, ""). + DoAndReturn(func(_ context.Context, _ string, _ rt.WorkloadStatus, _ string) error { + close(runningCalled) + return nil + }). + Times(1), + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", "", "", statusUpdater, fastBackOff) + ats.StartBackgroundMonitoring() + + select { + case <-runningCalled: + case <-time.After(2 * time.Second): + t.Fatal("did not recover to Running within 2s") + } + if ats.inAuthRetrying() { + t.Fatal("expected inAuthRetrying() to be false after recovery") + } + + cancel() + <-ats.Stopped() +} + +// TestMonitor_AuthRetryingCeilingTransitionsToUnauthenticated asserts that +// after the configured ceiling elapses while still in AuthRetrying, the +// monitor gives up and marks the workload Unauthenticated. +func TestMonitor_AuthRetryingCeilingTransitionsToUnauthenticated(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + drainShortRetryEnv(t) + t.Setenv(authRetryingTickIntervalEnv, "20ms") + // Ceiling = 200ms (vs e.g. 50ms): leaves enough margin that the + // SetWorkloadStatus(AuthRetrying) emit + the first post-entry tick + // don't accidentally race the ceiling check on a slow runner. + t.Setenv(authRetryingMaxElapsedEnv, "200ms") + + statusUpdater, statusManager := newMockStatusUpdater(ctrl) + tokenSource := newMockTokenSource() + + transientErr := &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}} + tokenSource.setTokenFn(func() (*oauth2.Token, error) { + if tokenSource.callCount == 1 { + return &oauth2.Token{ + AccessToken: "initial-token", + Expiry: time.Now().Add(10 * time.Millisecond), + }, nil + } + return nil, transientErr + }) + + gomock.InOrder( + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusAuthRetrying, gomock.Any()). + Return(nil). + Times(1), + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusUnauthenticated, gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, _ rt.WorkloadStatus, reason string) error { + if !strings.Contains(reason, "transiently for over") { + t.Errorf("expected ceiling-specific reason; got %q", reason) + } + return nil + }). + Times(1), + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", "", "", statusUpdater, fastBackOff) + ats.StartBackgroundMonitoring() + + select { + case <-ats.Stopped(): + case <-time.After(2 * time.Second): + t.Fatal("monitor did not stop after AuthRetrying ceiling exceeded") + } +} + +// TestToken_HotCallerFastFailsDuringAuthRetrying asserts that a hot +// caller (see MonitoredTokenSource type doc) calling Token() during the +// AuthRetrying window gets the cached error immediately, without +// re-entering the short-retry loop against the still-broken endpoint. +func TestToken_HotCallerFastFailsDuringAuthRetrying(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + drainShortRetryEnv(t) + t.Setenv(authRetryingTickIntervalEnv, "500ms") // long gap so the hot call sits inside it + t.Setenv(authRetryingMaxElapsedEnv, "10s") + + statusUpdater, statusManager := newMockStatusUpdater(ctrl) + tokenSource := newMockTokenSource() + + transientErr := &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}} + tokenSource.setTokenFn(func() (*oauth2.Token, error) { + if tokenSource.callCount == 1 { + return &oauth2.Token{ + AccessToken: "initial-token", + Expiry: time.Now().Add(10 * time.Millisecond), + }, nil + } + return nil, transientErr + }) + + authRetryingCalled := make(chan struct{}) + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusAuthRetrying, gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, _ rt.WorkloadStatus, _ string) error { + close(authRetryingCalled) + return nil + }). + Times(1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", "", "", statusUpdater, fastBackOff) + ats.StartBackgroundMonitoring() + + select { + case <-authRetryingCalled: + case <-time.After(2 * time.Second): + t.Fatal("did not enter AuthRetrying within 2s") + } + + // Snapshot underlying source's call count, then make a hot call. + tokenSource.mu.Lock() + beforeCount := tokenSource.callCount + tokenSource.mu.Unlock() + + start := time.Now() + _, hotErr := ats.Token() + elapsed := time.Since(start) + + if hotErr == nil { + t.Fatal("expected Token() to fail-fast during AuthRetrying") + } + if elapsed > 100*time.Millisecond { + t.Fatalf("Token() took %v, expected fast-fail (<100ms)", elapsed) + } + if !errors.Is(hotErr, transientErr) { + t.Errorf("expected Token() error to wrap transientErr; got %v", hotErr) + } + + tokenSource.mu.Lock() + afterCount := tokenSource.callCount + tokenSource.mu.Unlock() + if afterCount > beforeCount { + t.Errorf("Token() invoked underlying source during AuthRetrying (calls %d โ†’ %d)", beforeCount, afterCount) + } + + cancel() + <-ats.Stopped() +} + +// TestMonitor_PermanentErrorDuringAuthRetryingTickGivesUpImmediately asserts +// that a permanent OAuth error during a post-AuthRetrying tick stops the +// monitor immediately, without waiting for the ceiling. +func TestMonitor_PermanentErrorDuringAuthRetryingTickGivesUpImmediately(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + drainShortRetryEnv(t) + t.Setenv(authRetryingTickIntervalEnv, "20ms") + t.Setenv(authRetryingMaxElapsedEnv, "10s") // ample ceiling โ€” must NOT be reached + + statusUpdater, statusManager := newMockStatusUpdater(ctrl) + tokenSource := newMockTokenSource() + + transientErr := &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}} + permanentErr := createRetrieveErrorWithCode(http.StatusBadRequest, "invalid_grant", `{"error":"invalid_grant"}`) + + authRetryingCalled := make(chan struct{}) + tokenSource.setTokenFn(func() (*oauth2.Token, error) { + if tokenSource.callCount == 1 { + return &oauth2.Token{ + AccessToken: "initial-token", + Expiry: time.Now().Add(10 * time.Millisecond), + }, nil + } + select { + case <-authRetryingCalled: + return nil, permanentErr + default: + return nil, transientErr + } + }) + + gomock.InOrder( + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusAuthRetrying, gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, _ rt.WorkloadStatus, _ string) error { + close(authRetryingCalled) + return nil + }). + Times(1), + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusUnauthenticated, gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, _ rt.WorkloadStatus, reason string) error { + if !strings.Contains(reason, "invalid_grant") { + t.Errorf("expected unauthenticated reason to mention invalid_grant; got %q", reason) + } + return nil + }). + Times(1), + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", "", "", statusUpdater, fastBackOff) + ats.StartBackgroundMonitoring() + + select { + case <-ats.Stopped(): + case <-time.After(2 * time.Second): + t.Fatal("monitor did not transition Unauthenticated within 2s") + } +} + +// TestMonitor_DCRWarnSilentOnCeilingGiveUp asserts that when the monitor gives +// up at the AuthRetrying ceiling, the DCR/CIMD remediation warning does NOT +// fire โ€” a transient ceiling is not a "stale cached credentials" signal. +// +// Not run in parallel: slog.SetDefault is process-wide. +func TestMonitor_DCRWarnSilentOnCeilingGiveUp(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + drainShortRetryEnv(t) + t.Setenv(authRetryingTickIntervalEnv, "20ms") + t.Setenv(authRetryingMaxElapsedEnv, "200ms") + + // Capture slog output by swapping the default logger for the lifetime + // of this test. bytes.Buffer is not goroutine-safe, but the test reads + // it only after <-ats.Stopped() โ€” by which point the monitor goroutine + // has exited and no further writes occur. + var logBuf bytes.Buffer + prevLogger := slog.Default() + slog.SetDefault(slog.New(slog.NewTextHandler(&logBuf, nil))) + t.Cleanup(func() { slog.SetDefault(prevLogger) }) + + statusUpdater, statusManager := newMockStatusUpdater(ctrl) + tokenSource := newMockTokenSource() + + transientErr := &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}} + tokenSource.setTokenFn(func() (*oauth2.Token, error) { + if tokenSource.callCount == 1 { + return &oauth2.Token{ + AccessToken: "initial-token", + Expiry: time.Now().Add(10 * time.Millisecond), + }, nil + } + return nil, transientErr + }) + + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusAuthRetrying, gomock.Any()). + Return(nil). + Times(1) + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusUnauthenticated, gomock.Any()). + Return(nil). + Times(1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Construct with a non-empty client_id so the DCR Warn *could* fire on a + // permanent-classified error; we're verifying it doesn't fire on a + // transient-classified ceiling give-up. + ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", + "https://issuer.example.com", "client-123", statusUpdater, fastBackOff) + ats.StartBackgroundMonitoring() + + select { + case <-ats.Stopped(): + case <-time.After(2 * time.Second): + t.Fatal("monitor did not stop within 2s") + } + + if strings.Contains(logBuf.String(), "delete the cached credentials") { + t.Errorf("DCR remediation Warn fired on transient ceiling give-up; output:\n%s", logBuf.String()) + } +} + +// --- end-to-end integration tests against a real HTTP OAuth server --- +// +// The tests above use a mock oauth2.TokenSource (newMockTokenSource) and +// drive the state machine via synthetic errors. The tests below wire the +// real golang.org/x/oauth2 ReuseTokenSource against the scriptable +// oauthtest.ControllableServer so that errors are produced by the OAuth +// library parsing an actual HTTP response. + +// TestIntegration_AuthRetryingRecoversAfterRealWAFBlock drives the full +// state machine end-to-end against a real OAuth HTTP server: initial +// success โ†’ real WAF-style 403 HTML refresh failure โ†’ AuthRetrying โ†’ +// flip server back to success โ†’ Running. Unlike the mock-based tests +// above, this exercises the actual golang.org/x/oauth2 response parsing +// and isTransientNetworkError classification of *oauth2.RetrieveError +// values constructed by the library from real HTTP responses. +func TestIntegration_AuthRetryingRecoversAfterRealWAFBlock(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + drainShortRetryEnv(t) + t.Setenv(authRetryingTickIntervalEnv, "100ms") + t.Setenv(authRetryingMaxElapsedEnv, "30s") + + srv := oauthtest.NewControllableServer() + defer srv.Close() + + statusUpdater, statusManager := newMockStatusUpdater(ctrl) + + authRetryingCalled := make(chan struct{}) + runningCalled := make(chan struct{}) + + gomock.InOrder( + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusAuthRetrying, gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, _ rt.WorkloadStatus, _ string) error { + close(authRetryingCalled) + return nil + }). + Times(1), + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusRunning, ""). + DoAndReturn(func(_ context.Context, _ string, _ rt.WorkloadStatus, _ string) error { + close(runningCalled) + return nil + }). + Times(1), + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ats := newMonitoredTokenSourceWithBackOff(ctx, oauthtest.NewRealTokenSource(srv.URL), + "test-workload", "", "", statusUpdater, fastBackOff) + ats.StartBackgroundMonitoring() + + // Give the monitor a moment to do its first successful refresh. + initialRefreshDeadline := time.Now().Add(2 * time.Second) + for srv.RequestCount() < 1 && time.Now().Before(initialRefreshDeadline) { + time.Sleep(20 * time.Millisecond) + } + if srv.RequestCount() < 1 { + t.Fatal("monitor did not issue an initial refresh against the fake server") + } + + // Flip to WAF block. Next monitor tick should exhaust short retry + // and transition to AuthRetrying. + srv.SetMode(oauthtest.ModeWAFBlock) + select { + case <-authRetryingCalled: + case <-time.After(10 * time.Second): + t.Fatalf("did not transition to AuthRetrying within 10s; server saw %d requests", srv.RequestCount()) + } + + // Recover. Next AuthRetrying tick should refresh successfully and + // transition back to Running. + srv.SetMode(oauthtest.ModeSuccess) + select { + case <-runningCalled: + case <-time.After(10 * time.Second): + t.Fatalf("did not recover to Running within 10s; server saw %d requests", srv.RequestCount()) + } + if ats.inAuthRetrying() { + t.Error("expected inAuthRetrying() to be false after recovery") + } + + cancel() + <-ats.Stopped() +} + +// TestIntegration_AuthRetryingCeilingThroughRealOAuthServer drives the +// monitor through a complete ceiling timeout end-to-end against a real +// OAuth server: real WAF block persists โ†’ AuthRetrying โ†’ ceiling +// exceeded โ†’ Unauthenticated. The ceiling here is intentionally tight +// (200ms) so the test completes quickly; in production the default is +// 24h. +func TestIntegration_AuthRetryingCeilingThroughRealOAuthServer(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + drainShortRetryEnv(t) + t.Setenv(authRetryingTickIntervalEnv, "50ms") + t.Setenv(authRetryingMaxElapsedEnv, "200ms") + + srv := oauthtest.NewControllableServer() + defer srv.Close() + // Start in WAF-block mode so the first refresh attempt fails. + srv.SetMode(oauthtest.ModeWAFBlock) + + statusUpdater, statusManager := newMockStatusUpdater(ctrl) + gomock.InOrder( + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusAuthRetrying, gomock.Any()). + Return(nil). + Times(1), + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusUnauthenticated, gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, _ rt.WorkloadStatus, reason string) error { + if !strings.Contains(reason, "transiently for over") { + t.Errorf("expected ceiling-specific reason; got %q", reason) + } + return nil + }). + Times(1), + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ats := newMonitoredTokenSourceWithBackOff(ctx, oauthtest.NewRealTokenSource(srv.URL), + "test-workload", "", "", statusUpdater, fastBackOff) + ats.StartBackgroundMonitoring() + + select { + case <-ats.Stopped(): + case <-time.After(5 * time.Second): + t.Fatalf("monitor did not reach ceiling within 5s; server saw %d requests", srv.RequestCount()) + } +} + +// TestToken_DoesNotEnterAuthRetryingAfterMonitorStopped asserts the +// stopMonitoring gate in enterAuthRetrying: once a permanent error has +// closed the monitor, a subsequent hot caller observing transient errors +// must NOT transition the workload back into AuthRetrying. Without the +// gate, the workload would be stuck at AuthRetrying with no monitor alive +// to honor the ceiling or drive recovery. +func TestToken_DoesNotEnterAuthRetryingAfterMonitorStopped(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + drainShortRetryEnv(t) + // Tick interval and ceiling don't matter for this test โ€” the monitor + // never runs. Set them so any spurious test path completes quickly. + t.Setenv(authRetryingTickIntervalEnv, "10s") + t.Setenv(authRetryingMaxElapsedEnv, "10s") + + statusUpdater, statusManager := newMockStatusUpdater(ctrl) + tokenSource := newMockTokenSource() + + permanentErr := createRetrieveErrorWithCode(http.StatusBadRequest, "invalid_grant", `{"error":"invalid_grant"}`) + transientErr := &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}} + + tokenSource.setTokenFn(func() (*oauth2.Token, error) { + if tokenSource.callCount == 1 { + return nil, permanentErr + } + return nil, transientErr + }) + + // Expect exactly one Unauthenticated transition. Crucially, no + // AuthRetrying transition is allowed โ€” gomock fails the test if one + // fires unexpectedly. + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", rt.WorkloadStatusUnauthenticated, gomock.Any()). + Return(nil). + Times(1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", "", "", statusUpdater, fastBackOff) + + // Drive the workload to Unauthenticated via a hot-caller call. + if _, err := ats.Token(); err == nil { + t.Fatal("expected permanent error from initial Token() call") + } + + // Subsequent hot-caller calls observing transient errors must NOT + // transition back to AuthRetrying, and must NOT pay the full short- + // retry duration against the broken endpoint. Without the + // stopMonitoring gate in Token(), each call would spend ~50ms+ in + // refresher.Refresh before returning. + for i := 0; i < 3; i++ { + start := time.Now() + if _, err := ats.Token(); err == nil { + t.Fatalf("hot-caller call %d unexpectedly succeeded", i) + } + if elapsed := time.Since(start); elapsed > 5*time.Millisecond { + t.Errorf("hot-caller call %d took %v after monitor stopped; "+ + "expected fast return (<5ms) via stopMonitoring gate, "+ + "not a full short-retry pass", i, elapsed) + } + } + + if ats.inAuthRetrying() { + t.Error("workload entered AuthRetrying after monitor stopped; gate failed") + } +} + +// TestConcurrent_EnterAuthRetryingAndMarkAsUnauthenticated verifies the +// in-memory invariant that after markAsUnauthenticated completes +// concurrently with enterAuthRetrying, transientStartedAt and +// lastTransientErr are always cleared. Without the gate-under-mu +// ordering in enterAuthRetrying, a hostile interleaving (gate check +// before lock acquisition) allows enterAuthRetrying to re-populate the +// fields *after* markAsUnauthenticated cleared them, which would leave +// hot callers stuck fast-failing against a dead monitor via +// fastFailIfAuthRetrying (which reads the fields, not the channel). +func TestConcurrent_EnterAuthRetryingAndMarkAsUnauthenticated(t *testing.T) { + t.Parallel() + // 1000 iterations to give the runtime scheduler enough chances to + // interleave the two goroutines unfavourably across runs. + const iterations = 1000 + for i := 0; i < iterations; i++ { + runConcurrentEnterAndMark(t, i) + } +} + +func runConcurrentEnterAndMark(t *testing.T, iter int) { + t.Helper() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusUpdater, statusManager := newMockStatusUpdater(ctrl) + tokenSource := newMockTokenSource() + + // Either status transition may or may not fire depending on which + // goroutine wins the gate check; we are testing the in-memory + // invariant, not the disk-write order. + statusManager.EXPECT(). + SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + AnyTimes() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Skip StartBackgroundMonitoring โ€” we drive both callers ourselves. + ats := newMonitoredTokenSourceWithBackOff(ctx, tokenSource, "test-workload", "", "", statusUpdater, fastBackOff) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + ats.enterAuthRetrying(errors.New("transient")) + }() + go func() { + defer wg.Done() + ats.markAsUnauthenticated("permanent error", false) + }() + wg.Wait() + + ats.mu.Lock() + startedAt := ats.transientStartedAt + lastErr := ats.lastTransientErr + ats.mu.Unlock() + + if !startedAt.IsZero() { + t.Fatalf("iter %d: transientStartedAt should be zero after markAsUnauthenticated, got %v", iter, startedAt) + } + if lastErr != nil { + t.Fatalf("iter %d: lastTransientErr should be nil after markAsUnauthenticated, got %v", iter, lastErr) + } + // Guard against a future regression that drops the stopMonitoring + // close in markAsUnauthenticated: the in-memory invariant above + // depends on the channel being closed, so a regression there would + // silently break the gate check in enterAuthRetrying. + select { + case <-ats.stopMonitoring: + default: + t.Fatalf("iter %d: stopMonitoring should be closed after markAsUnauthenticated", iter) + } +} diff --git a/pkg/container/runtime/types.go b/pkg/container/runtime/types.go index b65c518820..f71b3d0ee0 100644 --- a/pkg/container/runtime/types.go +++ b/pkg/container/runtime/types.go @@ -44,6 +44,13 @@ const ( // WorkloadStatusUnauthenticated indicates that the workload is running but // cannot authenticate with the remote MCP server (e.g., expired refresh token). WorkloadStatusUnauthenticated WorkloadStatus = "unauthenticated" + // WorkloadStatusAuthRetrying indicates that background authentication + // refresh is failing transiently and the monitor is still retrying. + // The workload may recover on its own (โ†’ Running) or be marked + // unauthenticated when the configured ceiling is exceeded + // (โ†’ Unauthenticated). Hot requests fail fast with 503+Retry-After + // while in this state. + WorkloadStatusAuthRetrying WorkloadStatus = "auth_retrying" // WorkloadStatusPolicyStopped indicates that the workload was stopped by // policy enforcement. The StatusContext field carries the human-readable reason. WorkloadStatusPolicyStopped WorkloadStatus = "policy_stopped" diff --git a/pkg/core/workload.go b/pkg/core/workload.go index 06199504dd..c124cc523e 100644 --- a/pkg/core/workload.go +++ b/pkg/core/workload.go @@ -33,7 +33,7 @@ type Workload struct { ProxyMode string `json:"proxy_mode,omitempty"` // Status is the current status of the workload. //nolint:lll // enums tag needed for swagger generation with --parseDependencyLevel - Status runtime.WorkloadStatus `json:"status" enums:"running,stopped,error,starting,stopping,unhealthy,removing,unknown,unauthenticated,policy_stopped"` + Status runtime.WorkloadStatus `json:"status" enums:"running,stopped,error,starting,stopping,unhealthy,removing,unknown,unauthenticated,auth_retrying,policy_stopped"` // StatusContext provides additional context about the workload's status. // The exact meaning is determined by the status and the underlying runtime. StatusContext string `json:"status_context,omitempty"` diff --git a/pkg/oauthproto/oauthtest/server.go b/pkg/oauthproto/oauthtest/server.go new file mode 100644 index 0000000000..95709d5423 --- /dev/null +++ b/pkg/oauthproto/oauthtest/server.go @@ -0,0 +1,136 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package oauthtest + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "time" + + "golang.org/x/oauth2" +) + +// serverExpiresInSeconds is the expires_in value baked into success +// responses. One second is short enough that the issued token's Expiry +// triggers a refresh attempt on every monitor tick once the caller has +// passed the short-retry leeway (which is what tests exercising +// transition timing want). Cannot use 0 here: golang.org/x/oauth2 leaves +// the issued Token.Expiry zero when expires_in is zero, which downstream +// monitors typically interpret as "no further checking needed" and exit. +const serverExpiresInSeconds = 1 + +// FailureMode controls how a ControllableServer responds to a token +// endpoint request. ModeSuccess returns a valid JSON token response; the +// other modes produce error shapes commonly observed in production. +type FailureMode int + +const ( + // ModeSuccess returns 200 with a valid JSON token response carrying + // access_token, token_type=Bearer, expires_in, and refresh_token. + ModeSuccess FailureMode = iota + // ModeWAFBlock returns 403 with an HTML body and no RFC 6749 error code โ€” + // the shape commonly returned by WAFs, CDNs, or reverse proxies that + // block a request before it reaches the OAuth server. Classified as + // transient (4xx without an OAuth error code) per RFC 6749 ยง5.2; see + // pkg/auth/monitored_token_source.go isTransientRetrieveError. + ModeWAFBlock + // ModeServerError returns 500. Treated as transient. + ModeServerError + // ModeInvalidGrant returns 400 with {"error":"invalid_grant"}. Treated as + // permanent (RFC 6749 ยง5.2). + ModeInvalidGrant +) + +// ControllableServer is an httptest.NewServer with a programmable token +// endpoint. Tests flip the mode at runtime to drive the token source under +// test through specific error shapes. Used to write end-to-end tests that +// exercise the real golang.org/x/oauth2 library against actual HTTP +// responses rather than synthetic Go values. +// +// Construct with NewControllableServer, swap behavior with SetMode, and +// close via the embedded *httptest.Server's Close(). +type ControllableServer struct { + *httptest.Server + mu sync.Mutex + mode FailureMode + refreshCount int +} + +// NewControllableServer returns a server in ModeSuccess. Success responses +// use a fixed 1-second expires_in (see serverExpiresInSeconds for rationale). +func NewControllableServer() *ControllableServer { + s := &ControllableServer{mode: ModeSuccess} + s.Server = httptest.NewServer(http.HandlerFunc(s.handle)) + return s +} + +// SetMode swaps the response behavior. Concurrent-safe. +func (s *ControllableServer) SetMode(m FailureMode) { + s.mu.Lock() + defer s.mu.Unlock() + s.mode = m +} + +// RequestCount returns the number of token-endpoint requests observed so +// far. Useful for tests that need to assert refresh activity occurred. +func (s *ControllableServer) RequestCount() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.refreshCount +} + +func (s *ControllableServer) handle(w http.ResponseWriter, _ *http.Request) { + s.mu.Lock() + s.refreshCount++ + mode := s.mode + s.mu.Unlock() + + switch mode { + case ModeSuccess: + w.Header().Set("Content-Type", "application/json") + resp := map[string]any{ + "access_token": "test-access-token", + "token_type": "Bearer", + "expires_in": serverExpiresInSeconds, + "refresh_token": "test-refresh-token", + } + _ = json.NewEncoder(w).Encode(resp) + case ModeWAFBlock: + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`Request blocked by web application firewall`)) + case ModeServerError: + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("Internal Server Error")) + case ModeInvalidGrant: + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant"}`)) + } +} + +// NewRealTokenSource builds a real golang.org/x/oauth2 token source +// pointed at the given token endpoint URL (typically a ControllableServer's +// URL). The returned source is hardcoded with a fixed initial refresh +// token ("test-refresh-token") and an expiry one hour in the past, so the +// first Token() call always triggers an HTTP refresh against the endpoint. +// +// This is intentionally non-configurable. Tests that need to exercise +// "valid current token, will not refresh yet" or "refresh token rotation" +// scenarios should construct their own oauth2.TokenSource directly rather +// than extending this helper. +func NewRealTokenSource(tokenEndpointURL string) oauth2.TokenSource { + cfg := &oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + Endpoint: oauth2.Endpoint{TokenURL: tokenEndpointURL}, + } + initial := &oauth2.Token{ + RefreshToken: "test-refresh-token", + Expiry: time.Now().Add(-time.Hour), + } + return cfg.TokenSource(context.Background(), initial) +} diff --git a/pkg/tui/helpers_test.go b/pkg/tui/helpers_test.go index 0c404360de..44be2f7fda 100644 --- a/pkg/tui/helpers_test.go +++ b/pkg/tui/helpers_test.go @@ -193,8 +193,9 @@ func TestCountStatuses(t *testing.T) { {Status: rt.WorkloadStatusRunning}, {Status: rt.WorkloadStatusUnauthenticated}, {Status: rt.WorkloadStatusUnhealthy}, + {Status: rt.WorkloadStatusAuthRetrying}, }, - expectedRunning: 3, + expectedRunning: 4, expectedStopped: 0, }, { diff --git a/pkg/tui/view_helpers.go b/pkg/tui/view_helpers.go index b322c17405..3d36a08bc5 100644 --- a/pkg/tui/view_helpers.go +++ b/pkg/tui/view_helpers.go @@ -177,7 +177,8 @@ func truncateSidebar(s string, n int) string { func countStatuses(list []core.Workload) (running, stopped int) { for _, w := range list { switch w.Status { - case rt.WorkloadStatusRunning, rt.WorkloadStatusUnauthenticated, rt.WorkloadStatusUnhealthy: + case rt.WorkloadStatusRunning, rt.WorkloadStatusUnauthenticated, rt.WorkloadStatusUnhealthy, + rt.WorkloadStatusAuthRetrying: running++ case rt.WorkloadStatusStopped, rt.WorkloadStatusError, rt.WorkloadStatusStarting, rt.WorkloadStatusStopping, rt.WorkloadStatusRemoving, rt.WorkloadStatusUnknown, diff --git a/pkg/vmcp/types.go b/pkg/vmcp/types.go index 8867da1131..4488f3cabf 100644 --- a/pkg/vmcp/types.go +++ b/pkg/vmcp/types.go @@ -161,6 +161,8 @@ const ( // This occurs when: // - Health checks succeed but response times exceed the degraded threshold (slow but working) // - Backend just recovered from failures and is in a stabilizing state + // - Background OAuth token refresh is failing transiently while the + // workload monitor retries (auth_retrying workload status) BackendDegraded BackendHealthStatus = "degraded" // BackendUnhealthy indicates the backend is not responding to health checks. diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index aa2e0eb442..b5538195d9 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -331,6 +331,14 @@ func mapWorkloadStatusToVMCPHealth(status rt.WorkloadStatus) vmcp.BackendHealthS return vmcp.BackendUnknown case rt.WorkloadStatusUnauthenticated: return vmcp.BackendUnauthenticated + case rt.WorkloadStatusAuthRetrying: + // Token refresh is in transient retry; the workload may yet recover + // without operator intervention. Map to BackendDegraded so vmcp + // distinguishes this from the terminal Unauthenticated case and + // keeps the backend in tool-discovery aggregation โ€” tool invocation + // will still fail with 503 until the token refresh recovers, but + // discovery callers see the backend's capabilities throughout. + return vmcp.BackendDegraded case rt.WorkloadStatusPolicyStopped: return vmcp.BackendUnhealthy default: diff --git a/pkg/workloads/manager_test.go b/pkg/workloads/manager_test.go index ebbf383d10..8b8e415636 100644 --- a/pkg/workloads/manager_test.go +++ b/pkg/workloads/manager_test.go @@ -20,6 +20,7 @@ import ( runtimeMocks "github.com/stacklok/toolhive/pkg/container/runtime/mocks" "github.com/stacklok/toolhive/pkg/core" "github.com/stacklok/toolhive/pkg/runner" + "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/workloads/statuses" statusMocks "github.com/stacklok/toolhive/pkg/workloads/statuses/mocks" ) @@ -2147,3 +2148,32 @@ func TestDefaultManager_ListWorkloadsUsingSecret(t *testing.T) { assert.NotNil(t, listFunc, "ListWorkloadsUsingSecret method should exist with correct signature") }) } + +func TestMapWorkloadStatusToVMCPHealth(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in runtime.WorkloadStatus + want vmcp.BackendHealthStatus + }{ + {"running โ†’ healthy", runtime.WorkloadStatusRunning, vmcp.BackendHealthy}, + {"unhealthy โ†’ unhealthy", runtime.WorkloadStatusUnhealthy, vmcp.BackendUnhealthy}, + {"stopped โ†’ unhealthy", runtime.WorkloadStatusStopped, vmcp.BackendUnhealthy}, + {"error โ†’ unhealthy", runtime.WorkloadStatusError, vmcp.BackendUnhealthy}, + {"stopping โ†’ unhealthy", runtime.WorkloadStatusStopping, vmcp.BackendUnhealthy}, + {"removing โ†’ unhealthy", runtime.WorkloadStatusRemoving, vmcp.BackendUnhealthy}, + {"starting โ†’ unknown", runtime.WorkloadStatusStarting, vmcp.BackendUnknown}, + {"unknown โ†’ unknown", runtime.WorkloadStatusUnknown, vmcp.BackendUnknown}, + {"unauthenticated โ†’ unauthenticated", runtime.WorkloadStatusUnauthenticated, vmcp.BackendUnauthenticated}, + {"auth_retrying โ†’ degraded", runtime.WorkloadStatusAuthRetrying, vmcp.BackendDegraded}, + {"policy_stopped โ†’ unhealthy", runtime.WorkloadStatusPolicyStopped, vmcp.BackendUnhealthy}, + {"unrecognized โ†’ unknown", runtime.WorkloadStatus("not_a_real_status"), vmcp.BackendUnknown}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := mapWorkloadStatusToVMCPHealth(tc.in) + assert.Equal(t, tc.want, got) + }) + } +}