Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions pkg/modelerrors/modelerrors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ import (
"log/slog"
"math/rand"
"net"
"net/http"
"regexp"
"strconv"
"strings"
"time"

"github.com/anthropics/anthropic-sdk-go"
openai "github.com/openai/openai-go/v3"
"google.golang.org/genai"
)

Expand All @@ -27,6 +30,14 @@ const (
backoffJitter = 0.1
)

// maxRetryAfterWait caps how long we'll honor a Retry-After header to prevent
// a misbehaving server from blocking the agent for an unreasonable amount of time.
const maxRetryAfterWait = 60 * time.Second
Comment thread
rumpl marked this conversation as resolved.
Outdated

// MaxRetryAfterWait is the exported cap for Retry-After header values.
// See maxRetryAfterWait.
const MaxRetryAfterWait = maxRetryAfterWait

// Default fallback configuration.
const (
// DefaultRetries is the default number of retries per model with exponential
Expand Down Expand Up @@ -296,6 +307,92 @@ func IsRetryableModelError(err error) bool {
return false
}

// ExtractRetryAfter extracts the Retry-After duration from an HTTP error response.
// Works with Anthropic and OpenAI SDK error types that expose *http.Response.
// Returns 0 if no Retry-After header is present or the error type is unsupported.
func ExtractRetryAfter(err error) time.Duration {
Comment thread
rumpl marked this conversation as resolved.
Outdated
var resp *http.Response

if anthropicErr, ok := errors.AsType[*anthropic.Error](err); ok {
resp = anthropicErr.Response
} else if openaiErr, ok := errors.AsType[*openai.Error](err); ok {
resp = openaiErr.Response
}

if resp == nil {
return 0
}

return parseRetryAfterHeader(resp.Header.Get("Retry-After"))
}

// parseRetryAfterHeader parses the Retry-After header value.
// Supports both seconds (integer) and HTTP-date formats per RFC 7231 §7.1.3.
// Returns 0 if the value is empty, invalid, or results in a non-positive duration.
func parseRetryAfterHeader(value string) time.Duration {
if value == "" {
return 0
}
// Try integer seconds first (most common for rate limits)
if seconds, err := strconv.Atoi(value); err == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
// Try HTTP-date format
if t, err := http.ParseTime(value); err == nil {
d := time.Until(t)
if d > 0 {
return d
}
}
return 0
}

// ClassifyModelError classifies an error for the retry/fallback decision.
//
// Returns:
// - retryable=true: retry the SAME model with backoff (5xx, timeouts)
// - rateLimited=true: it's a 429 error; caller decides retry vs fallback based on config
// - retryAfter: suggested wait from Retry-After header (only set when rateLimited=true)
//
// When rateLimited=true, retryable is always false — the caller is responsible for
// deciding whether to retry (when no fallback is configured) or skip to the next
// model (when fallbacks are available).
//
// IsRetryableModelError and IsRetryableStatusCode are kept unchanged for backward
// compatibility. This function is the authoritative classifier used by the retry loop.
func ClassifyModelError(err error) (retryable, rateLimited bool, retryAfter time.Duration) {
if err == nil {
return false, false, 0
}

// Context cancellation and deadline are never retryable.
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false, false, 0
}

// Context overflow errors are never retryable — retrying the same oversized
// payload will always fail.
if IsContextOverflowError(err) {
return false, false, 0
}

statusCode := ExtractHTTPStatusCode(err)

// 429: rate limited — caller decides retry-vs-fallback based on config.
if statusCode == http.StatusTooManyRequests {
return false, true, ExtractRetryAfter(err)
}

// Known retryable status codes (5xx, 408, 529).
if statusCode != 0 {
return IsRetryableStatusCode(statusCode), false, 0
}

// No structured status code — fall back to IsRetryableModelError for net.Error
// and message-pattern matching.
return IsRetryableModelError(err), false, 0
}

// CalculateBackoff returns the backoff duration for a given attempt (0-indexed).
// Uses exponential backoff with jitter.
func CalculateBackoff(attempt int) time.Duration {
Expand Down
193 changes: 193 additions & 0 deletions pkg/modelerrors/modelerrors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ import (
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/anthropics/anthropic-sdk-go"
openai "github.com/openai/openai-go/v3"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -278,3 +282,192 @@ func TestFormatError(t *testing.T) {
assert.Equal(t, "authentication failed", FormatError(err))
})
}

