Skip to content

Commit 96806ae

Browse files
fix: address review feedback — double-sleep bug, dead test field, ctx-cancel test, nits
Assisted-By: docker-agent
1 parent 44554bd commit 96806ae

3 files changed

Lines changed: 89 additions & 42 deletions

File tree

pkg/modelerrors/modelerrors_test.go

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -330,24 +330,25 @@ func TestClassifyModelError(t *testing.T) {
330330
retryAfter time.Duration // value to pass as the retryAfter param
331331
wantRetryable bool
332332
wantRateLimited bool
333+
wantRetryAfter time.Duration // expected retryAfterOut
333334
}{
334-
{name: "nil", err: nil, wantRetryable: false, wantRateLimited: false},
335-
{name: "context canceled", err: context.Canceled, wantRetryable: false, wantRateLimited: false},
336-
{name: "context deadline exceeded", err: context.DeadlineExceeded, wantRetryable: false, wantRateLimited: false},
337-
{name: "context overflow", err: errors.New("prompt is too long: 200000 tokens > 100000 maximum"), wantRetryable: false, wantRateLimited: false},
335+
{name: "nil", err: nil, wantRetryable: false, wantRateLimited: false, wantRetryAfter: 0},
336+
{name: "context canceled", err: context.Canceled, wantRetryable: false, wantRateLimited: false, wantRetryAfter: 0},
337+
{name: "context deadline exceeded", err: context.DeadlineExceeded, wantRetryable: false, wantRateLimited: false, wantRetryAfter: 0},
338+
{name: "context overflow", err: errors.New("prompt is too long: 200000 tokens > 100000 maximum"), wantRetryable: false, wantRateLimited: false, wantRetryAfter: 0},
338339
// 429 rate limit cases (retryAfter passed in by caller from provider.ExtractRetryAfter)
339-
{name: "429 message only, no header", err: errors.New("POST /v1/chat: 429 Too Many Requests"), retryAfter: 0, wantRetryable: false, wantRateLimited: true},
340-
{name: "429 message only, with header", err: errors.New("POST /v1/chat: 429 Too Many Requests"), retryAfter: 30 * time.Second, wantRetryable: false, wantRateLimited: true},
340+
{name: "429 message only, no header", err: errors.New("POST /v1/chat: 429 Too Many Requests"), retryAfter: 0, wantRetryable: false, wantRateLimited: true, wantRetryAfter: 0},
341+
{name: "429 message only, with header", err: errors.New("POST /v1/chat: 429 Too Many Requests"), retryAfter: 30 * time.Second, wantRetryable: false, wantRateLimited: true, wantRetryAfter: 30 * time.Second},
341342
// Retryable server errors
342-
{name: "500 internal server error", err: errors.New("500 internal server error"), wantRetryable: true, wantRateLimited: false},
343-
{name: "502 bad gateway", err: errors.New("502 bad gateway"), wantRetryable: true, wantRateLimited: false},
344-
{name: "503 service unavailable", err: errors.New("503 service unavailable"), wantRetryable: true, wantRateLimited: false},
345-
{name: "504 gateway timeout", err: errors.New("504 gateway timeout"), wantRetryable: true, wantRateLimited: false},
343+
{name: "500 internal server error", err: errors.New("500 internal server error"), wantRetryable: true, wantRateLimited: false, wantRetryAfter: 0},
344+
{name: "502 bad gateway", err: errors.New("502 bad gateway"), wantRetryable: true, wantRateLimited: false, wantRetryAfter: 0},
345+
{name: "503 service unavailable", err: errors.New("503 service unavailable"), wantRetryable: true, wantRateLimited: false, wantRetryAfter: 0},
346+
{name: "504 gateway timeout", err: errors.New("504 gateway timeout"), wantRetryable: true, wantRateLimited: false, wantRetryAfter: 0},
346347
// Non-retryable errors (message-based, no SDK types needed here)
347-
{name: "401 unauthorized", err: errors.New("401 unauthorized"), wantRetryable: false, wantRateLimited: false},
348-
{name: "403 forbidden", err: errors.New("403 forbidden"), wantRetryable: false, wantRateLimited: false},
348+
{name: "401 unauthorized", err: errors.New("401 unauthorized"), wantRetryable: false, wantRateLimited: false, wantRetryAfter: 0},
349+
{name: "403 forbidden", err: errors.New("403 forbidden"), wantRetryable: false, wantRateLimited: false, wantRetryAfter: 0},
349350
// Network errors
350-
{name: "network timeout", err: &mockTimeoutError{}, wantRetryable: true, wantRateLimited: false},
351+
{name: "network timeout", err: &mockTimeoutError{}, wantRetryable: true, wantRateLimited: false, wantRetryAfter: 0},
351352
}
352353

