diff --git a/cmd/cache-proxy/cache_test.go b/cmd/cache-proxy/cache_test.go index de4b45b7..1de0a5ca 100644 --- a/cmd/cache-proxy/cache_test.go +++ b/cmd/cache-proxy/cache_test.go @@ -24,6 +24,15 @@ func counterValue(t *testing.T, c prometheus.Counter) float64 { return m.GetCounter().GetValue() } +func gaugeValue(t *testing.T, g prometheus.Gauge) float64 { + t.Helper() + var m dto.Metric + if err := g.Write(&m); err != nil { + t.Fatalf("read gauge: %v", err) + } + return m.GetGauge().GetValue() +} + // errAfterReader yields n bytes then returns a non-EOF error, simulating an // origin/peer connection dropping mid-body. type errAfterReader struct { diff --git a/cmd/cache-proxy/proxy.go b/cmd/cache-proxy/proxy.go index 1313084a..3c229963 100644 --- a/cmd/cache-proxy/proxy.go +++ b/cmd/cache-proxy/proxy.go @@ -372,10 +372,14 @@ func (p *CacheProxy) fetchDedup(cacheKey string, r *http.Request, rangeHeader st return fetchResult{size: n, source: "peer"}, nil } } + originFetchInFlight.Inc() + defer originFetchInFlight.Dec() size, ct, err := p.fetchOrigin(cacheKey, r) if err != nil { + originFetchesTotal.WithLabelValues(originFetchOutcome(err)).Inc() return fetchResult{}, err } + originFetchesTotal.WithLabelValues("success").Inc() cacheBytesServed.WithLabelValues("s3").Add(float64(size)) return fetchResult{size: size, contentType: ct, source: "miss"}, nil }) @@ -410,7 +414,11 @@ func (p *CacheProxy) fetchOrigin(cacheKey string, r *http.Request) (int64, strin if errors.As(err, &oe) { originSpan.SetAttributes(attribute.Int("http.response.status_code", oe.status)) } - if attempt == attempts || r.Context().Err() != nil || !isRetriableOriginFetchError(err) { + if err := r.Context().Err(); err != nil { + originSpan.SetStatus(codes.Error, err.Error()) + return 0, "", err + } + if attempt == attempts || !isRetriableOriginFetchError(err) { originSpan.SetStatus(codes.Error, err.Error()) return 0, "", err } @@ -424,8 +432,18 @@ func (p *CacheProxy) fetchOrigin(cacheKey string, r *http.Request) (int64, strin "backoff", delay, "error", err) if !sleepContext(r.Context(), delay) { - return 0, "", r.Context().Err() + err := r.Context().Err() + if err == nil { + err = context.Canceled + } + originSpan.SetStatus(codes.Error, err.Error()) + return 0, "", err + } + if err := r.Context().Err(); err != nil { + originSpan.SetStatus(codes.Error, err.Error()) + return 0, "", err } + originFetchRetriesTotal.WithLabelValues(originRetryReason(err)).Inc() if backoff > 0 { backoff *= 2 if p.originRetryMaxBackoff > 0 && backoff > p.originRetryMaxBackoff { @@ -545,6 +563,54 @@ func isRetriableOriginFetchError(err error) bool { strings.Contains(msg, "timeout") } +func originFetchOutcome(err error) string { + if err == nil { + return "success" + } + var oe *originStatusError + if errors.As(err, &oe) { + return "http_error" + } + if errors.Is(err, context.Canceled) { + return "canceled" + } + if errors.Is(err, context.DeadlineExceeded) { + return "timeout" + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return "timeout" + } + return "error" +} + +func originRetryReason(err error) string { + var oe *originStatusError + if errors.As(err, &oe) { + return fmt.Sprintf("http_%d", oe.status) + } + if errors.Is(err, context.DeadlineExceeded) { + return "timeout" + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return "timeout" + } + msg := strings.ToLower(err.Error()) + switch { + case strings.Contains(msg, "connection reset by peer"): + return "connection_reset" + case strings.Contains(msg, "connection refused"): + return "connection_refused" + case strings.Contains(msg, "unexpected eof"): + return "unexpected_eof" + case strings.Contains(msg, "timeout"): + return "timeout" + default: + return "transport_error" + } +} + // originErrorBodyCap is the maximum number of bytes we'll buffer from a // non-2xx origin response. S3 XML error envelopes are tiny; this is just a // safety net. diff --git a/cmd/cache-proxy/proxy_metrics.go b/cmd/cache-proxy/proxy_metrics.go new file mode 100644 index 00000000..d520751b --- /dev/null +++ b/cmd/cache-proxy/proxy_metrics.go @@ -0,0 +1,21 @@ +package main + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + originFetchesTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "cache_proxy_origin_fetches_total", + Help: "Total origin fetch outcomes for cacheable misses", + }, []string{"outcome"}) + originFetchRetriesTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "cache_proxy_origin_fetch_retries_total", + Help: "Total origin fetch retries by reason", + }, []string{"reason"}) + originFetchInFlight = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "cache_proxy_origin_fetches_in_flight", + Help: "Current number of origin fetches filling the local cache", + }) +) diff --git a/cmd/cache-proxy/proxy_test.go b/cmd/cache-proxy/proxy_test.go index 99c00ab5..4d2790d4 100644 --- a/cmd/cache-proxy/proxy_test.go +++ b/cmd/cache-proxy/proxy_test.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "fmt" "io" "log/slog" @@ -10,11 +11,68 @@ import ( "net/http/httptest" "net/url" "strings" + "sync" "sync/atomic" "testing" "time" ) +type timeoutNetError struct{} + +func (timeoutNetError) Error() string { return "network timeout" } +func (timeoutNetError) Timeout() bool { return true } +func (timeoutNetError) Temporary() bool { return true } + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +type retryLogSignalHandler struct { + ch chan struct{} + once sync.Once +} + +func (h *retryLogSignalHandler) Enabled(context.Context, slog.Level) bool { + return true +} + +func (h *retryLogSignalHandler) Handle(_ context.Context, r slog.Record) error { + if r.Message == "Origin fetch failed with retriable error, retrying." { + h.once.Do(func() { close(h.ch) }) + } + return nil +} + +func (h *retryLogSignalHandler) WithAttrs([]slog.Attr) slog.Handler { + return h +} + +func (h *retryLogSignalHandler) WithGroup(string) slog.Handler { + return h +} + +func waitForSignal(t *testing.T, ch <-chan struct{}, msg string) { + t.Helper() + select { + case <-ch: + case <-time.After(2 * time.Second): + t.Fatal(msg) + } +} + +func waitForRecorder(t *testing.T, ch <-chan *httptest.ResponseRecorder, msg string) *httptest.ResponseRecorder { + t.Helper() + select { + case rec := <-ch: + return rec + case <-time.After(2 * time.Second): + t.Fatal(msg) + return nil + } +} + // captureSlog redirects slog.Default to a buffer for the duration of a test // and returns the buffer + a restore function. Used by the forward-uncached // logging tests to assert presence of the request/response log lines that @@ -226,6 +284,331 @@ func TestHandleProxyRetriesOrigin503AndForwardsFinalFailure(t *testing.T) { } } +func TestHandleProxyOriginFetchMetrics(t *testing.T) { + proxy := newTestProxy(t) + + successBefore := counterValue(t, originFetchesTotal.WithLabelValues("success")) + retryBefore := counterValue(t, originFetchRetriesTotal.WithLabelValues("http_503")) + inFlightBefore := gaugeValue(t, originFetchInFlight) + + var calls int32 + _, originURL := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n == 1 { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte("retry later")) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok-after-metric-retry")) + }) + + rec := doForwardProxyRequest(proxy, "GET", originURL+"/bucket/metrics.parquet", http.Header{"Range": []string{"bytes=0-20"}}) + if rec.Code != http.StatusPartialContent { + t.Fatalf("status = %d, want 206 after retry", rec.Code) + } + if got := counterValue(t, originFetchesTotal.WithLabelValues("success")); got != successBefore+1 { + t.Fatalf("origin success metric = %v, want %v", got, successBefore+1) + } + if got := counterValue(t, originFetchRetriesTotal.WithLabelValues("http_503")); got != retryBefore+1 { + t.Fatalf("origin retry metric = %v, want %v", got, retryBefore+1) + } + if got := gaugeValue(t, originFetchInFlight); got != inFlightBefore { + t.Fatalf("in-flight origin fetches = %v, want %v after request completes", got, inFlightBefore) + } +} + +func TestHandleProxyOriginRetryMetricDoesNotCountCanceledBackoff(t *testing.T) { + proxy := newTestProxy(t) + proxy.originRetryInitialBackoff = 10 * time.Second + proxy.originRetryMaxBackoff = 10 * time.Second + + retryBefore := counterValue(t, originFetchRetriesTotal.WithLabelValues("http_503")) + canceledBefore := counterValue(t, originFetchesTotal.WithLabelValues("canceled")) + firstAttemptReturned := make(chan struct{}) + var originCalls int32 + _, originURL := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if atomic.AddInt32(&originCalls, 1) == 1 { + close(firstAttemptReturned) + } + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte("retry later")) + }) + + retryLogSeen := make(chan struct{}) + prevLogger := slog.Default() + slog.SetDefault(slog.New(&retryLogSignalHandler{ch: retryLogSeen})) + t.Cleanup(func() { slog.SetDefault(prevLogger) }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + req := httptest.NewRequest("GET", originURL+"/bucket/cancel-retry.parquet", nil).WithContext(ctx) + req.Host = req.URL.Host + req.Header.Set("Range", "bytes=0-20") + rec := httptest.NewRecorder() + done := make(chan struct{}) + go func() { + defer close(done) + proxy.HandleProxy(rec, req) + }() + + waitForSignal(t, firstAttemptReturned, "timed out waiting for first origin attempt") + waitForSignal(t, retryLogSeen, "timed out waiting for retry backoff path") + cancel() + waitForSignal(t, done, "timed out waiting for canceled proxy request") + + if rec.Code != http.StatusBadGateway { + t.Fatalf("status = %d, want 502 after context cancellation", rec.Code) + } + if got := atomic.LoadInt32(&originCalls); got != 1 { + t.Fatalf("origin calls = %d, want only the initial failed attempt", got) + } + if got := counterValue(t, originFetchRetriesTotal.WithLabelValues("http_503")); got != retryBefore { + t.Fatalf("origin retry metric = %v, want unchanged %v when backoff is canceled", got, retryBefore) + } + if got := counterValue(t, originFetchesTotal.WithLabelValues("canceled")); got != canceledBefore+1 { + t.Fatalf("origin canceled metric = %v, want %v", got, canceledBefore+1) + } +} + +func TestHandleProxyOriginRetryMetricCountsPreBackoffCancellationAsCanceled(t *testing.T) { + proxy := newTestProxy(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + proxy.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + cancel() + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("retry later")), + }, nil + })} + + canceledBefore := counterValue(t, originFetchesTotal.WithLabelValues("canceled")) + httpErrorBefore := counterValue(t, originFetchesTotal.WithLabelValues("http_error")) + retryBefore := counterValue(t, originFetchRetriesTotal.WithLabelValues("http_503")) + + req := httptest.NewRequest("GET", "http://origin.test/bucket/pre-backoff-cancel.parquet", nil).WithContext(ctx) + req.Host = req.URL.Host + req.Header.Set("Range", "bytes=0-20") + rec := httptest.NewRecorder() + proxy.HandleProxy(rec, req) + + if rec.Code != http.StatusBadGateway { + t.Fatalf("status = %d, want 502 after context cancellation", rec.Code) + } + if got := counterValue(t, originFetchesTotal.WithLabelValues("canceled")); got != canceledBefore+1 { + t.Fatalf("origin canceled metric = %v, want %v", got, canceledBefore+1) + } + if got := counterValue(t, originFetchesTotal.WithLabelValues("http_error")); got != httpErrorBefore { + t.Fatalf("origin http_error metric = %v, want unchanged %v", got, httpErrorBefore) + } + if got := counterValue(t, originFetchRetriesTotal.WithLabelValues("http_503")); got != retryBefore { + t.Fatalf("origin retry metric = %v, want unchanged %v", got, retryBefore) + } +} + +func TestHandleProxyOriginFetchInFlightMetricDuringRequest(t *testing.T) { + proxy := newTestProxy(t) + inFlightBefore := gaugeValue(t, originFetchInFlight) + originStarted := make(chan struct{}) + releaseOrigin := make(chan struct{}) + var releaseOnce sync.Once + release := func() { + releaseOnce.Do(func() { close(releaseOrigin) }) + } + defer release() + + _, originURL := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + close(originStarted) + <-releaseOrigin + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) + + results := make(chan *httptest.ResponseRecorder, 1) + go func() { + results <- doForwardProxyRequest(proxy, "GET", originURL+"/bucket/in-flight.parquet", http.Header{"Range": []string{"bytes=0-1"}}) + }() + waitForSignal(t, originStarted, "timed out waiting for origin request to start") + + if got := gaugeValue(t, originFetchInFlight); got != inFlightBefore+1 { + t.Fatalf("in-flight origin fetches during request = %v, want %v", got, inFlightBefore+1) + } + release() + + rec := waitForRecorder(t, results, "timed out waiting for proxy response") + if rec.Code != http.StatusPartialContent { + t.Fatalf("status = %d, want 206", rec.Code) + } + if got := gaugeValue(t, originFetchInFlight); got != inFlightBefore { + t.Fatalf("in-flight origin fetches after request = %v, want %v", got, inFlightBefore) + } +} + +func TestHandleProxyOriginFetchFailureOutcomeMetrics(t *testing.T) { + tests := []struct { + name string + label string + response *http.Response + err error + wantStatus int + }{ + { + name: "http error", + label: "http_error", + response: &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("retry later")), + }, + wantStatus: http.StatusServiceUnavailable, + }, + { + name: "timeout", + label: "timeout", + err: context.DeadlineExceeded, + wantStatus: http.StatusBadGateway, + }, + { + name: "generic error", + label: "error", + err: fmt.Errorf("transport boom"), + wantStatus: http.StatusBadGateway, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + proxy := newTestProxy(t) + proxy.originRetryMaxAttempts = 1 + proxy.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return tt.response, tt.err + })} + before := counterValue(t, originFetchesTotal.WithLabelValues(tt.label)) + + rec := doForwardProxyRequest(proxy, "GET", "http://origin.test/bucket/failure.parquet", http.Header{"Range": []string{"bytes=0-20"}}) + if rec.Code != tt.wantStatus { + t.Fatalf("status = %d, want %d", rec.Code, tt.wantStatus) + } + if got := counterValue(t, originFetchesTotal.WithLabelValues(tt.label)); got != before+1 { + t.Fatalf("origin %s metric = %v, want %v", tt.label, got, before+1) + } + }) + } +} + +func TestOriginFetchMetricLabels(t *testing.T) { + outcomeTests := []struct { + name string + err error + want string + }{ + {"success", nil, "success"}, + {"http error", &originStatusError{status: http.StatusServiceUnavailable}, "http_error"}, + {"canceled", context.Canceled, "canceled"}, + {"deadline", context.DeadlineExceeded, "timeout"}, + {"net timeout", timeoutNetError{}, "timeout"}, + {"generic", fmt.Errorf("boom"), "error"}, + } + for _, tt := range outcomeTests { + t.Run("outcome "+tt.name, func(t *testing.T) { + if got := originFetchOutcome(tt.err); got != tt.want { + t.Fatalf("originFetchOutcome(%v) = %q, want %q", tt.err, got, tt.want) + } + }) + } + + reasonTests := []struct { + name string + err error + want string + }{ + {"http", &originStatusError{status: http.StatusTooManyRequests}, "http_429"}, + {"deadline", context.DeadlineExceeded, "timeout"}, + {"net timeout", timeoutNetError{}, "timeout"}, + {"connection reset", fmt.Errorf("read: connection reset by peer"), "connection_reset"}, + {"connection refused", fmt.Errorf("dial tcp: connection refused"), "connection_refused"}, + {"unexpected eof", fmt.Errorf("unexpected EOF"), "unexpected_eof"}, + {"generic", fmt.Errorf("boom"), "transport_error"}, + } + for _, tt := range reasonTests { + t.Run("reason "+tt.name, func(t *testing.T) { + if got := originRetryReason(tt.err); got != tt.want { + t.Fatalf("originRetryReason(%v) = %q, want %q", tt.err, got, tt.want) + } + }) + } +} + +func TestHandleProxyMetricsOnlyDoesNotRejectConcurrentOriginMisses(t *testing.T) { + proxy := newTestProxy(t) + + const requests = 65 + var originCalls int32 + var activeOriginCalls int32 + var maxActiveOriginCalls int32 + releaseOrigin := make(chan struct{}) + var releaseOnce sync.Once + release := func() { + releaseOnce.Do(func() { close(releaseOrigin) }) + } + defer release() + + _, originURL := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + current := atomic.AddInt32(&activeOriginCalls, 1) + defer atomic.AddInt32(&activeOriginCalls, -1) + for { + maxActive := atomic.LoadInt32(&maxActiveOriginCalls) + if current <= maxActive || atomic.CompareAndSwapInt32(&maxActiveOriginCalls, maxActive, current) { + break + } + } + atomic.AddInt32(&originCalls, 1) + <-releaseOrigin + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) + + results := make(chan *httptest.ResponseRecorder, requests) + for i := 0; i < requests; i++ { + i := i + go func() { + url := fmt.Sprintf("%s/bucket/concurrent-%d.parquet", originURL, i) + headers := http.Header{"Range": []string{fmt.Sprintf("bytes=%d-%d", i, i+1)}} + results <- doForwardProxyRequest(proxy, "GET", url, headers) + }() + } + + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&originCalls) < requests && time.Now().Before(deadline) { + time.Sleep(10 * time.Millisecond) + } + if got := atomic.LoadInt32(&originCalls); got != requests { + release() + t.Fatalf("origin calls before release = %d, want %d; metrics-only PR must not limit origin concurrency", got, requests) + } + if got := atomic.LoadInt32(&maxActiveOriginCalls); got != requests { + release() + t.Fatalf("simultaneous origin calls before release = %d, want %d; metrics-only PR must not queue origin concurrency", got, requests) + } + release() + + var failed []string + for i := 0; i < requests; i++ { + rec := waitForRecorder(t, results, "timed out waiting for concurrent proxy response") + if rec.Code != http.StatusPartialContent { + failed = append(failed, fmt.Sprintf("response %d status = %d", i+1, rec.Code)) + } + } + if len(failed) > 0 { + t.Fatalf("metrics-only PR must not add local rejection behavior: %s", strings.Join(failed, ", ")) + } + if got := atomic.LoadInt32(&originCalls); got != requests { + t.Fatalf("origin calls = %d, want %d; metrics-only PR must not limit origin concurrency", got, requests) + } +} + func TestHandleProxyDoesNotRetryTerminalOriginStatuses(t *testing.T) { tests := []struct { name string