// makeAnthropicError creates an *anthropic.Error with the given status code and
// optional Retry-After header value. Used for testing ExtractRetryAfter.
func makeAnthropicError(statusCode int, retryAfterValue string) *anthropic.Error {
header := http.Header{}
if retryAfterValue != "" {
header.Set("Retry-After", retryAfterValue)
}
resp := httptest.NewRecorder().Result()
resp.StatusCode = statusCode
resp.Header = header
// anthropic.Error.Error() dereferences Request, so we must provide a non-nil one.
req, _ := http.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", http.NoBody)
return &anthropic.Error{
StatusCode: statusCode,
Response: resp,
Request: req,
}
}

// makeOpenAIError creates an *openai.Error with the given status code and
// optional Retry-After header value. Used for testing ExtractRetryAfter.
func makeOpenAIError(statusCode int, retryAfterValue string) *openai.Error {
header := http.Header{}
if retryAfterValue != "" {
header.Set("Retry-After", retryAfterValue)
}
resp := httptest.NewRecorder().Result()
resp.StatusCode = statusCode
resp.Header = header
// openai.Error.Error() dereferences Request, so we must provide a non-nil one.
req, _ := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", http.NoBody)
return &openai.Error{
StatusCode: statusCode,
Response: resp,
Request: req,
}
}

func TestParseRetryAfterHeader(t *testing.T) {
t.Parallel()

tests := []struct {
name string
value string
expected time.Duration
}{
{name: "empty", value: "", expected: 0},
{name: "zero seconds", value: "0", expected: 0},
{name: "negative seconds", value: "-1", expected: 0},
{name: "invalid string", value: "foo", expected: 0},
{name: "5 seconds", value: "5", expected: 5 * time.Second},
{name: "30 seconds", value: "30", expected: 30 * time.Second},
{name: "120 seconds", value: "120", expected: 120 * time.Second},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := parseRetryAfterHeader(tt.value)
assert.Equal(t, tt.expected, got, "parseRetryAfterHeader(%q)", tt.value)
})
}

t.Run("HTTP-date in the future", func(t *testing.T) {
t.Parallel()
// Use a time 10 seconds in the future
future := time.Now().Add(10 * time.Second).UTC().Format(http.TimeFormat)
got := parseRetryAfterHeader(future)
assert.Greater(t, got, 0*time.Second, "should return positive duration for future HTTP-date")
assert.LessOrEqual(t, got, 11*time.Second, "should not exceed ~10s for near-future date")
})

t.Run("HTTP-date in the past", func(t *testing.T) {
t.Parallel()
past := time.Now().Add(-10 * time.Second).UTC().Format(http.TimeFormat)
got := parseRetryAfterHeader(past)
assert.Equal(t, 0*time.Second, got, "should return 0 for past HTTP-date")
})
}

func TestExtractRetryAfter(t *testing.T) {
t.Parallel()

t.Run("nil error returns 0", func(t *testing.T) {
t.Parallel()
assert.Equal(t, time.Duration(0), ExtractRetryAfter(nil))
})

t.Run("plain error returns 0", func(t *testing.T) {
t.Parallel()
assert.Equal(t, time.Duration(0), ExtractRetryAfter(errors.New("some error")))
})

t.Run("anthropic error with Retry-After seconds", func(t *testing.T) {
t.Parallel()
err := makeAnthropicError(429, "15")
assert.Equal(t, 15*time.Second, ExtractRetryAfter(err))
})

t.Run("anthropic error without Retry-After header", func(t *testing.T) {
t.Parallel()
err := makeAnthropicError(429, "")
assert.Equal(t, time.Duration(0), ExtractRetryAfter(err))
})

t.Run("openai error with Retry-After seconds", func(t *testing.T) {
t.Parallel()
err := makeOpenAIError(429, "30")
assert.Equal(t, 30*time.Second, ExtractRetryAfter(err))
})

t.Run("openai error without Retry-After header", func(t *testing.T) {
t.Parallel()
err := makeOpenAIError(429, "")
assert.Equal(t, time.Duration(0), ExtractRetryAfter(err))
})

t.Run("wrapped anthropic error", func(t *testing.T) {
t.Parallel()
anthropicErr := makeAnthropicError(429, "5")
wrapped := fmt.Errorf("model failed: %w", anthropicErr)
assert.Equal(t, 5*time.Second, ExtractRetryAfter(wrapped))
})
}

