Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions server/internal/task/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,25 @@ import (
"github.com/pgEdge/control-plane/server/internal/storage"
)

// NewWatcher creates a Watcher that monitors the given task and closes its
// Done channel when the task reaches a terminal state or is deleted. Multiple
// callers watching the same task share a single etcd watch stream. The caller
// must call Close on the returned Watcher when done with it.
func (s *Service) NewWatcher(_ context.Context, scope Scope, entityID string, taskID uuid.UUID) (*Watcher, error) {
return s.registry.acquire(s.Store.Task, scope, entityID, taskID)
}

var ErrTaskNotFound = errors.New("task not found")

type Service struct {
Store *Store
Store *Store
registry *watcherRegistry
}

func NewService(store *Store) *Service {
return &Service{
Store: store,
Store: store,
registry: newWatcherRegistry(),
}
}

Expand Down
5 changes: 5 additions & 0 deletions server/internal/task/task_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,8 @@ func (s *TaskStore) DeleteByEntity(scope Scope, entityID string) storage.DeleteO
prefix := s.EntityPrefix(scope, entityID)
return storage.NewDeletePrefixOp(s.client, prefix)
}

func (s *TaskStore) Watch(scope Scope, entityID string, taskID uuid.UUID) storage.WatchOp[*StoredTask] {
key := s.Key(scope, entityID, taskID)
return storage.NewWatchOp[*StoredTask](s.client, key)
}
240 changes: 240 additions & 0 deletions server/internal/task/watcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
package task

import (
"context"
"errors"
"fmt"
"slices"
"sync"

"github.com/google/uuid"

"github.com/pgEdge/control-plane/server/internal/storage"
)

var (
ErrTaskCanceled = errors.New("task was canceled")
ErrTaskFailed = errors.New("task failed")
)

// Watcher is a subscription to a task's terminal state. Multiple Watchers for
// the same task share a single underlying etcd watch stream.
type Watcher struct {
mu sync.Mutex
closed bool
err error
done chan struct{}
errCh chan error
shared *sharedWatcher
}

// Done returns a channel that is closed when the task reaches a terminal state
// or is deleted.
func (w *Watcher) Done() <-chan struct{} {
return w.done
}

// Err returns nil if the task completed successfully, ErrTaskCanceled if it
// was canceled (or is canceling), or ErrTaskFailed if it failed. It is only
// meaningful after Done() is closed.
func (w *Watcher) Err() error {
w.mu.Lock()
defer w.mu.Unlock()
return w.err
}

// Close releases this subscription. When the last subscription for a task is
// closed, the underlying etcd watch stream is stopped.
func (w *Watcher) Close() {
w.shared.release(w)
}

// Error returns a channel that receives an error if the underlying watch
// stream fails. The channel carries at most one value. Callers that select on
// Done should also select on Error so they are not blocked when the watch
// stream dies before the task reaches a terminal state.
func (w *Watcher) Error() <-chan error {
return w.errCh
}

func (w *Watcher) finish(err error) {
w.mu.Lock()
defer w.mu.Unlock()
if w.closed {
return
}
w.closed = true
w.err = err
close(w.done)
}

// sharedWatcher holds one etcd watch stream for a task and fans events out to
// all active Watcher subscriptions. It is managed by watcherRegistry.
type sharedWatcher struct {
mu sync.Mutex
subscribers []*Watcher
terminal bool
terminalErr error
watchOp storage.WatchOp[*StoredTask]
registry *watcherRegistry
taskID uuid.UUID
shutdownCh chan struct{}
shutdownOnce sync.Once
cancelWatch context.CancelFunc
}

// newSubscription creates and registers a new Watcher. If the task is already
// in a terminal state, the returned Watcher's Done channel is closed immediately.
func (sw *sharedWatcher) newSubscription() *Watcher {
w := &Watcher{
done: make(chan struct{}),
errCh: make(chan error, 1),
shared: sw,
}
sw.mu.Lock()
sw.subscribers = append(sw.subscribers, w)
if sw.terminal {
w.closed = true
w.err = sw.terminalErr
close(w.done)
}
sw.mu.Unlock()
return w
}

