diff --git a/connection.go b/connection.go index 5b8a91c..3e6ac45 100644 --- a/connection.go +++ b/connection.go @@ -19,6 +19,7 @@ import ( context2 "github.com/databricks/databricks-sql-go/internal/compat/context" "github.com/databricks/databricks-sql-go/internal/config" dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" + "github.com/databricks/databricks-sql-go/internal/retry" "github.com/databricks/databricks-sql-go/internal/rows" "github.com/databricks/databricks-sql-go/internal/sentinel" "github.com/databricks/databricks-sql-go/internal/thrift_protocol" @@ -647,17 +648,21 @@ var _ driver.ConnBeginTx = (*conn)(nil) var _ driver.NamedValueChecker = (*conn)(nil) func Succeeded(response *http.Response) bool { - if response.StatusCode == 200 || response.StatusCode == 201 || response.StatusCode == 202 || response.StatusCode == 204 { - return true - } - return false + return statusInSuccessRange(response.StatusCode) +} + +// statusInSuccessRange returns true for the 2xx status codes the staging +// HTTP path treats as success: 200 OK / 201 Created / 202 Accepted / 204 +// No Content. Exposed separately from Succeeded so handlers can extend the +// accept set (e.g. REMOVE accepts 404 for idempotent-delete semantics). +func statusInSuccessRange(status int) bool { + return status == 200 || status == 201 || status == 202 || status == 204 } func (c *conn) handleStagingPut(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError { if localFile == "" { return dbsqlerrint.NewDriverError(ctx, "cannot perform PUT without specifying a local_file", nil) } - client := &http.Client{} dat, err := os.Open(localFile) //nolint:gosec // localFile is provided by the application, not user input if err != nil { @@ -669,73 +674,222 @@ func (c *conn) handleStagingPut(ctx context.Context, presignedUrl string, header if err != nil { return dbsqlerrint.NewDriverError(ctx, "error reading local file info", err) } - - req, _ := http.NewRequest("PUT", presignedUrl, dat) - req.ContentLength = info.Size() // backend actually requires content length to be known - - for k, v := range headers { - req.Header.Set(k, v) - } - res, err := client.Do(req) - if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) + size := info.Size() + + // Each retry attempt needs a fresh request because http.Client.Do consumes + // the request body. Rewind the *os.File between attempts so the server + // receives the full payload on every retry, not just attempt 1. + // + // Wrap the file in io.NopCloser so http.Client.Do can't close it — by + // default it closes any body that implements io.Closer, which would break + // the Seek on the next retry. The outer defer dat.Close() above owns the + // file's lifecycle. + reqFactory := func(attempt int) (*http.Request, error) { + if attempt > 0 { + if _, seekErr := dat.Seek(0, io.SeekStart); seekErr != nil { + return nil, seekErr + } + } + req, reqErr := http.NewRequestWithContext(ctx, http.MethodPut, presignedUrl, io.NopCloser(dat)) + if reqErr != nil { + return nil, reqErr + } + req.ContentLength = size // backend actually requires content length to be known + for k, v := range headers { + req.Header.Set(k, v) + } + return req, nil } - defer res.Body.Close() //nolint:errcheck - content, err := io.ReadAll(res.Body) - if err != nil || !Succeeded(res) { - return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil) + if _, err := c.doStagingRequestWithRetry(ctx, reqFactory); err != nil { + return err } return nil - } func (c *conn) handleStagingGet(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError { if localFile == "" { return dbsqlerrint.NewDriverError(ctx, "cannot perform GET without specifying a local_file", nil) } - client := &http.Client{} - req, _ := http.NewRequest("GET", presignedUrl, nil) - for k, v := range headers { - req.Header.Set(k, v) + reqFactory := func(_ int) (*http.Request, error) { + req, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, presignedUrl, nil) + if reqErr != nil { + return nil, reqErr + } + for k, v := range headers { + req.Header.Set(k, v) + } + return req, nil } - res, err := client.Do(req) + + content, err := c.doStagingRequestWithRetry(ctx, reqFactory) if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) + return err } - defer res.Body.Close() //nolint:errcheck - content, err := io.ReadAll(res.Body) + if writeErr := os.WriteFile(localFile, content, 0644); writeErr != nil { //nolint:gosec + return dbsqlerrint.NewDriverError(ctx, "error writing local file", writeErr) + } + return nil +} - if err != nil || !Succeeded(res) { - return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil) +func (c *conn) handleStagingRemove(ctx context.Context, presignedUrl string, headers map[string]string) dbsqlerr.DBError { + reqFactory := func(_ int) (*http.Request, error) { + req, reqErr := http.NewRequestWithContext(ctx, http.MethodDelete, presignedUrl, nil) + if reqErr != nil { + return nil, reqErr + } + for k, v := range headers { + req.Header.Set(k, v) + } + return req, nil } - err = os.WriteFile(localFile, content, 0644) //nolint:gosec - if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error writing local file", err) + // Treat 404 as success on REMOVE: DELETE is idempotent, and a 404 means + // the object is already absent — which is the post-condition the caller + // asked for. This also avoids spurious failures when a successful DELETE + // returns a transient 5xx mid-response and the retry sees 404 from the + // server having already applied the original request. + acceptStatus := func(status int) bool { + return statusInSuccessRange(status) || status == http.StatusNotFound + } + + if _, err := c.doStagingRequestWithRetryAccept(ctx, reqFactory, acceptStatus); err != nil { + return err } return nil } -func (c *conn) handleStagingRemove(ctx context.Context, presignedUrl string, headers map[string]string) dbsqlerr.DBError { - client := &http.Client{} - req, _ := http.NewRequest("DELETE", presignedUrl, nil) - for k, v := range headers { - req.Header.Set(k, v) +// maxStagingErrorBodyBytes bounds the response body bytes included in +// terminal staging error messages. Proxies and misconfigured backends can +// return multi-MB error bodies; truncating keeps the driver error readable +// without dropping the typical S3 XML error code that fits well under 512B. +const maxStagingErrorBodyBytes = 512 + +// doStagingRequestWithRetry executes a staging HTTP request with retry on +// transient object-storage failures (ES-1911239). Wraps +// doStagingRequestWithRetryAccept with the default success predicate (2xx +// from statusInSuccessRange / Succeeded). +func (c *conn) doStagingRequestWithRetry(ctx context.Context, reqFactory func(attempt int) (*http.Request, error)) ([]byte, dbsqlerr.DBError) { + return c.doStagingRequestWithRetryAccept(ctx, reqFactory, statusInSuccessRange) +} + +// doStagingRequestWithRetryAccept is the generalized staging retry helper +// used by all three handleStaging* methods. Mirrors the CloudFetch retry +// path (ES-1892645) in semantics — same retryable status set, same +// exponential-backoff-with-jitter schedule, same RetryMax/RetryWaitMin/ +// RetryWaitMax config knobs — so behavior is consistent across the driver's +// two object-storage code paths. +// +// reqFactory must return a fresh *http.Request on each call. Attempt 0 is +// the initial request; attempt N>0 is a retry. The PUT path uses this to +// rewind the file body between attempts; other staging paths just construct +// a new request each time. +// +// acceptStatus reports whether a given HTTP status code should be treated +// as success. Most handlers pass statusInSuccessRange. The REMOVE handler +// extends this to also accept 404 (idempotent-delete semantics). +// +// On success returns the response body bytes. On terminal failure (non- +// retryable status, exhausted retries, or context cancellation) returns a +// dbsqlerr.DBError describing the final state. +func (c *conn) doStagingRequestWithRetryAccept( + ctx context.Context, + reqFactory func(attempt int) (*http.Request, error), + acceptStatus func(status int) bool, +) ([]byte, dbsqlerr.DBError) { + retryMax := c.cfg.RetryMax + if retryMax < 0 { + retryMax = 0 } - res, err := client.Do(req) - if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) + client := &http.Client{} + + var ( + lastErr error + lastStatus int + lastBody []byte + lastRetryAfter string + ) + + for attempt := 0; attempt <= retryMax; attempt++ { + if attempt > 0 { + wait := retry.Backoff(attempt, c.cfg.RetryWaitMin, c.cfg.RetryWaitMax, lastRetryAfter) + logger.Debug().Msgf( + "staging: retrying HTTP request (attempt %d/%d) in %v; lastStatus=%d lastErr=%v", + attempt, retryMax, wait, lastStatus, lastErr, + ) + t := time.NewTimer(wait) + select { + case <-ctx.Done(): + if !t.Stop() { + <-t.C + } + return nil, dbsqlerrint.NewDriverError(ctx, "staging operation cancelled during retry backoff", ctx.Err()) + case <-t.C: + } + } + + req, reqErr := reqFactory(attempt) + if reqErr != nil { + return nil, dbsqlerrint.NewDriverError(ctx, "error building staging http request", reqErr) + } + + res, err := client.Do(req) + if err != nil { + // Caller cancellation is terminal; otherwise treat transport + // errors (TCP RST, TLS timeout, etc.) as transient. + if ctx.Err() != nil { + return nil, dbsqlerrint.NewDriverError(ctx, "error sending http request", ctx.Err()) + } + lastErr = err + lastStatus = 0 + lastRetryAfter = "" + continue + } + + body, readErr := io.ReadAll(res.Body) + res.Body.Close() //nolint:errcheck,gosec // G104: close after drain + + if readErr != nil { + if ctx.Err() != nil { + return nil, dbsqlerrint.NewDriverError(ctx, "error reading http response", ctx.Err()) + } + lastErr = readErr + lastStatus = 0 + lastRetryAfter = "" + continue + } + + if acceptStatus(res.StatusCode) { + return body, nil + } + + lastStatus = res.StatusCode + lastErr = nil + lastBody = body + lastRetryAfter = res.Header.Get("Retry-After") + + if !retry.IsRetryableStatus(res.StatusCode) { + return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, truncateErrorBody(body)), nil) + } } - defer res.Body.Close() //nolint:errcheck - content, err := io.ReadAll(res.Body) - if err != nil || !Succeeded(res) { - return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s, nil", res.StatusCode, content), nil) + if lastStatus != 0 { + // lastErr is nil here by construction: the HTTP-status branch above + // explicitly clears it on every iteration. The status code and body + // are captured in msg, so there's no underlying error to wrap. + return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s (after %d retries)", lastStatus, truncateErrorBody(lastBody), retryMax), nil) } + return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation HTTP request failed: %v (after %d retries)", lastErr, retryMax), lastErr) +} - return nil +// truncateErrorBody caps b at maxStagingErrorBodyBytes for inclusion in error +// messages, appending an indicator when truncation occurred. +func truncateErrorBody(b []byte) string { + if len(b) <= maxStagingErrorBodyBytes { + return string(b) + } + return fmt.Sprintf("%s... (%d bytes total, truncated)", b[:maxStagingErrorBodyBytes], len(b)) } func localPathIsAllowed(stagingAllowedLocalPaths []string, localFile string) bool { diff --git a/connection_test.go b/connection_test.go index badf6a7..92f0837 100644 --- a/connection_test.go +++ b/connection_test.go @@ -4,6 +4,13 @@ import ( "context" "database/sql/driver" "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "sync/atomic" "testing" "time" @@ -2035,3 +2042,317 @@ func getTestSession() *cli_service.TOpenSessionResp { }, }} } + +// TestConn_handleStagingRetry verifies that the staging-operation HTTP wrappers +// (handleStagingPut/Get/Remove) retry transient S3 errors. ES-1911239: FactSet +// hit intermittent HTTP 503 SlowDown on PUT against a Unity Catalog external +// volume; pre-fix any single 5xx failed the entire SQL statement. +// +// Retry behavior mirrors the CloudFetch fix from ES-1892645 / PR #355: +// - Retryable statuses: 408/429/500/502/503/504. +// - Exponential backoff with equal jitter, honoring RetryMax/RetryWaitMin/ +// RetryWaitMax from the connection config. +// - Integer Retry-After response header is honored (capped at RetryWaitMax). +// - Non-retryable statuses (e.g. 403) fail on the first attempt. +// - Context cancellation aborts backoff promptly. +func TestConn_handleStagingRetry(t *testing.T) { + // retryCfg returns a fast-backoff config so tests don't burn wall-clock + // on sleeps. RetryMax leaves room for several retries; RetryWaitMin/Max + // keep the worst-case test runtime under a second. + retryCfg := func() *config.Config { + cfg := config.WithDefaults() + cfg.RetryMax = 4 + cfg.RetryWaitMin = 1 * time.Millisecond + cfg.RetryWaitMax = 5 * time.Millisecond + return cfg + } + + t.Run("PUT retries transient 503 and eventually succeeds", func(t *testing.T) { + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte("SlowDown")) + return + } + // Drain the body so we exercise the retry-replay path for PUTs. + _, _ = io.Copy(io.Discard, r.Body) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + if err := os.WriteFile(localFile, []byte("parquet bytes"), 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + c := &conn{cfg: retryCfg()} + err := c.handleStagingPut(context.Background(), server.URL, nil, localFile) + assert.Nil(t, err) + assert.Equal(t, int32(3), atomic.LoadInt32(&attempts), "expected 2 retries before success") + }) + + t.Run("GET retries transient 503 and eventually succeeds", func(t *testing.T) { + var attempts int32 + body := []byte("downloaded bytes") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n < 2 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write(body) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "out.bin") + + c := &conn{cfg: retryCfg()} + err := c.handleStagingGet(context.Background(), server.URL, nil, localFile) + assert.Nil(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&attempts)) + + got, readErr := os.ReadFile(localFile) + assert.Nil(t, readErr) + assert.Equal(t, body, got, "GET should write the final-attempt body to local file") + }) + + t.Run("REMOVE retries transient 503 and eventually succeeds", func(t *testing.T) { + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n < 2 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + c := &conn{cfg: retryCfg()} + err := c.handleStagingRemove(context.Background(), server.URL, nil) + assert.Nil(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&attempts)) + }) + + t.Run("REMOVE treats 503-then-404 as success (idempotent delete)", func(t *testing.T) { + // The first DELETE may have applied server-side even though the + // response was 503 (load balancer dies mid-response, etc.). The + // retry then sees 404 — the object is already gone. The caller's + // post-condition ("object absent") is satisfied, so this is success. + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n == 1 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + c := &conn{cfg: retryCfg()} + err := c.handleStagingRemove(context.Background(), server.URL, nil) + assert.Nil(t, err, "503 then 404 on REMOVE should succeed: the object is absent, which is the caller's intent") + assert.Equal(t, int32(2), atomic.LoadInt32(&attempts)) + }) + + t.Run("REMOVE treats first-attempt 404 as success (idempotent delete)", func(t *testing.T) { + // DELETE on a non-existent object is success: the post-condition + // ("object absent") is already true. Documents the behavior change + // vs. the pre-retry implementation, which surfaced 404 as failure. + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + c := &conn{cfg: retryCfg()} + err := c.handleStagingRemove(context.Background(), server.URL, nil) + assert.Nil(t, err, "404 on REMOVE should always be success") + assert.Equal(t, int32(1), atomic.LoadInt32(&attempts), "404 must not trigger a retry") + }) + + t.Run("PUT first-attempt 404 still fails (only REMOVE treats 404 as success)", func(t *testing.T) { + // Guard against the 404-as-success behavior leaking into the other + // handlers. PUT/GET must still treat 404 as a terminal failure. + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + if err := os.WriteFile(localFile, []byte("data"), 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + c := &conn{cfg: retryCfg()} + err := c.handleStagingPut(context.Background(), server.URL, nil, localFile) + assert.NotNil(t, err, "404 on PUT must remain a terminal failure") + assert.ErrorContains(t, err, "404") + assert.Equal(t, int32(1), atomic.LoadInt32(&attempts)) + }) + + t.Run("PUT retries transient HTTP 500", func(t *testing.T) { + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n < 2 { + w.WriteHeader(http.StatusInternalServerError) + return + } + _, _ = io.Copy(io.Discard, r.Body) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + if err := os.WriteFile(localFile, []byte("data"), 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + c := &conn{cfg: retryCfg()} + err := c.handleStagingPut(context.Background(), server.URL, nil, localFile) + assert.Nil(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&attempts)) + }) + + t.Run("PUT fails after exhausting retries on persistent 503", func(t *testing.T) { + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + if err := os.WriteFile(localFile, []byte("data"), 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + cfg := retryCfg() + cfg.RetryMax = 2 + c := &conn{cfg: cfg} + err := c.handleStagingPut(context.Background(), server.URL, nil, localFile) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "503") + // initial attempt + RetryMax retries + assert.Equal(t, int32(3), atomic.LoadInt32(&attempts)) + }) + + t.Run("PUT does not retry non-retryable status (403)", func(t *testing.T) { + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusForbidden) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + if err := os.WriteFile(localFile, []byte("data"), 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + c := &conn{cfg: retryCfg()} + started := time.Now() + err := c.handleStagingPut(context.Background(), server.URL, nil, localFile) + elapsed := time.Since(started) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "403") + assert.Equal(t, int32(1), atomic.LoadInt32(&attempts), "non-retryable status must fail on first attempt") + // retryCfg's RetryWaitMin is 1ms; if a backoff fired by mistake we'd + // observe at least that. 50ms gives headroom for slow CI without + // masking an accidental retry. + assert.Less(t, elapsed, 50*time.Millisecond, "non-retryable status must not trigger backoff") + }) + + t.Run("PUT replays the file body on each retry", func(t *testing.T) { + // Verifies that the retry implementation correctly handles the request + // body lifecycle: an os.File consumed by attempt N must be rewound or + // re-opened before attempt N+1, otherwise the server sees a zero-length + // body on retries. + var ( + mu sync.Mutex + receivedSizes []int64 + ) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + mu.Lock() + receivedSizes = append(receivedSizes, int64(len(body))) + n := len(receivedSizes) + mu.Unlock() + if n < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + payload := []byte("important parquet data that must be re-sent on each retry") + if err := os.WriteFile(localFile, payload, 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + c := &conn{cfg: retryCfg()} + err := c.handleStagingPut(context.Background(), server.URL, nil, localFile) + assert.Nil(t, err) + mu.Lock() + defer mu.Unlock() + assert.Equal(t, 3, len(receivedSizes), "expected 3 PUT attempts") + for i, sz := range receivedSizes { + assert.Equal(t, int64(len(payload)), sz, "attempt %d received %d bytes, expected full payload of %d bytes", i+1, sz, len(payload)) + } + }) + + t.Run("PUT respects context cancellation during backoff", func(t *testing.T) { + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + if err := os.WriteFile(localFile, []byte("data"), 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + cfg := retryCfg() + cfg.RetryMax = 5 + cfg.RetryWaitMin = 500 * time.Millisecond + cfg.RetryWaitMax = 1 * time.Second + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + c := &conn{cfg: cfg} + started := time.Now() + err := c.handleStagingPut(ctx, server.URL, nil, localFile) + elapsed := time.Since(started) + + assert.NotNil(t, err) + // Cancellation should land well before the full retry budget elapses. + // 5 retries * 500ms+ minimum backoff = 2.5s+ without cancellation. + // 2s gives generous headroom on slow CI runners without masking a + // regression where cancellation is honored only at retry boundaries. + assert.Less(t, elapsed, 2*time.Second, "context cancel should abort PUT retry backoff promptly") + }) +} diff --git a/internal/retry/retry.go b/internal/retry/retry.go new file mode 100644 index 0000000..87fc9ad --- /dev/null +++ b/internal/retry/retry.go @@ -0,0 +1,67 @@ +// Package retry provides shared HTTP retry/backoff helpers for transient +// object-storage failures (S3 SlowDown, 5xx, etc.). Used by the CloudFetch +// download path and the staging-operation (PUT/GET/REMOVE) handlers so both +// share a single retryable-status set and backoff schedule. +package retry + +import ( + "math" + "math/rand" + "net/http" + "strconv" + "time" +) + +// RetryableStatuses lists HTTP status codes from object storage that indicate +// transient conditions and warrant a retry. Mirrors AWS S3 guidance for +// SlowDown (503) / InternalError (500) plus the general 408/429/502/504. +var RetryableStatuses = map[int]struct{}{ + http.StatusRequestTimeout: {}, // 408 + http.StatusTooManyRequests: {}, // 429 + http.StatusInternalServerError: {}, // 500 + http.StatusBadGateway: {}, // 502 + http.StatusServiceUnavailable: {}, // 503 + http.StatusGatewayTimeout: {}, // 504 +} + +// IsRetryableStatus reports whether the given HTTP status code is a transient +// object-storage failure that warrants a retry. +func IsRetryableStatus(status int) bool { + _, ok := RetryableStatuses[status] + return ok +} + +// Backoff returns the wait before retry attempt N (1-based). The base delay +// is exponential — waitMin * 2^(attempt-1) capped at waitMax — with equal +// jitter applied: the actual sleep is uniformly distributed in +// [base/2, base]. Equal jitter (rather than no jitter) spreads synchronized +// retries across concurrent callers, which would otherwise hammer the storage +// endpoint in lockstep after a region-wide blip. If the server returned a +// parseable integer Retry-After header, that value (in seconds) is honored +// instead, capped at waitMax. HTTP-date Retry-After values are ignored — +// same as the Thrift client's backoff. +func Backoff(attempt int, waitMin, waitMax time.Duration, retryAfter string) time.Duration { + if retryAfter != "" { + if secs, err := strconv.ParseInt(retryAfter, 10, 64); err == nil && secs >= 0 { + d := time.Duration(secs) * time.Second + if d > waitMax { + return waitMax + } + return d + } + } + + expo := float64(waitMin) * math.Pow(2, float64(attempt-1)) + if expo > float64(waitMax) || math.IsInf(expo, 0) { + expo = float64(waitMax) + } + base := time.Duration(expo) + if base <= 0 { + return 0 + } + half := base / 2 + if half <= 0 { + return base + } + return half + time.Duration(rand.Int63n(int64(half))) //nolint:gosec // G404: jitter only, non-cryptographic +} diff --git a/internal/retry/retry_test.go b/internal/retry/retry_test.go new file mode 100644 index 0000000..a4934dd --- /dev/null +++ b/internal/retry/retry_test.go @@ -0,0 +1,69 @@ +package retry + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBackoff(t *testing.T) { + t.Run("retry-after integer seconds is honored", func(t *testing.T) { + got := Backoff(1, 100*time.Millisecond, 60*time.Second, "2") + assert.Equal(t, 2*time.Second, got) + }) + + t.Run("retry-after is capped at waitMax", func(t *testing.T) { + got := Backoff(1, 100*time.Millisecond, 1*time.Second, "100") + assert.Equal(t, 1*time.Second, got) + }) + + t.Run("retry-after http-date is ignored, falls back to exponential", func(t *testing.T) { + minWait := 100 * time.Millisecond + got := Backoff(1, minWait, 10*time.Second, "Tue, 15 Nov 1994 08:12:31 GMT") + // attempt=1 base = minWait; equal jitter in [minWait/2, minWait] + assert.GreaterOrEqual(t, got, minWait/2) + assert.LessOrEqual(t, got, minWait) + }) + + t.Run("exponential is capped at waitMax", func(t *testing.T) { + maxWait := 200 * time.Millisecond + // 100ms * 2^9 = 51200ms, capped at 200ms; equal jitter -> [100ms, 200ms] + for i := 0; i < 50; i++ { + got := Backoff(10, 100*time.Millisecond, maxWait, "") + assert.GreaterOrEqual(t, got, maxWait/2) + assert.LessOrEqual(t, got, maxWait) + } + }) + + t.Run("base grows exponentially with attempt", func(t *testing.T) { + minWait, maxWait := 100*time.Millisecond, 10*time.Second + // attempt=1 -> base 100ms, jitter [50ms,100ms] + // attempt=3 -> base 400ms, jitter [200ms,400ms] + for i := 0; i < 50; i++ { + got1 := Backoff(1, minWait, maxWait, "") + got3 := Backoff(3, minWait, maxWait, "") + assert.GreaterOrEqual(t, got1, 50*time.Millisecond) + assert.LessOrEqual(t, got1, 100*time.Millisecond) + assert.GreaterOrEqual(t, got3, 200*time.Millisecond) + assert.LessOrEqual(t, got3, 400*time.Millisecond) + } + }) + + t.Run("zero waitMin returns zero", func(t *testing.T) { + got := Backoff(1, 0, 0, "") + assert.Equal(t, time.Duration(0), got) + }) +} + +func TestIsRetryableStatus(t *testing.T) { + retryable := []int{408, 429, 500, 502, 503, 504} + notRetryable := []int{200, 201, 301, 302, 400, 401, 403, 404, 409, 410, 501} + + for _, s := range retryable { + assert.True(t, IsRetryableStatus(s), "%d should be retryable", s) + } + for _, s := range notRetryable { + assert.False(t, IsRetryableStatus(s), "%d should not be retryable", s) + } +} diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 67dfd25..889524d 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -5,13 +5,11 @@ import ( "context" "fmt" "io" - "math" - "math/rand" - "strconv" "strings" "time" "github.com/databricks/databricks-sql-go/internal/config" + "github.com/databricks/databricks-sql-go/internal/retry" "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" "github.com/pierrec/lz4/v4" "github.com/pkg/errors" @@ -402,7 +400,7 @@ func fetchBatchBytes( for attempt := 0; attempt <= retryMax; attempt++ { if attempt > 0 { - wait := cloudFetchBackoff(attempt, retryWaitMin, retryWaitMax, lastRetryAfter) + wait := retry.Backoff(attempt, retryWaitMin, retryWaitMax, lastRetryAfter) logger.Debug().Msgf( "CloudFetch: retrying download of link at offset %d (attempt %d/%d) in %v; lastStatus=%d lastErr=%v", link.StartRowOffset, attempt, retryMax, wait, lastStatus, lastErr, @@ -475,7 +473,7 @@ func fetchBatchBytes( lastErr = nil lastRetryAfter = res.Header.Get("Retry-After") - if !isCloudFetchRetryableStatus(res.StatusCode) { + if !retry.IsRetryableStatus(res.StatusCode) { msg := fmt.Sprintf("%s: %s %d", errArrowRowsCloudFetchDownloadFailure, "HTTP error", res.StatusCode) return nil, dbsqlerrint.NewDriverError(ctx, msg, nil) } @@ -491,59 +489,6 @@ func fetchBatchBytes( return nil, dbsqlerrint.NewDriverError(ctx, msg, lastErr) } -// cloudFetchRetryableStatuses lists HTTP status codes from object storage that -// indicate transient conditions and warrant a retry. Mirrors AWS S3 guidance -// for SlowDown (503) / InternalError (500) plus the general 408/429/502/504. -var cloudFetchRetryableStatuses = map[int]struct{}{ - http.StatusRequestTimeout: {}, // 408 - http.StatusTooManyRequests: {}, // 429 - http.StatusInternalServerError: {}, // 500 - http.StatusBadGateway: {}, // 502 - http.StatusServiceUnavailable: {}, // 503 - http.StatusGatewayTimeout: {}, // 504 -} - -func isCloudFetchRetryableStatus(status int) bool { - _, ok := cloudFetchRetryableStatuses[status] - return ok -} - -// cloudFetchBackoff returns the wait before retry attempt N (1-based). The -// base delay is exponential — waitMin * 2^(attempt-1) capped at waitMax — with -// equal jitter applied: the actual sleep is uniformly distributed in -// [base/2, base]. Equal jitter (rather than no jitter) is used to spread -// synchronized retries across the up-to-MaxDownloadThreads concurrent -// downloads, which would otherwise hammer the storage endpoint in lockstep -// after a region-wide blip. If the server returned a parseable integer -// Retry-After header, that value (in seconds) is honored instead, capped at -// waitMax. HTTP-date Retry-After values are ignored — same as the Thrift -// client's backoff. -func cloudFetchBackoff(attempt int, waitMin, waitMax time.Duration, retryAfter string) time.Duration { - if retryAfter != "" { - if secs, err := strconv.ParseInt(retryAfter, 10, 64); err == nil && secs >= 0 { - d := time.Duration(secs) * time.Second - if d > waitMax { - return waitMax - } - return d - } - } - - expo := float64(waitMin) * math.Pow(2, float64(attempt-1)) - if expo > float64(waitMax) || math.IsInf(expo, 0) { - expo = float64(waitMax) - } - base := time.Duration(expo) - if base <= 0 { - return 0 - } - half := base / 2 - if half <= 0 { - return base - } - return half + time.Duration(rand.Int63n(int64(half))) //nolint:gosec // G404: jitter only, non-cryptographic -} - func getReader(r io.Reader, useLz4Compression bool) io.Reader { if useLz4Compression { return lz4.NewReader(r) diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index d8d942b..d66ded3 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -697,67 +697,6 @@ func TestCloudFetchIterator(t *testing.T) { }) } -func TestCloudFetchBackoff(t *testing.T) { - t.Run("retry-after integer seconds is honored", func(t *testing.T) { - got := cloudFetchBackoff(1, 100*time.Millisecond, 60*time.Second, "2") - assert.Equal(t, 2*time.Second, got) - }) - - t.Run("retry-after is capped at waitMax", func(t *testing.T) { - got := cloudFetchBackoff(1, 100*time.Millisecond, 1*time.Second, "100") - assert.Equal(t, 1*time.Second, got) - }) - - t.Run("retry-after http-date is ignored, falls back to exponential", func(t *testing.T) { - minWait := 100 * time.Millisecond - got := cloudFetchBackoff(1, minWait, 10*time.Second, "Tue, 15 Nov 1994 08:12:31 GMT") - // attempt=1 base = minWait; equal jitter in [minWait/2, minWait] - assert.GreaterOrEqual(t, got, minWait/2) - assert.LessOrEqual(t, got, minWait) - }) - - t.Run("exponential is capped at waitMax", func(t *testing.T) { - maxWait := 200 * time.Millisecond - // 100ms * 2^9 = 51200ms, capped at 200ms; equal jitter -> [100ms, 200ms] - for i := 0; i < 50; i++ { - got := cloudFetchBackoff(10, 100*time.Millisecond, maxWait, "") - assert.GreaterOrEqual(t, got, maxWait/2) - assert.LessOrEqual(t, got, maxWait) - } - }) - - t.Run("base grows exponentially with attempt", func(t *testing.T) { - minWait, maxWait := 100*time.Millisecond, 10*time.Second - // attempt=1 -> base 100ms, jitter [50ms,100ms] - // attempt=3 -> base 400ms, jitter [200ms,400ms] - for i := 0; i < 50; i++ { - got1 := cloudFetchBackoff(1, minWait, maxWait, "") - got3 := cloudFetchBackoff(3, minWait, maxWait, "") - assert.GreaterOrEqual(t, got1, 50*time.Millisecond) - assert.LessOrEqual(t, got1, 100*time.Millisecond) - assert.GreaterOrEqual(t, got3, 200*time.Millisecond) - assert.LessOrEqual(t, got3, 400*time.Millisecond) - } - }) - - t.Run("zero waitMin returns zero", func(t *testing.T) { - got := cloudFetchBackoff(1, 0, 0, "") - assert.Equal(t, time.Duration(0), got) - }) -} - -func TestCloudFetchRetryableStatus(t *testing.T) { - retryable := []int{408, 429, 500, 502, 503, 504} - notRetryable := []int{200, 201, 301, 302, 400, 401, 403, 404, 409, 410, 501} - - for _, s := range retryable { - assert.True(t, isCloudFetchRetryableStatus(s), "%d should be retryable", s) - } - for _, s := range notRetryable { - assert.False(t, isCloudFetchRetryableStatus(s), "%d should not be retryable", s) - } -} - func TestCloudFetchSchemaOverride(t *testing.T) { // Reproduces ES-1804970: When the server result cache serves Arrow IPC files // from a prior query, the embedded schema has stale column names. The