Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions pkg/chipingress/batch/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ type Client struct {
shutdownTimeout time.Duration
shutdownOnce sync.Once
batcherDone chan struct{}
cancelBatcher context.CancelFunc
started bool
counters sync.Map // map[seqnumKey]*atomic.Uint64 for per-(source,type) seqnum, cleared on Stop()

metrics batchClientMetrics
Expand Down Expand Up @@ -97,21 +97,23 @@ func NewBatchClient(client chipingress.Client, opts ...Opt) (*Client, error) {
return c, nil
}

// Start begins processing messages from the queue and sending them in batches
// Start begins processing messages from the queue and sending them in batches.
// The context is used only for the initial metrics recording call and is NOT
// retained after Start returns. The client manages its own internal lifecycle
// context that is cancelled when Stop is called.
Comment on lines +101 to +103
func (b *Client) Start(ctx context.Context) {
b.metrics.recordConfig(ctx, b)
b.started = true

// Create a cancellable context for the batcher
batcherCtx, cancel := context.WithCancel(ctx)
b.cancelBatcher = cancel
// Detach from the caller's cancellation but keep its values (trace IDs, etc.).
// This avoids retaining a startup context whose cancellation we don't control.
batcherCtx, cancel := context.WithCancel(context.WithoutCancel(ctx))

go func() {
defer close(b.batcherDone)
Comment on lines 104 to 113

go func() {
select {
case <-ctx.Done():
b.Stop()
case <-b.stopCh:
cancel()
}
Expand Down Expand Up @@ -143,15 +145,11 @@ func (b *Client) Stop() {
ctx, cancel := context.WithTimeout(context.Background(), b.shutdownTimeout)
defer cancel()

started := b.cancelBatcher != nil
if started {
b.cancelBatcher()
}
close(b.stopCh)

// Only wait for the batcher goroutine when Start() was called;
// otherwise batcherDone is never closed and we'd block until timeout.
if started {
if b.started {
done := make(chan struct{})
go func() {
<-b.batcherDone
Expand Down
31 changes: 13 additions & 18 deletions pkg/chipingress/batch/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,15 +451,15 @@ func TestStart(t *testing.T) {
mockClient.AssertExpectations(t)
})

t.Run("context cancellation flushes pending batch", func(t *testing.T) {
t.Run("stop flushes pending batch before batch interval", func(t *testing.T) {
mockClient := mocks.NewClient(t)
mockClient.EXPECT().Close().Return(nil).Maybe()
done := make(chan struct{})

mockClient.
On("PublishBatch",
mock.MatchedBy(func(ctx context.Context) bool {
// Regression guard: flush on cancellation must not use an already-canceled context.
// Regression guard: flush on stop must not use an already-canceled context.
return ctx != nil && ctx.Err() == nil
}),
mock.MatchedBy(func(batch *chipingress.CloudEventBatch) bool {
Expand All @@ -475,21 +475,19 @@ func TestStart(t *testing.T) {
client, err := NewBatchClient(mockClient, WithBatchSize(10), WithBatchInterval(5*time.Second))
require.NoError(t, err)

ctx, cancel := context.WithCancel(t.Context())

client.Start(ctx)
client.Start(t.Context())

_ = client.QueueMessage(&chipingress.CloudEventPb{Id: "test-id-1", Source: "test-source", Type: "test.event.type"}, nil)
_ = client.QueueMessage(&chipingress.CloudEventPb{Id: "test-id-2", Source: "test-source", Type: "test.event.type"}, nil)

time.Sleep(10 * time.Millisecond)

cancel()
client.Stop()

select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout waiting for flush on context cancellation")
t.Fatal("timeout waiting for flush on stop")
Comment on lines 454 to +490
}

mockClient.AssertExpectations(t)
Expand Down Expand Up @@ -541,12 +539,11 @@ func TestStart(t *testing.T) {
client, err := NewBatchClient(mockClient, WithBatchSize(10), WithBatchInterval(5*time.Second))
require.NoError(t, err)

ctx, cancel := context.WithCancel(t.Context())
client.Start(ctx)
client.Start(t.Context())

time.Sleep(10 * time.Millisecond)

cancel()
client.Stop()

time.Sleep(50 * time.Millisecond)

Expand Down Expand Up @@ -935,7 +932,7 @@ func TestCallbacks(t *testing.T) {
mockClient.AssertExpectations(t)
})

t.Run("callbacks invoked on context cancellation", func(t *testing.T) {
t.Run("callbacks invoked on stop", func(t *testing.T) {
mockClient := mocks.NewClient(t)
mockClient.EXPECT().Close().Return(nil).Maybe()
done := make(chan struct{})
Expand All @@ -944,7 +941,7 @@ func TestCallbacks(t *testing.T) {
mockClient.
On("PublishBatch",
mock.MatchedBy(func(ctx context.Context) bool {
// Regression guard: flush on cancellation must not use an already-canceled context.
// Regression guard: flush on stop must not use an already-canceled context.
return ctx != nil && ctx.Err() == nil
}),
mock.MatchedBy(func(batch *chipingress.CloudEventBatch) bool {
Expand All @@ -959,9 +956,7 @@ func TestCallbacks(t *testing.T) {
client, err := NewBatchClient(mockClient, WithBatchSize(10), WithBatchInterval(5*time.Second))
require.NoError(t, err)

ctx, cancel := context.WithCancel(t.Context())

client.Start(ctx)
client.Start(t.Context())

_ = client.QueueMessage(&chipingress.CloudEventPb{
Id: "test-id-1",
Expand All @@ -973,13 +968,13 @@ func TestCallbacks(t *testing.T) {

time.Sleep(10 * time.Millisecond)

// cancel context to trigger flush
cancel()
// stop to trigger flush
client.Stop()

select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout waiting for flush on cancellation")
t.Fatal("timeout waiting for flush on stop")
}

select {
Expand Down
Loading