From 0be12562df7376cf55b8d294e4b828f8d64682a9 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Thu, 21 May 2026 06:52:43 +0000 Subject: [PATCH] [ES-1911239] Retry transient S3 errors on staging PUT/GET/REMOVE MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sibling of ES-1892645/PR #355. The three staging-operation HTTP wrappers in connection.go (handleStagingPut/Get/Remove) make a single client.Do call with no retry — any single transient S3 5xx (e.g. 503 SlowDown during load) fails the entire SQL statement permanently. Adds retry-with-exponential-backoff to the staging path with the same semantics as the CloudFetch fix: - Retryable statuses: 408/429/500/502/503/504 - Equal-jitter exponential backoff capped at RetryWaitMax - Integer Retry-After response header honored - Context cancellation aborts backoff promptly - Reuses existing RetryMax/RetryWaitMin/RetryWaitMax config knobs (consistent with the CloudFetch path the customer asked about) The PUT path needs special handling: http.Client.Do consumes the request body (an *os.File), so the retry helper rewinds the file with Seek(0, SeekStart) between attempts and wraps it in io.NopCloser so the client can't close the file on us. Factors the shared retry primitives (RetryableStatuses, IsRetryableStatus, Backoff) into a new internal/retry package so the CloudFetch path (internal/rows/arrowbased/batchloader.go) and the staging path share one implementation. This addresses the "two divergent retry implementations" follow-up from the #355 review. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- connection.go | 246 +++++++++++--- connection_test.go | 321 +++++++++++++++++++ internal/retry/retry.go | 67 ++++ internal/retry/retry_test.go | 69 ++++ internal/rows/arrowbased/batchloader.go | 61 +--- internal/rows/arrowbased/batchloader_test.go | 61 ---- 6 files changed, 660 insertions(+), 165 deletions(-) create mode 100644 internal/retry/retry.go create mode 100644 internal/retry/retry_test.go diff --git a/connection.go b/connection.go index 5b8a91c..3e6ac45 100644 --- a/connection.go +++ b/connection.go @@ -19,6 +19,7 @@ import ( context2 "github.com/databricks/databricks-sql-go/internal/compat/context" "github.com/databricks/databricks-sql-go/internal/config" dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" + "github.com/databricks/databricks-sql-go/internal/retry" "github.com/databricks/databricks-sql-go/internal/rows" "github.com/databricks/databricks-sql-go/internal/sentinel" "github.com/databricks/databricks-sql-go/internal/thrift_protocol" @@ -647,17 +648,21 @@ var _ driver.ConnBeginTx = (*conn)(nil) var _ driver.NamedValueChecker = (*conn)(nil) func Succeeded(response *http.Response) bool { - if response.StatusCode == 200 || response.StatusCode == 201 || response.StatusCode == 202 || response.StatusCode == 204 { - return true - } - return false + return statusInSuccessRange(response.StatusCode) +} + +// statusInSuccessRange returns true for the 2xx status codes the staging +// HTTP path treats as success: 200 OK / 201 Created / 202 Accepted / 204 +// No Content. Exposed separately from Succeeded so handlers can extend the +// accept set (e.g. REMOVE accepts 404 for idempotent-delete semantics). +func statusInSuccessRange(status int) bool { + return status == 200 || status == 201 || status == 202 || status == 204 } func (c *conn) handleStagingPut(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError { if localFile == "" { return dbsqlerrint.NewDriverError(ctx, "cannot perform PUT without specifying a local_file", nil) } - client := &http.Client{} dat, err := os.Open(localFile) //nolint:gosec // localFile is provided by the application, not user input if err != nil { @@ -669,73 +674,222 @@ func (c *conn) handleStagingPut(ctx context.Context, presignedUrl string, header if err != nil { return dbsqlerrint.NewDriverError(ctx, "error reading local file info", err) } - - req, _ := http.NewRequest("PUT", presignedUrl, dat) - req.ContentLength = info.Size() // backend actually requires content length to be known - - for k, v := range headers { - req.Header.Set(k, v) - } - res, err := client.Do(req) - if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) + size := info.Size() + + // Each retry attempt needs a fresh request because http.Client.Do consumes + // the request body. Rewind the *os.File between attempts so the server + // receives the full payload on every retry, not just attempt 1. + // + // Wrap the file in io.NopCloser so http.Client.Do can't close it — by + // default it closes any body that implements io.Closer, which would break + // the Seek on the next retry. The outer defer dat.Close() above owns the + // file's lifecycle. + reqFactory := func(attempt int) (*http.Request, error) { + if attempt > 0 { + if _, seekErr := dat.Seek(0, io.SeekStart); seekErr != nil { + return nil, seekErr + } + } + req, reqErr := http.NewRequestWithContext(ctx, http.MethodPut, presignedUrl, io.NopCloser(dat)) + if reqErr != nil { + return nil, reqErr + } + req.ContentLength = size // backend actually requires content length to be known + for k, v := range headers { + req.Header.Set(k, v) + } + return req, nil } - defer res.Body.Close() //nolint:errcheck - content, err := io.ReadAll(res.Body) - if err != nil || !Succeeded(res) { - return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil) + if _, err := c.doStagingRequestWithRetry(ctx, reqFactory); err != nil { + return err } return nil - } func (c *conn) handleStagingGet(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError { if localFile == "" { return dbsqlerrint.NewDriverError(ctx, "cannot perform GET without specifying a local_file", nil) } - client := &http.Client{} - req, _ := http.NewRequest("GET", presignedUrl, nil) - for k, v := range headers { - req.Header.Set(k, v) + reqFactory := func(_ int) (*http.Request, error) { + req, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, presignedUrl, nil) + if reqErr != nil { + return nil, reqErr + } + for k, v := range headers { + req.Header.Set(k, v) + } + return req, nil } - res, err := client.Do(req) + + content, err := c.doStagingRequestWithRetry(ctx, reqFactory) if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) + return err } - defer res.Body.Close() //nolint:errcheck - content, err := io.ReadAll(res.Body) + if writeErr := os.WriteFile(localFile, content, 0644); writeErr != nil { //nolint:gosec + return dbsqlerrint.NewDriverError(ctx, "error writing local file", writeErr) + } + return nil +} - if err != nil || !Succeeded(res) { - return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil) +func (c *conn) handleStagingRemove(ctx context.Context, presignedUrl string, headers map[string]string) dbsqlerr.DBError { + reqFactory := func(_ int) (*http.Request, error) { + req, reqErr := http.NewRequestWithContext(ctx, http.MethodDelete, presignedUrl, nil) + if reqErr != nil { + return nil, reqErr + } + for k, v := range headers { + req.Header.Set(k, v) + } + return req, nil } - err = os.WriteFile(localFile, content, 0644) //nolint:gosec - if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error writing local file", err) + // Treat 404 as success on REMOVE: DELETE is idempotent, and a 404 means + // the object is already absent — which is the post-condition the caller + // asked for. This also avoids spurious failures when a successful DELETE + // returns a transient 5xx mid-response and the retry sees 404 from the + // server having already applied the original request. + acceptStatus := func(status int) bool { + return statusInSuccessRange(status) || status == http.StatusNotFound + } + + if _, err := c.doStagingRequestWithRetryAccept(ctx, reqFactory, acceptStatus); err != nil { + return err } return nil } -func (c *conn) handleStagingRemove(ctx context.Context, presignedUrl string, headers map[string]string) dbsqlerr.DBError { - client := &http.Client{} - req, _ := http.NewRequest("DELETE", presignedUrl, nil) - for k, v := range headers { - req.Header.Set(k, v) +// maxStagingErrorBodyBytes bounds the response body bytes included in +// terminal staging error messages. Proxies and misconfigured backends can +// return multi-MB error bodies; truncating keeps the driver error readable +// without dropping the typical S3 XML error code that fits well under 512B. +const maxStagingErrorBodyBytes = 512 + +// doStagingRequestWithRetry executes a staging HTTP request with retry on +// transient object-storage failures (ES-1911239). Wraps +// doStagingRequestWithRetryAccept with the default success predicate (2xx +// from statusInSuccessRange / Succeeded). +func (c *conn) doStagingRequestWithRetry(ctx context.Context, reqFactory func(attempt int) (*http.Request, error)) ([]byte, dbsqlerr.DBError) { + return c.doStagingRequestWithRetryAccept(ctx, reqFactory, statusInSuccessRange) +} + +// doStagingRequestWithRetryAccept is the generalized staging retry helper +// used by all three handleStaging* methods. Mirrors the CloudFetch retry +// path (ES-1892645) in semantics — same retryable status set, same +// exponential-backoff-with-jitter schedule, same RetryMax/RetryWaitMin/ +// RetryWaitMax config knobs — so behavior is consistent across the driver's +// two object-storage code paths. +// +// reqFactory must return a fresh *http.Request on each call. Attempt 0 is +// the initial request; attempt N>0 is a retry. The PUT path uses this to +// rewind the file body between attempts; other staging paths just construct +// a new request each time. +// +// acceptStatus reports whether a given HTTP status code should be treated +// as success. Most handlers pass statusInSuccessRange. The REMOVE handler +// extends this to also accept 404 (idempotent-delete semantics). +// +// On success returns the response body bytes. On terminal failure (non- +// retryable status, exhausted retries, or context cancellation) returns a +// dbsqlerr.DBError describing the final state. +func (c *conn) doStagingRequestWithRetryAccept( + ctx context.Context, + reqFactory func(attempt int) (*http.Request, error), + acceptStatus func(status int) bool, +) ([]byte, dbsqlerr.DBError) { + retryMax := c.cfg.RetryMax + if retryMax < 0 { + retryMax = 0 } - res, err := client.Do(req) - if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) + client := &http.Client{} + + var ( + lastErr error + lastStatus int + lastBody []byte + lastRetryAfter string + ) + + for attempt := 0; attempt <= retryMax; attempt++ { + if attempt > 0 { + wait := retry.Backoff(attempt, c.cfg.RetryWaitMin, c.cfg.RetryWaitMax, lastRetryAfter) + logger.Debug().Msgf( + "staging: retrying HTTP request (attempt %d/%d) in %v; lastStatus=%d lastErr=%v", + attempt, retryMax, wait, lastStatus, lastErr, + ) + t := time.NewTimer(wait) + select { + case <-ctx.Done(): + if !t.Stop() { + <-t.C + } + return nil, dbsqlerrint.NewDriverError(ctx, "staging operation cancelled during retry backoff", ctx.Err()) + case <-t.C: + } + } + + req, reqErr := reqFactory(attempt) + if reqErr != nil { + return nil, dbsqlerrint.NewDriverError(ctx, "error building staging http request", reqErr) + } + + res, err := client.Do(req) + if err != nil { + // Caller cancellation is terminal; otherwise treat transport + // errors (TCP RST, TLS timeout, etc.) as transient. + if ctx.Err() != nil { + return nil, dbsqlerrint.NewDriverError(ctx, "error sending http request", ctx.Err()) + } + lastErr = err + lastStatus = 0 + lastRetryAfter = "" + continue + } + + body, readErr := io.ReadAll(res.Body) + res.Body.Close() //nolint:errcheck,gosec // G104: close after drain + + if readErr != nil { + if ctx.Err() != nil { + return nil, dbsqlerrint.NewDriverError(ctx, "error reading http response", ctx.Err()) + } + lastErr = readErr + lastStatus = 0 + lastRetryAfter = "" + continue + } + + if acceptStatus(res.StatusCode) { + return body, nil + } + + lastStatus = res.StatusCode + lastErr = nil + lastBody = body + lastRetryAfter = res.Header.Get("Retry-After") + + if !retry.IsRetryableStatus(res.StatusCode) { + return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, truncateErrorBody(body)), nil) + } } - defer res.Body.Close() //nolint:errcheck - content, err := io.ReadAll(res.Body) - if err != nil || !Succeeded(res) { - return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s, nil", res.StatusCode, content), nil) + if lastStatus != 0 { + // lastErr is nil here by construction: the HTTP-status branch above + // explicitly clears it on every iteration. The status code and body + // are captured in msg, so there's no underlying error to wrap. + return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s (after %d retries)", lastStatus, truncateErrorBody(lastBody), retryMax), nil) } + return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation HTTP request failed: %v (after %d retries)", lastErr, retryMax), lastErr) +} - return nil +// truncateErrorBody caps b at maxStagingErrorBodyBytes for inclusion in error +// messages, appending an indicator when truncation occurred. +func truncateErrorBody(b []byte) string { + if len(b) <= maxStagingErrorBodyBytes { + return string(b) + } + return fmt.Sprintf("%s... (%d bytes total, truncated)", b[:maxStagingErrorBodyBytes], len(b)) } func localPathIsAllowed(stagingAllowedLocalPaths []string, localFile string) bool { diff --git a/connection_test.go b/connection_test.go index badf6a7..92f0837 100644 --- a/connection_test.go +++ b/connection_test.go @@ -4,6 +4,13 @@ import ( "context" "database/sql/driver" "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "sync/atomic" "testing" "time" @@ -2035,3 +2042,317 @@ func getTestSession() *cli_service.TOpenSessionResp { }, }} } + +// TestConn_handleStagingRetry verifies that the staging-operation HTTP wrappers +// (handleStagingPut/Get/Remove) retry transient S3 errors. ES-1911239: FactSet +// hit intermittent HTTP 503 SlowDown on PUT against a Unity Catalog external +// volume; pre-fix any single 5xx failed the entire SQL statement. +// +// Retry behavior mirrors the CloudFetch fix from ES-1892645 / PR #355: +// - Retryable statuses: 408/429/500/502/503/504. +// - Exponential backoff with equal jitter, honoring RetryMax/RetryWaitMin/ +// RetryWaitMax from the connection config. +// - Integer Retry-After response header is honored (capped at RetryWaitMax). +// - Non-retryable statuses (e.g. 403) fail on the first attempt. +// - Context cancellation aborts backoff promptly. +func TestConn_handleStagingRetry(t *testing.T) { + // retryCfg returns a fast-backoff config so tests don't burn wall-clock + // on sleeps. RetryMax leaves room for several retries; RetryWaitMin/Max + // keep the worst-case test runtime under a second. + retryCfg := func() *config.Config { + cfg := config.WithDefaults() + cfg.RetryMax = 4 + cfg.RetryWaitMin = 1 * time.Millisecond + cfg.RetryWaitMax = 5 * time.Millisecond + return cfg + } + + t.Run("PUT retries transient 503 and eventually succeeds", func(t *testing.T) { + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte("SlowDown")) + return + } + // Drain the body so we exercise the retry-replay path for PUTs. + _, _ = io.Copy(io.Discard, r.Body) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + if err := os.WriteFile(localFile, []byte("parquet bytes"), 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + c := &conn{cfg: retryCfg()} + err := c.handleStagingPut(context.Background(), server.URL, nil, localFile) + assert.Nil(t, err) + assert.Equal(t, int32(3), atomic.LoadInt32(&attempts), "expected 2 retries before success") + }) + + t.Run("GET retries transient 503 and eventually succeeds", func(t *testing.T) { + var attempts int32 + body := []byte("downloaded bytes") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n < 2 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write(body) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "out.bin") + + c := &conn{cfg: retryCfg()} + err := c.handleStagingGet(context.Background(), server.URL, nil, localFile) + assert.Nil(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&attempts)) + + got, readErr := os.ReadFile(localFile) + assert.Nil(t, readErr) + assert.Equal(t, body, got, "GET should write the final-attempt body to local file") + }) + + t.Run("REMOVE retries transient 503 and eventually succeeds", func(t *testing.T) { + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n < 2 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + c := &conn{cfg: retryCfg()} + err := c.handleStagingRemove(context.Background(), server.URL, nil) + assert.Nil(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&attempts)) + }) + + t.Run("REMOVE treats 503-then-404 as success (idempotent delete)", func(t *testing.T) { + // The first DELETE may have applied server-side even though the + // response was 503 (load balancer dies mid-response, etc.). The + // retry then sees 404 — the object is already gone. The caller's + // post-condition ("object absent") is satisfied, so this is success. + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n == 1 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + c := &conn{cfg: retryCfg()} + err := c.handleStagingRemove(context.Background(), server.URL, nil) + assert.Nil(t, err, "503 then 404 on REMOVE should succeed: the object is absent, which is the caller's intent") + assert.Equal(t, int32(2), atomic.LoadInt32(&attempts)) + }) + + t.Run("REMOVE treats first-attempt 404 as success (idempotent delete)", func(t *testing.T) { + // DELETE on a non-existent object is success: the post-condition + // ("object absent") is already true. Documents the behavior change + // vs. the pre-retry implementation, which surfaced 404 as failure. + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + c := &conn{cfg: retryCfg()} + err := c.handleStagingRemove(context.Background(), server.URL, nil) + assert.Nil(t, err, "404 on REMOVE should always be success") + assert.Equal(t, int32(1), atomic.LoadInt32(&attempts), "404 must not trigger a retry") + }) + + t.Run("PUT first-attempt 404 still fails (only REMOVE treats 404 as success)", func(t *testing.T) { + // Guard against the 404-as-success behavior leaking into the other + // handlers. PUT/GET must still treat 404 as a terminal failure. + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + if err := os.WriteFile(localFile, []byte("data"), 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + c := &conn{cfg: retryCfg()} + err := c.handleStagingPut(context.Background(), server.URL, nil, localFile) + assert.NotNil(t, err, "404 on PUT must remain a terminal failure") + assert.ErrorContains(t, err, "404") + assert.Equal(t, int32(1), atomic.LoadInt32(&attempts)) + }) + + t.Run("PUT retries transient HTTP 500", func(t *testing.T) { + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n < 2 { + w.WriteHeader(http.StatusInternalServerError) + return + } + _, _ = io.Copy(io.Discard, r.Body) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + if err := os.WriteFile(localFile, []byte("data"), 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + c := &conn{cfg: retryCfg()} + err := c.handleStagingPut(context.Background(), server.URL, nil, localFile) + assert.Nil(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&attempts)) + }) + + t.Run("PUT fails after exhausting retries on persistent 503", func(t *testing.T) { + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + if err := os.WriteFile(localFile, []byte("data"), 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + cfg := retryCfg() + cfg.RetryMax = 2 + c := &conn{cfg: cfg} + err := c.handleStagingPut(context.Background(), server.URL, nil, localFile) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "503") + // initial attempt + RetryMax retries + assert.Equal(t, int32(3), atomic.LoadInt32(&attempts)) + }) + + t.Run("PUT does not retry non-retryable status (403)", func(t *testing.T) { + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusForbidden) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + if err := os.WriteFile(localFile, []byte("data"), 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + c := &conn{cfg: retryCfg()} + started := time.Now() + err := c.handleStagingPut(context.Background(), server.URL, nil, localFile) + elapsed := time.Since(started) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "403") + assert.Equal(t, int32(1), atomic.LoadInt32(&attempts), "non-retryable status must fail on first attempt") + // retryCfg's RetryWaitMin is 1ms; if a backoff fired by mistake we'd + // observe at least that. 50ms gives headroom for slow CI without + // masking an accidental retry. + assert.Less(t, elapsed, 50*time.Millisecond, "non-retryable status must not trigger backoff") + }) + + t.Run("PUT replays the file body on each retry", func(t *testing.T) { + // Verifies that the retry implementation correctly handles the request + // body lifecycle: an os.File consumed by attempt N must be rewound or + // re-opened before attempt N+1, otherwise the server sees a zero-length + // body on retries. + var ( + mu sync.Mutex + receivedSizes []int64 + ) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + mu.Lock() + receivedSizes = append(receivedSizes, int64(len(body))) + n := len(receivedSizes) + mu.Unlock() + if n < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + payload := []byte("important parquet data that must be re-sent on each retry") + if err := os.WriteFile(localFile, payload, 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + c := &conn{cfg: retryCfg()} + err := c.handleStagingPut(context.Background(), server.URL, nil, localFile) + assert.Nil(t, err) + mu.Lock() + defer mu.Unlock() + assert.Equal(t, 3, len(receivedSizes), "expected 3 PUT attempts") + for i, sz := range receivedSizes { + assert.Equal(t, int64(len(payload)), sz, "attempt %d received %d bytes, expected full payload of %d bytes", i+1, sz, len(payload)) + } + }) + + t.Run("PUT respects context cancellation during backoff", func(t *testing.T) { + var attempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer server.Close() + + tmpDir := t.TempDir() + localFile := filepath.Join(tmpDir, "payload.parquet") + if err := os.WriteFile(localFile, []byte("data"), 0600); err != nil { + t.Fatalf("write local file: %v", err) + } + + cfg := retryCfg() + cfg.RetryMax = 5 + cfg.RetryWaitMin = 500 * time.Millisecond + cfg.RetryWaitMax = 1 * time.Second + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + c := &conn{cfg: cfg} + started := time.Now() + err := c.handleStagingPut(ctx, server.URL, nil, localFile) + elapsed := time.Since(started) + + assert.NotNil(t, err) + // Cancellation should land well before the full retry budget elapses. + // 5 retries * 500ms+ minimum backoff = 2.5s+ without cancellation. + // 2s gives generous headroom on slow CI runners without masking a + // regression where cancellation is honored only at retry boundaries. + assert.Less(t, elapsed, 2*time.Second, "context cancel should abort PUT retry backoff promptly") + }) +} diff --git a/internal/retry/retry.go b/internal/retry/retry.go new file mode 100644 index 0000000..87fc9ad --- /dev/null +++ b/internal/retry/retry.go @@ -0,0 +1,67 @@ +// Package retry provides shared HTTP retry/backoff helpers for transient +// object-storage failures (S3 SlowDown, 5xx, etc.). Used by the CloudFetch +// download path and the staging-operation (PUT/GET/REMOVE) handlers so both +// share a single retryable-status set and backoff schedule. +package retry + +import ( + "math" + "math/rand" + "net/http" + "strconv" + "time" +) + +// RetryableStatuses lists HTTP status codes from object storage that indicate +// transient conditions and warrant a retry. Mirrors AWS S3 guidance for +// SlowDown (503) / InternalError (500) plus the general 408/429/502/504. +var RetryableStatuses = map[int]struct{}{ + http.StatusRequestTimeout: {}, // 408 + http.StatusTooManyRequests: {}, // 429 + http.StatusInternalServerError: {}, // 500 + http.StatusBadGateway: {}, // 502 + http.StatusServiceUnavailable: {}, // 503 + http.StatusGatewayTimeout: {}, // 504 +} + +// IsRetryableStatus reports whether the given HTTP status code is a transient +// object-storage failure that warrants a retry. +func IsRetryableStatus(status int) bool { + _, ok := RetryableStatuses[status] + return ok +} + +// Backoff returns the wait before retry attempt N (1-based). The base delay +// is exponential — waitMin * 2^(attempt-1) capped at waitMax — with equal +// jitter applied: the actual sleep is uniformly distributed in +// [base/2, base]. Equal jitter (rather than no jitter) spreads synchronized +// retries across concurrent callers, which would otherwise hammer the storage +// endpoint in lockstep after a region-wide blip. If the server returned a +// parseable integer Retry-After header, that value (in seconds) is honored +// instead, capped at waitMax. HTTP-date Retry-After values are ignored — +// same as the Thrift client's backoff. +func Backoff(attempt int, waitMin, waitMax time.Duration, retryAfter string) time.Duration { + if retryAfter != "" { + if secs, err := strconv.ParseInt(retryAfter, 10, 64); err == nil && secs >= 0 { + d := time.Duration(secs) * time.Second + if d > waitMax { + return waitMax + } + return d + } + } + + expo := float64(waitMin) * math.Pow(2, float64(attempt-1)) + if expo > float64(waitMax) || math.IsInf(expo, 0) { + expo = float64(waitMax) + } + base := time.Duration(expo) + if base <= 0 { + return 0 + } + half := base / 2 + if half <= 0 { + return base + } + return half + time.Duration(rand.Int63n(int64(half))) //nolint:gosec // G404: jitter only, non-cryptographic +} diff --git a/internal/retry/retry_test.go b/internal/retry/retry_test.go new file mode 100644 index 0000000..a4934dd --- /dev/null +++ b/internal/retry/retry_test.go @@ -0,0 +1,69 @@ +package retry + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBackoff(t *testing.T) { + t.Run("retry-after integer seconds is honored", func(t *testing.T) { + got := Backoff(1, 100*time.Millisecond, 60*time.Second, "2") + assert.Equal(t, 2*time.Second, got) + }) + + t.Run("retry-after is capped at waitMax", func(t *testing.T) { + got := Backoff(1, 100*time.Millisecond, 1*time.Second, "100") + assert.Equal(t, 1*time.Second, got) + }) + + t.Run("retry-after http-date is ignored, falls back to exponential", func(t *testing.T) { + minWait := 100 * time.Millisecond + got := Backoff(1, minWait, 10*time.Second, "Tue, 15 Nov 1994 08:12:31 GMT") + // attempt=1 base = minWait; equal jitter in [minWait/2, minWait] + assert.GreaterOrEqual(t, got, minWait/2) + assert.LessOrEqual(t, got, minWait) + }) + + t.Run("exponential is capped at waitMax", func(t *testing.T) { + maxWait := 200 * time.Millisecond + // 100ms * 2^9 = 51200ms, capped at 200ms; equal jitter -> [100ms, 200ms] + for i := 0; i < 50; i++ { + got := Backoff(10, 100*time.Millisecond, maxWait, "") + assert.GreaterOrEqual(t, got, maxWait/2) + assert.LessOrEqual(t, got, maxWait) + } + }) + + t.Run("base grows exponentially with attempt", func(t *testing.T) { + minWait, maxWait := 100*time.Millisecond, 10*time.Second + // attempt=1 -> base 100ms, jitter [50ms,100ms] + // attempt=3 -> base 400ms, jitter [200ms,400ms] + for i := 0; i < 50; i++ { + got1 := Backoff(1, minWait, maxWait, "") + got3 := Backoff(3, minWait, maxWait, "") + assert.GreaterOrEqual(t, got1, 50*time.Millisecond) + assert.LessOrEqual(t, got1, 100*time.Millisecond) + assert.GreaterOrEqual(t, got3, 200*time.Millisecond) + assert.LessOrEqual(t, got3, 400*time.Millisecond) + } + }) + + t.Run("zero waitMin returns zero", func(t *testing.T) { + got := Backoff(1, 0, 0, "") + assert.Equal(t, time.Duration(0), got) + }) +} + +func TestIsRetryableStatus(t *testing.T) { + retryable := []int{408, 429, 500, 502, 503, 504} + notRetryable := []int{200, 201, 301, 302, 400, 401, 403, 404, 409, 410, 501} + + for _, s := range retryable { + assert.True(t, IsRetryableStatus(s), "%d should be retryable", s) + } + for _, s := range notRetryable { + assert.False(t, IsRetryableStatus(s), "%d should not be retryable", s) + } +} diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 67dfd25..889524d 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -5,13 +5,11 @@ import ( "context" "fmt" "io" - "math" - "math/rand" - "strconv" "strings" "time" "github.com/databricks/databricks-sql-go/internal/config" + "github.com/databricks/databricks-sql-go/internal/retry" "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" "github.com/pierrec/lz4/v4" "github.com/pkg/errors" @@ -402,7 +400,7 @@ func fetchBatchBytes( for attempt := 0; attempt <= retryMax; attempt++ { if attempt > 0 { - wait := cloudFetchBackoff(attempt, retryWaitMin, retryWaitMax, lastRetryAfter) + wait := retry.Backoff(attempt, retryWaitMin, retryWaitMax, lastRetryAfter) logger.Debug().Msgf( "CloudFetch: retrying download of link at offset %d (attempt %d/%d) in %v; lastStatus=%d lastErr=%v", link.StartRowOffset, attempt, retryMax, wait, lastStatus, lastErr, @@ -475,7 +473,7 @@ func fetchBatchBytes( lastErr = nil lastRetryAfter = res.Header.Get("Retry-After") - if !isCloudFetchRetryableStatus(res.StatusCode) { + if !retry.IsRetryableStatus(res.StatusCode) { msg := fmt.Sprintf("%s: %s %d", errArrowRowsCloudFetchDownloadFailure, "HTTP error", res.StatusCode) return nil, dbsqlerrint.NewDriverError(ctx, msg, nil) } @@ -491,59 +489,6 @@ func fetchBatchBytes( return nil, dbsqlerrint.NewDriverError(ctx, msg, lastErr) } -// cloudFetchRetryableStatuses lists HTTP status codes from object storage that -// indicate transient conditions and warrant a retry. Mirrors AWS S3 guidance -// for SlowDown (503) / InternalError (500) plus the general 408/429/502/504. -var cloudFetchRetryableStatuses = map[int]struct{}{ - http.StatusRequestTimeout: {}, // 408 - http.StatusTooManyRequests: {}, // 429 - http.StatusInternalServerError: {}, // 500 - http.StatusBadGateway: {}, // 502 - http.StatusServiceUnavailable: {}, // 503 - http.StatusGatewayTimeout: {}, // 504 -} - -func isCloudFetchRetryableStatus(status int) bool { - _, ok := cloudFetchRetryableStatuses[status] - return ok -} - -// cloudFetchBackoff returns the wait before retry attempt N (1-based). The -// base delay is exponential — waitMin * 2^(attempt-1) capped at waitMax — with -// equal jitter applied: the actual sleep is uniformly distributed in -// [base/2, base]. Equal jitter (rather than no jitter) is used to spread -// synchronized retries across the up-to-MaxDownloadThreads concurrent -// downloads, which would otherwise hammer the storage endpoint in lockstep -// after a region-wide blip. If the server returned a parseable integer -// Retry-After header, that value (in seconds) is honored instead, capped at -// waitMax. HTTP-date Retry-After values are ignored — same as the Thrift -// client's backoff. -func cloudFetchBackoff(attempt int, waitMin, waitMax time.Duration, retryAfter string) time.Duration { - if retryAfter != "" { - if secs, err := strconv.ParseInt(retryAfter, 10, 64); err == nil && secs >= 0 { - d := time.Duration(secs) * time.Second - if d > waitMax { - return waitMax - } - return d - } - } - - expo := float64(waitMin) * math.Pow(2, float64(attempt-1)) - if expo > float64(waitMax) || math.IsInf(expo, 0) { - expo = float64(waitMax) - } - base := time.Duration(expo) - if base <= 0 { - return 0 - } - half := base / 2 - if half <= 0 { - return base - } - return half + time.Duration(rand.Int63n(int64(half))) //nolint:gosec // G404: jitter only, non-cryptographic -} - func getReader(r io.Reader, useLz4Compression bool) io.Reader { if useLz4Compression { return lz4.NewReader(r) diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index d8d942b..d66ded3 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -697,67 +697,6 @@ func TestCloudFetchIterator(t *testing.T) { }) } -func TestCloudFetchBackoff(t *testing.T) { - t.Run("retry-after integer seconds is honored", func(t *testing.T) { - got := cloudFetchBackoff(1, 100*time.Millisecond, 60*time.Second, "2") - assert.Equal(t, 2*time.Second, got) - }) - - t.Run("retry-after is capped at waitMax", func(t *testing.T) { - got := cloudFetchBackoff(1, 100*time.Millisecond, 1*time.Second, "100") - assert.Equal(t, 1*time.Second, got) - }) - - t.Run("retry-after http-date is ignored, falls back to exponential", func(t *testing.T) { - minWait := 100 * time.Millisecond - got := cloudFetchBackoff(1, minWait, 10*time.Second, "Tue, 15 Nov 1994 08:12:31 GMT") - // attempt=1 base = minWait; equal jitter in [minWait/2, minWait] - assert.GreaterOrEqual(t, got, minWait/2) - assert.LessOrEqual(t, got, minWait) - }) - - t.Run("exponential is capped at waitMax", func(t *testing.T) { - maxWait := 200 * time.Millisecond - // 100ms * 2^9 = 51200ms, capped at 200ms; equal jitter -> [100ms, 200ms] - for i := 0; i < 50; i++ { - got := cloudFetchBackoff(10, 100*time.Millisecond, maxWait, "") - assert.GreaterOrEqual(t, got, maxWait/2) - assert.LessOrEqual(t, got, maxWait) - } - }) - - t.Run("base grows exponentially with attempt", func(t *testing.T) { - minWait, maxWait := 100*time.Millisecond, 10*time.Second - // attempt=1 -> base 100ms, jitter [50ms,100ms] - // attempt=3 -> base 400ms, jitter [200ms,400ms] - for i := 0; i < 50; i++ { - got1 := cloudFetchBackoff(1, minWait, maxWait, "") - got3 := cloudFetchBackoff(3, minWait, maxWait, "") - assert.GreaterOrEqual(t, got1, 50*time.Millisecond) - assert.LessOrEqual(t, got1, 100*time.Millisecond) - assert.GreaterOrEqual(t, got3, 200*time.Millisecond) - assert.LessOrEqual(t, got3, 400*time.Millisecond) - } - }) - - t.Run("zero waitMin returns zero", func(t *testing.T) { - got := cloudFetchBackoff(1, 0, 0, "") - assert.Equal(t, time.Duration(0), got) - }) -} - -func TestCloudFetchRetryableStatus(t *testing.T) { - retryable := []int{408, 429, 500, 502, 503, 504} - notRetryable := []int{200, 201, 301, 302, 400, 401, 403, 404, 409, 410, 501} - - for _, s := range retryable { - assert.True(t, isCloudFetchRetryableStatus(s), "%d should be retryable", s) - } - for _, s := range notRetryable { - assert.False(t, isCloudFetchRetryableStatus(s), "%d should not be retryable", s) - } -} - func TestCloudFetchSchemaOverride(t *testing.T) { // Reproduces ES-1804970: When the server result cache serves Arrow IPC files // from a prior query, the embedded schema has stale column names. The