From bbc9ae24efebb485acf6bd62f8a969fdb3925256 Mon Sep 17 00:00:00 2001 From: Steven Miller Date: Thu, 30 Apr 2026 11:44:47 -0400 Subject: [PATCH] Make lifecycle no-op transitions idempotent --- lib/instances/lifecycle_noop_test.go | 214 +++++++++++++++++++++++++++ lib/instances/manager.go | 49 ++++++ 2 files changed, 263 insertions(+) create mode 100644 lib/instances/lifecycle_noop_test.go diff --git a/lib/instances/lifecycle_noop_test.go b/lib/instances/lifecycle_noop_test.go new file mode 100644 index 00000000..41cf252a --- /dev/null +++ b/lib/instances/lifecycle_noop_test.go @@ -0,0 +1,214 @@ +package instances + +import ( + "context" + "errors" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/kernel/hypeman/lib/hypervisor" + "github.com/kernel/hypeman/lib/paths" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const lifecycleNoopHypervisorType hypervisor.Type = "lifecycle-noop-test" + +var lifecycleNoopHypervisorStates sync.Map + +func init() { + hypervisor.RegisterClientFactory(lifecycleNoopHypervisorType, func(socketPath string) (hypervisor.Hypervisor, error) { + state, ok := lifecycleNoopHypervisorStates.Load(socketPath) + if !ok { + return nil, errors.New("missing fake hypervisor state") + } + return lifecycleNoopHypervisor{state: state.(hypervisor.VMState)}, nil + }) +} + +type lifecycleNoopHypervisor struct { + state hypervisor.VMState +} + +func (h lifecycleNoopHypervisor) DeleteVM(context.Context) error { return nil } +func (h lifecycleNoopHypervisor) Shutdown(context.Context) error { return nil } +func (h lifecycleNoopHypervisor) GetVMInfo(context.Context) (*hypervisor.VMInfo, error) { + return &hypervisor.VMInfo{State: h.state}, nil +} +func (h lifecycleNoopHypervisor) Pause(context.Context) error { return nil } +func (h lifecycleNoopHypervisor) Resume(context.Context) error { return nil } +func (h lifecycleNoopHypervisor) Snapshot(context.Context, string) error { return nil } +func (h lifecycleNoopHypervisor) ResizeMemory(context.Context, int64) error { + return nil +} +func (h lifecycleNoopHypervisor) ResizeMemoryAndWait(context.Context, int64, time.Duration) error { + return nil +} +func (h lifecycleNoopHypervisor) SetTargetGuestMemoryBytes(context.Context, int64) error { + return nil +} +func (h lifecycleNoopHypervisor) GetTargetGuestMemoryBytes(context.Context) (int64, error) { + return 0, nil +} +func (h lifecycleNoopHypervisor) Capabilities() hypervisor.Capabilities { + return hypervisor.Capabilities{} +} + +func TestLifecycleNoopTransitionsReturnCurrentInstanceWithoutEvent(t *testing.T) { + now := time.Now().UTC() + + tests := []struct { + name string + state State + action func(context.Context, *manager, string) (*Instance, error) + }{ + { + name: "restore running", + state: StateRunning, + action: func(ctx context.Context, m *manager, id string) (*Instance, error) { + return m.RestoreInstance(ctx, id) + }, + }, + { + name: "restore initializing", + state: StateInitializing, + action: func(ctx context.Context, m *manager, id string) (*Instance, error) { + return m.RestoreInstance(ctx, id) + }, + }, + { + name: "start running without overrides", + state: StateRunning, + action: func(ctx context.Context, m *manager, id string) (*Instance, error) { + return m.StartInstance(ctx, id, StartInstanceRequest{}) + }, + }, + { + name: "start initializing without overrides", + state: StateInitializing, + action: func(ctx context.Context, m *manager, id string) (*Instance, error) { + return m.StartInstance(ctx, id, StartInstanceRequest{}) + }, + }, + { + name: "standby already standby without options", + state: StateStandby, + action: func(ctx context.Context, m *manager, id string) (*Instance, error) { + return m.StandbyInstance(ctx, id, StandbyInstanceRequest{}) + }, + }, + { + name: "stop already stopped", + state: StateStopped, + action: func(ctx context.Context, m *manager, id string) (*Instance, error) { + return m.StopInstance(ctx, id) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m, id := newLifecycleNoopManagerWithInstance(t, tt.state, now) + events, cancel := m.SubscribeLifecycleEvents(LifecycleEventConsumerWaitForState) + defer cancel() + + inst, err := tt.action(context.Background(), m, id) + require.NoError(t, err) + require.NotNil(t, inst) + assert.Equal(t, tt.state, inst.State) + assertNoLifecycleEvent(t, events) + }) + } +} + +func TestLifecycleNoopStartWithOverridesStillRejectsActiveInstance(t *testing.T) { + m, id := newLifecycleNoopManagerWithInstance(t, StateRunning, time.Now().UTC()) + events, cancel := m.SubscribeLifecycleEvents(LifecycleEventConsumerWaitForState) + defer cancel() + + _, err := m.StartInstance(context.Background(), id, StartInstanceRequest{Cmd: []string{"echo", "hello"}}) + require.ErrorIs(t, err, ErrInvalidState) + assertNoLifecycleEvent(t, events) +} + +func TestLifecycleNoopStandbyWithOptionsStillRejectsStandbyInstance(t *testing.T) { + m, id := newLifecycleNoopManagerWithInstance(t, StateStandby, time.Now().UTC()) + events, cancel := m.SubscribeLifecycleEvents(LifecycleEventConsumerWaitForState) + defer cancel() + + delay := time.Second + _, err := m.StandbyInstance(context.Background(), id, StandbyInstanceRequest{CompressionDelay: &delay}) + require.ErrorIs(t, err, ErrInvalidState) + assertNoLifecycleEvent(t, events) +} + +func newLifecycleNoopManagerWithInstance(t *testing.T, state State, now time.Time) (*manager, string) { + t.Helper() + + p := paths.New(t.TempDir()) + m := &manager{ + paths: p, + instanceLocks: sync.Map{}, + bootMarkerScans: sync.Map{}, + now: func() time.Time { + return now + }, + lifecycleEvents: newLifecycleSubscribers(), + } + + id := "inst-" + string(state) + require.NoError(t, m.ensureDirectories(id)) + + stored := StoredMetadata{ + Id: id, + Name: id, + Image: "test-image", + CreatedAt: now, + HypervisorType: lifecycleNoopHypervisorType, + SocketPath: p.InstanceSocket(id, "noop.sock"), + DataDir: p.InstanceDir(id), + } + + switch state { + case StateRunning: + stored.ProgramStartedAt = &now + stored.GuestAgentReadyAt = &now + writeLifecycleNoopSocket(t, stored.SocketPath, hypervisor.StateRunning) + case StateInitializing: + writeLifecycleNoopSocket(t, stored.SocketPath, hypervisor.StateRunning) + case StateStandby: + writeLifecycleNoopSnapshot(t, p, id) + } + + require.NoError(t, m.saveMetadata(&metadata{StoredMetadata: stored})) + t.Cleanup(func() { + lifecycleNoopHypervisorStates.Delete(stored.SocketPath) + }) + return m, id +} + +func writeLifecycleNoopSocket(t *testing.T, socketPath string, state hypervisor.VMState) { + t.Helper() + require.NoError(t, os.MkdirAll(filepath.Dir(socketPath), 0o755)) + require.NoError(t, os.WriteFile(socketPath, []byte("fake socket"), 0o644)) + lifecycleNoopHypervisorStates.Store(socketPath, state) +} + +func writeLifecycleNoopSnapshot(t *testing.T, p *paths.Paths, id string) { + t.Helper() + snapshotDir := p.InstanceSnapshotLatest(id) + require.NoError(t, os.MkdirAll(snapshotDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(snapshotDir, "memory"), []byte("snapshot"), 0o644)) +} + +func assertNoLifecycleEvent(t *testing.T, events <-chan LifecycleEvent) { + t.Helper() + select { + case event := <-events: + t.Fatalf("unexpected lifecycle event: %+v", event) + default: + } +} diff --git a/lib/instances/manager.go b/lib/instances/manager.go index 495d8a4e..abdc5b9a 100644 --- a/lib/instances/manager.go +++ b/lib/instances/manager.go @@ -399,6 +399,15 @@ func (m *manager) StandbyInstance(ctx context.Context, id string, req StandbyIns lock := m.getInstanceLock(id) lock.Lock() defer lock.Unlock() + if !standbyRequestHasOptions(req) { + current, err := m.currentInstanceWithoutHydration(ctx, id) + if err != nil { + return nil, err + } + if current.State == StateStandby { + return current, nil + } + } inst, err := m.standbyInstance(ctx, id, req, false) if err == nil { m.notifyLifecycleEvent(ctx, LifecycleEventStandby, inst) @@ -411,6 +420,13 @@ func (m *manager) RestoreInstance(ctx context.Context, id string) (*Instance, er lock := m.getInstanceLock(id) lock.Lock() defer lock.Unlock() + current, err := m.currentInstanceWithoutHydration(ctx, id) + if err != nil { + return nil, err + } + if current.State == StateRunning || current.State == StateInitializing { + return current, nil + } inst, err := m.restoreInstance(ctx, id) if err == nil { m.notifyLifecycleEvent(ctx, LifecycleEventRestore, inst) @@ -434,6 +450,13 @@ func (m *manager) StopInstance(ctx context.Context, id string) (*Instance, error lock := m.getInstanceLock(id) lock.Lock() defer lock.Unlock() + current, err := m.currentInstanceWithoutHydration(ctx, id) + if err != nil { + return nil, err + } + if current.State == StateStopped { + return current, nil + } inst, err := m.stopInstance(ctx, id) if err == nil { m.notifyLifecycleEvent(ctx, LifecycleEventStop, inst) @@ -446,6 +469,15 @@ func (m *manager) StartInstance(ctx context.Context, id string, req StartInstanc lock := m.getInstanceLock(id) lock.Lock() defer lock.Unlock() + if !startRequestHasOverrides(req) { + current, err := m.currentInstanceWithoutHydration(ctx, id) + if err != nil { + return nil, err + } + if current.State == StateRunning || current.State == StateInitializing { + return current, nil + } + } inst, err := m.startInstance(ctx, id, req) if err == nil { m.notifyLifecycleEvent(ctx, LifecycleEventStart, inst) @@ -453,6 +485,23 @@ func (m *manager) StartInstance(ctx context.Context, id string, req StartInstanc return inst, err } +func (m *manager) currentInstanceWithoutHydration(ctx context.Context, id string) (*Instance, error) { + meta, err := m.loadMetadata(id) + if err != nil { + return nil, err + } + inst := m.toInstanceWithoutHydration(ctx, meta) + return &inst, nil +} + +func startRequestHasOverrides(req StartInstanceRequest) bool { + return len(req.Entrypoint) > 0 || len(req.Cmd) > 0 +} + +func standbyRequestHasOptions(req StandbyInstanceRequest) bool { + return req.Compression != nil || req.CompressionDelay != nil +} + // UpdateInstance updates mutable properties of a running instance func (m *manager) UpdateInstance(ctx context.Context, id string, req UpdateInstanceRequest) (*Instance, error) { lock := m.getInstanceLock(id)