diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 0065213..0646ec4 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -14,12 +14,12 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: '1.20' + go-version-file: 'go.mod' - name: Build run: go build -v ./... diff --git a/README.md b/README.md index 25d4511..c6e2088 100644 --- a/README.md +++ b/README.md @@ -45,9 +45,8 @@ r := runnable.New(func(ctx context.Context) error { select { case <-ctx.Done(): return nil - default: + case <-time.After(time.Second): } - time.Sleep(1 * time.Second) fmt.Println("Running...") } }) @@ -71,7 +70,8 @@ if err != nil { ### Runnable Function with timeout ```go fmt.Println("Simple function with timeout...") -ctxWithTimeout, _ := context.WithTimeout(context.Background(), 5*time.Second) +ctxWithTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) +defer cancel() err = runnable.New(func(ctx context.Context) error { fmt.Println("Starting...") defer fmt.Println("Stopping...") @@ -80,9 +80,8 @@ err = runnable.New(func(ctx context.Context) error { select { case <-ctx.Done(): return nil - default: + case <-time.After(time.Second): } - time.Sleep(1 * time.Second) fmt.Println("Running...") } }).Run(ctxWithTimeout) @@ -91,36 +90,100 @@ if err != nil { } ``` -### Runnable Function with retry +### Adapters + +Cross-cutting behaviors that aren't part of the core lifecycle live in +the `runnable/adapters` subpackage as chi-style middleware: each +`runnable.Adapter` has the shape `func(next RunFunc) RunFunc`. Apply +them with `runnable.WithAdapters` (left-to-right = outermost-to-innermost): + ```go -fmt.Println("Simple function with retry...") -errorReturned := false -err = runnable.New(func(ctx context.Context) error { - fmt.Println("Starting...") - defer fmt.Println("Stopping...") - - if !errorReturned { - errorReturned = true - return fmt.Errorf("error") - } - - // do something - for i := 0; i < 5; i++ { - select { - case <-ctx.Done(): - return nil - default: - } - time.Sleep(1 * time.Second) - fmt.Println("Running...") - } - return nil -}, runnable.WithRetry(3, runnable.ResetNever)).Run(context.Background()) -if err != nil { - fmt.Println(err) -} +r := runnable.New(reconcile, runnable.WithAdapters( + adapters.Draining(10*time.Second), + adapters.Recovering(reportPanic), + adapters.Retry(3, time.Minute), + adapters.Ticker(30*time.Second), +)) ``` +**Draining** — graceful shutdown with a grace window. When the outer +ctx is cancelled, the wrapped work has `timeout` to return via +`adapters.Stopping(ctx)` before its ctx is force-cancelled and +`adapters.ErrDrainTimedOut` is returned. + +**Ticker** — calls the wrapped work once per interval until ctx is +cancelled or the work returns an error. Composes with Draining: an +in-flight tick is allowed to finish before the loop exits. + +**Recovering** — turns panics in the wrapped work into errors and +invokes the optional handler before returning. Place inside Draining +when both are in use. + +**Retry** — re-invokes the wrapped work up to `maxRetries` times on +non-context errors. If `resetAfter` is non-zero and at least that long +has passed since the previous attempt, the retry budget resets. + +Inside long-running work, always select on both `ctx.Done()` and +`adapters.Stopping(ctx)` — `Stopping` signals drain start, `ctx.Done()` +fires only when the drain timer expires. + +A full SIGTERM-safe service shape lives in +[`examples/ticker-with-drain`](examples/ticker-with-drain/main.go). + +### Migrating from v0.1 to v0.2 + +v0.2 moves drain, ticker, retry, and panic recovery out of the core +package. `WithDrain`, `NewTicker`, `WithRetry`, and `WithRecoverer` +are removed; their replacements live at `runnable/adapters` as +chi-style middleware. + +Before (v0.1): + + r := runnable.NewTicker(30*time.Second, doWork, + runnable.WithDrain(10*time.Second), + runnable.WithRecoverer(reporter, nil), + runnable.WithRetry(3, time.Minute), + ) + +After (v0.2): + + r := runnable.New(doWork, runnable.WithAdapters( + adapters.Draining(10*time.Second), + adapters.Recovering(handler), + adapters.Retry(3, time.Minute), + adapters.Ticker(30*time.Second), + )) + +Symbol mapping: + +- `runnable.WithDrain` → `adapters.Draining` under `runnable.WithAdapters`. +- `runnable.NewTicker` → `adapters.Ticker` under `runnable.WithAdapters` + (no longer takes the work argument; pass work to `runnable.New`). +- `runnable.WithRetry` / `runnable.ResetNever` → `adapters.Retry` / + `adapters.ResetNever`. +- `runnable.WithRecoverer` → `adapters.Recovering` with a single + `PanicHandler` callback (the two-interface `RecoveryReporter` / + `StackPrinter` split is gone). +- `runnable.Stopping` → `adapters.Stopping`. +- `runnable.ErrDrainTimedOut` → `adapters.ErrDrainTimedOut`. + +**Behavioral change:** `Stop(ctx)`'s ctx no longer shortens the drain +window. In v0.1, a caller ctx shorter than `WithDrain`'s timeout would +force-cancel mid-drain. In v0.2, `Stop`'s ctx only governs how long +the caller waits for `Stop` to return; the drain runs on its own +fixed-duration timer regardless. If you need a shorter drain budget, +configure `Draining` with the shorter duration. + +**Status.Restarts removed.** The `Restarts` field on `Status` counted +`WithRetry` re-entries via the deprecated `onStart` coupling; with +retry moved into adapters it had no clean way to surface. Pending a +proper event/observer hook in a later release. + +**NewGroup interaction:** drain-enabled children of `NewGroup` now +drain when the group is stopped (v0.1 silently bypassed the child's +drain). No code change required at call sites — the adapter design +fixes this by construction. + ### Runnable Object ```go package main diff --git a/adapter.go b/adapter.go new file mode 100644 index 0000000..896163c --- /dev/null +++ b/adapter.go @@ -0,0 +1,28 @@ +package runnable + +import "context" + +// RunFunc is the lifecycle function wrapped by runnable.New. +type RunFunc func(ctx context.Context) error + +// Adapter wraps a RunFunc with cross-cutting behavior, mirroring the +// chi middleware shape. Concrete adapters live in runnable/adapters. +type Adapter func(next RunFunc) RunFunc + +type withAdapters struct { + adapters []Adapter +} + +// WithAdapters wraps the runnable's runFunc left-to-right (first listed +// = outermost). Apply order across Options matters. +func WithAdapters(adapters ...Adapter) Option { + return &withAdapters{adapters: adapters} +} + +func (w *withAdapters) apply(r *runnable) { + next := RunFunc(r.runFunc) + for i := len(w.adapters) - 1; i >= 0; i-- { + next = w.adapters[i](next) + } + r.runFunc = next +} diff --git a/adapters/doc.go b/adapters/doc.go new file mode 100644 index 0000000..e4b55ad --- /dev/null +++ b/adapters/doc.go @@ -0,0 +1,9 @@ +// Package adapters provides chi-style middleware around the runnable +// RunFunc signature. Each constructor returns a runnable.Adapter; +// compose them via runnable.WithAdapters (first listed = outermost): +// +// r := runnable.New(reconcile, runnable.WithAdapters( +// adapters.Draining(10*time.Second), +// adapters.Ticker(time.Second), +// )) +package adapters diff --git a/adapters/draining.go b/adapters/draining.go new file mode 100644 index 0000000..d8a00b3 --- /dev/null +++ b/adapters/draining.go @@ -0,0 +1,72 @@ +package adapters + +import ( + "context" + "errors" + "fmt" + "runtime/debug" + "time" + + "github.com/0xsequence/runnable" +) + +// ErrDrainTimedOut is returned by Draining when work did not exit +// within the drain timeout and was force-cancelled. +var ErrDrainTimedOut = errors.New("adapters: drain timed out") + +type stoppingKey struct{} + +// Stopping returns a channel that closes when Draining begins shutdown, +// or nil outside Draining. Select on this alongside ctx.Done() — ctx is +// force-cancelled only after the drain timer expires. +func Stopping(ctx context.Context) <-chan struct{} { + ch, _ := ctx.Value(stoppingKey{}).(<-chan struct{}) + return ch +} + +// Draining returns an Adapter that delays cancellation: when outerCtx +// is cancelled, next has up to timeout to return via Stopping(workCtx) +// before workCtx is force-cancelled and ErrDrainTimedOut is returned. +// Panics in next are recovered into an error (they would otherwise +// crash the process, since next runs on its own goroutine). +func Draining(timeout time.Duration) runnable.Adapter { + return func(next runnable.RunFunc) runnable.RunFunc { + return func(outerCtx context.Context) error { + // Decoupled from outerCtx so outer cancellation triggers drain + // rather than aborting next directly. + workCtx, cancelWork := context.WithCancel(context.WithoutCancel(outerCtx)) + defer cancelWork() + + stopping := make(chan struct{}) + workCtx = context.WithValue(workCtx, stoppingKey{}, (<-chan struct{})(stopping)) + + done := make(chan error, 1) + go func() { + defer func() { + if rec := recover(); rec != nil { + done <- fmt.Errorf("adapters: panic in draining work: %v\n%s", rec, debug.Stack()) + } + }() + done <- next(workCtx) + }() + + select { + case err := <-done: + return err + case <-outerCtx.Done(): + close(stopping) + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case err := <-done: + return err + case <-timer.C: + cancelWork() + <-done + return ErrDrainTimedOut + } + } + } +} diff --git a/adapters/draining_test.go b/adapters/draining_test.go new file mode 100644 index 0000000..83f4bbf --- /dev/null +++ b/adapters/draining_test.go @@ -0,0 +1,188 @@ +package adapters_test + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/0xsequence/runnable" + "github.com/0xsequence/runnable/adapters" +) + +func TestStopping_NilOutsideDraining(t *testing.T) { + ch := adapters.Stopping(context.Background()) + assert.Nil(t, ch, "Stopping(ctx) must be nil when ctx is not inside a Draining adapter") +} + +func TestDraining_WorkReturnsNaturallyViaStopping(t *testing.T) { + started := make(chan struct{}) + + work := func(ctx context.Context) error { + close(started) + select { + case <-adapters.Stopping(ctx): + return nil + case <-ctx.Done(): + return errors.New("ctx cancelled before Stopping") + } + } + + r := runnable.New(work, runnable.WithAdapters(adapters.Draining(1*time.Second))) + runErr := make(chan error, 1) + go func() { runErr <- r.Run(context.Background()) }() + + <-started + + start := time.Now() + require.NoError(t, r.Stop(context.Background())) + elapsed := time.Since(start) + assert.Less(t, elapsed, 500*time.Millisecond, "Stop returned long after drain completed") + + select { + case err := <-runErr: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("Run did not return") + } +} + +func TestDraining_TimerForcesCancelWhenWorkIgnoresStopping(t *testing.T) { + started := make(chan struct{}) + workErr := make(chan error, 1) + + work := func(ctx context.Context) error { + close(started) + <-ctx.Done() // ignore Stopping; wait for force-cancel + workErr <- ctx.Err() + return ctx.Err() + } + + r := runnable.New(work, runnable.WithAdapters(adapters.Draining(100*time.Millisecond))) + runErr := make(chan error, 1) + go func() { runErr <- r.Run(context.Background()) }() + + <-started + require.NoError(t, r.Stop(context.Background())) + + select { + case e := <-runErr: + require.ErrorIs(t, e, adapters.ErrDrainTimedOut) + case <-time.After(time.Second): + t.Fatal("Run did not return") + } + select { + case e := <-workErr: + require.ErrorIs(t, e, context.Canceled) + case <-time.After(time.Second): + t.Fatal("work did not exit via ctx.Done()") + } +} + +func TestDraining_OuterCtxCancelTriggersDrain(t *testing.T) { + // Same as previous but cancellation comes from outer ctx, not Stop. + started := make(chan struct{}) + + work := func(ctx context.Context) error { + close(started) + select { + case <-adapters.Stopping(ctx): + return nil + case <-ctx.Done(): + return errors.New("ctx cancelled before Stopping") + } + } + + r := runnable.New(work, runnable.WithAdapters(adapters.Draining(1*time.Second))) + ctx, cancel := context.WithCancel(context.Background()) + runErr := make(chan error, 1) + go func() { runErr <- r.Run(ctx) }() + + <-started + cancel() + + select { + case err := <-runErr: + require.NoError(t, err, "work should observe Stopping and exit cleanly") + case <-time.After(time.Second): + t.Fatal("Run did not return after outer ctx cancel") + } +} + +func TestDraining_ConcurrentStopsPreserveDrainSemantics(t *testing.T) { + started := make(chan struct{}) + drainObserved := make(chan struct{}) + var ctxCancelObserved atomic.Bool + + work := func(ctx context.Context) error { + close(started) + select { + case <-adapters.Stopping(ctx): + close(drainObserved) + return nil + case <-ctx.Done(): + ctxCancelObserved.Store(true) + return ctx.Err() + } + } + + r := runnable.New(work, runnable.WithAdapters(adapters.Draining(2*time.Second))) + go func() { _ = r.Run(context.Background()) }() + + <-started + + const callers = 10 + var wg sync.WaitGroup + errs := make([]error, callers) + for i := 0; i < callers; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + errs[i] = r.Stop(context.Background()) + }() + } + wg.Wait() + + for _, err := range errs { + if err != nil { + require.ErrorIs(t, err, runnable.ErrNotRunning) + } + } + + select { + case <-drainObserved: + default: + t.Fatal("work never observed Stopping(ctx); concurrent Stop bypassed drain") + } + assert.False(t, ctxCancelObserved.Load(), "drain bypassed: work saw ctx.Done()") +} + +func TestDraining_WorkErrorPropagatesWithoutDrain(t *testing.T) { + sentinel := errors.New("work failed") + work := func(ctx context.Context) error { return sentinel } + + r := runnable.New(work, runnable.WithAdapters(adapters.Draining(1*time.Second))) + err := r.Run(context.Background()) + require.ErrorIs(t, err, sentinel) +} + +func TestDraining_RecoversPanicAsError(t *testing.T) { + // Regression: panics in work run on Draining's spawned goroutine, + // not on the goroutine where outer recover defers live. Without + // internal recovery, a tick panic would crash the process. + work := func(ctx context.Context) error { + panic("boom") + } + + r := runnable.New(work, runnable.WithAdapters(adapters.Draining(1*time.Second))) + err := r.Run(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "boom", "panic value should be embedded in error") + assert.Contains(t, err.Error(), "panic in draining work", "error should identify itself as a recovered panic") +} diff --git a/adapters/recovering.go b/adapters/recovering.go new file mode 100644 index 0000000..9e7c598 --- /dev/null +++ b/adapters/recovering.go @@ -0,0 +1,36 @@ +package adapters + +import ( + "context" + "fmt" + "runtime/debug" + + "github.com/0xsequence/runnable" +) + +// PanicHandler observes a panic caught by Recovering. Runs on next's +// goroutine, so must not block. +type PanicHandler func(ctx context.Context, rec any, stack []byte) + +// Recovering returns an Adapter that converts panics from next into +// errors and invokes handler (if non-nil). Place inside Draining when +// both are used, so handler sees the panic before Draining's safety-net +// recovery formats it. +func Recovering(handler PanicHandler) runnable.Adapter { + return func(next runnable.RunFunc) runnable.RunFunc { + return func(ctx context.Context) (err error) { + defer func() { + rec := recover() + if rec == nil { + return + } + stack := debug.Stack() + if handler != nil { + handler(ctx, rec, stack) + } + err = fmt.Errorf("adapters: panic: %v", rec) + }() + return next(ctx) + } + } +} diff --git a/adapters/recovering_test.go b/adapters/recovering_test.go new file mode 100644 index 0000000..6dc9ba9 --- /dev/null +++ b/adapters/recovering_test.go @@ -0,0 +1,61 @@ +package adapters_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/0xsequence/runnable" + "github.com/0xsequence/runnable/adapters" +) + +func TestRecovering_TurnsPanicIntoError(t *testing.T) { + var captured any + handler := func(_ context.Context, rec any, _ []byte) { + captured = rec + } + + work := func(ctx context.Context) error { + panic("boom") + } + + r := runnable.New(work, runnable.WithAdapters(adapters.Recovering(handler))) + err := r.Run(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "boom") + assert.Equal(t, "boom", captured) +} + +func TestRecovering_NilHandlerStillRecovers(t *testing.T) { + work := func(ctx context.Context) error { + panic("boom") + } + + r := runnable.New(work, runnable.WithAdapters(adapters.Recovering(nil))) + err := r.Run(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "boom") +} + +func TestRecovering_PassesThroughOnSuccess(t *testing.T) { + called := false + handler := func(_ context.Context, rec any, _ []byte) { + called = true + } + + work := func(ctx context.Context) error { return nil } + + r := runnable.New(work, runnable.WithAdapters(adapters.Recovering(handler))) + require.NoError(t, r.Run(context.Background())) + assert.False(t, called, "handler must not fire when next returns normally") +} + +func TestRecovering_PassesThroughError(t *testing.T) { + work := func(ctx context.Context) error { return assert.AnError } + + r := runnable.New(work, runnable.WithAdapters(adapters.Recovering(nil))) + err := r.Run(context.Background()) + require.ErrorIs(t, err, assert.AnError) +} diff --git a/adapters/retry.go b/adapters/retry.go new file mode 100644 index 0000000..4623e5d --- /dev/null +++ b/adapters/retry.go @@ -0,0 +1,42 @@ +package adapters + +import ( + "context" + "errors" + "time" + + "github.com/0xsequence/runnable" +) + +// ResetNever (as resetAfter) disables retry-budget reset. +const ResetNever time.Duration = 0 + +// Retry returns an Adapter that re-invokes next up to maxRetries times +// on non-context errors. If resetAfter > 0 and at least that long has +// passed since the previous attempt, the budget resets. Retry does not +// observe Stopping — wrap it inside Draining if you need both. +func Retry(maxRetries int, resetAfter time.Duration) runnable.Adapter { + return func(next runnable.RunFunc) runnable.RunFunc { + return func(ctx context.Context) error { + // lastTime is per-call: the timer for reset budgets is local + // to this invocation, not shared across runnable cycles. + var lastTime time.Time + var err error + for i := 0; i < maxRetries; i++ { + if resetAfter != ResetNever && time.Since(lastTime) > resetAfter { + i = 0 + } + lastTime = time.Now() + + err = next(ctx) + if err == nil { + return nil + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return err + } + } + return err + } + } +} diff --git a/adapters/retry_test.go b/adapters/retry_test.go new file mode 100644 index 0000000..807db70 --- /dev/null +++ b/adapters/retry_test.go @@ -0,0 +1,70 @@ +package adapters_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/0xsequence/runnable" + "github.com/0xsequence/runnable/adapters" +) + +func TestRetry_SucceedsOnSecondAttempt(t *testing.T) { + var count atomic.Int32 + work := func(ctx context.Context) error { + if count.Add(1) < 2 { + return assert.AnError + } + return nil + } + + r := runnable.New(work, runnable.WithAdapters(adapters.Retry(3, adapters.ResetNever))) + require.NoError(t, r.Run(context.Background())) + assert.Equal(t, int32(2), count.Load()) +} + +func TestRetry_GivesUpAfterMaxAttempts(t *testing.T) { + var count atomic.Int32 + work := func(ctx context.Context) error { + count.Add(1) + return assert.AnError + } + + r := runnable.New(work, runnable.WithAdapters(adapters.Retry(3, adapters.ResetNever))) + err := r.Run(context.Background()) + require.ErrorIs(t, err, assert.AnError) + assert.Equal(t, int32(3), count.Load()) +} + +func TestRetry_ResetsBudgetAfterQuietPeriod(t *testing.T) { + var count atomic.Int32 + work := func(ctx context.Context) error { + c := count.Add(1) + if c < 5 { + time.Sleep(200 * time.Millisecond) + return assert.AnError + } + return nil + } + + r := runnable.New(work, runnable.WithAdapters(adapters.Retry(3, 100*time.Millisecond))) + require.NoError(t, r.Run(context.Background())) + assert.Equal(t, int32(5), count.Load()) +} + +func TestRetry_DoesNotRetryContextErrors(t *testing.T) { + var count atomic.Int32 + work := func(ctx context.Context) error { + count.Add(1) + return context.Canceled + } + + r := runnable.New(work, runnable.WithAdapters(adapters.Retry(3, adapters.ResetNever))) + err := r.Run(context.Background()) + require.ErrorIs(t, err, context.Canceled) + assert.Equal(t, int32(1), count.Load()) +} diff --git a/adapters/ticker.go b/adapters/ticker.go new file mode 100644 index 0000000..aff1ec9 --- /dev/null +++ b/adapters/ticker.go @@ -0,0 +1,43 @@ +package adapters + +import ( + "context" + "time" + + "github.com/0xsequence/runnable" +) + +// Ticker returns an Adapter that calls next once per interval until +// ctx is cancelled or next errors. Composes with Draining: an in-flight +// tick is allowed to finish before exit. +func Ticker(interval time.Duration) runnable.Adapter { + return func(next runnable.RunFunc) runnable.RunFunc { + return func(ctx context.Context) error { + t := time.NewTicker(interval) + defer t.Stop() + stopping := Stopping(ctx) + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-stopping: + return nil + case <-t.C: + // Re-check shutdown signals: queued ticks during a slow + // tick can race against stopping in select's random pick. + select { + case <-ctx.Done(): + return ctx.Err() + case <-stopping: + return nil + default: + } + if err := next(ctx); err != nil { + return err + } + } + } + } + } +} diff --git a/adapters/ticker_test.go b/adapters/ticker_test.go new file mode 100644 index 0000000..2de83e5 --- /dev/null +++ b/adapters/ticker_test.go @@ -0,0 +1,132 @@ +package adapters_test + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/0xsequence/runnable" + "github.com/0xsequence/runnable/adapters" +) + +func TestTicker_FiresOnInterval(t *testing.T) { + // Count tick signals on a channel rather than asserting wall-clock + // arithmetic; loaded CI runners would otherwise queue extra ticks + // and bust an upper bound. The behavioral claim is "Ticker fires + // repeatedly on interval" — wait for N ticks, stop, done. + ticks := make(chan struct{}, 8) + tick := func(ctx context.Context) error { + select { + case ticks <- struct{}{}: + default: + } + return nil + } + + r := runnable.New(tick, runnable.WithAdapters(adapters.Ticker(20*time.Millisecond))) + go func() { _ = r.Run(context.Background()) }() + + for i := 0; i < 3; i++ { + select { + case <-ticks: + case <-time.After(time.Second): + t.Fatalf("only %d ticks observed before timeout", i) + } + } + require.NoError(t, r.Stop(context.Background())) +} + +func TestTicker_ComposesWithDraining(t *testing.T) { + tickStarted := make(chan struct{}, 1) + var completed atomic.Int32 + + tick := func(ctx context.Context) error { + select { + case tickStarted <- struct{}{}: + default: + } + time.Sleep(200 * time.Millisecond) + completed.Add(1) + return nil + } + + r := runnable.New(tick, runnable.WithAdapters( + adapters.Draining(1*time.Second), + adapters.Ticker(20*time.Millisecond), + )) + go func() { _ = r.Run(context.Background()) }() + + <-tickStarted + + start := time.Now() + require.NoError(t, r.Stop(context.Background())) + elapsed := time.Since(start) + + assert.GreaterOrEqual(t, completed.Load(), int32(1), "in-flight tick should complete") + assert.Less(t, elapsed, 500*time.Millisecond) +} + +func TestTicker_WithoutDrainCancelsInFlightTick(t *testing.T) { + tickStarted := make(chan struct{}, 1) + tickErr := make(chan error, 1) + + tick := func(ctx context.Context) error { + select { + case tickStarted <- struct{}{}: + default: + } + <-ctx.Done() + tickErr <- ctx.Err() + return ctx.Err() + } + + r := runnable.New(tick, runnable.WithAdapters(adapters.Ticker(20*time.Millisecond))) + go func() { _ = r.Run(context.Background()) }() + + <-tickStarted + require.NoError(t, r.Stop(context.Background())) + + select { + case e := <-tickErr: + require.ErrorIs(t, e, context.Canceled) + case <-time.After(time.Second): + t.Fatal("tick did not observe ctx cancellation") + } +} + +func TestTicker_TickErrorAbortsLoop(t *testing.T) { + sentinel := errors.New("boom") + var count atomic.Int32 + + tick := func(ctx context.Context) error { + if count.Add(1) == 2 { + return sentinel + } + return nil + } + + r := runnable.New(tick, runnable.WithAdapters(adapters.Ticker(20*time.Millisecond))) + err := r.Run(context.Background()) + require.ErrorIs(t, err, sentinel) + assert.Equal(t, int32(2), count.Load()) +} + +func TestTicker_RespectsOuterCtxCancel(t *testing.T) { + var count atomic.Int32 + tick := func(ctx context.Context) error { + count.Add(1) + return nil + } + + r := runnable.New(tick, runnable.WithAdapters(adapters.Ticker(20*time.Millisecond))) + ctx, cancel := context.WithTimeout(context.Background(), 75*time.Millisecond) + defer cancel() + + err := r.Run(ctx) + require.ErrorIs(t, err, context.DeadlineExceeded) +} diff --git a/examples/main.go b/examples/main.go index e1e710a..eefdfb1 100644 --- a/examples/main.go +++ b/examples/main.go @@ -6,6 +6,7 @@ import ( "time" "github.com/0xsequence/runnable" + "github.com/0xsequence/runnable/adapters" ) type Monitor struct { @@ -33,7 +34,6 @@ func (m *Monitor) run(ctx context.Context) error { time.Sleep(1 * time.Second) fmt.Println("Monitoring...") } - return nil } func main() { @@ -92,7 +92,8 @@ func main() { // simple function with timeout fmt.Println("Simple function with timeout...") - ctxWithTimeout, _ := context.WithTimeout(context.Background(), 5*time.Second) + ctxWithTimeout, cancelTimeout := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelTimeout() err = runnable.New(func(ctx context.Context) error { fmt.Println("Starting...") defer fmt.Println("Stopping...") @@ -123,7 +124,6 @@ func main() { return fmt.Errorf("error") } - // do something for i := 0; i < 5; i++ { select { case <-ctx.Done(): @@ -134,7 +134,7 @@ func main() { fmt.Println("Running...") } return nil - }, runnable.WithRetry(3, runnable.ResetNever)).Run(context.Background()) + }, runnable.WithAdapters(adapters.Retry(3, adapters.ResetNever))).Run(context.Background()) if err != nil { fmt.Println(err) } diff --git a/examples/ticker-with-drain/main.go b/examples/ticker-with-drain/main.go new file mode 100644 index 0000000..49decaf --- /dev/null +++ b/examples/ticker-with-drain/main.go @@ -0,0 +1,72 @@ +// Example: a periodic reconciler that drains gracefully on SIGTERM. +// +// Shape: runnable.WithAdapters composing Draining + Recovering + Ticker, +// driven by signal.NotifyContext. Copy-paste into a service's +// cmd/.../main.go and replace the reconcile body with your work. +package main + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/0xsequence/runnable" + "github.com/0xsequence/runnable/adapters" +) + +func reconcile(ctx context.Context) error { + // Pretend this is an HTTP call to an external system that must not + // be aborted mid-request when SIGTERM fires. Under Draining, this + // tick is allowed to finish before the runnable tears down. + fmt.Println("tick: reconciling...") + time.Sleep(500 * time.Millisecond) + fmt.Println("tick: done") + return nil +} + +func panicHandler(_ context.Context, rec any, stack []byte) { + fmt.Fprintf(os.Stderr, "tick panic: %v\n%s", rec, stack) +} + +func main() { + sigCtx, stopSig := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + defer stopSig() + + // Adapters compose left-to-right (first listed = outermost). Draining + // catches outer ctx cancellation and turns it into drain rather than + // abort. Recovering sits inside Draining so the handler observes + // panics before Draining's safety-net recovery formats them. + rc := runnable.New(reconcile, runnable.WithAdapters( + adapters.Draining(10*time.Second), + adapters.Recovering(panicHandler), + adapters.Ticker(2*time.Second), + )) + + runErr := make(chan error, 1) + go func() { + runErr <- rc.Run(sigCtx) + }() + + select { + case <-sigCtx.Done(): + fmt.Println("shutdown: draining in-flight tick...") + stopCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + if err := rc.Stop(stopCtx); err != nil { + fmt.Fprintf(os.Stderr, "stop: %v\n", err) + } + if err := <-runErr; err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, adapters.ErrDrainTimedOut) { + fmt.Fprintf(os.Stderr, "reconciler stopped: %v\n", err) + os.Exit(1) + } + case err := <-runErr: + if err != nil && !errors.Is(err, context.Canceled) { + fmt.Fprintf(os.Stderr, "reconciler stopped: %v\n", err) + os.Exit(1) + } + } +} diff --git a/group_drain_test.go b/group_drain_test.go new file mode 100644 index 0000000..f27487d --- /dev/null +++ b/group_drain_test.go @@ -0,0 +1,55 @@ +package runnable_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/0xsequence/runnable" + "github.com/0xsequence/runnable/adapters" +) + +func TestNewGroup_DrainEnabledChild(t *testing.T) { + // Load-bearing test: a Draining-wrapped child of a group must + // drain when the group is stopped. In v0.1 this was silently + // broken — the child observed groupCtx.Done() and exited + // without ever seeing its drain signal. The adapter design + // fixes this by construction. + started := make(chan struct{}) + drainObserved := make(chan struct{}) + var ctxCancelObserved atomic.Bool + + drainingChild := runnable.New(func(ctx context.Context) error { + close(started) + select { + case <-adapters.Stopping(ctx): + close(drainObserved) + return nil + case <-ctx.Done(): + ctxCancelObserved.Store(true) + return ctx.Err() + } + }, runnable.WithAdapters(adapters.Draining(1*time.Second))) + + plainChild := runnable.New(func(ctx context.Context) error { + <-ctx.Done() + return ctx.Err() + }) + + group := runnable.NewGroup(drainingChild, plainChild) + go func() { _ = group.Run(context.Background()) }() + + <-started + require.NoError(t, group.Stop(context.Background())) + + select { + case <-drainObserved: + default: + t.Fatal("draining child never observed Stopping; group did not propagate drain") + } + assert.False(t, ctxCancelObserved.Load(), "draining child saw ctx.Done() instead of Stopping") +} diff --git a/runnable.go b/runnable.go index ca62fef..8c2fbce 100644 --- a/runnable.go +++ b/runnable.go @@ -132,11 +132,14 @@ func (r *runnable) Stop(ctx context.Context) error { r.mu.Unlock() return ErrNotRunning } - runStop := r.runStop + // Snapshot runCancel under the lock — Run overwrites this field + // on each cycle, so reading it without synchronization races with + // a concurrent or subsequent Run. + runCancel := r.runCancel r.mu.Unlock() - r.runCancel() + runCancel() select { case <-ctx.Done(): diff --git a/runnable_group_test.go b/runnable_group_test.go index 6efa922..fb70ed9 100644 --- a/runnable_group_test.go +++ b/runnable_group_test.go @@ -48,9 +48,7 @@ func TestNewGroup(t *testing.T) { // Create a new group group := NewGroup( New(func(ctx context.Context) error { - select { - case <-ctx.Done(): - } + <-ctx.Done() return nil }), New(func(ctx context.Context) error { diff --git a/runnable_test.go b/runnable_test.go index 045be86..b50359d 100644 --- a/runnable_test.go +++ b/runnable_test.go @@ -43,11 +43,8 @@ func TestRunnable(t *testing.T) { r := New(func(ctx context.Context) error { started <- struct{}{} time.Sleep(2 * time.Second) - - select { - case <-ctx.Done(): - return ctx.Err() - } + <-ctx.Done() + return ctx.Err() }) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) @@ -86,13 +83,13 @@ func TestRunnable(t *testing.T) { r := New(func(ctx context.Context) error { started <- struct{}{} + <-ctx.Done() time.Sleep(2 * time.Second) - return nil + return ctx.Err() }) go func() { - err := r.Run(context.Background()) - require.NoError(t, err) + _ = r.Run(context.Background()) }() <-started @@ -101,7 +98,7 @@ func TestRunnable(t *testing.T) { stopCtx, stopCancel := context.WithTimeout(context.Background(), 1*time.Second) defer stopCancel() err := r.Stop(stopCtx) - require.Error(t, err, context.DeadlineExceeded) + require.ErrorIs(t, err, context.DeadlineExceeded) assert.Equal(t, true, r.IsRunning()) }) } diff --git a/with_recoverer.go b/with_recoverer.go deleted file mode 100644 index 5de5eeb..0000000 --- a/with_recoverer.go +++ /dev/null @@ -1,63 +0,0 @@ -package runnable - -import ( - "context" - "fmt" - "runtime/debug" -) - -type RecoveryReporter interface { - Report(ctx context.Context, rec interface{}) -} - -type StackPrinter interface { - Print(ctx context.Context, callstack []byte) -} - -// NoopReporter -// Used to continue running go routine and do nothing -type NoopReporter struct{} - -func (*NoopReporter) Report(ctx context.Context, rec interface{}) {} - -type recoverer struct { - reporter RecoveryReporter - stackPrinter StackPrinter -} - -func WithRecoverer(reporter RecoveryReporter, stackPrinter StackPrinter) Option { - return &recoverer{ - reporter: reporter, - stackPrinter: stackPrinter, - } -} - -func (rec *recoverer) apply(r *runnable) { - originalRunFunc := r.runFunc - r.runFunc = func(ctx context.Context) error { - var err error - innerRun := func(ctx context.Context) error { - defer func() { - if recovery := recover(); recovery != nil { - err = fmt.Errorf("panic: %v", recovery) - - if rec.stackPrinter != nil { - rec.stackPrinter.Print(ctx, debug.Stack()) - } - - if rec.reporter != nil { - rec.reporter.Report(ctx, recovery) - } - } - }() - - return originalRunFunc(ctx) - } - - if errInner := innerRun(ctx); errInner != nil { - return errInner - } - - return err - } -} diff --git a/with_recoverer_test.go b/with_recoverer_test.go deleted file mode 100644 index e7cb5dd..0000000 --- a/with_recoverer_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package runnable - -import ( - "context" - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type InMemoryReporter struct { - logs []string -} - -func (i *InMemoryReporter) Report(ctx context.Context, rec interface{}) { - i.logs = append(i.logs, fmt.Sprintf("%s", rec.(string))) -} - -func TestWithRecoverer(t *testing.T) { - t.Run("with recoverer", func(t *testing.T) { - counter := 0 - reporter := InMemoryReporter{} - - fn := func(ctx context.Context) error { - defer func() { counter++ }() - panic("something went wrong") - return nil - } - r := New(fn, WithRecoverer(&reporter, nil)) - - err := r.Run(context.Background()) - require.Error(t, err) - assert.Equal(t, 1, counter) - assert.Equal(t, []string{"something went wrong"}, reporter.logs) - }) - - t.Run("panics as errors", func(t *testing.T) { - started, stopped := make(chan struct{}), make(chan struct{}) - reporter := &InMemoryReporter{} - - r := New(func(ctx context.Context) error { - started <- struct{}{} - panic("something went wrong") - return nil - }, WithRecoverer(reporter, nil)) - - go func() { - err := r.Run(context.Background()) - require.Error(t, err) - stopped <- struct{}{} - }() - - <-started - <-stopped - }) - - t.Run("panics as errors, no panic", func(t *testing.T) { - reporter := &InMemoryReporter{} - started, stopped := make(chan struct{}), make(chan struct{}) - - r := New(func(ctx context.Context) error { - started <- struct{}{} - return nil - }, WithRecoverer(reporter, nil)) - - go func() { - err := r.Run(context.Background()) - require.NoError(t, err) - stopped <- struct{}{} - }() - - <-started - <-stopped - }) - - t.Run("panics as errors, with stats", func(t *testing.T) { - reporter := &InMemoryReporter{} - started, stopped := make(chan struct{}), make(chan struct{}) - - store := NewStatusStore() - r := New(func(ctx context.Context) error { - started <- struct{}{} - panic("something went wrong") - return nil - }, WithRecoverer(reporter, nil), WithStatus("test", store)) - - go func() { - err := r.Run(context.Background()) - require.Error(t, err) - stopped <- struct{}{} - }() - - <-started - <-stopped - - s := store.Get() - require.Equal(t, false, s["test"].Running) - require.Error(t, s["test"].LastError) - }) -} diff --git a/with_retry.go b/with_retry.go deleted file mode 100644 index 97e0bbd..0000000 --- a/with_retry.go +++ /dev/null @@ -1,57 +0,0 @@ -package runnable - -import ( - "context" - "errors" - "time" -) - -const ResetNever time.Duration = 0 - -type withRetry struct { - maxRetries int - resetAfter time.Duration - - lastTime time.Time -} - -func WithRetry(maxRetries int, resetAfter time.Duration) Option { - return &withRetry{ - maxRetries: maxRetries, - resetAfter: resetAfter, - } -} - -func (w *withRetry) apply(r *runnable) { - runFunc := r.runFunc - r.runFunc = func(ctx context.Context) error { - var err error - for i := 0; i < w.maxRetries; i++ { - if w.resetAfter != ResetNever && time.Since(w.lastTime) > w.resetAfter { - i = 0 - } - w.lastTime = time.Now() - - if i > 0 { - if r.onStart != nil { - r.onStart() - } - } - - err = runFunc(ctx) - if err == nil { - return nil - } - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return err - } - - if i > 0 { - if r.onStop != nil { - r.onStop() - } - } - } - return err - } -} diff --git a/with_retry_test.go b/with_retry_test.go deleted file mode 100644 index 6166473..0000000 --- a/with_retry_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package runnable - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestWithRetry(t *testing.T) { - - t.Run("with retry", func(t *testing.T) { - counter := 0 - - r := New(func(ctx context.Context) error { - defer func() { counter++ }() - if counter < 1 { - return assert.AnError - } - - time.Sleep(500 * time.Millisecond) - return nil - }, WithRetry(3, ResetNever)) - - err := r.Run(context.Background()) - require.NoError(t, err) - assert.Equal(t, 2, counter) - }) - - t.Run("with retry, error", func(t *testing.T) { - counter := 0 - - r := New(func(ctx context.Context) error { - defer func() { counter++ }() - return assert.AnError - }, WithRetry(3, ResetNever)) - - err := r.Run(context.Background()) - require.Error(t, err) - assert.Equal(t, 3, counter) - }) - - t.Run("with retry, reset", func(t *testing.T) { - counter := 0 - - r := New(func(ctx context.Context) error { - defer func() { counter++ }() - if counter < 5 { - time.Sleep(200 * time.Millisecond) - return assert.AnError - } - return nil - }, WithRetry(3, 100*time.Millisecond)) - - err := r.Run(context.Background()) - require.NoError(t, err) - assert.Equal(t, 6, counter) - }) -} diff --git a/with_status.go b/with_status.go index 17e72b4..875e634 100644 --- a/with_status.go +++ b/with_status.go @@ -10,7 +10,6 @@ type StatusMap map[string]Status type Status struct { Running bool `json:"running"` - Restarts int `json:"restarts"` StartTime time.Time `json:"start_time"` EndTime *time.Time `json:"end_time,omitempty"` LastError error `json:"last_error"` @@ -18,7 +17,6 @@ type Status struct { type StatusStore struct { running map[string]bool - restarts map[string]int startTime map[string]time.Time endTime map[string]time.Time lastError map[string]error @@ -29,7 +27,6 @@ type StatusStore struct { func NewStatusStore() *StatusStore { return &StatusStore{ running: make(map[string]bool), - restarts: make(map[string]int), startTime: make(map[string]time.Time), endTime: make(map[string]time.Time), lastError: make(map[string]error), @@ -46,10 +43,6 @@ func (s *StatusStore) Get() StatusMap { Running: running, } - if restarts, ok := s.restarts[id]; ok { - st.Restarts = restarts - } - if startTime, ok := s.startTime[id]; ok { st.StartTime = startTime } @@ -99,15 +92,8 @@ func (w *withStatus) apply(r *runnable) { r.onStart = func() { w.store.mu.Lock() - w.store.running[w.runnableID] = true w.store.startTime[w.runnableID] = time.Now() - if _, ok := w.store.restarts[w.runnableID]; !ok { - w.store.restarts[w.runnableID] = 0 - } else { - w.store.restarts[w.runnableID]++ - } - w.store.mu.Unlock() if onStartRunnable != nil { diff --git a/with_status_test.go b/with_status_test.go index c19f45e..81e0244 100644 --- a/with_status_test.go +++ b/with_status_test.go @@ -52,24 +52,4 @@ func TestWithStatus(t *testing.T) { assert.Equal(t, false, s["test"].Running) assert.Equal(t, assert.AnError, s["test"].LastError) }) - - t.Run("with status, restart", func(t *testing.T) { - store := NewStatusStore() - - counter := 0 - r := New(func(ctx context.Context) error { - defer func() { counter++ }() - if counter < 1 { - return assert.AnError - } - return nil - }, WithStatus("test", store), WithRetry(3, ResetNever)) - - err := r.Run(context.Background()) - require.NoError(t, err) - - s := store.Get() - assert.Equal(t, false, s["test"].Running) - assert.Equal(t, 1, s["test"].Restarts) - }) }