func TestClassifyModelError(t *testing.T) {
t.Parallel()

tests := []struct {
name string
err error
wantRetryable bool
wantRateLimited bool
wantRetryAfterGT time.Duration // retryAfter should be > this (0 means just checking it's >=0)
}{
{name: "nil", err: nil, wantRetryable: false, wantRateLimited: false},
{name: "context canceled", err: context.Canceled, wantRetryable: false, wantRateLimited: false},
{name: "context deadline exceeded", err: context.DeadlineExceeded, wantRetryable: false, wantRateLimited: false},
{name: "context overflow", err: errors.New("prompt is too long: 200000 tokens > 100000 maximum"), wantRetryable: false, wantRateLimited: false},
// 429 rate limit cases
{name: "429 message only", err: errors.New("POST /v1/chat: 429 Too Many Requests"), wantRetryable: false, wantRateLimited: true},
{name: "429 anthropic error no header", err: makeAnthropicError(429, ""), wantRetryable: false, wantRateLimited: true},
{name: "429 openai error no header", err: makeOpenAIError(429, ""), wantRetryable: false, wantRateLimited: true},
{name: "500 openai error", err: makeOpenAIError(500, ""), wantRetryable: true, wantRateLimited: false},
// Retryable server errors
{name: "500 message", err: errors.New("500 internal server error"), wantRetryable: true, wantRateLimited: false},
{name: "500 anthropic error", err: makeAnthropicError(500, ""), wantRetryable: true, wantRateLimited: false},
{name: "502 bad gateway", err: errors.New("502 bad gateway"), wantRetryable: true, wantRateLimited: false},
{name: "503 service unavailable", err: errors.New("503 service unavailable"), wantRetryable: true, wantRateLimited: false},
{name: "504 gateway timeout", err: errors.New("504 gateway timeout"), wantRetryable: true, wantRateLimited: false},
{name: "529 overloaded", err: makeAnthropicError(529, ""), wantRetryable: true, wantRateLimited: false},
{name: "408 timeout", err: makeAnthropicError(408, ""), wantRetryable: true, wantRateLimited: false},
// Non-retryable client errors
{name: "400 bad request", err: makeAnthropicError(400, ""), wantRetryable: false, wantRateLimited: false},
{name: "401 unauthorized", err: makeAnthropicError(401, ""), wantRetryable: false, wantRateLimited: false},
{name: "403 forbidden", err: makeAnthropicError(403, ""), wantRetryable: false, wantRateLimited: false},
// Network errors
{name: "network timeout", err: &mockTimeoutError{}, wantRetryable: true, wantRateLimited: false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
retryable, rateLimited, retryAfter := ClassifyModelError(tt.err)
assert.Equal(t, tt.wantRetryable, retryable, "retryable mismatch")
assert.Equal(t, tt.wantRateLimited, rateLimited, "rateLimited mismatch")
assert.GreaterOrEqual(t, retryAfter, time.Duration(0), "retryAfter should never be negative")
})
}

t.Run("429 with Retry-After header propagated", func(t *testing.T) {
t.Parallel()
err := makeAnthropicError(429, "20")
retryable, rateLimited, retryAfter := ClassifyModelError(err)
assert.False(t, retryable)
assert.True(t, rateLimited)
assert.Equal(t, 20*time.Second, retryAfter)
})

t.Run("429 openai with Retry-After header", func(t *testing.T) {
t.Parallel()
err := makeOpenAIError(429, "10")
retryable, rateLimited, retryAfter := ClassifyModelError(err)
assert.False(t, retryable)
assert.True(t, rateLimited)
assert.Equal(t, 10*time.Second, retryAfter)
})
}
Loading
Loading