diff --git a/client.go b/client.go index 3f29e545..6417d2dc 100644 --- a/client.go +++ b/client.go @@ -221,6 +221,18 @@ type Config struct { // Jobs may have their own specific hooks by implementing JobArgsWithHooks. Hooks []rivertype.Hook + // HardStopTimeout is the maximum amount of time that the client will wait + // after job contexts are cancelled during shutdown before forcing jobs still + // running to an errored state. This hard stop phase lets jobs be retried + // immediately on the next client start instead of waiting for rescue. + // + // The timer starts only after a soft stop has begun by cancelling job + // contexts, like after SoftStopTimeout elapses, StopAndCancel is called, or + // the Start context is cancelled without SoftStopTimeout configured. + // + // Defaults to no timeout (hard stop disabled). + HardStopTimeout time.Duration + // Logger is the structured logger to use for logging purposes. If none is // specified, logs will be emitted to STDOUT with messages at warn level // or higher. @@ -330,11 +342,9 @@ type Config struct { Schema string // SoftStopTimeout is the maximum amount of time that the client will wait - // for running jobs to finish during a stop before their contexts are - // cancelled. After the timeout elapses, the client escalates to a hard stop - // by cancelling the context of all running jobs. This applies regardless of - // how stop is initiated — whether by calling Stop, StopAndCancel, or by - // cancelling the context passed to Start. + // for running jobs to finish during a graceful stop before entering soft + // stop by cancelling job contexts. This applies when stop is initiated by + // calling Stop or by cancelling the context passed to Start. // // In combination with signal.NotifyContext on the context passed to Start, // this can simplify graceful stop to: @@ -345,12 +355,12 @@ type Config struct { // if err := client.Start(ctx); err != nil { ... } // <-client.Stopped() // - // The signal cancels the Start context, which initiates a soft stop. If + // The signal cancels the Start context, which initiates a graceful stop. If // running jobs haven't finished after SoftStopTimeout, their contexts are - // automatically cancelled to trigger a hard stop. + // cancelled. // - // StopAndCancel bypasses the timeout entirely and cancels job contexts - // immediately. + // StopAndCancel cancels job contexts immediately instead of waiting for + // SoftStopTimeout. // // Defaults to no timeout (wait indefinitely for jobs to finish). SoftStopTimeout time.Duration @@ -468,6 +478,7 @@ func (c *Config) WithDefaults() *Config { ErrorHandler: c.ErrorHandler, FetchCooldown: cmp.Or(c.FetchCooldown, FetchCooldownDefault), FetchPollInterval: cmp.Or(c.FetchPollInterval, FetchPollIntervalDefault), + HardStopTimeout: c.HardStopTimeout, ID: valutil.ValOrDefaultFunc(c.ID, func() string { return defaultClientID(time.Now().UTC()) }), Hooks: c.Hooks, JobInsertMiddleware: c.JobInsertMiddleware, @@ -515,6 +526,9 @@ func (c *Config) validate() error { if c.FetchPollInterval < c.FetchCooldown { return fmt.Errorf("FetchPollInterval cannot be shorter than FetchCooldown (%s)", c.FetchCooldown) } + if c.HardStopTimeout < 0 { + return errors.New("HardStopTimeout cannot be less than zero") + } if len(c.ID) > 100 { return errors.New("ID cannot be longer than 100 characters") } @@ -547,6 +561,9 @@ func (c *Config) validate() error { if c.Schema != "" && !postgresSchemaNameRE.MatchString(c.Schema) { return errors.New("Schema name can only contain letters, numbers, and underscores, and must start with a letter or underscore") } + if c.SoftStopTimeout < 0 { + return errors.New("SoftStopTimeout cannot be less than zero") + } for queue, queueConfig := range c.Queues { if err := queueConfig.validate(queue, c.FetchCooldown, c.FetchPollInterval); err != nil { @@ -1048,10 +1065,12 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client // A graceful shutdown stops fetching new jobs but allows any previously fetched // jobs to complete. This can be initiated with the Stop method. // -// A more abrupt shutdown can be achieved by either cancelling the provided -// context or by calling StopAndCancel. This will not only stop fetching new -// jobs, but will also cancel the context for any currently-running jobs. If -// using StopAndCancel, there's no need to also call Stop. +// A soft stop cancels job contexts after fetching has stopped. It can be +// initiated by calling StopAndCancel, by cancelling the provided context when +// SoftStopTimeout is not configured, or by waiting for SoftStopTimeout to elapse +// during graceful stop. If HardStopTimeout is configured, jobs still running +// after that timeout will be forced into an errored state. If using +// StopAndCancel, there's no need to also call Stop. func (c *Client[TTx]) Start(ctx context.Context) error { fetchCtx, shouldStart, started, stopped := c.baseStartStop.StartInit(ctx) if !shouldStart { @@ -1065,9 +1084,13 @@ func (c *Client[TTx]) Start(ctx context.Context) error { // sure to take a channel reference before finishing stopped. c.stopped = c.baseStartStop.StoppedUnsafe() - producersAsServices := func() []startstop.Service { + producers := func() []*producer { + return maputil.Values(c.producersByQueueName) + } + + producersAsServices := func(producers []*producer) []startstop.Service { return sliceutil.Map( - maputil.Values(c.producersByQueueName), + producers, func(p *producer) startstop.Service { return p }, ) } @@ -1121,8 +1144,8 @@ func (c *Client[TTx]) Start(ctx context.Context) error { // We use separate contexts for fetching and working to allow for a // graceful stop. When SoftStopTimeout is configured, the work context // is detached from the start context so that cancelling the start - // context initiates a soft stop (with timeout escalation) rather than - // an immediate hard stop. When SoftStopTimeout is not configured, the + // context initiates a graceful stop (with timeout escalation) rather + // than an immediate soft stop. When SoftStopTimeout is not configured, the // work context inherits from the start context to preserve the // existing behavior where cancelling the start context is equivalent // to StopAndCancel. @@ -1145,7 +1168,7 @@ func (c *Client[TTx]) Start(ctx context.Context) error { for _, producer := range c.producersByQueueName { if err := producer.StartWorkContext(fetchCtx, workCtx); err != nil { workCancel(err) - startstop.StopAllParallel(producersAsServices()...) + startstop.StopAllParallel(producersAsServices(producers())...) stopServicesOnError() return err } @@ -1167,7 +1190,7 @@ func (c *Client[TTx]) Start(ctx context.Context) error { // Generate producer services while c.queues.startStopMu.Lock() is still // held. This is used for WaitAllStarted below, but don't use it elsewhere // because new producers may have been added while the client is running. - producerServices := producersAsServices() + producerServices := producersAsServices(producers()) go func() { // Wait for all subservices to start up before signaling our own start. @@ -1194,22 +1217,57 @@ func (c *Client[TTx]) Start(ctx context.Context) error { c.queues.startStopMu.Lock() defer c.queues.startStopMu.Unlock() + producerList := producers() + + hardStopTimerCtx, hardStopTimerCancel := context.WithCancel(context.WithoutCancel(ctx)) + defer hardStopTimerCancel() + + startHardStopTimer := sync.OnceFunc(func() { + if c.config.HardStopTimeout <= 0 { + return + } + + go func() { + timer := time.NewTimer(c.config.HardStopTimeout) + defer timer.Stop() + + select { + case <-timer.C: + c.baseService.Logger.WarnContext(ctx, c.baseService.Name+": Hard stop timeout; setting remaining jobs to errored", slog.Duration("hard_stop_timeout", c.config.HardStopTimeout)) + for _, producer := range producerList { + producer.hardStop() + } + case <-hardStopTimerCtx.Done(): + } + }() + }) + + workCtx := c.queues.workCtx + go func() { + select { + case <-workCtx.Done(): + startHardStopTimer() + case <-hardStopTimerCtx.Done(): + } + }() + // If SoftStopTimeout is configured, start a timer that will cancel - // the work context (escalating to a hard stop) if producers don't - // finish in time. StopAndCancel also calls workCancel, in which case - // this timer is a harmless no-op because the context is already done. + // the work context if producers don't finish in time. Once the work + // context is cancelled, the optional hard stop timer starts. if c.config.SoftStopTimeout > 0 { softStopTimer := time.AfterFunc(c.config.SoftStopTimeout, func() { c.baseService.Logger.WarnContext(ctx, c.baseService.Name+": Soft stop timeout; cancelling remaining job contexts", slog.Duration("soft_stop_timeout", c.config.SoftStopTimeout)) c.workCancel(rivercommon.ErrStop) + startHardStopTimer() }) defer softStopTimer.Stop() } // On stop, have the producers stop fetching first of all. c.baseService.Logger.DebugContext(ctx, c.baseService.Name+": Stopping producers") - startstop.StopAllParallel(producersAsServices()...) + startstop.StopAllParallel(producersAsServices(producerList)...) c.baseService.Logger.DebugContext(ctx, c.baseService.Name+": All producers stopped") + hardStopTimerCancel() c.workCancel(rivercommon.ErrStop) @@ -1238,12 +1296,14 @@ func (c *Client[TTx]) Start(ctx context.Context) error { // complete before exiting. If the provided context is done before shutdown has // completed, Stop will return immediately with the context's error. // -// If SoftStopTimeout is configured, running job contexts will be automatically -// cancelled after the timeout elapses, escalating to a hard stop. This also +// If SoftStopTimeout is configured, jobs still running after the timeout +// elapses have their contexts cancelled. If HardStopTimeout is also configured, +// jobs still running after that second timeout are forced into an errored state +// so they can be retried immediately on the next client start. This also // applies when stop is initiated by cancelling the context passed to Start. // -// There's no need to call this method if a hard stop has already been initiated -// by cancelling the context passed to Start or by calling StopAndCancel. +// There's no need to call this method if shutdown has already been initiated by +// cancelling the context passed to Start or by calling StopAndCancel. func (c *Client[TTx]) Stop(ctx context.Context) error { shouldStop, stopped, finalizeStop := c.baseStartStop.StopInit() if !shouldStop { @@ -1262,10 +1322,11 @@ func (c *Client[TTx]) Stop(ctx context.Context) error { // StopAndCancel shuts down the client and cancels all work in progress. It is a // more aggressive stop than Stop because the contexts for any in-progress jobs -// are cancelled. However, it still waits for jobs to complete before returning, -// even though their contexts are cancelled. If the provided context is done -// before shutdown has completed, StopAndCancel will return immediately with the -// context's error. +// are cancelled immediately. If HardStopTimeout is configured, jobs that still +// remain running after the timeout are hard-stopped; otherwise, StopAndCancel +// waits for jobs to complete even though their contexts are cancelled. If the +// provided context is done before shutdown has completed, StopAndCancel will +// return immediately with the context's error. // // This can also be initiated by cancelling the context passed to Start. There is // no need to call this method if the context passed to Start is cancelled @@ -1277,7 +1338,7 @@ func (c *Client[TTx]) Stop(ctx context.Context) error { // graceful stop semantics without requiring manual orchestration of Stop and // StopAndCancel. func (c *Client[TTx]) StopAndCancel(ctx context.Context) error { - c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Hard stop started; cancelling all work") + c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Soft stop started; cancelling all work") c.workCancel(rivercommon.ErrStop) shouldStop, stopped, finalizeStop := c.baseStartStop.StopInit() diff --git a/client_test.go b/client_test.go index 9352dcbb..a8a3bfa8 100644 --- a/client_test.go +++ b/client_test.go @@ -2538,6 +2538,98 @@ func Test_Client_StopAndCancel(t *testing.T) { }) } +func Test_Client_HardStopTimeout(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type JobArgs struct { + testutil.JobArgsReflectKind[JobArgs] + } + + setup := func(t *testing.T, configFunc func(config *Config)) (*Client[pgx.Tx], *rivertype.JobRow, chan struct{}, chan struct{}, func()) { + t.Helper() + + config := newTestConfig(t, "") + configFunc(config) + + jobContextDoneChan := make(chan struct{}) + jobReleasedChan := make(chan struct{}) + jobStartedChan := make(chan struct{}) + releaseJobChan := make(chan struct{}) + releaseJob := sync.OnceFunc(func() { close(releaseJobChan) }) + + AddWorker(config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error { + close(jobStartedChan) + <-ctx.Done() + close(jobContextDoneChan) + <-releaseJobChan + close(jobReleasedChan) + return nil + })) + + client := runNewTestClient(ctx, t, config) + t.Cleanup(releaseJob) + + insertRes, err := client.Insert(ctx, JobArgs{}, nil) + require.NoError(t, err) + + riversharedtest.WaitOrTimeout(t, jobStartedChan) + + return client, insertRes.Job, jobContextDoneChan, jobReleasedChan, releaseJob + } + + requireHardStoppedAvailable := func(t *testing.T, client *Client[pgx.Tx], jobID int64) { + t.Helper() + + jobAfter, err := client.JobGet(ctx, jobID) + require.NoError(t, err) + require.Equal(t, rivertype.JobStateAvailable, jobAfter.State) + require.Len(t, jobAfter.Errors, 1) + require.Equal(t, producerHardStopError, jobAfter.Errors[0].Error) + require.Equal(t, 1, jobAfter.Errors[0].Attempt) + require.Empty(t, jobAfter.Errors[0].Trace) + require.Nil(t, jobAfter.FinalizedAt) + } + + t.Run("AfterSoftStopTimeout", func(t *testing.T) { + t.Parallel() + + client, job, jobContextDoneChan, jobReleasedChan, releaseJob := setup(t, func(config *Config) { + config.HardStopTimeout = 100 * time.Millisecond + config.SoftStopTimeout = 100 * time.Millisecond + }) + + stopCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + require.NoError(t, client.Stop(stopCtx)) + + riversharedtest.WaitOrTimeout(t, jobContextDoneChan) + requireHardStoppedAvailable(t, client, job.ID) + + releaseJob() + riversharedtest.WaitOrTimeout(t, jobReleasedChan) + }) + + t.Run("AfterStopAndCancel", func(t *testing.T) { + t.Parallel() + + client, job, jobContextDoneChan, jobReleasedChan, releaseJob := setup(t, func(config *Config) { + config.HardStopTimeout = 100 * time.Millisecond + }) + + stopCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + require.NoError(t, client.StopAndCancel(stopCtx)) + + riversharedtest.WaitOrTimeout(t, jobContextDoneChan) + requireHardStoppedAvailable(t, client, job.ID) + + releaseJob() + riversharedtest.WaitOrTimeout(t, jobReleasedChan) + }) +} + func Test_Client_SoftStopTimeout(t *testing.T) { t.Parallel() @@ -2547,7 +2639,7 @@ func Test_Client_SoftStopTimeout(t *testing.T) { testutil.JobArgsReflectKind[JobArgs] } - t.Run("EscalatesToHardStopAfterTimeout", func(t *testing.T) { + t.Run("CancelsJobsAfterTimeout", func(t *testing.T) { t.Parallel() config := newTestConfig(t, "") @@ -2569,8 +2661,8 @@ func Test_Client_SoftStopTimeout(t *testing.T) { riversharedtest.WaitOrTimeout(t, jobStartedChan) - // Stop initiates a soft stop. The job won't finish on its own, but - // SoftStopTimeout should escalate to a hard stop after 100ms. + // Stop initiates a graceful stop. The job won't finish on its own, but + // SoftStopTimeout should cancel its context after 100ms. require.NoError(t, client.Stop(ctx)) // Verify the job's context was indeed cancelled. @@ -2605,7 +2697,7 @@ func Test_Client_SoftStopTimeout(t *testing.T) { require.NoError(t, client.Stop(ctx)) }) - t.Run("ContextCancellationEscalatesAfterTimeout", func(t *testing.T) { + t.Run("StartContextCancellationCancelsJobsAfterTimeout", func(t *testing.T) { t.Parallel() config := newTestConfig(t, "") @@ -2640,8 +2732,8 @@ func Test_Client_SoftStopTimeout(t *testing.T) { riversharedtest.WaitOrTimeout(t, jobStartedChan) - // Cancel the start context. This should initiate a soft stop, then - // escalate to hard stop after SoftStopTimeout. + // Cancel the start context. This should initiate a graceful stop, then + // cancel job contexts after SoftStopTimeout. startCtxCancel() riversharedtest.WaitOrTimeout(t, client.Stopped()) @@ -8291,6 +8383,22 @@ func Test_NewClient_Validations(t *testing.T) { }, wantErr: fmt.Errorf("FetchPollInterval cannot be shorter than FetchCooldown (%s)", 20*time.Millisecond), }, + { + name: "HardStopTimeout cannot be negative", + configFunc: func(config *Config) { + config.HardStopTimeout = -1 + }, + wantErr: errors.New("HardStopTimeout cannot be less than zero"), + }, + { + name: "HardStopTimeout may be overridden", + configFunc: func(config *Config) { + config.HardStopTimeout = 23 * time.Second + }, + validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper + require.Equal(t, 23*time.Second, client.config.HardStopTimeout) + }, + }, { name: "FetchPollInterval cannot be less than MinFetchPollInterval", configFunc: func(config *Config) { config.FetchPollInterval = time.Millisecond - 1 }, @@ -8486,6 +8594,22 @@ func Test_NewClient_Validations(t *testing.T) { }, wantErr: errors.New("Schema name can only contain letters, numbers, and underscores, and must start with a letter or underscore"), }, + { + name: "SoftStopTimeout cannot be negative", + configFunc: func(config *Config) { + config.SoftStopTimeout = -1 + }, + wantErr: errors.New("SoftStopTimeout cannot be less than zero"), + }, + { + name: "SoftStopTimeout may be overridden", + configFunc: func(config *Config) { + config.SoftStopTimeout = 23 * time.Second + }, + validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper + require.Equal(t, 23*time.Second, client.config.SoftStopTimeout) + }, + }, { name: "Queues can be nil when Workers is also nil", configFunc: func(config *Config) { diff --git a/example_graceful_shutdown_stop_and_cancel_test.go b/example_graceful_shutdown_stop_and_cancel_test.go index 30217c2d..86aea931 100644 --- a/example_graceful_shutdown_stop_and_cancel_test.go +++ b/example_graceful_shutdown_stop_and_cancel_test.go @@ -20,8 +20,8 @@ import ( // Example_gracefulShutdownStopCancel demonstrates graceful stop with explicit // fallback to StopAndCancel. When a SIGINT/SIGTERM arrives, Stop initiates a -// soft stop. If running jobs don't finish before the soft stop context expires, -// StopAndCancel cancels their contexts (hard stop). This example is intended to +// graceful stop. If running jobs don't finish before the graceful stop context +// expires, StopAndCancel cancels their contexts. This example is intended to // demonstrate advanced use of StopAndCancel. Generally, prefer the method shown // in Example_gracefulShutdown over the one here. func Example_gracefulShutdownStopAndCancel() { @@ -59,8 +59,8 @@ func Example_gracefulShutdownStopAndCancel() { } // Use signal.NotifyContext to detect SIGINT/SIGTERM, but don't pass the - // signal context to Start. Cancelling the Start context cancels running job - // contexts immediately, which is equivalent to StopAndCancel. + // signal context to Start. Cancelling the Start context would cancel running + // job contexts immediately, which is equivalent to StopAndCancel. signalCtx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) defer stop() diff --git a/example_graceful_shutdown_test.go b/example_graceful_shutdown_test.go index bf99b8ac..72cb4d60 100644 --- a/example_graceful_shutdown_test.go +++ b/example_graceful_shutdown_test.go @@ -43,8 +43,8 @@ func (w *WaitsForCancelOnlyWorker) Work(ctx context.Context, job *river.Job[Wait // Example_gracefulShutdown demonstrates graceful stop using SoftStopTimeout. // When a SIGINT/SIGTERM arrives, the start context is cancelled, which -// initiates a soft stop. If running jobs don't finish within the configured -// SoftStopTimeout, their contexts are automatically cancelled (hard stop). +// initiates a graceful stop. If running jobs don't finish within the configured +// SoftStopTimeout, their contexts are automatically cancelled. func Example_gracefulShutdown() { ctx := context.Background() @@ -77,7 +77,7 @@ func Example_gracefulShutdown() { } // Use signal.NotifyContext to cancel the start context on SIGINT/SIGTERM. - // When the signal fires, the client initiates a soft stop. If running jobs + // When the signal fires, the client initiates a graceful stop. If running jobs // don't finish within SoftStopTimeout, their contexts are cancelled. signalCtx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) defer stop() diff --git a/internal/jobexecutor/job_executor.go b/internal/jobexecutor/job_executor.go index dab44f39..8bf9e852 100644 --- a/internal/jobexecutor/job_executor.go +++ b/internal/jobexecutor/job_executor.go @@ -9,6 +9,7 @@ import ( "log/slog" "runtime" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -110,6 +111,7 @@ type JobExecutor struct { ClientRetryPolicy ClientRetryPolicy DefaultClientRetryPolicy ClientRetryPolicy ErrorHandler ErrorHandler + hardStopped atomic.Bool HookLookupByJob *hooklookup.JobHookLookup HookLookupGlobal hooklookup.HookLookupInterface JobRow *rivertype.JobRow @@ -148,6 +150,12 @@ func (e *JobExecutor) Execute(ctx context.Context) { res.Err = context.Cause(ctx) } + // Hard-stopped jobs have already been moved out of running by the producer. + if e.hardStopped.Load() { + e.ProducerCallbacks.JobDone(e.JobRow) + return + } + var multiJobErrors withJobsAndErrorsByID if res.Err != nil { multiJobErrors, _ = res.Err.(withJobsAndErrorsByID) @@ -167,6 +175,12 @@ func (e *JobExecutor) Execute(ctx context.Context) { e.ProducerCallbacks.JobDone(e.JobRow) } +// HardStop suppresses completion reporting for a job that's been forcibly +// errored by its producer during shutdown. +func (e *JobExecutor) HardStop() { + e.hardStopped.Store(true) +} + // Executes the job, handling a panic if necessary (and various other error // conditions). The named return value is so that we can still return a value in // case of a panic. diff --git a/producer.go b/producer.go index dbd40653..b365606f 100644 --- a/producer.go +++ b/producer.go @@ -34,6 +34,7 @@ import ( ) const ( + producerHardStopError = "job stopped because River client hard stopped" producerReportIntervalDefault = time.Minute queuePollIntervalDefault = 2 * time.Second queueReportIntervalDefault = 10 * time.Minute @@ -187,8 +188,10 @@ type producer struct { exec riverdriver.Executor errorHandler jobexecutor.ErrorHandler fetchLimiter *chanutil.DebouncedChan - state riverpilot.ProducerState + hardStopCh chan struct{} // signals that a "hard stop" has been initiated (set all running jobs to errored regardless of whether they stopped cleanly) + hardStopOnce *sync.Once // closes hardStopChan exactly once pilot riverpilot.Pilot + state riverpilot.ProducerState workers *Workers // Receives job IDs to cancel. Written by notifier goroutine, only read from @@ -202,7 +205,7 @@ type producer struct { // Receives completed jobs from workers. Written by completed workers, only // read from main goroutine. - jobResultCh chan *rivertype.JobRow + jobResultCh chan *producerJobResult jobTimeout time.Duration @@ -240,7 +243,9 @@ func newProducer(archetype *baseservice.Archetype, exec riverdriver.Executor, pi config: config.mustValidate(), exec: exec, errorHandler: errorHandler, - jobResultCh: make(chan *rivertype.JobRow, config.MaxWorkers), + hardStopCh: make(chan struct{}), + hardStopOnce: &sync.Once{}, + jobResultCh: make(chan *producerJobResult, config.MaxWorkers), jobTimeout: config.JobTimeout, pilot: pilot, queueControlCh: make(chan *controlEventPayload, 100), @@ -279,6 +284,9 @@ func (p *producer) StartWorkContext(fetchCtx, workCtx context.Context) error { return nil } + p.hardStopCh = make(chan struct{}) + p.hardStopOnce = &sync.Once{} + isExpectedShutdownError := func(err error) bool { return errors.Is(err, startstop.ErrStop) || strings.HasSuffix(err.Error(), "conn closed") || fetchCtx.Err() != nil } @@ -415,7 +423,7 @@ func (p *producer) StartWorkContext(fetchCtx, workCtx context.Context) error { p.fetchAndRunLoop(fetchCtx, workCtx) p.Logger.DebugContext(workCtx, p.Name+": Entering shutdown loop", slog.String("queue", p.config.Queue), slog.Int64("id", p.id.Load())) - p.executorShutdownLoop() + p.executorShutdownLoop(context.WithoutCancel(fetchCtx)) p.Logger.DebugContext(workCtx, p.Name+": Shutdown loop exited, awaiting subroutines", slog.String("queue", p.config.Queue), slog.Int64("id", p.id.Load())) cancelSubroutines(fmt.Errorf("producer stopped: %w", startstop.ErrStop)) @@ -470,6 +478,11 @@ type insertPayload struct { Queue string `json:"queue"` } +type producerJobResult struct { + executor *jobexecutor.JobExecutor + job *rivertype.JobRow +} + func (p *producer) handleControlNotification(workCtx context.Context) func(notifier.NotificationTopic, string) { return func(topic notifier.NotificationTopic, payload string) { var decoded controlEventPayload @@ -663,12 +676,29 @@ func (p *producer) innerFetchLoop(workCtx context.Context, fetchResultCh chan pr } } -func (p *producer) executorShutdownLoop() { +func (p *producer) executorShutdownLoop(ctx context.Context) { // No more jobs will be fetched or executed. However, we must wait for all // in-progress jobs to complete. for len(p.activeJobs) != 0 { - result := <-p.jobResultCh - p.removeActiveJob(result) + select { + case result := <-p.jobResultCh: + p.removeActiveJob(result) + case <-p.hardStopCh: + p.drainJobResults() + p.hardStopActiveJobs(ctx) + return + } + } +} + +func (p *producer) drainJobResults() { + for { + select { + case result := <-p.jobResultCh: + p.removeActiveJob(result) + default: + return + } } } @@ -728,11 +758,90 @@ func (p *producer) addActiveJob(id int64, executor *jobexecutor.JobExecutor) { p.activeJobs[id] = executor } -func (p *producer) removeActiveJob(job *rivertype.JobRow) { - delete(p.activeJobs, job.ID) +func (p *producer) hardStop() { + p.hardStopOnce.Do(func() { close(p.hardStopCh) }) +} + +func (p *producer) hardStopActiveJobs(ctx context.Context) { + if len(p.activeJobs) == 0 { + return + } + + now := p.Time.Now() + params := &riverdriver.JobSetStateIfRunningManyParams{ + Attempt: make([]*int, 0, len(p.activeJobs)), + ErrData: make([][]byte, 0, len(p.activeJobs)), + FinalizedAt: make([]*time.Time, 0, len(p.activeJobs)), + ID: make([]int64, 0, len(p.activeJobs)), + MetadataDoMerge: make([]bool, 0, len(p.activeJobs)), + MetadataUpdates: make([][]byte, 0, len(p.activeJobs)), + Now: &now, + ScheduledAt: make([]*time.Time, 0, len(p.activeJobs)), + Schema: p.config.Schema, + State: make([]rivertype.JobState, 0, len(p.activeJobs)), + } + + for _, executor := range p.activeJobs { + p.hardStop() + + job := executor.JobRow + errData, err := json.Marshal(rivertype.AttemptError{ + At: now, + Attempt: job.Attempt, + Error: producerHardStopError, + }) + if err != nil { + panic(fmt.Errorf("error serializing hard stop error: %w", err)) + } + + var setStateParams *riverdriver.JobSetStateIfRunningParams + if job.Attempt >= job.MaxAttempts { + setStateParams = riverdriver.JobSetStateDiscarded(job.ID, now, errData, nil) + } else { + setStateParams = riverdriver.JobSetStateErrorAvailable(job.ID, now, errData, nil) + } + + params.Attempt = append(params.Attempt, setStateParams.Attempt) + params.ErrData = append(params.ErrData, setStateParams.ErrData) + params.FinalizedAt = append(params.FinalizedAt, setStateParams.FinalizedAt) + params.ID = append(params.ID, setStateParams.ID) + params.MetadataDoMerge = append(params.MetadataDoMerge, setStateParams.MetadataDoMerge) + params.MetadataUpdates = append(params.MetadataUpdates, setStateParams.MetadataUpdates) + params.ScheduledAt = append(params.ScheduledAt, setStateParams.ScheduledAt) + params.State = append(params.State, setStateParams.State) + } + + timeoutCtx, cancel := context.WithTimeout(ctx, rivercommon.HotOperationTimeout) + defer cancel() + + if _, err := p.pilot.JobSetStateIfRunningMany(timeoutCtx, p.exec, params); err != nil { + p.Logger.ErrorContext(ctx, p.Name+": Error setting hard-stopped jobs to errored", slog.String("err", err.Error()), slog.Int("num_jobs", len(params.ID)), slog.String("queue", p.config.Queue)) + } else { + p.Logger.WarnContext(ctx, p.Name+": Hard-stopped running jobs", slog.Int("num_jobs", len(params.ID)), slog.String("queue", p.config.Queue)) + } + + numActiveJobs := len(p.activeJobs) + for _, executor := range p.activeJobs { + if p.state != nil { + p.state.JobFinish(executor.JobRow) + } + } + p.activeJobs = make(map[int64]*jobexecutor.JobExecutor) + p.numJobsActive.Add(-int32(numActiveJobs)) //nolint:gosec +} + +func (p *producer) removeActiveJob(result *producerJobResult) { + // Ignore stale results from executors hard-stopped out of active tracking. + if activeExecutor := p.activeJobs[result.job.ID]; activeExecutor != result.executor { + return + } + + delete(p.activeJobs, result.job.ID) p.numJobsActive.Add(-1) p.numJobsRan.Add(1) - p.state.JobFinish(job) + if p.state != nil { + p.state.JobFinish(result.job) + } } func (p *producer) maybeCancelJob(ctx context.Context, id int64) { @@ -822,7 +931,8 @@ func (p *producer) startNewExecutors(workCtx context.Context, jobs []*rivertype. // jobCancel will always be called by the executor to prevent leaks. jobCtx, jobCancel := context.WithCancelCause(workCtx) - executor := baseservice.Init(&p.Archetype, &jobexecutor.JobExecutor{ + var executor *jobexecutor.JobExecutor + executor = baseservice.Init(&p.Archetype, &jobexecutor.JobExecutor{ CancelFunc: jobCancel, ClientJobTimeout: p.jobTimeout, ClientRetryPolicy: p.retryPolicy, @@ -838,7 +948,7 @@ func (p *producer) startNewExecutors(workCtx context.Context, jobs []*rivertype. Stuck func() Unstuck func() }{ - JobDone: p.handleWorkerDone, + JobDone: func(jobRow *rivertype.JobRow) { p.handleWorkerDone(executor, jobRow) }, Stuck: func() { p.numJobsStuck.Add(1) }, Unstuck: func() { p.numJobsStuck.Add(-1) }, }, @@ -859,8 +969,11 @@ func (p *producer) maxJobsToFetch() int { return p.config.MaxWorkers - int(p.numJobsActive.Load()) } -func (p *producer) handleWorkerDone(job *rivertype.JobRow) { - p.jobResultCh <- job +func (p *producer) handleWorkerDone(executor *jobexecutor.JobExecutor, job *rivertype.JobRow) { + p.jobResultCh <- &producerJobResult{ + executor: executor, + job: job, + } } func (p *producer) pollForSettingChanges(ctx context.Context, wg *sync.WaitGroup, lastPaused bool, lastMetadata []byte) { diff --git a/producer_test.go b/producer_test.go index c03bd766..8b1464c6 100644 --- a/producer_test.go +++ b/producer_test.go @@ -12,6 +12,7 @@ import ( "github.com/riverqueue/river/internal/hooklookup" "github.com/riverqueue/river/internal/jobcompleter" + "github.com/riverqueue/river/internal/jobexecutor" "github.com/riverqueue/river/internal/maintenance" "github.com/riverqueue/river/internal/middlewarelookup" "github.com/riverqueue/river/internal/notifier" @@ -161,6 +162,86 @@ func Test_Producer_CanSafelyCompleteJobsWhileFetchingNewOnes(t *testing.T) { } } +func TestProducer_HardStopActiveJobs(t *testing.T) { + t.Parallel() + + ctx := context.Background() + require := require.New(t) + + var ( + archetype = riversharedtest.BaseServiceArchetype(t) + dbPool = riversharedtest.DBPool(ctx, t) + driver = riverpgxv5.New(dbPool) + exec = driver.GetExecutor() + pilot = &riverpilot.StandardPilot{} + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + ) + + completer := jobcompleter.NewInlineCompleter(archetype, schema, exec, pilot, make(chan []jobcompleter.CompleterJobUpdated, 10)) + + producer := newProducer(archetype, exec, pilot, &producerConfig{ + ClientID: testClientID, + Completer: completer, + ErrorHandler: newTestErrorHandler(), + FetchCooldown: FetchCooldownDefault, + FetchPollInterval: FetchPollIntervalDefault, + HookLookupByJob: hooklookup.NewJobHookLookup(), + HookLookupGlobal: hooklookup.NewHookLookup(nil), + JobTimeout: JobTimeoutDefault, + MaxWorkers: 10, + MiddlewareLookupGlobal: middlewarelookup.NewMiddlewareLookup(nil), + Queue: rivercommon.QueueDefault, + QueuePollInterval: queuePollIntervalDefault, + QueueReportInterval: queueReportIntervalDefault, + RetryPolicy: &DefaultClientRetryPolicy{}, + SchedulerInterval: maintenance.JobSchedulerIntervalDefault, + Schema: schema, + StaleProducerRetentionPeriod: time.Minute, + Workers: NewWorkers(), + }) + + runningState := rivertype.JobStateRunning + retryableJob := testfactory.Job(ctx, t, exec, &testfactory.JobOpts{ + Attempt: ptrutil.Ptr(1), + MaxAttempts: ptrutil.Ptr(3), + Schema: schema, + State: &runningState, + }) + discardedJob := testfactory.Job(ctx, t, exec, &testfactory.JobOpts{ + Attempt: ptrutil.Ptr(3), + MaxAttempts: ptrutil.Ptr(3), + Schema: schema, + State: &runningState, + }) + + producer.addActiveJob(retryableJob.ID, &jobexecutor.JobExecutor{JobRow: retryableJob}) + producer.addActiveJob(discardedJob.ID, &jobexecutor.JobExecutor{JobRow: discardedJob}) + + producer.hardStop() + producer.executorShutdownLoop(ctx) + + require.Empty(producer.activeJobs) + require.Zero(producer.numJobsActive.Load()) + + retryableJobAfter, err := exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{ID: retryableJob.ID, Schema: schema}) + require.NoError(err) + require.Equal(rivertype.JobStateAvailable, retryableJobAfter.State) + require.Len(retryableJobAfter.Errors, 1) + require.Equal(producerHardStopError, retryableJobAfter.Errors[0].Error) + require.Equal(retryableJob.Attempt, retryableJobAfter.Errors[0].Attempt) + require.Empty(retryableJobAfter.Errors[0].Trace) + require.Nil(retryableJobAfter.FinalizedAt) + + discardedJobAfter, err := exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{ID: discardedJob.ID, Schema: schema}) + require.NoError(err) + require.Equal(rivertype.JobStateDiscarded, discardedJobAfter.State) + require.Len(discardedJobAfter.Errors, 1) + require.Equal(producerHardStopError, discardedJobAfter.Errors[0].Error) + require.Equal(discardedJob.Attempt, discardedJobAfter.Errors[0].Attempt) + require.Empty(discardedJobAfter.Errors[0].Trace) + require.NotNil(discardedJobAfter.FinalizedAt) +} + func TestProducer_PollOnly(t *testing.T) { t.Parallel()