diff --git a/cmd/server/process_unix.go b/cmd/server/process_unix.go new file mode 100644 index 00000000..d7995212 --- /dev/null +++ b/cmd/server/process_unix.go @@ -0,0 +1,19 @@ +//go:build unix + +package server + +import ( + "errors" + "os" + "syscall" +) + +// isProcessRunning checks if a process with the given PID is running. +func isProcessRunning(pid int) bool { + process, err := os.FindProcess(pid) + if err != nil { + return false + } + err = process.Signal(syscall.Signal(0)) + return err == nil || errors.Is(err, syscall.EPERM) +} diff --git a/cmd/server/process_windows.go b/cmd/server/process_windows.go new file mode 100644 index 00000000..b4d8ba42 --- /dev/null +++ b/cmd/server/process_windows.go @@ -0,0 +1,10 @@ +//go:build windows + +package server + +// isProcessRunning checks if a process with the given PID is running. +// On Windows, Signal(0) is not supported, so this always returns false. +// PID file liveness detection is best-effort on this platform. +func isProcessRunning(_ int) bool { + return false +} diff --git a/cmd/server/server.go b/cmd/server/server.go index e578a17d..168ea157 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -12,7 +12,6 @@ import ( "sort" "strconv" "strings" - "syscall" "time" "github.com/coder/agentapi/lib/screentracker" @@ -292,8 +291,21 @@ func writePIDFile(pidFile string, logger *slog.Logger) error { return nil } -// cleanupPIDFile removes the PID file if it exists +// cleanupPIDFile removes the PID file if it was written by this process. func cleanupPIDFile(pidFile string, logger *slog.Logger) { + data, err := os.ReadFile(pidFile) + if err != nil { + if !os.IsNotExist(err) { + logger.Error("Failed to read PID file for cleanup", "pidFile", pidFile, "error", err) + } + return + } + pidStr := strings.TrimSpace(string(data)) + filePID, err := strconv.Atoi(pidStr) + if err != nil || filePID != os.Getpid() { + logger.Info("PID file belongs to another process, skipping cleanup", "pidFile", pidFile, "filePID", pidStr) + return + } if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { logger.Error("Failed to remove PID file", "pidFile", pidFile, "error", err) } else if err == nil { @@ -301,16 +313,6 @@ func cleanupPIDFile(pidFile string, logger *slog.Logger) { } } -// isProcessRunning checks if a process with the given PID is running -func isProcessRunning(pid int) bool { - process, err := os.FindProcess(pid) - if err != nil { - return false - } - err = process.Signal(syscall.Signal(0)) - return err == nil || errors.Is(err, syscall.EPERM) -} - type flagSpec struct { name string shorthand string diff --git a/cmd/server/server_test.go b/cmd/server/server_test.go index 7b9372c1..29eb65b4 100644 --- a/cmd/server/server_test.go +++ b/cmd/server/server_test.go @@ -641,8 +641,9 @@ func TestPIDFileOperations(t *testing.T) { tmpDir := t.TempDir() pidFile := tmpDir + "/test.pid" - // Write initial PID file - err := os.WriteFile(pidFile, []byte("12345\n"), 0o644) + // Write a non-numeric PID so strconv.Atoi fails and the liveness + // check is skipped, avoiding flakes when a real PID matches. + err := os.WriteFile(pidFile, []byte("not-a-pid\n"), 0o644) require.NoError(t, err) // Overwrite with current PID @@ -657,12 +658,25 @@ func TestPIDFileOperations(t *testing.T) { assert.Equal(t, expectedPID, string(data)) }) + t.Run("writePIDFile detects running process", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + // Write the current process PID so isProcessRunning returns true. + err := os.WriteFile(pidFile, []byte(fmt.Sprintf("%d\n", os.Getpid())), 0o644) + require.NoError(t, err) + + err = writePIDFile(pidFile, discardLogger) + require.Error(t, err) + assert.Contains(t, err.Error(), "another instance is already running") + }) + t.Run("cleanupPIDFile removes file", func(t *testing.T) { tmpDir := t.TempDir() pidFile := tmpDir + "/test.pid" - // Create PID file - err := os.WriteFile(pidFile, []byte("12345\n"), 0o644) + // Create PID file with current process PID so ownership check passes + err := os.WriteFile(pidFile, []byte(fmt.Sprintf("%d\n", os.Getpid())), 0o644) require.NoError(t, err) // Cleanup diff --git a/cmd/server/signals_unix.go b/cmd/server/signals_unix.go index b15b5b2b..6a8012ad 100644 --- a/cmd/server/signals_unix.go +++ b/cmd/server/signals_unix.go @@ -18,7 +18,7 @@ import ( func handleSignals(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, srv *httpapi.Server) { // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) shutdownCh := make(chan os.Signal, 1) - signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGINT) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) go func() { defer signal.Stop(shutdownCh) sig := <-shutdownCh diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index c92bb48f..c47f6801 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -94,6 +94,9 @@ func convertStatus(status st.ConversationStatus) AgentStatus { const defaultSubscriptionBufSize uint = 1024 +// maxStoredErrors caps the number of errors retained for late subscribers. +const maxStoredErrors = 100 + type EventEmitterOption func(*EventEmitter) func WithSubscriptionBufSize(size uint) EventEmitterOption { @@ -224,8 +227,11 @@ func (e *EventEmitter) EmitError(message string, level st.ErrorLevel) { Time: e.clock.Now(), } - // Store the error so new subscribers can receive all errors + // Store the error so new subscribers can receive recent errors. e.errors = append(e.errors, errorBody) + if len(e.errors) > maxStoredErrors { + e.errors = e.errors[len(e.errors)-maxStoredErrors:] + } e.notifyChannels(EventTypeError, errorBody) } diff --git a/lib/httpapi/events_test.go b/lib/httpapi/events_test.go index 106766af..a93bde05 100644 --- a/lib/httpapi/events_test.go +++ b/lib/httpapi/events_test.go @@ -99,6 +99,70 @@ func TestEventEmitter(t *testing.T) { } }) + t.Run("error-cap", func(t *testing.T) { + emitter := NewEventEmitter(WithSubscriptionBufSize(10)) + + for i := range 150 { + emitter.EmitError(fmt.Sprintf("error %d", i), st.ErrorLevelError) + } + + _, _, stateEvents := emitter.Subscribe() + + var errorEvents []Event + for _, ev := range stateEvents { + if ev.Type == EventTypeError { + errorEvents = append(errorEvents, ev) + } + } + + assert.Len(t, errorEvents, maxStoredErrors) + + // Errors should be the last 100: "error 50" through "error 149". + for i, ev := range errorEvents { + body, ok := ev.Payload.(ErrorBody) + assert.True(t, ok) + assert.Equal(t, fmt.Sprintf("error %d", i+50), body.Message) + } + }) + + t.Run("error-events-in-initial-state", func(t *testing.T) { + mockClock := quartz.NewMock(t) + fixedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + mockClock.Set(fixedTime) + + emitter := NewEventEmitter(WithClock(mockClock), WithSubscriptionBufSize(10)) + + emitter.EmitError("err1", st.ErrorLevelError) + mockClock.Set(fixedTime.Add(1 * time.Second)) + emitter.EmitError("err2", st.ErrorLevelWarning) + mockClock.Set(fixedTime.Add(2 * time.Second)) + emitter.EmitError("err3", st.ErrorLevelError) + + _, _, stateEvents := emitter.Subscribe() + + var errorEvents []Event + for _, ev := range stateEvents { + if ev.Type == EventTypeError { + errorEvents = append(errorEvents, ev) + } + } + + assert.Len(t, errorEvents, 3) + + expected := []ErrorBody{ + {Message: "err1", Level: st.ErrorLevelError, Time: fixedTime}, + {Message: "err2", Level: st.ErrorLevelWarning, Time: fixedTime.Add(1 * time.Second)}, + {Message: "err3", Level: st.ErrorLevelError, Time: fixedTime.Add(2 * time.Second)}, + } + for i, ev := range errorEvents { + body, ok := ev.Payload.(ErrorBody) + assert.True(t, ok) + assert.Equal(t, expected[i].Message, body.Message) + assert.Equal(t, expected[i].Level, body.Level) + assert.Equal(t, expected[i].Time, body.Time) + } + }) + t.Run("clock-injection", func(t *testing.T) { mockClock := quartz.NewMock(t) fixedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 37e5c374..84c20cdf 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -200,10 +200,11 @@ func (c *PTYConversation) Start(ctx context.Context) { c.initialPromptReady = true } + var loadErr string if c.initialPromptReady && c.loadStateStatus == LoadStatePending && c.cfg.StatePersistenceConfig.LoadState { if err := c.loadStateLocked(); err != nil { c.cfg.Logger.Error("Failed to load state", "error", err) - c.emitter.EmitError(fmt.Sprintf("Failed to restore previous session: %v", err), ErrorLevelWarning) + loadErr = fmt.Sprintf("Failed to restore previous session: %v", err) c.loadStateStatus = LoadStateFailed } else { c.loadStateStatus = LoadStateSucceeded @@ -211,6 +212,9 @@ func (c *PTYConversation) Start(ctx context.Context) { } if c.initialPromptReady && len(c.cfg.InitialPrompt) > 0 && !c.initialPromptSent { + // Safe to send under lock: the queue is guaranteed empty here because + // statusLocked blocks Send until the snapshot buffer fills, which + // cannot happen before this first enqueue completes. c.outboundQueue <- outboundMessage{parts: c.cfg.InitialPrompt, errCh: nil} c.initialPromptSent = true c.dirty = true @@ -226,6 +230,9 @@ func (c *PTYConversation) Start(ctx context.Context) { } c.lock.Unlock() + if loadErr != "" { + c.emitter.EmitError(loadErr, ErrorLevelWarning) + } c.emitter.EmitStatus(status) c.emitter.EmitMessages(messages) c.emitter.EmitScreen(screen) @@ -292,7 +299,8 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp if c.cfg.FormatMessage != nil { agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) } - if c.loadStateStatus == LoadStateSucceeded && !c.userSentMessageAfterLoadState && len(c.messages) > 0 { + if c.loadStateStatus == LoadStateSucceeded && !c.userSentMessageAfterLoadState && len(c.messages) > 0 && + c.messages[len(c.messages)-1].Role == ConversationRoleAgent { agentMessage = c.messages[len(c.messages)-1].Message } if c.cfg.FormatToolCall != nil { @@ -605,6 +613,12 @@ func (c *PTYConversation) SaveState() error { return xerrors.Errorf("failed to encode state: %w", err) } + // Flush to disk before rename for crash safety + if err := f.Sync(); err != nil { + _ = f.Close() + return xerrors.Errorf("failed to sync state file: %w", err) + } + // Close file before rename if err := f.Close(); err != nil { return xerrors.Errorf("failed to close temp state file: %w", err) @@ -668,7 +682,10 @@ func (c *PTYConversation) loadStateLocked() error { c.initialPromptSent = agentState.InitialPromptSent if len(c.cfg.InitialPrompt) > 0 { isDifferent := buildStringFromMessageParts(c.cfg.InitialPrompt) != agentState.InitialPrompt - c.initialPromptSent = !isDifferent + if isDifferent { + c.initialPromptSent = false + } + // If same prompt, keep agentState.InitialPromptSent } else if agentState.InitialPrompt != "" { c.cfg.InitialPrompt = []MessagePart{MessagePartText{ Content: agentState.InitialPrompt, diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index 6342bd74..0e7d9635 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -937,6 +937,139 @@ func TestStatePersistence(t *testing.T) { messages := c.Messages() assert.Len(t, messages, 1) }) + + t.Run("LoadState_last_message_is_user_role", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file where the last message is a user message. + // Without the role check in updateLastAgentMessageLocked, the + // user message content would be used as the new agent message. + testState := st.AgentState{ + Version: 1, + InitialPromptSent: true, + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent greeting", Role: st.ConversationRoleAgent, Time: time.Now()}, + {Id: 1, Message: "user question", Role: st.ConversationRoleUser, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Advance past stability so state loads and a new agent message + // is created from the current screen content. + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + messages := c.Messages() + require.True(t, len(messages) >= 3, "expected at least 3 messages, got %d", len(messages)) + // The new agent message should derive from screen content ("ready"), + // NOT from the last loaded message ("user question"). + lastMsg := messages[len(messages)-1] + assert.Equal(t, st.ConversationRoleAgent, lastMsg.Role) + assert.NotEqual(t, "user question", lastMsg.Message, + "agent message must not contain the user message content") + assert.Contains(t, lastMsg.Message, "ready") + }) + + t.Run("LoadState_preserves_unsent_initial_prompt_status", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create state where the initial prompt was NOT sent (e.g. previous crash). + testState := st.AgentState{ + Version: 1, + InitialPrompt: "test prompt", + InitialPromptSent: false, + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent greeting", Role: st.ConversationRoleAgent, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + writeCounter := 0 + agent := &testAgent{screen: "ready"} + agent.onWrite = func(data []byte) { + writeCounter++ + agent.screen = fmt.Sprintf("__write_%d", writeCounter) + } + + mClock := quartz.NewMock(t) + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + // Same initial prompt as saved state. + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Advance until we see a user message with the initial prompt. + advanceUntil(ctx, t, mClock, func() bool { + for _, m := range c.Messages() { + if m.Role == st.ConversationRoleUser && m.Message == "test prompt" { + return true + } + } + return false + }) + + // Verify the initial prompt was sent as a user message. + found := false + for _, m := range c.Messages() { + if m.Role == st.ConversationRoleUser && m.Message == "test prompt" { + found = true + break + } + } + assert.True(t, found, "initial prompt should have been sent since InitialPromptSent was false in saved state") + }) } func TestInitialPromptReadiness(t *testing.T) {