func (sw *sharedWatcher) finishAll(err error) {
sw.mu.Lock()
sw.terminal = true
sw.terminalErr = err
subs := make([]*Watcher, len(sw.subscribers))
copy(subs, sw.subscribers)
sw.mu.Unlock()
for _, sub := range subs {
sub.finish(err)
}
Comment on lines +105 to +114

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Make terminal completion sticky.

finishAll overwrites terminalErr on later events, so a task seen as StatusCanceling can later be reported as failed/completed/deleted to new subscribers while the same shared watcher remains registered. Return early once sw.terminal is already set.

Proposed fix
 func (sw *sharedWatcher) finishAll(err error) {
 	sw.mu.Lock()
+	if sw.terminal {
+		sw.mu.Unlock()
+		return
+	}
 	sw.terminal = true
 	sw.terminalErr = err
 	subs := make([]*Watcher, len(sw.subscribers))
 	copy(subs, sw.subscribers)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func (sw *sharedWatcher) finishAll(err error) {
sw.mu.Lock()
sw.terminal = true
sw.terminalErr = err
subs := make([]*Watcher, len(sw.subscribers))
copy(subs, sw.subscribers)
sw.mu.Unlock()
for _, sub := range subs {
sub.finish(err)
}
func (sw *sharedWatcher) finishAll(err error) {
sw.mu.Lock()
if sw.terminal {
sw.mu.Unlock()
return
}
sw.terminal = true
sw.terminalErr = err
subs := make([]*Watcher, len(sw.subscribers))
copy(subs, sw.subscribers)
sw.mu.Unlock()
for _, sub := range subs {
sub.finish(err)
}
}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@server/internal/task/watcher.go` around lines 105 - 114, The shared watcher
completion state is being overwritten on subsequent terminal events, so once
finishAll marks sharedWatcher as terminal it should not update terminalErr
again. Add an early return at the start of sharedWatcher.finishAll when
sw.terminal is already true, and keep the existing subscriber fan-out only for
the first terminal transition so new subscribers don’t see a later status
replace the original canceling state.

}

func (sw *sharedWatcher) handleEvent(e *storage.Event[*StoredTask]) error {
switch e.Type {
case storage.EventTypeDelete:
sw.finishAll(ErrTaskCanceled)
case storage.EventTypeError:
return e.Err
case storage.EventTypePut:
if e.Value == nil || e.Value.Task == nil {
return nil
}
switch e.Value.Task.Status {
case StatusCanceled, StatusCanceling:
sw.finishAll(ErrTaskCanceled)
case StatusFailed:
sw.finishAll(ErrTaskFailed)
case StatusCompleted:
sw.finishAll(nil)
}
}
return nil
}

// propagateErrors forwards watch stream errors to all active subscriptions.
// context.Canceled is filtered out — it indicates normal cleanup when
// cancelWatch is called and should not be surfaced as an error.
func (sw *sharedWatcher) propagateErrors() {
select {
case <-sw.shutdownCh:
case err := <-sw.watchOp.Error():
if errors.Is(err, context.Canceled) {
return
}
sw.mu.Lock()
subs := make([]*Watcher, len(sw.subscribers))
copy(subs, sw.subscribers)
sw.mu.Unlock()
for _, w := range subs {
select {
case w.errCh <- err:
default:
}
}
Comment on lines +142 to +158

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Remove or mark failed shared watchers after stream errors.

After watchOp.Error() fires, propagateErrors notifies only the subscribers copied at that moment and then exits. The dead sharedWatcher remains in registry.entries, so a concurrent/later acquire can attach to a watcher with no active stream and never receive Done or Error.

Possible direction
 	case err := <-sw.watchOp.Error():
 		if errors.Is(err, context.Canceled) {
 			return
 		}
+		sw.shutdown()
 		sw.mu.Lock()
 		subs := make([]*Watcher, len(sw.subscribers))

If there is a race with acquire, also persist the stream error on sharedWatcher so newSubscription can immediately populate errCh for late subscribers before the registry entry is removed.

Also applies to: 213-215

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@server/internal/task/watcher.go` around lines 142 - 158, The shared watcher
error handling in propagateErrors leaves a dead sharedWatcher in
registry.entries after watchOp.Error() fires, so later acquire/newSubscription
calls can attach to a stale watcher and miss Done/Error. Update
sharedWatcher/registry teardown so stream errors are persisted on sharedWatcher
and the registry entry is removed or marked failed immediately after the error
is observed, and make newSubscription populate errCh from that stored error for
late subscribers; also ensure propagateErrors continues to use the sharedWatcher
state consistently for both current and future subscribers.

}
}

// release removes w from the subscriber list. When the last subscriber is
// removed, it stops the underlying watch stream and removes the sharedWatcher
// from the registry.
//
// sw.mu is always released before sw.registry.mu is acquired so that
// watcherRegistry.acquire (which holds registry.mu and may acquire sw.mu via
// newSubscription) cannot deadlock with release.
func (sw *sharedWatcher) release(w *Watcher) {
sw.mu.Lock()
for i, sub := range sw.subscribers {
if sub == w {
sw.subscribers = slices.Delete(sw.subscribers, i, i+1)
break
}
}
remaining := len(sw.subscribers)
sw.mu.Unlock()

if remaining == 0 {
sw.shutdown()
}
}

func (sw *sharedWatcher) shutdown() {
sw.shutdownOnce.Do(func() {
sw.registry.mu.Lock()
delete(sw.registry.entries, sw.taskID)
sw.registry.mu.Unlock()
close(sw.shutdownCh)
sw.cancelWatch()
sw.watchOp.Close()
})
}

// watcherRegistry maintains at most one shared watch stream per task across
// all concurrent callers on the same service instance.
type watcherRegistry struct {
mu sync.Mutex
entries map[uuid.UUID]*sharedWatcher
}

func newWatcherRegistry() *watcherRegistry {
return &watcherRegistry{
entries: make(map[uuid.UUID]*sharedWatcher),
}
}

func (r *watcherRegistry) acquire(store *TaskStore, scope Scope, entityID string, taskID uuid.UUID) (*Watcher, error) {
r.mu.Lock()
defer r.mu.Unlock()

if sw, ok := r.entries[taskID]; ok {
return sw.newSubscription(), nil
}

watchCtx, cancelWatch := context.WithCancel(context.Background())
watchOp := store.Watch(scope, entityID, taskID)
sw := &sharedWatcher{
watchOp: watchOp,
registry: r,
taskID: taskID,
shutdownCh: make(chan struct{}),
cancelWatch: cancelWatch,
}

// Create the first subscription before starting the watch so that
// handleEvent's synchronous load() call can signal it if the task is
// already terminal.
w := sw.newSubscription()

if err := watchOp.Watch(watchCtx, sw.handleEvent); err != nil {
cancelWatch()
return nil, fmt.Errorf("failed to start task watcher: %w", err)
}

r.entries[taskID] = sw
go sw.propagateErrors()
return w, nil
}
22 changes: 22 additions & 0 deletions server/internal/workflows/activities/apply_event.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ func (a *Activities) ExecuteApplyEvent(
}

func (a *Activities) ApplyEvent(ctx context.Context, input *ApplyEventInput) (*ApplyEventOutput, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

if input.TaskID != uuid.Nil {
watcher, err := a.TaskSvc.NewWatcher(ctx, task.ScopeDatabase, input.DatabaseID, input.TaskID)
if err != nil {
activity.Logger(ctx).Warn("failed to start task watcher; activity won't be interrupted on task cancellation", "error", err)
} else {
go func() {
defer watcher.Close()
select {
case <-watcher.Done():
cancel()
case <-watcher.Error():
// Watch stream died; stop monitoring without cancelling
// the activity — we don't know the task's current state.
case <-ctx.Done():
}
}()
}
}

logger := activity.Logger(ctx).With("database_id", input.DatabaseID)
logStart := logger.With(
"event_type", input.Event.Type,
Expand Down