Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cmd/server/process_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//go:build unix

package server

import (
"errors"
"os"
"syscall"
)

// isProcessRunning checks if a process with the given PID is running.
func isProcessRunning(pid int) bool {
process, err := os.FindProcess(pid)
if err != nil {
return false
}
err = process.Signal(syscall.Signal(0))
return err == nil || errors.Is(err, syscall.EPERM)
}
10 changes: 10 additions & 0 deletions cmd/server/process_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//go:build windows

package server

// isProcessRunning checks if a process with the given PID is running.
// On Windows, Signal(0) is not supported, so this always returns false.
// PID file liveness detection is best-effort on this platform.
func isProcessRunning(_ int) bool {
return false
}
26 changes: 14 additions & 12 deletions cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"sort"
"strconv"
"strings"
"syscall"
"time"

"github.com/coder/agentapi/lib/screentracker"
Expand Down Expand Up @@ -292,25 +291,28 @@ func writePIDFile(pidFile string, logger *slog.Logger) error {
return nil
}

// cleanupPIDFile removes the PID file if it exists
// cleanupPIDFile removes the PID file if it was written by this process.
func cleanupPIDFile(pidFile string, logger *slog.Logger) {
data, err := os.ReadFile(pidFile)
if err != nil {
if !os.IsNotExist(err) {
logger.Error("Failed to read PID file for cleanup", "pidFile", pidFile, "error", err)
}
return
}
pidStr := strings.TrimSpace(string(data))
filePID, err := strconv.Atoi(pidStr)
if err != nil || filePID != os.Getpid() {
logger.Info("PID file belongs to another process, skipping cleanup", "pidFile", pidFile, "filePID", pidStr)
return
}
if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) {
logger.Error("Failed to remove PID file", "pidFile", pidFile, "error", err)
} else if err == nil {
logger.Info("Removed PID file", "pidFile", pidFile)
}
}

// isProcessRunning checks if a process with the given PID is running
func isProcessRunning(pid int) bool {
process, err := os.FindProcess(pid)
if err != nil {
return false
}
err = process.Signal(syscall.Signal(0))
return err == nil || errors.Is(err, syscall.EPERM)
}

