diff --git a/cmd/cachewd/main.go b/cmd/cachewd/main.go index c1222e8..de02b51 100644 --- a/cmd/cachewd/main.go +++ b/cmd/cachewd/main.go @@ -209,7 +209,7 @@ func newMux(ctx context.Context, cr *cache.Registry, mr *metadatadb.Registry, sr http.DefaultServeMux.ServeHTTP(w, r) })) - handler, _, loaded, err := config.Load(ctx, cr, mr, sr, providersConfigHCL, mux, vars) + handler, loaded, err := config.Load(ctx, cr, mr, sr, providersConfigHCL, mux, vars) if err != nil { return nil, errors.Errorf("load config: %w", err) } diff --git a/internal/config/config.go b/internal/config/config.go index 6e7167c..b2f5fed 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -154,47 +154,49 @@ func Load( ast *hcl.AST, mux *http.ServeMux, vars map[string]string, -) (http.Handler, metadatadb.Backend, []strategy.Readier, error) { +) (http.Handler, []strategy.Readier, error) { logger := logging.FromContext(ctx) expandVars(ast, vars) classified, err := classifyBlocks(ast) if err != nil { - return nil, nil, nil, err + return nil, nil, err } var caches []cache.Cache for _, block := range classified.caches { name, inner, err := unwrapBlock(block) if err != nil { - return nil, nil, nil, err + return nil, nil, err } c, err := cr.Create(ctx, name, inner, vars) if err != nil { - return nil, nil, nil, errors.Errorf("%s: %w", block.Pos, err) + return nil, nil, errors.Errorf("%s: %w", block.Pos, err) } caches = append(caches, c) } if len(caches) == 0 { - return nil, nil, nil, errors.Errorf("%s: expected at least one cache backend", ast.Pos) + return nil, nil, errors.Errorf("%s: expected at least one cache backend", ast.Pos) } if classified.metadata == nil { - return nil, nil, nil, errors.Errorf("%s: expected a metadata backend", ast.Pos) + return nil, nil, errors.Errorf("%s: expected a metadata backend", ast.Pos) } metaName, metaInner, err := unwrapBlock(classified.metadata) if err != nil { - return nil, nil, nil, err + return nil, nil, err } metadata, err := mr.Create(ctx, metaName, metaInner, vars) if err != nil { - return nil, nil, nil, errors.Errorf("%s: %w", classified.metadata.Pos, err) + return nil, nil, errors.Errorf("%s: %w", classified.metadata.Pos, err) } cache := cache.MaybeNewTiered(ctx, caches) logger.DebugContext(ctx, "Cache backend", "cache", cache) + metadataStore := metadatadb.New(ctx, metadata) + // Second pass, instantiate strategies and bind them to the mux. // Collect strategies that implement Interceptor separately — they need // to run before ServeMux route matching, not as mux routes. Strategies @@ -207,7 +209,10 @@ func Load( mlog := &loggingMux{logger: slogger, mux: mux} s, err := sr.Create(ctx, name, block, cache, mlog, vars) if err != nil { - return nil, nil, nil, errors.Errorf("%s: %w", block.Pos, err) + return nil, nil, errors.Errorf("%s: %w", block.Pos, err) + } + if mc, ok := s.(strategy.MetadataConsumer); ok { + mc.SetMetadataStore(metadataStore) } if interceptor, ok := s.(strategy.Interceptor); ok { interceptors = append(interceptors, interceptor) @@ -223,7 +228,7 @@ func Load( for i := len(interceptors) - 1; i >= 0; i-- { h = interceptors[i].Intercept(h) } - return h, metadata, readiers, nil + return h, readiers, nil } // expandVars expands environment variable references in HCL `*hcl.String` diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ae41838..01147bc 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -234,7 +234,7 @@ func TestLoadRequiresMetadataBackend(t *testing.T) { assert.NoError(t, err) ctx := logging.ContextWithLogger(context.Background(), slog.Default()) - _, _, _, err = Load(ctx, cr, mr, sr, ast, http.NewServeMux(), nil) + _, _, err = Load(ctx, cr, mr, sr, ast, http.NewServeMux(), nil) assert.Error(t, err) assert.Contains(t, err.Error(), "expected a metadata backend") } diff --git a/internal/strategy/api.go b/internal/strategy/api.go index cdc3f46..b2fead3 100644 --- a/internal/strategy/api.go +++ b/internal/strategy/api.go @@ -10,6 +10,7 @@ import ( "github.com/alecthomas/hcl/v2" "github.com/block/cachew/internal/cache" + "github.com/block/cachew/internal/metadatadb" ) // ErrNotFound is returned when a strategy is not found. @@ -121,6 +122,16 @@ type Interceptor interface { Intercept(next http.Handler) http.Handler } +// MetadataConsumer is an optional interface a Strategy may implement to receive +// the metadata store after construction. config.Load invokes SetMetadataStore +// on each consumer once the metadata backend has been built. This avoids the +// construction-order cycle where strategies are built inside config.Load but +// the metadata backend is also created there. +type MetadataConsumer interface { + Strategy + SetMetadataStore(*metadatadb.Store) +} + // Readier is an optional interface a Strategy may implement to gate the // /_readiness probe on background warm-up completing. The HTTP listener and // /_liveness come up immediately so the kubelet doesn't restart the pod, but diff --git a/internal/strategy/git/git.go b/internal/strategy/git/git.go index e76ecba..381253c 100644 --- a/internal/strategy/git/git.go +++ b/internal/strategy/git/git.go @@ -25,6 +25,7 @@ import ( "github.com/block/cachew/internal/githubapp" "github.com/block/cachew/internal/jobscheduler" "github.com/block/cachew/internal/logging" + "github.com/block/cachew/internal/metadatadb" "github.com/block/cachew/internal/snapshot" "github.com/block/cachew/internal/strategy" ) @@ -59,6 +60,7 @@ type Strategy struct { coldSnapshotMu sync.Map // keyed by upstream URL, values are *coldSnapshotEntry deferredRestoreOnce sync.Map // keyed by upstream URL, ensures at most one deferred restore per repo metrics *gitMetrics + repoCounts *RepoCounts ready atomic.Bool } @@ -181,12 +183,30 @@ func New( var _ strategy.Strategy = (*Strategy)(nil) var _ strategy.Readier = (*Strategy)(nil) +var _ strategy.MetadataConsumer = (*Strategy)(nil) // Ready reports whether startup warm-up has completed. func (s *Strategy) Ready() bool { return s.ready.Load() } +// SetMetadataStore enables the per-repo clone histogram and schedules its +// daily reaper. Called by config.Load after the metadata backend is built. +func (s *Strategy) SetMetadataStore(store *metadatadb.Store) { + if store == nil { + return + } + s.repoCounts = NewRepoCounts(store.Namespace("git")) + logging.FromContext(s.ctx).InfoContext(s.ctx, "Per-repo clone histogram enabled", + "retention_days", s.repoCounts.retentionDays) + s.scheduler.SubmitPeriodicJob("repo-counts-reaper", "reap-repo-counts", defaultRepoCountsReapInterval, func(ctx context.Context) error { + if deleted := s.repoCounts.Reap(); deleted > 0 { + logging.FromContext(ctx).InfoContext(ctx, "Reaped stale repo clone counts", "deleted", deleted) + } + return nil + }) +} + func (s *Strategy) warmExistingRepos(ctx context.Context) error { logger := logging.FromContext(ctx) existing, err := s.cloneManager.DiscoverExisting(ctx) @@ -301,6 +321,13 @@ func (s *Strategy) handleGitRequest(w http.ResponseWriter, r *http.Request, host return } + // Increment after GetOrCreate so unvalidated URLs can't bloat the keyspace. + if isClone, cerr := RequestIsClone(pathValue, r); cerr != nil { + logger.WarnContext(ctx, "Failed to inspect upload-pack body for clone counting", "error", cerr) + } else if isClone { + s.repoCounts.IncrementClone(upstreamURL) + } + state := repo.State() isInfoRefs := strings.HasSuffix(pathValue, "/info/refs") diff --git a/internal/strategy/git/git_test.go b/internal/strategy/git/git_test.go index 38cefb4..e2a4ff2 100644 --- a/internal/strategy/git/git_test.go +++ b/internal/strategy/git/git_test.go @@ -16,6 +16,7 @@ import ( "github.com/block/cachew/internal/githubapp" "github.com/block/cachew/internal/jobscheduler" "github.com/block/cachew/internal/logging" + "github.com/block/cachew/internal/metadatadb" "github.com/block/cachew/internal/strategy/git" ) @@ -324,6 +325,25 @@ func TestNewMissingSnapshotBinaries(t *testing.T) { }) } +// TestSetMetadataStore verifies that wiring a metadata store after construction +// enables the per-repo histogram. The behaviour of the resulting RepoCounts is +// covered in repocounts_test.go. +func TestSetMetadataStore(t *testing.T) { + _, ctx := logging.Configure(context.Background(), logging.Config{}) + + mux := newTestMux() + cm := gitclone.NewManagerProvider(ctx, gitclone.Config{ + MirrorRoot: filepath.Join(t.TempDir(), "clones"), + FetchInterval: 15, + }, nil) + s, err := git.New(ctx, git.Config{}, newTestScheduler(ctx, t), nil, mux, cm, + func() (*githubapp.TokenManager, error) { return nil, nil }) //nolint:nilnil + assert.NoError(t, err) + + store := metadatadb.New(ctx, metadatadb.NewMemoryBackend()) + s.SetMetadataStore(store) +} + func TestParseGitRefs(t *testing.T) { _, ctx := logging.Configure(context.Background(), logging.Config{}) _ = ctx diff --git a/internal/strategy/git/repocounts.go b/internal/strategy/git/repocounts.go new file mode 100644 index 0000000..24e7919 --- /dev/null +++ b/internal/strategy/git/repocounts.go @@ -0,0 +1,208 @@ +package git + +import ( + "bytes" + "compress/gzip" + "io" + "net/http" + "sort" + "strings" + "time" + + "github.com/alecthomas/errors" + + "github.com/block/cachew/internal/metadatadb" +) + +const ( + repoCountsMapName = "repo_clone_counts" + // '|' cannot appear in a Git upstream URL, so the split is unambiguous. + repoCountsKeySeparator = "|" + defaultRepoCountsRetentionDays = 90 + defaultRepoCountsReapInterval = 24 * time.Hour +) + +// RepoCounts tracks per-repository clone counts in a daily-bucketed IntMap. +// All methods are nil-safe. +type RepoCounts struct { + counts *metadatadb.IntMap[string] + now func() time.Time + retentionDays int +} + +// NewRepoCounts returns nil if ns is nil so callers don't need a separate +// "no metadata configured" code path. +func NewRepoCounts(ns *metadatadb.Namespace) *RepoCounts { + if ns == nil { + return nil + } + return &RepoCounts{ + counts: metadatadb.NewIntMap[string](ns, repoCountsMapName), + now: time.Now, + retentionDays: defaultRepoCountsRetentionDays, + } +} + +// IncrementClone bumps today's bucket (UTC) for upstreamURL. +func (r *RepoCounts) IncrementClone(upstreamURL string) { + if r == nil || upstreamURL == "" { + return + } + r.counts.Add(repoCountsKey(upstreamURL, r.now()), 1) +} + +// uploadPackBodyInspectLimit bounds CPU spend on hostile bodies; real bodies +// are well under this, including multi-ref fetches against monorepos. +const uploadPackBodyInspectLimit = 64 * 1024 + +// lsRefsLookahead caps the prefix scanned for v2 command=ls-refs. The command +// is in the first pkt-line, so this is generous. +const lsRefsLookahead = 256 + +//nolint:gochecknoglobals // hoisted to avoid per-request []byte allocation +var ( + lsRefsNeedle = []byte("command=ls-refs") + haveNeedle = []byte("have ") +) + +// RequestIsClone reports whether r is an initial clone of a Git repo. +// +// Detection: POST /git-upload-pack whose pkt-line body contains no "have " +// line. v2 command=ls-refs (discovery) is also rejected. The body is buffered +// and replayed via io.NopCloser. +func RequestIsClone(pathValue string, r *http.Request) (bool, error) { + if r.Method != http.MethodPost || !strings.HasSuffix(pathValue, "/git-upload-pack") { + return false, nil + } + if r.Body == nil || r.Body == http.NoBody { + return true, nil + } + prefix := make([]byte, uploadPackBodyInspectLimit) + n, err := io.ReadFull(r.Body, prefix) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { + return false, errors.Wrap(err, "read upload-pack body") + } + prefix = prefix[:n] + // Replay: prefix + the rest of the body, untouched. ContentLength stays + // correct since we haven't dropped any bytes. + original := r.Body + r.Body = struct { + io.Reader + io.Closer + }{io.MultiReader(bytes.NewReader(prefix), original), original} + + inspect := prefix + if strings.EqualFold(r.Header.Get("Content-Encoding"), "gzip") { + if gr, gzErr := gzip.NewReader(bytes.NewReader(prefix)); gzErr == nil { + decoded, _ := io.ReadAll(io.LimitReader(gr, uploadPackBodyInspectLimit)) //nolint:errcheck // best-effort + _ = gr.Close() //nolint:errcheck // best-effort + inspect = decoded + } + } + + head := inspect + if len(head) > lsRefsLookahead { + head = head[:lsRefsLookahead] + } + if bytes.Contains(head, lsRefsNeedle) { + return false, nil + } + // Trailing space disambiguates from capability tokens and repo names. + if bytes.Contains(inspect, haveNeedle) { + return false, nil + } + return true, nil +} + +// RepoCount is one row of the histogram. +type RepoCount struct { + Repo string `json:"repo"` + Count int64 `json:"count"` +} + +// TopRepos aggregates buckets over the last windowDays days (UTC) and returns +// rows sorted by count descending then repo name ascending. windowDays <= 0 +// means no window; limit <= 0 means no truncation. +func (r *RepoCounts) TopRepos(windowDays, limit int) []RepoCount { + if r == nil { + return nil + } + entries := r.counts.Entries() + var cutoff time.Time + if windowDays > 0 { + cutoff = r.now().UTC().AddDate(0, 0, -windowDays+1).Truncate(24 * time.Hour) + } + agg := make(map[string]int64, len(entries)) + for k, v := range entries { + repo, day, ok := splitRepoCountsKey(k) + if !ok { + continue + } + if !cutoff.IsZero() && day.Before(cutoff) { + continue + } + agg[repo] += v + } + out := make([]RepoCount, 0, len(agg)) + for repo, count := range agg { + out = append(out, RepoCount{Repo: repo, Count: count}) + } + sort.Slice(out, func(i, j int) bool { + if out[i].Count != out[j].Count { + return out[i].Count > out[j].Count + } + return out[i].Repo < out[j].Repo + }) + if limit > 0 && len(out) > limit { + out = out[:limit] + } + return out +} + +// Reap deletes buckets older than the retention window and any malformed keys, +// returning the number of entries deleted. +func (r *RepoCounts) Reap() int { + if r == nil { + return 0 + } + entries := r.counts.Entries() + if len(entries) == 0 { + return 0 + } + cutoff := r.now().UTC().AddDate(0, 0, -r.retentionDays).Truncate(24 * time.Hour) + var deleted int + for k := range entries { + _, day, ok := splitRepoCountsKey(k) + if !ok { + r.counts.Delete(k) + deleted++ + continue + } + if day.Before(cutoff) { + r.counts.Delete(k) + deleted++ + } + } + return deleted +} + +func repoCountsKey(upstreamURL string, now time.Time) string { + return upstreamURL + repoCountsKeySeparator + now.UTC().Format("2006-01-02") +} + +func splitRepoCountsKey(k string) (repo string, day time.Time, ok bool) { + idx := strings.LastIndex(k, repoCountsKeySeparator) + if idx < 0 { + return "", time.Time{}, false + } + repo = k[:idx] + dateStr := k[idx+1:] + if repo == "" || dateStr == "" { + return "", time.Time{}, false + } + d, err := time.Parse("2006-01-02", dateStr) + if err != nil { + return "", time.Time{}, false + } + return repo, d, true +} diff --git a/internal/strategy/git/repocounts_test.go b/internal/strategy/git/repocounts_test.go new file mode 100644 index 0000000..3ce51e7 --- /dev/null +++ b/internal/strategy/git/repocounts_test.go @@ -0,0 +1,261 @@ +package git //nolint:testpackage // white-box testing required for clock and retention injection + +import ( + "bytes" + "compress/gzip" + "context" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/alecthomas/assert/v2" + + "github.com/block/cachew/internal/logging" + "github.com/block/cachew/internal/metadatadb" +) + +func newTestRepoCounts(t *testing.T, now func() time.Time) *RepoCounts { + t.Helper() + ctx := logging.ContextWithLogger(context.Background(), slog.Default()) + store := metadatadb.New(ctx, metadatadb.NewMemoryBackend()) + rc := NewRepoCounts(store.Namespace("git")) + if now != nil { + rc.now = now + } + return rc +} + +func TestRepoCountsNilSafe(t *testing.T) { + var rc *RepoCounts + rc.IncrementClone("https://github.com/foo/bar") + assert.Equal(t, 0, rc.Reap()) + assert.Zero(t, len(rc.TopRepos(0, 0))) + assert.Zero(t, NewRepoCounts(nil)) +} + +func TestRepoCountsReapEmpty(t *testing.T) { + rc := newTestRepoCounts(t, nil) + assert.Equal(t, 0, rc.Reap(), "reap on an empty namespace should report zero deletions") +} + +func TestRepoCountsIncrementAndAggregate(t *testing.T) { + clock := time.Date(2026, 5, 5, 12, 0, 0, 0, time.UTC) + rc := newTestRepoCounts(t, func() time.Time { return clock }) + + for range 5 { + rc.IncrementClone("https://github.com/foo/popular") + } + for range 2 { + rc.IncrementClone("https://github.com/foo/quiet") + } + // Bump the popular repo on a previous day too. + clock = clock.AddDate(0, 0, -1) + for range 3 { + rc.IncrementClone("https://github.com/foo/popular") + } + clock = clock.AddDate(0, 0, 1) + + top := rc.TopRepos(0, 0) + assert.Equal(t, []RepoCount{ + {Repo: "https://github.com/foo/popular", Count: 8}, + {Repo: "https://github.com/foo/quiet", Count: 2}, + }, top) +} + +func TestRepoCountsWindowFilter(t *testing.T) { + clock := time.Date(2026, 5, 5, 12, 0, 0, 0, time.UTC) + rc := newTestRepoCounts(t, func() time.Time { return clock }) + + // 10 days ago: only "old" repo gets hits. + clock = time.Date(2026, 4, 25, 12, 0, 0, 0, time.UTC) + for range 4 { + rc.IncrementClone("https://github.com/foo/old") + } + // Today: only "new" repo gets hits. + clock = time.Date(2026, 5, 5, 12, 0, 0, 0, time.UTC) + rc.IncrementClone("https://github.com/foo/new") + + all := rc.TopRepos(0, 0) + assert.Equal(t, 2, len(all), "no window includes both repos") + + last7 := rc.TopRepos(7, 0) + assert.Equal(t, []RepoCount{{Repo: "https://github.com/foo/new", Count: 1}}, last7) +} + +func TestRepoCountsLimit(t *testing.T) { + clock := time.Date(2026, 5, 5, 12, 0, 0, 0, time.UTC) + rc := newTestRepoCounts(t, func() time.Time { return clock }) + + for i, repo := range []string{"alpha", "bravo", "charlie", "delta"} { + for j := 0; j <= i; j++ { + rc.IncrementClone("https://github.com/foo/" + repo) + } + } + + top2 := rc.TopRepos(0, 2) + assert.Equal(t, []RepoCount{ + {Repo: "https://github.com/foo/delta", Count: 4}, + {Repo: "https://github.com/foo/charlie", Count: 3}, + }, top2) +} + +func TestRepoCountsReap(t *testing.T) { + clock := time.Date(2026, 5, 5, 12, 0, 0, 0, time.UTC) + rc := newTestRepoCounts(t, func() time.Time { return clock }) + rc.retentionDays = 30 + + // Old entry: 100 days ago. + clock = time.Date(2026, 1, 25, 12, 0, 0, 0, time.UTC) + rc.IncrementClone("https://github.com/foo/ancient") + // Recent entry: 5 days ago. + clock = time.Date(2026, 4, 30, 12, 0, 0, 0, time.UTC) + rc.IncrementClone("https://github.com/foo/fresh") + + clock = time.Date(2026, 5, 5, 12, 0, 0, 0, time.UTC) + assert.Equal(t, 1, rc.Reap(), "one stale bucket should be deleted") + + remaining := rc.TopRepos(0, 0) + assert.Equal(t, []RepoCount{{Repo: "https://github.com/foo/fresh", Count: 1}}, remaining) +} + +func TestRepoCountsReapMalformedKeys(t *testing.T) { + clock := time.Date(2026, 5, 5, 12, 0, 0, 0, time.UTC) + rc := newTestRepoCounts(t, func() time.Time { return clock }) + + // Inject a malformed key directly via Add. + rc.counts.Add("not-a-real-key", 7) + rc.IncrementClone("https://github.com/foo/valid") + + assert.Equal(t, 1, rc.Reap(), "the malformed key should be deleted, the valid one preserved") + remaining := rc.TopRepos(0, 0) + assert.Equal(t, []RepoCount{{Repo: "https://github.com/foo/valid", Count: 1}}, remaining) +} + +func TestRepoCountsKeyRoundTrip(t *testing.T) { + now := time.Date(2026, 5, 5, 12, 0, 0, 0, time.UTC) + url := "https://github.com/squareup/some-repo" + k := repoCountsKey(url, now) + repo, day, ok := splitRepoCountsKey(k) + assert.True(t, ok) + assert.Equal(t, url, repo) + assert.Equal(t, time.Date(2026, 5, 5, 0, 0, 0, 0, time.UTC), day) +} + +func TestRequestIsClone(t *testing.T) { + gzipBytes := func(t *testing.T, in []byte) []byte { + t.Helper() + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + _, err := gw.Write(in) + assert.NoError(t, err) + assert.NoError(t, gw.Close()) + return buf.Bytes() + } + + v1Clone := []byte("0067want abc123 multi_ack_detailed no-done side-band-64k thin-pack ofs-delta agent=git/2.40\n0009done\n0000") + v1Fetch := []byte("0067want abc123 multi_ack_detailed no-done side-band-64k thin-pack ofs-delta agent=git/2.40\n0032have def4560000000000000000000000000\n0009done\n0000") + v2Clone := []byte("0014command=fetch\n0009want abc1230009done0000") + v2Fetch := []byte("0014command=fetch\n0009want abc1230009have def4560009done0000") + + tests := []struct { + name string + path string + method string + body []byte + contentEncoding string + want bool + }{ + {name: "GETInfoRefs", path: "org/repo.git/info/refs", method: http.MethodGet, want: false}, + {name: "GETUploadPack", path: "org/repo.git/git-upload-pack", method: http.MethodGet, want: false}, + {name: "POSTUnknownPath", path: "org/repo.git/something-else", method: http.MethodPost, body: v1Clone, want: false}, + {name: "POSTLsRefs", path: "org/repo.git/git-upload-pack", method: http.MethodPost, body: []byte("0014command=ls-refs\n0000"), want: false}, + {name: "POSTV1Clone", path: "org/repo.git/git-upload-pack", method: http.MethodPost, body: v1Clone, want: true}, + {name: "POSTV1Fetch", path: "org/repo.git/git-upload-pack", method: http.MethodPost, body: v1Fetch, want: false}, + {name: "POSTV2Clone", path: "org/repo.git/git-upload-pack", method: http.MethodPost, body: v2Clone, want: true}, + {name: "POSTV2Fetch", path: "org/repo.git/git-upload-pack", method: http.MethodPost, body: v2Fetch, want: false}, + {name: "POSTEmptyBody", path: "org/repo.git/git-upload-pack", method: http.MethodPost, body: nil, want: true}, + {name: "POSTLsRefsGzipped", path: "org/repo.git/git-upload-pack", method: http.MethodPost, body: gzipBytes(t, []byte("0014command=ls-refs\n0000")), contentEncoding: "gzip", want: false}, + {name: "POSTV2CloneGzipped", path: "org/repo.git/git-upload-pack", method: http.MethodPost, body: gzipBytes(t, v2Clone), contentEncoding: "gzip", want: true}, + {name: "POSTV2FetchGzipped", path: "org/repo.git/git-upload-pack", method: http.MethodPost, body: gzipBytes(t, v2Fetch), contentEncoding: "gzip", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var bodyReader io.Reader + if tt.body != nil { + bodyReader = bytes.NewReader(tt.body) + } + r := httptest.NewRequest(tt.method, "/"+tt.path, bodyReader) + if tt.contentEncoding != "" { + r.Header.Set("Content-Encoding", tt.contentEncoding) + } + got, err := RequestIsClone(tt.path, r) + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + + // The body must remain readable for downstream handlers. + if tt.body != nil { + replayed, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, tt.body, replayed) + } + }) + } +} + +// TestRequestIsCloneBoundsBodyRead guards against a regression where the +// entire body was buffered before the inspect cap was applied, allowing a +// large or hostile body to OOM the proxy. +func TestRequestIsCloneBoundsBodyRead(t *testing.T) { + prefix := []byte("0067want abc123 multi_ack_detailed no-done side-band-64k thin-pack ofs-delta agent=git/2.40\n0009done\n0000") + tail := bytes.Repeat([]byte("x"), 4*1024*1024) // 4MiB after the prefix + cr := &countingReader{Reader: io.MultiReader(bytes.NewReader(prefix), bytes.NewReader(tail))} + r := httptest.NewRequest(http.MethodPost, "/org/repo.git/git-upload-pack", cr) + r.ContentLength = int64(len(prefix) + len(tail)) + + got, err := RequestIsClone("org/repo.git/git-upload-pack", r) + assert.NoError(t, err) + assert.True(t, got) + // Inspection must not have pulled the full body into memory. + assert.True(t, cr.n <= int64(2*uploadPackBodyInspectLimit), + "RequestIsClone read %d bytes upfront; expected <= %d", cr.n, 2*uploadPackBodyInspectLimit) + + // Downstream must still be able to consume the entire body. + replayed, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, len(prefix)+len(tail), len(replayed)) +} + +type countingReader struct { + io.Reader + n int64 +} + +func (c *countingReader) Read(p []byte) (int, error) { + n, err := c.Reader.Read(p) + c.n += int64(n) + return n, err +} + +// TestRequestIsCloneFetchWithManyWants guards against a regression where a +// fetch with a long list of "want" lines preceding any "have" line gets +// classified as a clone because the inspect window cuts off before the haves. +func TestRequestIsCloneFetchWithManyWants(t *testing.T) { + var b strings.Builder + b.WriteString("0014command=fetch\n") + for range 500 { + b.WriteString("0009want abc123\n") + } + b.WriteString("0009have def456\n") + b.WriteString("0000") + body := []byte(b.String()) + assert.True(t, len(body) > 1024, "body must exceed the prior inspect limit so this test is meaningful") + r := httptest.NewRequest(http.MethodPost, "/org/repo.git/git-upload-pack", bytes.NewReader(body)) + got, err := RequestIsClone("org/repo.git/git-upload-pack", r) + assert.NoError(t, err) + assert.False(t, got, "fetch with many wants must still be classified as a fetch, not a clone") +}