-
Notifications
You must be signed in to change notification settings - Fork 351
Expand file tree
/
Copy pathstreaming.go
More file actions
245 lines (215 loc) · 7.75 KB
/
streaming.go
File metadata and controls
245 lines (215 loc) · 7.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
package runtime
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"strings"
"github.com/docker/docker-agent/pkg/agent"
"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/modelsdev"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/telemetry"
"github.com/docker/docker-agent/pkg/tools"
)
// streamResult holds the aggregated result of processing a single chat
// completion stream: the assistant's textual reply, any tool calls requested,
// and metadata such as token usage.
type streamResult struct {
Calls []tools.ToolCall
Content string
ReasoningContent string
ThinkingSignature string
ThoughtSignature []byte
Stopped bool
ActualModel string
Usage *chat.Usage
RateLimit *chat.RateLimit
}
// handleStream reads a chat.MessageStream to completion, emitting streaming
// events (content deltas, partial tool calls, reasoning tokens) and returning
// the aggregated streamResult. The caller is responsible for adding the
// resulting assistant message to the session.
func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStream, a *agent.Agent, agentTools []tools.Tool, sess *session.Session, m *modelsdev.Model, events chan Event) (streamResult, error) {
defer stream.Close()
var fullContent strings.Builder
var fullReasoningContent strings.Builder
var thinkingSignature string
var thoughtSignature []byte
var toolCalls []tools.ToolCall
var actualModel string
var messageUsage *chat.Usage
var messageRateLimit *chat.RateLimit
toolCallIndex := make(map[string]int) // toolCallID -> index in toolCalls slice
emittedPartial := make(map[string]bool) // toolCallID -> whether we've emitted a partial event
toolDefMap := make(map[string]tools.Tool, len(agentTools))
for _, t := range agentTools {
toolDefMap[t.Name] = t
}
// recordUsage persists the final token counts and emits telemetry exactly
// once per stream, after we have the most accurate usage snapshot.
usageRecorded := false
recordUsage := func() {
if usageRecorded || messageUsage == nil {
return
}
usageRecorded = true
sess.InputTokens = messageUsage.InputTokens + messageUsage.CachedInputTokens + messageUsage.CacheWriteTokens
sess.OutputTokens = messageUsage.OutputTokens
modelName := "unknown"
if m != nil {
modelName = m.Name
}
telemetry.RecordTokenUsage(ctx, modelName, sess.InputTokens, sess.OutputTokens, sess.TotalCost())
}
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return streamResult{Stopped: true}, fmt.Errorf("error receiving from stream: %w", err)
}
if response.Usage != nil {
// Always keep the latest usage snapshot; some providers (e.g.
// Gemini) emit updated usage on every chunk with cumulative
// token counts, so the last value is the most accurate.
messageUsage = response.Usage
}
if response.RateLimit != nil {
messageRateLimit = response.RateLimit
}
if len(response.Choices) == 0 {
continue
}
choice := response.Choices[0]
if len(choice.Delta.ThoughtSignature) > 0 {
thoughtSignature = choice.Delta.ThoughtSignature
}
// Capture the actual model from the stream response (useful for model routing)
if actualModel == "" && response.Model != "" {
actualModel = response.Model
}
if choice.FinishReason == chat.FinishReasonStop || choice.FinishReason == chat.FinishReasonLength {
recordUsage()
return streamResult{
Calls: toolCalls,
Content: fullContent.String(),
ReasoningContent: fullReasoningContent.String(),
ThinkingSignature: thinkingSignature,
ThoughtSignature: thoughtSignature,
Stopped: true,
ActualModel: actualModel,
Usage: messageUsage,
RateLimit: messageRateLimit,
}, nil
}
// Handle tool calls
if len(choice.Delta.ToolCalls) > 0 {
// Process each tool call delta
for _, delta := range choice.Delta.ToolCalls {
idx, exists := toolCallIndex[delta.ID]
if !exists {
idx = len(toolCalls)
toolCallIndex[delta.ID] = idx
toolCalls = append(toolCalls, tools.ToolCall{
ID: delta.ID,
Type: delta.Type,
})
}
tc := &toolCalls[idx]
// Track if we're learning the name for the first time
learningName := delta.Function.Name != "" && tc.Function.Name == ""
// Update fields from delta
if delta.Type != "" {
tc.Type = delta.Type
}
if delta.Function.Name != "" {
tc.Function.Name = delta.Function.Name
}
if delta.Function.Arguments != "" {
tc.Function.Arguments += delta.Function.Arguments
}
// Emit PartialToolCall once we have a name, and on subsequent argument deltas.
// Only the current token (delta.Function.Arguments) is sent, not the
// full accumulated arguments, to avoid re-transmitting the entire
// payload on every token.
if tc.Function.Name != "" && (learningName || delta.Function.Arguments != "") {
if !emittedPartial[delta.ID] || delta.Function.Arguments != "" {
partial := tools.ToolCall{
ID: tc.ID,
Type: tc.Type,
Function: tools.FunctionCall{
Name: tc.Function.Name,
Arguments: delta.Function.Arguments,
},
}
events <- PartialToolCall(partial, toolDefMap[tc.Function.Name], a.Name())
emittedPartial[delta.ID] = true
}
}
}
continue
}
if choice.Delta.ReasoningContent != "" {
events <- AgentChoiceReasoning(a.Name(), sess.ID, choice.Delta.ReasoningContent)
fullReasoningContent.WriteString(choice.Delta.ReasoningContent)
}
// Capture thinking signature for Anthropic extended thinking
if choice.Delta.ThinkingSignature != "" {
thinkingSignature = choice.Delta.ThinkingSignature
}
if choice.Delta.Content != "" {
events <- AgentChoice(a.Name(), sess.ID, choice.Delta.Content)
fullContent.WriteString(choice.Delta.Content)
}
}
recordUsage()
// If the stream completed without producing any content or tool calls, likely because of a token limit, stop to avoid breaking the request loop
// NOTE(krissetto): this can likely be removed once compaction works properly with all providers (aka dmr)
stoppedDueToNoOutput := fullContent.Len() == 0 && len(toolCalls) == 0
return streamResult{
Calls: toolCalls,
Content: fullContent.String(),
ReasoningContent: fullReasoningContent.String(),
ThinkingSignature: thinkingSignature,
ThoughtSignature: thoughtSignature,
Stopped: stoppedDueToNoOutput,
ActualModel: actualModel,
Usage: messageUsage,
RateLimit: messageRateLimit,
}, nil
}
// stripImageContent returns a copy of messages with all image-related content
// removed. This is used when the target model doesn't support image input to
// prevent API errors. Text content is preserved; image parts in MultiContent
// are filtered out, and file attachments with image MIME types are dropped.
func stripImageContent(messages []chat.Message) []chat.Message {
result := make([]chat.Message, len(messages))
for i, msg := range messages {
result[i] = msg
if len(msg.MultiContent) == 0 {
continue
}
var filtered []chat.MessagePart
for _, part := range msg.MultiContent {
switch part.Type {
case chat.MessagePartTypeImageURL:
// Drop image URL parts entirely.
continue
case chat.MessagePartTypeFile:
// Drop file parts that are images.
if part.File != nil && chat.IsImageMimeType(part.File.MimeType) {
continue
}
}
filtered = append(filtered, part)
}
if len(filtered) != len(msg.MultiContent) {
result[i].MultiContent = filtered
slog.Debug("Stripped image content from message", "role", msg.Role, "original_parts", len(msg.MultiContent), "remaining_parts", len(filtered))
}
}
return result
}