Skip to content

Commit 923ba05

Browse files
authored
Merge pull request #1864 from aheritier/fix/trim-messages-preserve-user-messages
fix(#1863): preserve user messages in trimMessages to prevent session derailment
2 parents da2fb08 + fe8ec0b commit 923ba05

3 files changed

Lines changed: 172 additions & 45 deletions

File tree

pkg/session/session.go

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,9 @@ func (s *Session) GetMessages(a *agent.Agent) []chat.Message {
691691

692692
// trimMessages ensures we don't exceed the maximum number of messages while maintaining
693693
// consistency between assistant messages and their tool call results.
694-
// System messages are always preserved and not counted against the limit.
694+
// System messages and user messages are always preserved and not counted against the limit.
695+
// User messages are protected from trimming to prevent the model from losing
696+
// track of what was asked in long agentic loops.
695697
func trimMessages(messages []chat.Message, maxItems int) []chat.Message {
696698
// Separate system messages from conversation messages
697699
var systemMessages []chat.Message
@@ -710,15 +712,27 @@ func trimMessages(messages []chat.Message, maxItems int) []chat.Message {
710712
return messages
711713
}
712714

715+
// Identify user message indices — these are protected from trimming
716+
protected := make(map[int]bool)
717+
for i, msg := range conversationMessages {
718+
if msg.Role == chat.MessageRoleUser {
719+
protected[i] = true
720+
}
721+
}
722+
713723
// Keep track of tool call IDs that need to be removed
714724
toolCallsToRemove := make(map[string]bool)
715725

716726
// Calculate how many conversation messages we need to remove
717727
toRemove := len(conversationMessages) - maxItems
718728

719-
// Start from the beginning (oldest messages)
720-
for i := range toRemove {
721-
// If this is an assistant message with tool calls, mark them for removal
729+
// Mark the oldest non-protected messages for removal
730+
removed := make(map[int]bool)
731+
for i := 0; i < len(conversationMessages) && len(removed) < toRemove; i++ {
732+
if protected[i] {
733+
continue
734+
}
735+
removed[i] = true
722736
if conversationMessages[i].Role == chat.MessageRoleAssistant {
723737
for _, toolCall := range conversationMessages[i].ToolCalls {
724738
toolCallsToRemove[toolCall.ID] = true
@@ -732,11 +746,13 @@ func trimMessages(messages []chat.Message, maxItems int) []chat.Message {
732746
// Add all system messages first
733747
result = append(result, systemMessages...)
734748

735-
// Add the most recent conversation messages
736-
for i := toRemove; i < len(conversationMessages); i++ {
737-
msg := conversationMessages[i]
749+
// Add protected and non-removed conversation messages
750+
for i, msg := range conversationMessages {
751+
if removed[i] {
752+
continue
753+
}
738754

739-
// Skip tool messages that correspond to removed assistant messages
755+
// Skip orphaned tool results whose assistant message was removed
740756
if msg.Role == chat.MessageRoleTool && toolCallsToRemove[msg.ToolCallID] {
741757
continue
742758
}

pkg/session/session_history_test.go

Lines changed: 147 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@ func TestSessionNumHistoryItems(t *testing.T) {
1919
expectedConversationMsgs int
2020
}{
2121
{
22-
name: "limit to 3 conversation messages",
23-
numHistoryItems: 3,
24-
messageCount: 10,
25-
expectedConversationMsgs: 3, // Limited to 3 despite 20 total messages
22+
name: "limit to 3 conversation messages — user messages protected",
23+
numHistoryItems: 3,
24+
messageCount: 10,
25+
// 10 user (all protected) + 10 assistant. Need to remove 17, but only 10 removable.
26+
// Result: 10 users + 0 assistants = 10
27+
expectedConversationMsgs: 10,
2628
},
2729
{
28-
name: "limit to 5 conversation messages",
29-
numHistoryItems: 5,
30-
messageCount: 8,
31-
expectedConversationMsgs: 5, // Limited to 5 out of 16 total messages
30+
name: "limit to 5 conversation messages — user messages protected",
31+
numHistoryItems: 5,
32+
messageCount: 8,
33+
// 8 user (all protected) + 8 assistant. Need to remove 11, but only 8 removable.
34+
// Result: 8 users + 0 assistants = 8
35+
expectedConversationMsgs: 8,
3236
},
3337
{
3438
name: "fewer messages than limit",
@@ -71,9 +75,8 @@ func TestSessionNumHistoryItems(t *testing.T) {
7175
// System messages should always be present (at least the instruction)
7276
assert.Positive(t, systemCount, "Should have system messages")
7377

74-
// Conversation messages should be limited
75-
assert.LessOrEqual(t, conversationCount, tt.expectedConversationMsgs,
76-
"Conversation messages should not exceed the configured limit")
78+
assert.Equal(t, tt.expectedConversationMsgs, conversationCount,
79+
"Conversation messages should match expected count")
7780
})
7881
}
7982
}
@@ -95,22 +98,20 @@ func TestTrimMessagesPreservesSystemMessages(t *testing.T) {
9598

9699
// Count message types
97100
systemCount := 0
98-
conversationCount := 0
101+
userCount := 0
99102
for _, msg := range trimmed {
100103
if msg.Role == chat.MessageRoleSystem {
101104
systemCount++
102-
} else {
103-
conversationCount++
105+
}
106+
if msg.Role == chat.MessageRoleUser {
107+
userCount++
104108
}
105109
}
106110

107111
// All system messages should be preserved
108112
assert.Equal(t, 3, systemCount, "All system messages should be preserved")
109-
assert.Equal(t, 1, conversationCount, "Should have exactly 1 conversation message")
110-
111-
// The preserved conversation message should be the most recent
112-
assert.Equal(t, "Assistant response 3", trimmed[len(trimmed)-1].Content,
113-
"Should preserve the most recent conversation message")
113+
// All user messages should be preserved even with maxItems=1
114+
assert.Equal(t, 3, userCount, "All user messages should be preserved")
114115
}
115116

116117
func TestTrimMessagesConversationLimit(t *testing.T) {
@@ -126,34 +127,45 @@ func TestTrimMessagesConversationLimit(t *testing.T) {
126127
{Role: chat.MessageRoleAssistant, Content: "Response 4"},
127128
}
128129

130+
// 8 conversation messages: 4 user + 4 assistant
131+
// User messages are always protected, so only assistant messages can be trimmed.
129132
testCases := []struct {
130133
limit int
131-
expectedTotal int
132-
expectedConversation int
133134
expectedSystem int
135+
expectedUser int
136+
expectedConversation int // total non-system
134137
}{
135-
{limit: 2, expectedTotal: 3, expectedConversation: 2, expectedSystem: 1},
136-
{limit: 4, expectedTotal: 5, expectedConversation: 4, expectedSystem: 1},
137-
{limit: 8, expectedTotal: 9, expectedConversation: 8, expectedSystem: 1},
138-
{limit: 100, expectedTotal: 9, expectedConversation: 8, expectedSystem: 1},
138+
// limit=2: need to remove 6 of 8, but 4 are protected users → only 4 assistants removable → remove 4
139+
{limit: 2, expectedSystem: 1, expectedUser: 4, expectedConversation: 4},
140+
// limit=4: need to remove 4 of 8, 4 are protected → remove all 4 assistants
141+
{limit: 4, expectedSystem: 1, expectedUser: 4, expectedConversation: 4},
142+
// limit=8: no trimming needed (8 <= 8)
143+
{limit: 8, expectedSystem: 1, expectedUser: 4, expectedConversation: 8},
144+
// limit=100: no trimming needed
145+
{limit: 100, expectedSystem: 1, expectedUser: 4, expectedConversation: 8},
139146
}
140147

141148
for _, tc := range testCases {
142149
t.Run(fmt.Sprintf("limit_%d", tc.limit), func(t *testing.T) {
143150
trimmed := trimMessages(messages, tc.limit)
144151

145152
systemCount := 0
153+
userCount := 0
146154
conversationCount := 0
147155
for _, msg := range trimmed {
148-
if msg.Role == chat.MessageRoleSystem {
156+
switch msg.Role {
157+
case chat.MessageRoleSystem:
149158
systemCount++
150-
} else {
159+
case chat.MessageRoleUser:
160+
userCount++
161+
conversationCount++
162+
default:
151163
conversationCount++
152164
}
153165
}
154166

155-
assert.Len(t, trimmed, tc.expectedTotal, "Total message count")
156167
assert.Equal(t, tc.expectedSystem, systemCount, "System message count")
168+
assert.Equal(t, tc.expectedUser, userCount, "User messages should always be preserved")
157169
assert.Equal(t, tc.expectedConversation, conversationCount, "Conversation message count")
158170
})
159171
}
@@ -190,7 +202,7 @@ func TestTrimMessagesWithToolCallsPreservation(t *testing.T) {
190202
},
191203
}
192204

193-
// Limit to 3 conversation messages (should keep the recent tool interaction)
205+
// Limit to 3 conversation messages
194206
trimmed := trimMessages(messages, 3)
195207

196208
toolCallIDs := make(map[string]bool)
@@ -209,12 +221,113 @@ func TestTrimMessagesWithToolCallsPreservation(t *testing.T) {
209221
}
210222
}
211223

212-
// Should not have the old tool call
213-
hasOldTool := false
224+
// Both user messages should be preserved
225+
userMessages := 0
214226
for _, msg := range trimmed {
215-
if msg.Role == chat.MessageRoleTool && msg.ToolCallID == "old_tool_1" {
216-
hasOldTool = true
227+
if msg.Role == chat.MessageRoleUser {
228+
userMessages++
217229
}
218230
}
219-
assert.False(t, hasOldTool, "Should not have old tool results without their calls")
231+
assert.Equal(t, 2, userMessages, "Both user messages should be preserved")
232+
}
233+
234+
func TestTrimMessagesPreservesUserMessagesInAgenticLoop(t *testing.T) {
235+
// Simulate a single-turn agentic loop: one user message followed by many tool calls
236+
messages := []chat.Message{
237+
{Role: chat.MessageRoleSystem, Content: "System prompt"},
238+
{Role: chat.MessageRoleUser, Content: "Analyze MR #123 and build an integration plan"},
239+
}
240+
241+
for i := range 30 {
242+
toolID := fmt.Sprintf("tool_%d", i)
243+
messages = append(messages, chat.Message{
244+
Role: chat.MessageRoleAssistant,
245+
Content: fmt.Sprintf("Calling tool %d", i),
246+
ToolCalls: []tools.ToolCall{
247+
{ID: toolID, Function: tools.FunctionCall{Name: "shell"}},
248+
},
249+
}, chat.Message{
250+
Role: chat.MessageRoleTool,
251+
Content: fmt.Sprintf("Tool result %d", i),
252+
ToolCallID: toolID,
253+
})
254+
}
255+
256+
// 61 conversation messages (1 user + 30 assistant + 30 tool), limit to 30
257+
trimmed := trimMessages(messages, 30)
258+
259+
// The user message must survive
260+
var userMessages []string
261+
for _, msg := range trimmed {
262+
if msg.Role == chat.MessageRoleUser {
263+
userMessages = append(userMessages, msg.Content)
264+
}
265+
}
266+
267+
assert.Len(t, userMessages, 1, "User message must be preserved")
268+
assert.Equal(t, "Analyze MR #123 and build an integration plan", userMessages[0])
269+
270+
// Tool call consistency: every tool result must have a matching assistant tool call
271+
toolCallIDs := make(map[string]bool)
272+
for _, msg := range trimmed {
273+
if msg.Role == chat.MessageRoleAssistant {
274+
for _, tc := range msg.ToolCalls {
275+
toolCallIDs[tc.ID] = true
276+
}
277+
}
278+
}
279+
for _, msg := range trimmed {
280+
if msg.Role == chat.MessageRoleTool {
281+
assert.True(t, toolCallIDs[msg.ToolCallID],
282+
"Tool result %s should have a corresponding assistant tool call", msg.ToolCallID)
283+
}
284+
}
285+
}
286+
287+
func TestTrimMessagesPreservesAllUserMessages(t *testing.T) {
288+
// Multiple user messages interspersed with tool calls
289+
messages := []chat.Message{
290+
{Role: chat.MessageRoleSystem, Content: "System prompt"},
291+
{Role: chat.MessageRoleUser, Content: "First request"},
292+
}
293+
294+
for i := range 10 {
295+
toolID := fmt.Sprintf("tool_%d", i)
296+
messages = append(messages, chat.Message{
297+
Role: chat.MessageRoleAssistant,
298+
ToolCalls: []tools.ToolCall{{ID: toolID}},
299+
}, chat.Message{
300+
Role: chat.MessageRoleTool,
301+
Content: fmt.Sprintf("result %d", i),
302+
ToolCallID: toolID,
303+
})
304+
}
305+
306+
messages = append(messages, chat.Message{Role: chat.MessageRoleUser, Content: "Follow-up request"})
307+
308+
for i := 10; i < 20; i++ {
309+
toolID := fmt.Sprintf("tool_%d", i)
310+
messages = append(messages, chat.Message{
311+
Role: chat.MessageRoleAssistant,
312+
ToolCalls: []tools.ToolCall{{ID: toolID}},
313+
}, chat.Message{
314+
Role: chat.MessageRoleTool,
315+
Content: fmt.Sprintf("result %d", i),
316+
ToolCallID: toolID,
317+
})
318+
}
319+
320+
// 42 conversation messages (2 user + 20 assistant + 20 tool), limit to 10
321+
trimmed := trimMessages(messages, 10)
322+
323+
var userContents []string
324+
for _, msg := range trimmed {
325+
if msg.Role == chat.MessageRoleUser {
326+
userContents = append(userContents, msg.Content)
327+
}
328+
}
329+
330+
assert.Len(t, userContents, 2, "Both user messages must be preserved")
331+
assert.Equal(t, "First request", userContents[0])
332+
assert.Equal(t, "Follow-up request", userContents[1])
220333
}

pkg/session/session_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ func TestTrimMessagesWithToolCalls(t *testing.T) {
5858

5959
result := trimMessages(messages, maxItems)
6060

61-
// Should keep last 3 messages, but ensure tool call consistency
62-
assert.Len(t, result, maxItems)
63-
61+
// Both user messages are protected, so result includes them plus the most recent assistant/tool pair
6462
toolCalls := make(map[string]bool)
6563
for _, msg := range result {
6664
if msg.Role == chat.MessageRoleAssistant {

0 commit comments

Comments
 (0)