From 4f72de430977393199843f55d8edddc4352ada22 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 21 May 2026 13:14:09 +0200 Subject: [PATCH 1/4] fix: wait for supervisor watcher shutdown --- pkg/tools/lifecycle/supervisor.go | 38 ++++++++++-- pkg/tools/lifecycle/supervisor_test.go | 80 ++++++++++++++++++++++++-- 2 files changed, 109 insertions(+), 9 deletions(-) diff --git a/pkg/tools/lifecycle/supervisor.go b/pkg/tools/lifecycle/supervisor.go index ef3c2c4ae..911e8af41 100644 --- a/pkg/tools/lifecycle/supervisor.go +++ b/pkg/tools/lifecycle/supervisor.go @@ -140,6 +140,10 @@ type Supervisor struct { // fresh channel by Start when transitioning out of a terminal state. done chan struct{} + // watchDone is closed by the current watcher goroutine. Stop waits on it + // after closing the session so no transport goroutines are left behind. + watchDone chan struct{} + // randFloat is the jitter source; tests may override. randFloat func() float64 } @@ -214,6 +218,9 @@ func (s *Supervisor) Start(ctx context.Context) error { } s.session = sess spawnWatcher := !s.watcherAlive + if spawnWatcher { + s.watchDone = make(chan struct{}) + } s.watcherAlive = true // Recovering from a terminal state (Failed → Start, or a watcher // that previously exited): refresh `done` so RestartAndWait callers @@ -244,24 +251,40 @@ func (s *Supervisor) Start(ctx context.Context) error { func (s *Supervisor) Stop(ctx context.Context) error { s.mu.Lock() if s.stopping { + watchDone := s.watchDone s.mu.Unlock() - return nil + return waitForWatcher(ctx, watchDone) } s.stopping = true sess := s.session s.session = nil + watchDone := s.watchDone s.mu.Unlock() s.tracker.Set(StateStopped) s.signalDone() - if sess == nil { + var closeErr error + if sess != nil { + closeErr = sess.Close(context.WithoutCancel(ctx)) + } + waitErr := waitForWatcher(ctx, watchDone) + if closeErr != nil && ctx.Err() == nil { + return closeErr + } + return waitErr +} + +func waitForWatcher(ctx context.Context, done <-chan struct{}) error { + if done == nil { return nil } - if err := sess.Close(context.WithoutCancel(ctx)); err != nil && ctx.Err() == nil { - return err + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() } - return nil } // RestartAndWait closes the current session (if any) so the watcher @@ -326,7 +349,12 @@ func (s *Supervisor) watch(ctx context.Context) { defer func() { s.mu.Lock() s.watcherAlive = false + watchDone := s.watchDone + s.watchDone = nil s.mu.Unlock() + if watchDone != nil { + close(watchDone) + } }() log := s.policy.logger() diff --git a/pkg/tools/lifecycle/supervisor_test.go b/pkg/tools/lifecycle/supervisor_test.go index ac8dfd6ba..2586672a3 100644 --- a/pkg/tools/lifecycle/supervisor_test.go +++ b/pkg/tools/lifecycle/supervisor_test.go @@ -17,20 +17,41 @@ import ( // fakeSession is a controllable session: its Wait blocks until either // Close is called or fail is invoked. type fakeSession struct { - mu sync.Mutex - closed bool - failCh chan error + mu sync.Mutex + closed bool + waitDone atomic.Bool // set true after Wait returns + waiting chan struct{} // closed once Wait has parked on failCh + waitOnce sync.Once + failCh chan error } func newFakeSession() *fakeSession { - return &fakeSession{failCh: make(chan error, 1)} + return &fakeSession{ + waiting: make(chan struct{}), + failCh: make(chan error, 1), + } } func (f *fakeSession) Wait() error { + f.waitOnce.Do(func() { close(f.waiting) }) err := <-f.failCh + f.waitDone.Store(true) return err } +// waitParked blocks until the watcher goroutine has entered sess.Wait(). +// Used by tests that need to exercise Stop against an actively-blocking +// watcher rather than the racy connect-then-stop path where the watcher +// could exit before parking. +func (f *fakeSession) waitParked(t *testing.T) { + t.Helper() + select { + case <-f.waiting: + case <-time.After(time.Second): + t.Fatal("watcher did not enter Wait()") + } +} + func (f *fakeSession) Close(context.Context) error { f.mu.Lock() if !f.closed { @@ -458,3 +479,54 @@ func TestBackoff_Jitter(t *testing.T) { d = lifecycle.ExportedBackoffDelay(b, 0, func() float64 { return 0 }) assert.Check(t, d == 50*time.Millisecond) } + +func TestSupervisor_StopWaitsForWatcher(t *testing.T) { + t.Parallel() + + sess := newFakeSession() + c := newScriptedConnector(scriptStep{session: sess}) + s := lifecycle.New("test", c, lifecycle.Policy{}) + + assert.NilError(t, s.Start(t.Context())) + sess.waitParked(t) + + assert.NilError(t, s.Stop(t.Context())) + assert.Check(t, is.Equal(s.State().State, lifecycle.StateStopped)) + + // Stop must not return until the watcher has observed Wait() unblock. + assert.Check(t, sess.waitDone.Load(), "Stop returned before watcher's Wait() completed") +} + +// TestSupervisor_StopConcurrent exercises the s.stopping guard: several +// goroutines call Stop concurrently while the watcher is live in +// sess.Wait(). All calls must return without error and observe a +// fully-shut-down supervisor. +func TestSupervisor_StopConcurrent(t *testing.T) { + t.Parallel() + + sess := newFakeSession() + c := newScriptedConnector(scriptStep{session: sess}) + s := lifecycle.New("test", c, lifecycle.Policy{}) + + assert.NilError(t, s.Start(t.Context())) + sess.waitParked(t) + + const n = 4 + errs := make(chan error, n) + var wg sync.WaitGroup + wg.Add(n) + for range n { + go func() { + defer wg.Done() + errs <- s.Stop(t.Context()) + }() + } + wg.Wait() + close(errs) + + for err := range errs { + assert.NilError(t, err) + } + assert.Check(t, is.Equal(s.State().State, lifecycle.StateStopped)) + assert.Check(t, sess.waitDone.Load(), "a Stop returned before watcher's Wait() completed") +} From 6dd026249c0dc0788ec53cb56972ff46127986d5 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 21 May 2026 13:21:22 +0200 Subject: [PATCH 2/4] fix: spool large mcp media to disk --- pkg/runtime/toolexec/dispatcher.go | 22 ++++-- pkg/tools/mcp/mcp.go | 114 ++++++++++++++++++++++++++--- pkg/tools/mcp/mcp_test.go | 77 ++++++++++++++++++- pkg/tools/tools.go | 7 +- 4 files changed, 200 insertions(+), 20 deletions(-) diff --git a/pkg/runtime/toolexec/dispatcher.go b/pkg/runtime/toolexec/dispatcher.go index bf901243b..4820a878c 100644 --- a/pkg/runtime/toolexec/dispatcher.go +++ b/pkg/runtime/toolexec/dispatcher.go @@ -724,13 +724,21 @@ func buildMultiContent(text string, images []tools.MediaContent) []chat.MessageP parts := make([]chat.MessagePart, 0, 1+len(images)) parts = append(parts, chat.MessagePart{Type: chat.MessagePartTypeText, Text: text}) for _, img := range images { - parts = append(parts, chat.MessagePart{ - Type: chat.MessagePartTypeImageURL, - ImageURL: &chat.MessageImageURL{ - URL: "data:" + img.MimeType + ";base64," + img.Data, - Detail: chat.ImageURLDetailAuto, - }, - }) + switch { + case img.FilePath != "": + parts = append(parts, chat.MessagePart{ + Type: chat.MessagePartTypeText, + Text: fmt.Sprintf("[image saved to %s (%s)]", img.FilePath, img.MimeType), + }) + case img.Data != "": + parts = append(parts, chat.MessagePart{ + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{ + URL: "data:" + img.MimeType + ";base64," + img.Data, + Detail: chat.ImageURLDetailAuto, + }, + }) + } } return parts } diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index 42208e4fc..84734716f 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -160,6 +160,11 @@ type Toolset struct { supervisor *lifecycle.Supervisor + // mediaDir is the toolset-scoped temp dir holding spooled media + // payloads. Created lazily on first spool, removed by Stop. + mediaMu sync.Mutex + mediaDir string + mu sync.Mutex // Cached tools and prompts, invalidated via MCP notifications and @@ -426,6 +431,7 @@ func (ts *Toolset) Start(ctx context.Context) error { // Stop tears the supervisor down. Idempotent. func (ts *Toolset) Stop(ctx context.Context) error { slog.DebugContext(ctx, "Stopping MCP toolset", "server", ts.logID) + defer ts.cleanupMediaDir() if ts.supervisor == nil { return nil } @@ -694,7 +700,7 @@ func (ts *Toolset) callTool(ctx context.Context, toolCall tools.ToolCall) (*tool return nil, fmt.Errorf("failed to call tool: %w", err) } - result := processMCPContent(resp) + result := ts.processMCPContent(resp) slog.DebugContext(ctx, "MCP tool call completed", "tool", toolCall.Function.Name, "output_length", len(result.Output)) slog.DebugContext(ctx, result.Output) return result, nil @@ -714,7 +720,13 @@ func isInitNotificationSendError(err error) bool { return false } -func processMCPContent(toolResult *mcp.CallToolResult) *tools.ToolCallResult { +const maxInlineMediaBytes = 256 * 1024 + +// writeMediaFile is a package-level indirection so tests can simulate +// disk failures without manipulating the filesystem. +var writeMediaFile = defaultWriteMediaFile + +func (ts *Toolset) processMCPContent(toolResult *mcp.CallToolResult) *tools.ToolCallResult { var text strings.Builder var images, audios []tools.MediaContent @@ -723,9 +735,9 @@ func processMCPContent(toolResult *mcp.CallToolResult) *tools.ToolCallResult { case *mcp.TextContent: text.WriteString(c.Text) case *mcp.ImageContent: - images = append(images, encodeMedia(c.Data, c.MIMEType)) + images = append(images, ts.encodeMedia(c.Data, c.MIMEType)) case *mcp.AudioContent: - audios = append(audios, encodeMedia(c.Data, c.MIMEType)) + audios = append(audios, ts.encodeMedia(c.Data, c.MIMEType)) case *mcp.ResourceLink: if c.Name != "" { // Escape ] in name and ) in URI to prevent broken markdown links. @@ -760,12 +772,94 @@ func processMCPContent(toolResult *mcp.CallToolResult) *tools.ToolCallResult { } } -// encodeMedia re-encodes raw bytes (as decoded by the MCP SDK) back to base64 -// for our internal MediaContent representation. -func encodeMedia(data []byte, mimeType string) tools.MediaContent { - return tools.MediaContent{ - Data: base64.StdEncoding.EncodeToString(data), - MimeType: mimeType, +// encodeMedia keeps small payloads inline and spools larger ones to disk so the +// session and TUI do not retain duplicate base64 copies. Spooled files live +// under a toolset-scoped temp directory removed by Stop. +func (ts *Toolset) encodeMedia(data []byte, mimeType string) tools.MediaContent { + media := tools.MediaContent{MimeType: mimeType} + if len(data) <= maxInlineMediaBytes { + media.Data = base64.StdEncoding.EncodeToString(data) + return media + } + + dir, err := ts.ensureMediaDir() + if err == nil { + var path string + path, err = writeMediaFile(dir, data, mimeType) + if err == nil { + media.FilePath = path + return media + } + } + slog.Warn("failed to spool MCP media to disk", "mime_type", mimeType, "bytes", len(data), "error", err) + media.Data = base64.StdEncoding.EncodeToString(data) + return media +} + +// ensureMediaDir lazily creates the toolset-scoped temp dir for spooled +// media payloads. The directory is removed by Stop. +func (ts *Toolset) ensureMediaDir() (string, error) { + ts.mediaMu.Lock() + defer ts.mediaMu.Unlock() + if ts.mediaDir != "" { + return ts.mediaDir, nil + } + dir, err := os.MkdirTemp("", "docker-agent-mcp-media-*") + if err != nil { + return "", err + } + ts.mediaDir = dir + return dir, nil +} + +// cleanupMediaDir removes the toolset-scoped media spool directory, if any. +func (ts *Toolset) cleanupMediaDir() { + ts.mediaMu.Lock() + dir := ts.mediaDir + ts.mediaDir = "" + ts.mediaMu.Unlock() + if dir == "" { + return + } + if err := os.RemoveAll(dir); err != nil { + slog.Warn("failed to remove MCP media spool directory", "dir", dir, "error", err) + } +} + +func defaultWriteMediaFile(dir string, data []byte, mimeType string) (string, error) { + f, err := os.CreateTemp(dir, "media-*"+mediaExtension(mimeType)) + if err != nil { + return "", err + } + path := f.Name() + if _, err := f.Write(data); err != nil { + _ = f.Close() + _ = os.Remove(path) + return "", err + } + if err := f.Close(); err != nil { + _ = os.Remove(path) + return "", err + } + return path, nil +} + +func mediaExtension(mimeType string) string { + switch mimeType { + case "image/png": + return ".png" + case "image/jpeg": + return ".jpg" + case "image/gif": + return ".gif" + case "image/webp": + return ".webp" + case "audio/wav", "audio/wave", "audio/x-wav": + return ".wav" + case "audio/mpeg", "audio/mp3": + return ".mp3" + default: + return ".bin" } } diff --git a/pkg/tools/mcp/mcp_test.go b/pkg/tools/mcp/mcp_test.go index 997ac7c8f..24bdb6a34 100644 --- a/pkg/tools/mcp/mcp_test.go +++ b/pkg/tools/mcp/mcp_test.go @@ -1,9 +1,12 @@ package mcp import ( + "bytes" "context" + "errors" "fmt" "iter" + "os" "sync" "sync/atomic" "testing" @@ -478,7 +481,7 @@ func TestProcessMCPContent(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - result := processMCPContent(tt.input) + result := (&Toolset{}).processMCPContent(tt.input) assert.Equal(t, tt.wantOutput, result.Output) assert.Equal(t, tt.wantIsError, result.IsError) @@ -536,3 +539,75 @@ func TestCallToolRecoversFromErrSessionMissing(t *testing.T) { assert.Equal(t, "recovered", result.Output) assert.Equal(t, int32(2), callCount.Load(), "expected exactly 2 CallTool invocations (1 failed + 1 retry)") } + +func TestProcessMCPContentSpoolsLargeMedia(t *testing.T) { + ts := &Toolset{} + t.Cleanup(ts.cleanupMediaDir) + + tests := []struct { + name string + size int + wantInline bool + }{ + {"at threshold stays inline", maxInlineMediaBytes, true}, + {"above threshold spools", maxInlineMediaBytes + 1, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := bytes.Repeat([]byte("x"), tt.size) + result := ts.processMCPContent(callToolResult(&mcp.ImageContent{Data: data, MIMEType: "image/png"})) + + require.Len(t, result.Images, 1) + img := result.Images[0] + assert.Equal(t, "image/png", img.MimeType) + + if tt.wantInline { + assert.NotEmpty(t, img.Data) + assert.Empty(t, img.FilePath) + return + } + + assert.Empty(t, img.Data) + require.NotEmpty(t, img.FilePath) + + got, err := os.ReadFile(img.FilePath) + require.NoError(t, err) + assert.Equal(t, data, got) + }) + } +} + +func TestEncodeMediaFallsBackToInlineOnDiskFailure(t *testing.T) { + original := writeMediaFile + t.Cleanup(func() { writeMediaFile = original }) + writeMediaFile = func(string, []byte, string) (string, error) { + return "", errors.New("disk full") + } + + ts := &Toolset{} + t.Cleanup(ts.cleanupMediaDir) + + data := bytes.Repeat([]byte("x"), maxInlineMediaBytes+1) + result := ts.processMCPContent(callToolResult(&mcp.ImageContent{Data: data, MIMEType: "image/png"})) + + require.Len(t, result.Images, 1) + img := result.Images[0] + assert.Empty(t, img.FilePath) + assert.NotEmpty(t, img.Data, "falls back to inline base64 when disk write fails") +} + +func TestToolsetStopRemovesMediaDir(t *testing.T) { + ts := &Toolset{} + data := bytes.Repeat([]byte("x"), maxInlineMediaBytes+1) + media := ts.encodeMedia(data, "image/png") + require.NotEmpty(t, media.FilePath) + + _, err := os.Stat(media.FilePath) + require.NoError(t, err) + + require.NoError(t, ts.Stop(t.Context())) + + _, err = os.Stat(media.FilePath) + assert.True(t, os.IsNotExist(err), "spooled media file should be removed by Stop") +} diff --git a/pkg/tools/tools.go b/pkg/tools/tools.go index 2d0185943..1cc6f8026 100644 --- a/pkg/tools/tools.go +++ b/pkg/tools/tools.go @@ -73,8 +73,11 @@ type FunctionCall struct { // MediaContent represents base64-encoded binary data (image, audio, etc.) // returned by a tool. type MediaContent struct { - // Data is the base64-encoded payload. - Data string `json:"data"` + // Data is the base64-encoded payload. It is kept only for small media; large + // MCP payloads are spooled to FilePath to avoid retaining duplicate base64. + Data string `json:"data,omitempty"` + // FilePath is an optional local file containing the decoded media payload. + FilePath string `json:"filePath,omitempty"` // MimeType identifies the content type (e.g. "image/png", "audio/wav"). MimeType string `json:"mimeType"` } From d7660f0de256cf43d8ad9fdeb5b198d421a4515c Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 21 May 2026 13:23:28 +0200 Subject: [PATCH 3/4] fix: slim retained tui tool results --- pkg/tools/tools.go | 10 ++++++++++ pkg/tools/tools_test.go | 21 +++++++++++++++++++++ pkg/tui/components/messages/messages.go | 4 ++-- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/pkg/tools/tools.go b/pkg/tools/tools.go index 1cc6f8026..ca0cf84e3 100644 --- a/pkg/tools/tools.go +++ b/pkg/tools/tools.go @@ -102,6 +102,16 @@ type ToolCallResult struct { StructuredContent any `json:"structuredContent,omitempty"` } +func (r *ToolCallResult) WithoutPayload() *ToolCallResult { + if r == nil { + return nil + } + return &ToolCallResult{ + IsError: r.IsError, + Meta: r.Meta, + } +} + func ResultError(output string) *ToolCallResult { return &ToolCallResult{ Output: output, diff --git a/pkg/tools/tools_test.go b/pkg/tools/tools_test.go index 1f5338dbe..0ada71e0d 100644 --- a/pkg/tools/tools_test.go +++ b/pkg/tools/tools_test.go @@ -78,3 +78,24 @@ func TestNewHandler_InvalidArguments(t *testing.T) { }) require.Error(t, err) } + +func TestToolCallResultWithoutPayload(t *testing.T) { + result := &ToolCallResult{ + Output: "large output", + IsError: true, + Meta: "metadata", + Images: []MediaContent{{Data: "image", MimeType: "image/png"}}, + Audios: []MediaContent{{Data: "audio", MimeType: "audio/wav"}}, + StructuredContent: map[string]any{"key": "value"}, + } + + slim := result.WithoutPayload() + + require.NotNil(t, slim) + assert.Empty(t, slim.Output) + assert.True(t, slim.IsError) + assert.Equal(t, "metadata", slim.Meta) + assert.Nil(t, slim.Images) + assert.Nil(t, slim.Audios) + assert.Nil(t, slim.StructuredContent) +} diff --git a/pkg/tui/components/messages/messages.go b/pkg/tui/components/messages/messages.go index c5c5748f8..178d7461f 100644 --- a/pkg/tui/components/messages/messages.go +++ b/pkg/tui/components/messages/messages.go @@ -1475,7 +1475,7 @@ func (m *model) AddToolResult(msg *runtime.ToolCallResponseEvent, status types.T if m.messages[i].Type == types.MessageTypeAssistantReasoningBlock { if block, ok := m.views[i].(*reasoningblock.Model); ok { if block.HasToolCall(msg.ToolCallID) { - cmd := block.UpdateToolResult(msg.ToolCallID, msg.Response, status, msg.Result) + cmd := block.UpdateToolResult(msg.ToolCallID, msg.Response, status, msg.Result.WithoutPayload()) m.invalidateItem(i) return cmd } @@ -1489,7 +1489,7 @@ func (m *model) AddToolResult(msg *runtime.ToolCallResponseEvent, status types.T if toolMessage.Type == types.MessageTypeToolCall && toolMessage.ToolCall.ID == msg.ToolCallID { toolMessage.Content = strings.ReplaceAll(msg.Response, "\t", " ") toolMessage.ToolStatus = status - toolMessage.ToolResult = msg.Result + toolMessage.ToolResult = msg.Result.WithoutPayload() m.invalidateItem(i) view := m.createToolCallView(toolMessage) From 6bbc74b8d914ea5446a52aef0ddd668e70fda750 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 21 May 2026 14:08:12 +0200 Subject: [PATCH 4/4] fix: avoid retaining file contents in metadata --- pkg/tools/builtin/filesystem/filesystem.go | 2 -- pkg/tools/builtin/filesystem/filesystem_test.go | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/tools/builtin/filesystem/filesystem.go b/pkg/tools/builtin/filesystem/filesystem.go index 95293737d..811c31b45 100644 --- a/pkg/tools/builtin/filesystem/filesystem.go +++ b/pkg/tools/builtin/filesystem/filesystem.go @@ -267,7 +267,6 @@ type ReadFileArgs struct { type ReadFileMeta struct { Path string `json:"path"` - Content string `json:"content"` LineCount int `json:"lineCount"` Error string `json:"error,omitempty"` } @@ -1086,7 +1085,6 @@ func (t *ToolSet) handleReadMultipleFiles(ctx context.Context, args ReadMultiple Path: path, Content: text, }) - entry.Content = text entry.LineCount = strings.Count(text, "\n") + 1 meta.Files = append(meta.Files, entry) } diff --git a/pkg/tools/builtin/filesystem/filesystem_test.go b/pkg/tools/builtin/filesystem/filesystem_test.go index 9c2139641..f1c7398ee 100644 --- a/pkg/tools/builtin/filesystem/filesystem_test.go +++ b/pkg/tools/builtin/filesystem/filesystem_test.go @@ -104,6 +104,7 @@ func TestFilesystemTool_ReadFile_TildePath(t *testing.T) { require.NoError(t, err) assert.False(t, result.IsError) assert.Equal(t, content, result.Output) + assert.Equal(t, ReadFileMeta{LineCount: 1}, result.Meta) } func TestFilesystemTool_WriteFile(t *testing.T) { @@ -166,6 +167,7 @@ func TestFilesystemTool_ReadFile(t *testing.T) { }) require.NoError(t, err) assert.Equal(t, content, result.Output) + assert.Equal(t, ReadFileMeta{LineCount: 1}, result.Meta) result, err = tool.handleReadFile(t.Context(), ReadFileArgs{ Path: "nonexistent.txt",