type flagSpec struct {
name string
shorthand string
Expand Down
22 changes: 18 additions & 4 deletions cmd/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,9 @@ func TestPIDFileOperations(t *testing.T) {
tmpDir := t.TempDir()
pidFile := tmpDir + "/test.pid"

// Write initial PID file
err := os.WriteFile(pidFile, []byte("12345\n"), 0o644)
// Write a non-numeric PID so strconv.Atoi fails and the liveness
// check is skipped, avoiding flakes when a real PID matches.
err := os.WriteFile(pidFile, []byte("not-a-pid\n"), 0o644)
require.NoError(t, err)

// Overwrite with current PID
Expand All @@ -657,12 +658,25 @@ func TestPIDFileOperations(t *testing.T) {
assert.Equal(t, expectedPID, string(data))
})

t.Run("writePIDFile detects running process", func(t *testing.T) {
tmpDir := t.TempDir()
pidFile := tmpDir + "/test.pid"

// Write the current process PID so isProcessRunning returns true.
err := os.WriteFile(pidFile, []byte(fmt.Sprintf("%d\n", os.Getpid())), 0o644)
require.NoError(t, err)

err = writePIDFile(pidFile, discardLogger)
require.Error(t, err)
assert.Contains(t, err.Error(), "another instance is already running")
})

t.Run("cleanupPIDFile removes file", func(t *testing.T) {
tmpDir := t.TempDir()
pidFile := tmpDir + "/test.pid"

// Create PID file
err := os.WriteFile(pidFile, []byte("12345\n"), 0o644)
// Create PID file with current process PID so ownership check passes
err := os.WriteFile(pidFile, []byte(fmt.Sprintf("%d\n", os.Getpid())), 0o644)
require.NoError(t, err)

// Cleanup
Expand Down
2 changes: 1 addition & 1 deletion cmd/server/signals_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
func handleSignals(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, srv *httpapi.Server) {
// Handle shutdown signals (SIGTERM, SIGINT, SIGHUP)
shutdownCh := make(chan os.Signal, 1)
signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGINT)
signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP)
go func() {
defer signal.Stop(shutdownCh)
sig := <-shutdownCh
Expand Down
8 changes: 7 additions & 1 deletion lib/httpapi/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ func convertStatus(status st.ConversationStatus) AgentStatus {

const defaultSubscriptionBufSize uint = 1024

// maxStoredErrors caps the number of errors retained for late subscribers.
const maxStoredErrors = 100

type EventEmitterOption func(*EventEmitter)

func WithSubscriptionBufSize(size uint) EventEmitterOption {
Expand Down Expand Up @@ -224,8 +227,11 @@ func (e *EventEmitter) EmitError(message string, level st.ErrorLevel) {
Time: e.clock.Now(),
}

// Store the error so new subscribers can receive all errors
// Store the error so new subscribers can receive recent errors.
e.errors = append(e.errors, errorBody)
if len(e.errors) > maxStoredErrors {
e.errors = e.errors[len(e.errors)-maxStoredErrors:]
}

e.notifyChannels(EventTypeError, errorBody)
}
Expand Down
64 changes: 64 additions & 0 deletions lib/httpapi/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,70 @@ func TestEventEmitter(t *testing.T) {
}
})

t.Run("error-cap", func(t *testing.T) {
emitter := NewEventEmitter(WithSubscriptionBufSize(10))

for i := range 150 {
emitter.EmitError(fmt.Sprintf("error %d", i), st.ErrorLevelError)
}

_, _, stateEvents := emitter.Subscribe()

var errorEvents []Event
for _, ev := range stateEvents {
if ev.Type == EventTypeError {
errorEvents = append(errorEvents, ev)
}
}

assert.Len(t, errorEvents, maxStoredErrors)

// Errors should be the last 100: "error 50" through "error 149".
for i, ev := range errorEvents {
body, ok := ev.Payload.(ErrorBody)
assert.True(t, ok)
assert.Equal(t, fmt.Sprintf("error %d", i+50), body.Message)
}
})

t.Run("error-events-in-initial-state", func(t *testing.T) {
mockClock := quartz.NewMock(t)
fixedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
mockClock.Set(fixedTime)

emitter := NewEventEmitter(WithClock(mockClock), WithSubscriptionBufSize(10))

emitter.EmitError("err1", st.ErrorLevelError)
mockClock.Set(fixedTime.Add(1 * time.Second))
emitter.EmitError("err2", st.ErrorLevelWarning)
mockClock.Set(fixedTime.Add(2 * time.Second))
emitter.EmitError("err3", st.ErrorLevelError)

_, _, stateEvents := emitter.Subscribe()

var errorEvents []Event
for _, ev := range stateEvents {
if ev.Type == EventTypeError {
errorEvents = append(errorEvents, ev)
}
}

assert.Len(t, errorEvents, 3)

expected := []ErrorBody{
{Message: "err1", Level: st.ErrorLevelError, Time: fixedTime},
{Message: "err2", Level: st.ErrorLevelWarning, Time: fixedTime.Add(1 * time.Second)},
{Message: "err3", Level: st.ErrorLevelError, Time: fixedTime.Add(2 * time.Second)},
}
for i, ev := range errorEvents {
body, ok := ev.Payload.(ErrorBody)
assert.True(t, ok)
assert.Equal(t, expected[i].Message, body.Message)
assert.Equal(t, expected[i].Level, body.Level)
assert.Equal(t, expected[i].Time, body.Time)
}
})

t.Run("clock-injection", func(t *testing.T) {
mockClock := quartz.NewMock(t)
fixedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
Expand Down
23 changes: 20 additions & 3 deletions lib/screentracker/pty_conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,21 @@ func (c *PTYConversation) Start(ctx context.Context) {
c.initialPromptReady = true
}

var loadErr string
if c.initialPromptReady && c.loadStateStatus == LoadStatePending && c.cfg.StatePersistenceConfig.LoadState {
if err := c.loadStateLocked(); err != nil {
c.cfg.Logger.Error("Failed to load state", "error", err)
c.emitter.EmitError(fmt.Sprintf("Failed to restore previous session: %v", err), ErrorLevelWarning)
loadErr = fmt.Sprintf("Failed to restore previous session: %v", err)
c.loadStateStatus = LoadStateFailed
} else {
c.loadStateStatus = LoadStateSucceeded
}
}

if c.initialPromptReady && len(c.cfg.InitialPrompt) > 0 && !c.initialPromptSent {
// Safe to send under lock: the queue is guaranteed empty here because
// statusLocked blocks Send until the snapshot buffer fills, which
// cannot happen before this first enqueue completes.
c.outboundQueue <- outboundMessage{parts: c.cfg.InitialPrompt, errCh: nil}
c.initialPromptSent = true
c.dirty = true
Expand All @@ -226,6 +230,9 @@ func (c *PTYConversation) Start(ctx context.Context) {
}
c.lock.Unlock()

if loadErr != "" {
c.emitter.EmitError(loadErr, ErrorLevelWarning)
}
c.emitter.EmitStatus(status)
c.emitter.EmitMessages(messages)
c.emitter.EmitScreen(screen)
Expand Down Expand Up @@ -292,7 +299,8 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp
if c.cfg.FormatMessage != nil {
agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message)
}
if c.loadStateStatus == LoadStateSucceeded && !c.userSentMessageAfterLoadState && len(c.messages) > 0 {
if c.loadStateStatus == LoadStateSucceeded && !c.userSentMessageAfterLoadState && len(c.messages) > 0 &&
c.messages[len(c.messages)-1].Role == ConversationRoleAgent {
agentMessage = c.messages[len(c.messages)-1].Message
}
if c.cfg.FormatToolCall != nil {
Expand Down Expand Up @@ -605,6 +613,12 @@ func (c *PTYConversation) SaveState() error {
return xerrors.Errorf("failed to encode state: %w", err)
}

// Flush to disk before rename for crash safety
if err := f.Sync(); err != nil {
_ = f.Close()
return xerrors.Errorf("failed to sync state file: %w", err)
}

// Close file before rename
if err := f.Close(); err != nil {
return xerrors.Errorf("failed to close temp state file: %w", err)
Expand Down Expand Up @@ -668,7 +682,10 @@ func (c *PTYConversation) loadStateLocked() error {
c.initialPromptSent = agentState.InitialPromptSent
if len(c.cfg.InitialPrompt) > 0 {
isDifferent := buildStringFromMessageParts(c.cfg.InitialPrompt) != agentState.InitialPrompt
c.initialPromptSent = !isDifferent
if isDifferent {
c.initialPromptSent = false
}
// If same prompt, keep agentState.InitialPromptSent
} else if agentState.InitialPrompt != "" {
c.cfg.InitialPrompt = []MessagePart{MessagePartText{
Content: agentState.InitialPrompt,
Expand Down
Loading