Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions server/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,13 @@ func defaultAddresses() ([]string, error) {
return []string{ip.String()}, nil
}

// getFirstIP gets the first non-loopback IPv4 address
// getFirstIP gets the first non-loopback IP address, preferring IPv4.
func getFirstIP() (net.IP, error) {
interfaces, err := net.Interfaces()
if err != nil {
return net.IPv4zero, fmt.Errorf("failed to list interfaces: %w", err)
}
var ipv6 net.IP
for _, iface := range interfaces {
// Skip loopback and down interfaces
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
Expand All @@ -448,7 +449,6 @@ func getFirstIP() (net.IP, error) {
if err != nil {
continue
}

for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
Expand All @@ -458,21 +458,24 @@ func getFirstIP() (net.IP, error) {
ip = v.IP
}

// Check if it is a valid IPv4 address and not a loopback address
// Check if it is a valid address and not a loopback address
if ip == nil || ip.IsLoopback() {
continue
}

// To4() returns nil if the IP is not an IPv4 address
ip = ip.To4()
if ip == nil {
continue
if v4 := ip.To4(); v4 != nil {
return v4, nil
}
// Store first IPv6 address found, keep searching for IPv4
if ipv6 == nil {
ipv6 = ip.To16()
}

// Return the first valid IPv4 address found
return ip, nil
}
}
if ipv6 != nil {
return ipv6, nil
}

return net.IPv4zero, fmt.Errorf("could not find a valid network interface")
}
Expand Down
14 changes: 12 additions & 2 deletions server/internal/task/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,25 @@ import (
"github.com/pgEdge/control-plane/server/internal/storage"
)

// NewWatcher creates a Watcher that monitors the given task and closes its
// Done channel when the task reaches a terminal state or is deleted. Multiple
// callers watching the same task share a single etcd watch stream. The caller
// must call Close on the returned Watcher when done with it.
func (s *Service) NewWatcher(_ context.Context, scope Scope, entityID string, taskID uuid.UUID) (*Watcher, error) {
return s.registry.acquire(s.Store.Task, scope, entityID, taskID)
}

var ErrTaskNotFound = errors.New("task not found")

type Service struct {
Store *Store
Store *Store
registry *watcherRegistry
}

func NewService(store *Store) *Service {
return &Service{
Store: store,
Store: store,
registry: newWatcherRegistry(),
}
}

Expand Down
5 changes: 5 additions & 0 deletions server/internal/task/task_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,8 @@ func (s *TaskStore) DeleteByEntity(scope Scope, entityID string) storage.DeleteO
prefix := s.EntityPrefix(scope, entityID)
return storage.NewDeletePrefixOp(s.client, prefix)
}

func (s *TaskStore) Watch(scope Scope, entityID string, taskID uuid.UUID) storage.WatchOp[*StoredTask] {
key := s.Key(scope, entityID, taskID)
return storage.NewWatchOp[*StoredTask](s.client, key)
}
240 changes: 240 additions & 0 deletions server/internal/task/watcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
package task

import (
"context"
"errors"
"fmt"
"slices"
"sync"

"github.com/google/uuid"

"github.com/pgEdge/control-plane/server/internal/storage"
)

var (
ErrTaskCanceled = errors.New("task was canceled")
ErrTaskFailed = errors.New("task failed")
)

// Watcher is a subscription to a task's terminal state. Multiple Watchers for
// the same task share a single underlying etcd watch stream.
type Watcher struct {
mu sync.Mutex
closed bool
err error
done chan struct{}
errCh chan error
shared *sharedWatcher
}

// Done returns a channel that is closed when the task reaches a terminal state
// or is deleted.
func (w *Watcher) Done() <-chan struct{} {
return w.done
}

// Err returns nil if the task completed successfully, ErrTaskCanceled if it
// was canceled (or is canceling), or ErrTaskFailed if it failed. It is only
// meaningful after Done() is closed.
func (w *Watcher) Err() error {
w.mu.Lock()
defer w.mu.Unlock()
return w.err
}

// Close releases this subscription. When the last subscription for a task is
// closed, the underlying etcd watch stream is stopped.
func (w *Watcher) Close() {
w.shared.release(w)
}

// Error returns a channel that receives an error if the underlying watch
// stream fails. The channel carries at most one value. Callers that select on
// Done should also select on Error so they are not blocked when the watch
// stream dies before the task reaches a terminal state.
func (w *Watcher) Error() <-chan error {
return w.errCh
}

func (w *Watcher) finish(err error) {
w.mu.Lock()
defer w.mu.Unlock()
if w.closed {
return
}
w.closed = true
w.err = err
close(w.done)
}

// sharedWatcher holds one etcd watch stream for a task and fans events out to
// all active Watcher subscriptions. It is managed by watcherRegistry.
type sharedWatcher struct {
mu sync.Mutex
subscribers []*Watcher
terminal bool
terminalErr error
watchOp storage.WatchOp[*StoredTask]
registry *watcherRegistry
taskID uuid.UUID
shutdownCh chan struct{}
shutdownOnce sync.Once
cancelWatch context.CancelFunc
}

