Skip to content

Commit 6ed4780

Browse files
committed
fix(ccm): mark credential unavailable on refresh failure, handle poll 401
tryRefreshCredentials now returns error and calls markCredentialsUnavailable when lock acquisition or file write permission fails. getAccessToken propagates the error instead of silently returning the expired token. pollUsage handles 401 by attempting auth recovery and marking unavailable on failure. All credential error paths now use Error log level instead of Debug. Startup checks expired tokens eagerly via tryRefreshCredentials.
1 parent cf11e0e commit 6ed4780

4 files changed

Lines changed: 94 additions & 34 deletions

File tree

service/ccm/credential_default.go

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,15 @@ func (c *defaultCredential) start() error {
158158
}
159159
err = c.ensureCredentialWatcher()
160160
if err != nil {
161-
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
161+
c.logger.Error("start credential watcher for ", c.tag, ": ", err)
162162
}
163163
err = c.reloadCredentials(true)
164164
if err != nil {
165-
c.logger.Warn("initial credential load for ", c.tag, ": ", err)
165+
c.logger.Error("initial credential load for ", c.tag, ": ", err)
166+
}
167+
if c.credentials != nil && c.credentials.needsRefresh() &&
168+
slices.Contains(c.credentials.Scopes, "user:inference") {
169+
c.tryRefreshCredentials(false)
166170
}
167171
if c.usageTracker != nil {
168172
err = c.usageTracker.Load()
@@ -240,7 +244,10 @@ func (c *defaultCredential) getAccessToken() (string, error) {
240244
if !currentCredentials.needsRefresh() || !slices.Contains(currentCredentials.Scopes, "user:inference") {
241245
return currentCredentials.AccessToken, nil
242246
}
243-
c.tryRefreshCredentials(false)
247+
refreshErr := c.tryRefreshCredentials(false)
248+
if refreshErr != nil {
249+
return "", refreshErr
250+
}
244251
c.access.RLock()
245252
defer c.access.RUnlock()
246253
if c.credentials != nil && c.credentials.AccessToken != "" {
@@ -354,23 +361,25 @@ func (c *defaultCredential) shouldAttemptRefresh(credentials *oauthCredentials,
354361
return credentials.needsRefresh()
355362
}
356363

357-
func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
364+
func (c *defaultCredential) tryRefreshCredentials(force bool) error {
358365
latestCredentials, err := platformReadCredentials(c.credentialPath)
359366
if err == nil && latestCredentials != nil {
360367
c.absorbCredentials(latestCredentials)
361368
}
362369
currentCredentials := c.currentCredentials()
363370
if !c.shouldAttemptRefresh(currentCredentials, force) {
364-
return false
371+
return nil
365372
}
366373
acquireLock := c.acquireLock
367374
if acquireLock == nil {
368375
acquireLock = acquireCredentialLock
369376
}
370377
release, err := acquireLock(c.configDir)
371378
if err != nil {
372-
c.logger.Debug("acquire credential lock for ", c.tag, ": ", err)
373-
return false
379+
lockErr := E.Cause(err, "acquire credential lock for ", c.tag)
380+
c.logger.Error(lockErr)
381+
c.markCredentialsUnavailable(lockErr)
382+
return lockErr
374383
}
375384
defer release()
376385

@@ -382,30 +391,35 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
382391
currentCredentials = c.currentCredentials()
383392
}
384393
if !c.shouldAttemptRefresh(currentCredentials, force) {
385-
return false
394+
return nil
386395
}
387-
if err := platformCanWriteCredentials(c.credentialPath); err != nil {
388-
c.logger.Debug("credential file not writable for ", c.tag, ": ", err)
389-
return false
396+
err = platformCanWriteCredentials(c.credentialPath)
397+
if err != nil {
398+
writeErr := E.Cause(err, "credential file not writable for ", c.tag)
399+
c.logger.Error(writeErr)
400+
c.markCredentialsUnavailable(writeErr)
401+
return writeErr
390402
}
391403

392404
baseCredentials := cloneCredentials(currentCredentials)
393405
refreshResult, retryDelay, err := refreshToken(c.serviceContext, c.forwardHTTPClient, currentCredentials)
394406
if err != nil {
395407
if retryDelay != 0 {
396-
c.logger.Debug("refresh token for ", c.tag, ": retry delay=", retryDelay, ", error=", err)
408+
c.logger.Error("refresh token for ", c.tag, ": retry delay=", retryDelay, ", error=", err)
397409
} else {
398-
c.logger.Debug("refresh token for ", c.tag, ": ", err)
410+
c.logger.Error("refresh token for ", c.tag, ": ", err)
399411
}
400412
latestCredentials, readErr := platformReadCredentials(c.credentialPath)
401413
if readErr == nil && latestCredentials != nil {
402414
c.absorbCredentials(latestCredentials)
403-
return latestCredentials.AccessToken != "" && (latestCredentials.AccessToken != baseCredentials.AccessToken || !latestCredentials.needsRefresh())
415+
if latestCredentials.AccessToken != "" && (latestCredentials.AccessToken != baseCredentials.AccessToken || !latestCredentials.needsRefresh()) {
416+
return nil
417+
}
404418
}
405-
return false
419+
return E.Cause(err, "refresh token for ", c.tag)
406420
}
407421
if refreshResult == nil || refreshResult.Credentials == nil {
408-
return false
422+
return E.New("refresh token for ", c.tag, ": empty result")
409423
}
410424

411425
refreshedCredentials := cloneCredentials(refreshResult.Credentials)
@@ -419,7 +433,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
419433
if c.needsProfileHydration() {
420434
profileSnapshot, profileErr := c.fetchProfileSnapshot(c.forwardHTTPClient, refreshedCredentials.AccessToken)
421435
if profileErr != nil {
422-
c.logger.Debug("fetch profile for ", c.tag, ": ", profileErr)
436+
c.logger.Error("fetch profile for ", c.tag, ": ", profileErr)
423437
} else if profileSnapshot != nil {
424438
credentialsChanged := c.applyProfileSnapshot(profileSnapshot)
425439
c.persistOAuthAccount()
@@ -428,7 +442,7 @@ func (c *defaultCredential) tryRefreshCredentials(force bool) bool {
428442
}
429443
}
430444
}
431-
return true
445+
return nil
432446
}
433447

434448
func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool {
@@ -439,7 +453,10 @@ func (c *defaultCredential) recoverAuthFailure(failedAccessToken string) bool {
439453
return true
440454
}
441455
}
442-
c.tryRefreshCredentials(true)
456+
err = c.tryRefreshCredentials(true)
457+
if err != nil {
458+
return false
459+
}
443460
currentCredentials := c.currentCredentials()
444461
return currentCredentials != nil && currentCredentials.AccessToken != "" && currentCredentials.AccessToken != failedAccessToken
445462
}
@@ -924,7 +941,16 @@ func (c *defaultCredential) pollUsage() {
924941
return
925942
}
926943
body, _ := io.ReadAll(response.Body)
927-
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
944+
if response.StatusCode == http.StatusUnauthorized {
945+
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
946+
if !c.recoverAuthFailure(accessToken) {
947+
c.markCredentialsUnavailable(E.New("poll usage unauthorized for ", c.tag))
948+
}
949+
return
950+
}
951+
if !c.isPollBackoffAtCap() {
952+
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
953+
}
928954
c.incrementPollFailures()
929955
return
930956
}
@@ -941,7 +967,9 @@ func (c *defaultCredential) pollUsage() {
941967
}
942968
err = json.NewDecoder(response.Body).Decode(&usageResponse)
943969
if err != nil {
944-
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
970+
if !c.isPollBackoffAtCap() {
971+
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
972+
}
945973
c.incrementPollFailures()
946974
return
947975
}
@@ -982,7 +1010,7 @@ func (c *defaultCredential) pollUsage() {
9821010
if needsProfileFetch {
9831011
profileSnapshot, err := c.fetchProfileSnapshot(httpClient, accessToken)
9841012
if err != nil {
985-
c.logger.Debug("fetch profile for ", c.tag, ": ", err)
1013+
c.logger.Error("fetch profile for ", c.tag, ": ", err)
9861014
return
9871015
}
9881016
if profileSnapshot != nil {

service/ccm/credential_default_test.go

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"time"
1010
)
1111

12-
func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) {
12+
func TestGetAccessTokenMarksUnavailableWhenLockFails(t *testing.T) {
1313
t.Parallel()
1414

1515
directory := t.TempDir()
@@ -32,15 +32,47 @@ func TestGetAccessTokenReturnsExistingTokenWhenLockFails(t *testing.T) {
3232
}
3333

3434
credential.acquireLock = func(string) (func(), error) {
35-
return nil, errors.New("locked")
35+
return nil, errors.New("permission denied")
3636
}
3737

38-
token, err := credential.getAccessToken()
39-
if err != nil {
38+
_, err := credential.getAccessToken()
39+
if err == nil {
40+
t.Fatal("expected error when lock acquisition fails, got nil")
41+
}
42+
if credential.isUsable() {
43+
t.Fatal("credential should be marked unavailable after lock failure")
44+
}
45+
}
46+
47+
func TestGetAccessTokenMarksUnavailableOnUnwritableFile(t *testing.T) {
48+
t.Parallel()
49+
50+
directory := t.TempDir()
51+
credentialPath := filepath.Join(directory, ".credentials.json")
52+
writeTestCredentials(t, credentialPath, &oauthCredentials{
53+
AccessToken: "old-token",
54+
RefreshToken: "refresh-token",
55+
ExpiresAt: time.Now().Add(-time.Minute).UnixMilli(),
56+
Scopes: []string{"user:profile", "user:inference"},
57+
})
58+
59+
credential := newTestDefaultCredential(t, credentialPath, roundTripFunc(func(request *http.Request) (*http.Response, error) {
60+
t.Fatal("refresh should not be attempted when file is not writable")
61+
return nil, nil
62+
}))
63+
if err := credential.reloadCredentials(true); err != nil {
4064
t.Fatal(err)
4165
}
42-
if token != "old-token" {
43-
t.Fatalf("expected old token, got %q", token)
66+
67+
os.Chmod(credentialPath, 0o444)
68+
t.Cleanup(func() { os.Chmod(credentialPath, 0o644) })
69+
70+
_, err := credential.getAccessToken()
71+
if err == nil {
72+
t.Fatal("expected error when credential file is not writable, got nil")
73+
}
74+
if credential.isUsable() {
75+
t.Fatal("credential should be marked unavailable after write permission failure")
4476
}
4577
}
4678

service/ccm/credential_external.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -598,29 +598,29 @@ func (c *externalCredential) pollUsage() {
598598
ctx := c.getReverseContext()
599599
response, err := c.doPollUsageRequest(ctx)
600600
if err != nil {
601-
c.logger.Debug("poll usage for ", c.tag, ": ", err)
601+
c.logger.Error("poll usage for ", c.tag, ": ", err)
602602
c.incrementPollFailures()
603603
return
604604
}
605605
defer response.Body.Close()
606606

607607
if response.StatusCode != http.StatusOK {
608608
body, _ := io.ReadAll(response.Body)
609-
c.logger.Debug("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
609+
c.logger.Error("poll usage for ", c.tag, ": status ", response.StatusCode, " ", string(body))
610610
c.incrementPollFailures()
611611
return
612612
}
613613

614614
body, err := io.ReadAll(response.Body)
615615
if err != nil {
616-
c.logger.Debug("poll usage for ", c.tag, ": read body: ", err)
616+
c.logger.Error("poll usage for ", c.tag, ": read body: ", err)
617617
c.incrementPollFailures()
618618
return
619619
}
620620
var rawFields map[string]json.RawMessage
621621
err = json.Unmarshal(body, &rawFields)
622622
if err != nil {
623-
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
623+
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
624624
c.incrementPollFailures()
625625
return
626626
}
@@ -634,7 +634,7 @@ func (c *externalCredential) pollUsage() {
634634
var statusResponse statusPayload
635635
err = json.Unmarshal(body, &statusResponse)
636636
if err != nil {
637-
c.logger.Debug("poll usage for ", c.tag, ": decode: ", err)
637+
c.logger.Error("poll usage for ", c.tag, ": decode: ", err)
638638
c.incrementPollFailures()
639639
return
640640
}

service/ccm/credential_file.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func (c *defaultCredential) retryCredentialReloadIfNeeded() {
7575

7676
err := c.ensureCredentialWatcher()
7777
if err != nil {
78-
c.logger.Debug("start credential watcher for ", c.tag, ": ", err)
78+
c.logger.Error("start credential watcher for ", c.tag, ": ", err)
7979
}
8080
_ = c.reloadCredentials(false)
8181
}

0 commit comments

Comments
 (0)