diff --git a/pkg/chipingress/batch/client.go b/pkg/chipingress/batch/client.go index 691bb9117..db9f09e29 100644 --- a/pkg/chipingress/batch/client.go +++ b/pkg/chipingress/batch/client.go @@ -45,7 +45,7 @@ type Client struct { shutdownTimeout time.Duration shutdownOnce sync.Once batcherDone chan struct{} - cancelBatcher context.CancelFunc + started bool counters sync.Map // map[seqnumKey]*atomic.Uint64 for per-(source,type) seqnum, cleared on Stop() metrics batchClientMetrics @@ -97,21 +97,23 @@ func NewBatchClient(client chipingress.Client, opts ...Opt) (*Client, error) { return c, nil } -// Start begins processing messages from the queue and sending them in batches +// Start begins processing messages from the queue and sending them in batches. +// The context is used only for the initial metrics recording call and is NOT +// retained after Start returns. The client manages its own internal lifecycle +// context that is cancelled when Stop is called. func (b *Client) Start(ctx context.Context) { b.metrics.recordConfig(ctx, b) + b.started = true - // Create a cancellable context for the batcher - batcherCtx, cancel := context.WithCancel(ctx) - b.cancelBatcher = cancel + // Detach from the caller's cancellation but keep its values (trace IDs, etc.). + // This avoids retaining a startup context whose cancellation we don't control. + batcherCtx, cancel := context.WithCancel(context.WithoutCancel(ctx)) go func() { defer close(b.batcherDone) go func() { select { - case <-ctx.Done(): - b.Stop() case <-b.stopCh: cancel() } @@ -143,15 +145,11 @@ func (b *Client) Stop() { ctx, cancel := context.WithTimeout(context.Background(), b.shutdownTimeout) defer cancel() - started := b.cancelBatcher != nil - if started { - b.cancelBatcher() - } close(b.stopCh) // Only wait for the batcher goroutine when Start() was called; // otherwise batcherDone is never closed and we'd block until timeout. - if started { + if b.started { done := make(chan struct{}) go func() { <-b.batcherDone diff --git a/pkg/chipingress/batch/client_test.go b/pkg/chipingress/batch/client_test.go index 899cadd5c..fcf835b3c 100644 --- a/pkg/chipingress/batch/client_test.go +++ b/pkg/chipingress/batch/client_test.go @@ -451,7 +451,7 @@ func TestStart(t *testing.T) { mockClient.AssertExpectations(t) }) - t.Run("context cancellation flushes pending batch", func(t *testing.T) { + t.Run("stop flushes pending batch before batch interval", func(t *testing.T) { mockClient := mocks.NewClient(t) mockClient.EXPECT().Close().Return(nil).Maybe() done := make(chan struct{}) @@ -459,7 +459,7 @@ func TestStart(t *testing.T) { mockClient. On("PublishBatch", mock.MatchedBy(func(ctx context.Context) bool { - // Regression guard: flush on cancellation must not use an already-canceled context. + // Regression guard: flush on stop must not use an already-canceled context. return ctx != nil && ctx.Err() == nil }), mock.MatchedBy(func(batch *chipingress.CloudEventBatch) bool { @@ -475,21 +475,19 @@ func TestStart(t *testing.T) { client, err := NewBatchClient(mockClient, WithBatchSize(10), WithBatchInterval(5*time.Second)) require.NoError(t, err) - ctx, cancel := context.WithCancel(t.Context()) - - client.Start(ctx) + client.Start(t.Context()) _ = client.QueueMessage(&chipingress.CloudEventPb{Id: "test-id-1", Source: "test-source", Type: "test.event.type"}, nil) _ = client.QueueMessage(&chipingress.CloudEventPb{Id: "test-id-2", Source: "test-source", Type: "test.event.type"}, nil) time.Sleep(10 * time.Millisecond) - cancel() + client.Stop() select { case <-done: case <-time.After(time.Second): - t.Fatal("timeout waiting for flush on context cancellation") + t.Fatal("timeout waiting for flush on stop") } mockClient.AssertExpectations(t) @@ -541,12 +539,11 @@ func TestStart(t *testing.T) { client, err := NewBatchClient(mockClient, WithBatchSize(10), WithBatchInterval(5*time.Second)) require.NoError(t, err) - ctx, cancel := context.WithCancel(t.Context()) - client.Start(ctx) + client.Start(t.Context()) time.Sleep(10 * time.Millisecond) - cancel() + client.Stop() time.Sleep(50 * time.Millisecond) @@ -935,7 +932,7 @@ func TestCallbacks(t *testing.T) { mockClient.AssertExpectations(t) }) - t.Run("callbacks invoked on context cancellation", func(t *testing.T) { + t.Run("callbacks invoked on stop", func(t *testing.T) { mockClient := mocks.NewClient(t) mockClient.EXPECT().Close().Return(nil).Maybe() done := make(chan struct{}) @@ -944,7 +941,7 @@ func TestCallbacks(t *testing.T) { mockClient. On("PublishBatch", mock.MatchedBy(func(ctx context.Context) bool { - // Regression guard: flush on cancellation must not use an already-canceled context. + // Regression guard: flush on stop must not use an already-canceled context. return ctx != nil && ctx.Err() == nil }), mock.MatchedBy(func(batch *chipingress.CloudEventBatch) bool { @@ -959,9 +956,7 @@ func TestCallbacks(t *testing.T) { client, err := NewBatchClient(mockClient, WithBatchSize(10), WithBatchInterval(5*time.Second)) require.NoError(t, err) - ctx, cancel := context.WithCancel(t.Context()) - - client.Start(ctx) + client.Start(t.Context()) _ = client.QueueMessage(&chipingress.CloudEventPb{ Id: "test-id-1", @@ -973,13 +968,13 @@ func TestCallbacks(t *testing.T) { time.Sleep(10 * time.Millisecond) - // cancel context to trigger flush - cancel() + // stop to trigger flush + client.Stop() select { case <-done: case <-time.After(time.Second): - t.Fatal("timeout waiting for flush on cancellation") + t.Fatal("timeout waiting for flush on stop") } select {