353354
for _, tt := range tests {
@@ -356,7 +357,7 @@ func TestClassifyModelError(t *testing.T) {
356357
retryable, rateLimited, retryAfterOut := ClassifyModelError(tt.err, tt.retryAfter)
357358
assert.Equal(t, tt.wantRetryable, retryable, "retryable mismatch")
358359
assert.Equal(t, tt.wantRateLimited, rateLimited, "rateLimited mismatch")
359-
assert.GreaterOrEqual(t, retryAfterOut, time.Duration(0), "retryAfterOut should never be negative")
360+
assert.Equal(t, tt.wantRetryAfter, retryAfterOut, "retryAfter mismatch")
360361
})
361362
}
362363

pkg/runtime/fallback.go

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,6 @@ func logFallbackAttempt(agentName string, model modelWithFallback, attempt, maxR
6969
}
7070
}
7171

72-
// logRetryBackoff logs when we're backing off before a retry
73-
func logRetryBackoff(agentName, modelID string, attempt int, backoff time.Duration) {
74-
slog.Debug("Backing off before retry",
75-
"agent", agentName,
76-
"model", modelID,
77-
"attempt", attempt+1,
78-
"backoff", backoff)
79-
}
80-
8172
// getCooldownState returns the current cooldown state for an agent (thread-safe).
8273
// Returns nil if no cooldown is active or if cooldown has expired.
8374
// Expired entries are evicted to prevent stale state accumulation.
@@ -228,15 +219,6 @@ func (r *LocalRuntime) tryModelWithFallback(
228219
return streamResult{}, nil, ctx.Err()
229220
}
230221