// newSubscription creates and registers a new Watcher. If the task is already
// in a terminal state, the returned Watcher's Done channel is closed immediately.
func (sw *sharedWatcher) newSubscription() *Watcher {
w := &Watcher{
done: make(chan struct{}),
errCh: make(chan error, 1),
shared: sw,
}
sw.mu.Lock()
sw.subscribers = append(sw.subscribers, w)
if sw.terminal {
w.closed = true
w.err = sw.terminalErr
close(w.done)
}
sw.mu.Unlock()
return w
}

func (sw *sharedWatcher) finishAll(err error) {
sw.mu.Lock()
sw.terminal = true
sw.terminalErr = err
subs := make([]*Watcher, len(sw.subscribers))
copy(subs, sw.subscribers)
sw.mu.Unlock()
for _, sub := range subs {
sub.finish(err)
}
}

func (sw *sharedWatcher) handleEvent(e *storage.Event[*StoredTask]) error {
switch e.Type {
case storage.EventTypeDelete:
sw.finishAll(ErrTaskCanceled)
case storage.EventTypeError:
return e.Err
case storage.EventTypePut:
if e.Value == nil || e.Value.Task == nil {
return nil
}
switch e.Value.Task.Status {
case StatusCanceled, StatusCanceling:
sw.finishAll(ErrTaskCanceled)
case StatusFailed:
sw.finishAll(ErrTaskFailed)
case StatusCompleted:
sw.finishAll(nil)
}
}
return nil
}

// propagateErrors forwards watch stream errors to all active subscriptions.
// context.Canceled is filtered out — it indicates normal cleanup when
// cancelWatch is called and should not be surfaced as an error.
func (sw *sharedWatcher) propagateErrors() {
select {
case <-sw.shutdownCh:
case err := <-sw.watchOp.Error():
if errors.Is(err, context.Canceled) {
return
}
sw.mu.Lock()
subs := make([]*Watcher, len(sw.subscribers))
copy(subs, sw.subscribers)
sw.mu.Unlock()
for _, w := range subs {
select {
case w.errCh <- err:
default:
}
}
}
}

// release removes w from the subscriber list. When the last subscriber is
// removed, it stops the underlying watch stream and removes the sharedWatcher
// from the registry.
//
// sw.mu is always released before sw.registry.mu is acquired so that
// watcherRegistry.acquire (which holds registry.mu and may acquire sw.mu via
// newSubscription) cannot deadlock with release.
func (sw *sharedWatcher) release(w *Watcher) {
sw.mu.Lock()
for i, sub := range sw.subscribers {
if sub == w {
sw.subscribers = slices.Delete(sw.subscribers, i, i+1)
break
}
}
remaining := len(sw.subscribers)
sw.mu.Unlock()

if remaining == 0 {
sw.shutdown()
}
}

func (sw *sharedWatcher) shutdown() {
sw.shutdownOnce.Do(func() {
sw.registry.mu.Lock()
delete(sw.registry.entries, sw.taskID)
sw.registry.mu.Unlock()
close(sw.shutdownCh)
sw.cancelWatch()
sw.watchOp.Close()
})
}

// watcherRegistry maintains at most one shared watch stream per task across
// all concurrent callers on the same service instance.
type watcherRegistry struct {
mu sync.Mutex
entries map[uuid.UUID]*sharedWatcher
}

func newWatcherRegistry() *watcherRegistry {
return &watcherRegistry{
entries: make(map[uuid.UUID]*sharedWatcher),
}
}

func (r *watcherRegistry) acquire(store *TaskStore, scope Scope, entityID string, taskID uuid.UUID) (*Watcher, error) {
r.mu.Lock()
defer r.mu.Unlock()

if sw, ok := r.entries[taskID]; ok {
return sw.newSubscription(), nil
}

watchCtx, cancelWatch := context.WithCancel(context.Background())
watchOp := store.Watch(scope, entityID, taskID)
sw := &sharedWatcher{
watchOp: watchOp,
registry: r,
taskID: taskID,
shutdownCh: make(chan struct{}),
cancelWatch: cancelWatch,
}

// Create the first subscription before starting the watch so that
// handleEvent's synchronous load() call can signal it if the task is
// already terminal.
w := sw.newSubscription()

if err := watchOp.Watch(watchCtx, sw.handleEvent); err != nil {
cancelWatch()
return nil, fmt.Errorf("failed to start task watcher: %w", err)
}

r.entries[taskID] = sw
go sw.propagateErrors()
return w, nil
}
22 changes: 22 additions & 0 deletions server/internal/workflows/activities/apply_event.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ func (a *Activities) ExecuteApplyEvent(
}

func (a *Activities) ApplyEvent(ctx context.Context, input *ApplyEventInput) (*ApplyEventOutput, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

if input.TaskID != uuid.Nil {
watcher, err := a.TaskSvc.NewWatcher(ctx, task.ScopeDatabase, input.DatabaseID, input.TaskID)
if err != nil {
activity.Logger(ctx).Warn("failed to start task watcher; activity won't be interrupted on task cancellation", "error", err)
} else {
go func() {
defer watcher.Close()
select {
case <-watcher.Done():
cancel()
case <-watcher.Error():
// Watch stream died; stop monitoring without cancelling
// the activity — we don't know the task's current state.
case <-ctx.Done():
}
}()
}
}

logger := activity.Logger(ctx).With("database_id", input.DatabaseID)
logStart := logger.With(
"event_type", input.Event.Type,
Expand Down