231-
// Apply backoff before retry (not on first attempt of each model)
232-
if attempt > 0 {
233-
backoff := modelerrors.CalculateBackoff(attempt - 1)
234-
logRetryBackoff(a.Name(), modelEntry.provider.ID(), attempt, backoff)
235-
if !modelerrors.SleepWithContext(ctx, backoff) {
236-
return streamResult{}, nil, ctx.Err()
237-
}
238-
}
239-
240222
// Emit fallback event when transitioning to a new model (but not when starting in cooldown)
241223
if chainIdx > startIndex && attempt == 0 {
242224
logFallbackAttempt(a.Name(), modelEntry, attempt, fallbackRetries, lastErr)
@@ -272,7 +254,7 @@ func (r *LocalRuntime) tryModelWithFallback(
272254
return streamResult{}, nil, err
273255
}
274256

275-
decision := r.handleModelError(ctx, err, a, modelEntry, attempt, hasFallbacks, &primaryFailedWithNonRetryable)
257+
decision := handleModelError(ctx, err, a, modelEntry, attempt, hasFallbacks, &primaryFailedWithNonRetryable)
276258
if decision == retryDecisionReturn {
277259
return streamResult{}, nil, ctx.Err()
278260
} else if decision == retryDecisionBreak {
@@ -292,7 +274,7 @@ func (r *LocalRuntime) tryModelWithFallback(
292274
return streamResult{}, nil, err
293275
}
294276

295-
decision := r.handleModelError(ctx, err, a, modelEntry, attempt, hasFallbacks, &primaryFailedWithNonRetryable)
277+
decision := handleModelError(ctx, err, a, modelEntry, attempt, hasFallbacks, &primaryFailedWithNonRetryable)
296278
if decision == retryDecisionReturn {
297279
return streamResult{}, nil, ctx.Err()
298280
} else if decision == retryDecisionBreak {
@@ -335,10 +317,12 @@ func (r *LocalRuntime) tryModelWithFallback(
335317
type retryDecision int
336318

337319
const (
338-
// retryDecisionContinue means retry the same model (backoff already applied).
339-
retryDecisionContinue retryDecision = iota
340320
// retryDecisionBreak means skip to the next model in the fallback chain.
341-
retryDecisionBreak
321+
// This is the zero value — safe default: skip to next model rather than
322+
// accidentally retrying or returning early.
323+
retryDecisionBreak retryDecision = iota
324+
// retryDecisionContinue means retry the same model (sleep already applied).
325+
retryDecisionContinue
342326
// retryDecisionReturn means context was cancelled; return immediately.
343327
retryDecisionReturn
344328
)
@@ -348,9 +332,12 @@ const (
348332
// - retryDecisionBreak — non-retryable error or 429 with fallbacks; skip to next model
349333
// - retryDecisionContinue — retryable error or 429 without fallbacks; retry same model
350334
//
335+
// All sleeping (both 5xx backoff and 429 Retry-After) is performed here so the
336+
// outer loop never needs its own sleep path.
337+
//
351338
// Side-effect: sets *primaryFailedWithNonRetryable when the primary model fails with a
352339
// non-retryable (or rate-limited-with-fallbacks) error.
353-
func (r *LocalRuntime) handleModelError(
340+
func handleModelError(
354341
ctx context.Context,
355342
err error,
356343
a *agent.Agent,
@@ -363,11 +350,12 @@ func (r *LocalRuntime) handleModelError(
363350

364351
if rateLimited {
365352
if hasFallbacks {
366-
// Fallbacks available → skip to next model immediately (existing behaviour).
367-
slog.Warn("Rate limited with fallbacks available, skipping to next model",
353+
// Fallbacks available → skip to next model immediately.
354+
slog.Warn("Rate limited, skipping model",
368355
"agent", a.Name(),
369356
"model", modelEntry.provider.ID(),
370-
"retry_after", retryAfter)
357+
"retry_after", retryAfter,
358+
"error", err)
371359
if !modelEntry.isFallback {
372360
*primaryFailedWithNonRetryable = true
373361
}
@@ -391,7 +379,8 @@ func (r *LocalRuntime) handleModelError(
391379
"model", modelEntry.provider.ID(),
392380
"attempt", attempt+1,
393381
"wait", waitDuration,
394-
"retry_after_from_header", retryAfter > 0)
382+
"retry_after_from_header", retryAfter > 0,
383+
"error", err)
395384
if !modelerrors.SleepWithContext(ctx, waitDuration) {
396385
return retryDecisionReturn
397386
}
@@ -409,10 +398,16 @@ func (r *LocalRuntime) handleModelError(
409398
return retryDecisionBreak
410399
}
411400

401+
// Retryable (5xx, timeouts): sleep with backoff then retry same model.
402+
waitDuration := modelerrors.CalculateBackoff(attempt)
412403
slog.Warn("Retryable error from model",
413404
"agent", a.Name(),
414405
"model", modelEntry.provider.ID(),
415406
"attempt", attempt+1,
407+
"wait", waitDuration,
416408
"error", err)
409+
if !modelerrors.SleepWithContext(ctx, waitDuration) {
410+
return retryDecisionReturn
411+
}
417412
return retryDecisionContinue
418413
}

pkg/runtime/fallback_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,3 +613,54 @@ func TestFallback500RetryableWithBackoff(t *testing.T) {
613613
assert.Equal(t, 2, primary.callCount, "primary should be called twice: 1 failure + 1 success")
614614
})
615615
}
616+
617+
func TestFallback429WithoutFallbacksContextCancelled(t *testing.T) {
618+
synctest.Test(t, func(t *testing.T) {
619+
// Model always returns 429 with no fallbacks; handleModelError will sleep before
620+
// retrying. We cancel the context while it is sleeping and verify that RunStream
621+
// returns promptly (stream channel closed) rather than hanging until the backoff
622+
// expires.
623+
primary := &failingProvider{
624+
id: "primary/always-429",
625+
err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"),
626+
}
627+
628+
root := agent.New("root", "test",
629+
agent.WithModel(primary),
630+
// No fallback models; 429 will be retried with backoff.
631+
// Use many retries to ensure the runtime would block for a long time
632+
// without context cancellation.
633+
agent.WithFallbackRetries(5),
634+
)
635+
636+
tm := team.New(team.WithAgents(root))
637+
rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}))
638+
require.NoError(t, err)
639+
640+
ctx, cancel := context.WithCancel(t.Context())
641+
defer cancel()
642+
643+
sess := session.New(session.WithUserMessage("test"))
644+
sess.Title = "429 Context Cancel Test"
645+
646+
// Cancel the context from a goroutine once all goroutines in the bubble are
647+
// durably blocked (i.e., the retry sleep has started). synctest.Wait() returns
648+
// only when every goroutine is blocked, so at that point the runtime is mid-sleep.
649+
go func() {
650+
synctest.Wait()
651+
cancel()
652+
}()
653+
654+
// Drain the stream. If context cancellation is properly handled, RunStream
655+
// must close the channel promptly; if not, the bubble's fake time would never
656+
// advance and the test would deadlock.
657+
var eventCount int
658+
for range rt.RunStream(ctx, sess) {
659+
eventCount++
660+
}
661+
// The primary was called at least once before the sleep started.
662+
// We can't assert on eventCount because no content is produced — just verify
663+
// the channel closed (loop above completed without deadlock).
664+
_ = eventCount
665+
})
666+
}

0 commit comments

Comments
 (0)