Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions pkg/runtime/after_llm_call_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package runtime

import (
"context"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/agent"
"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/hooks"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/team"
)

// TestAfterLLMCallHook_PopulatesModelID is a regression test for the
// doc/impl mismatch where [hooks.Input.ModelID] is documented as
// populated for after_llm_call but executeAfterLLMCallHooks never
// actually set it — handlers reading model_id always saw an empty
// string. A single successful turn must dispatch after_llm_call with
// ModelID equal to the provider's canonical "<provider>/<model>" id.
func TestAfterLLMCallHook_PopulatesModelID(t *testing.T) {
t.Parallel()

const (
hookName = "test-after-llm-model-id"
modelID = "test/mock-model"
)

var captured atomic.Pointer[hooks.Input]

stream := newStreamBuilder().
AddContent("ok").
AddStopWithUsage(1, 1).
Build()
prov := &mockProvider{id: modelID, stream: stream}

root := agent.New("root", "test agent",
agent.WithModel(prov),
agent.WithHooks(&latest.HooksConfig{
AfterLLMCall: []latest.HookDefinition{
{Type: "builtin", Command: hookName},
},
}),
)
tm := team.New(team.WithAgents(root))

rt, err := NewLocalRuntime(tm,
WithSessionCompaction(false),
WithModelStore(mockModelStore{}),
)
require.NoError(t, err)

require.NoError(t, rt.hooksRegistry.RegisterBuiltin(
hookName,
func(_ context.Context, in *hooks.Input, _ []string) (*hooks.Output, error) {
snap := *in
captured.Store(&snap)
return nil, nil
},
))

sess := session.New(session.WithUserMessage("hi"))
sess.Title = "Unit Test"

for range rt.RunStream(t.Context(), sess) {
}

got := captured.Load()
require.NotNil(t, got, "after_llm_call hook must fire on a successful turn")
assert.Equal(t, modelID, got.ModelID,
"after_llm_call payload must include the canonical model id; "+
"see pkg/hooks/types.go:177-186 for the documented contract")
}
2 changes: 1 addition & 1 deletion pkg/runtime/harness.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func (r *LocalRuntime) runHarnessAgent(ctx context.Context, sess *session.Sessio
content = strings.TrimSpace(finalResult)
}

r.executeAfterLLMCallHooks(ctx, sess, a, content)
r.executeAfterLLMCallHooks(ctx, sess, a, modelID, content)
r.recordHarnessAssistantMessage(sess, a, content, modelID, usage, cost, events)
r.executeStopHooks(ctx, sess, a, content, events)

Expand Down
3 changes: 2 additions & 1 deletion pkg/runtime/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,11 @@ func (r *LocalRuntime) executeBeforeLLMCallHooks(
// stop_response (matching the stop event), so handlers can reuse the
// same parsing logic. Failed model calls fire on_error instead and
// skip this event.
func (r *LocalRuntime) executeAfterLLMCallHooks(ctx context.Context, sess *session.Session, a *agent.Agent, responseContent string) {
func (r *LocalRuntime) executeAfterLLMCallHooks(ctx context.Context, sess *session.Session, a *agent.Agent, modelID, responseContent string) {
r.dispatchHook(ctx, a, hooks.EventAfterLLMCall, &hooks.Input{
SessionID: sess.ID,
AgentName: a.Name(),
ModelID: modelID,
StopResponse: responseContent,
LastUserMessage: sess.GetLastUserMessageContent(),
}, nil)
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ func (r *LocalRuntime) runTurn(
// fire on_error above. The assistant text content is passed
// via stop_response, matching the stop event's payload, so
// handlers can reuse the same parsing.
r.executeAfterLLMCallHooks(ctx, sess, a, res.Content)
r.executeAfterLLMCallHooks(ctx, sess, a, modelID.String(), res.Content)

if usedModel != nil && usedModel.ID() != model.ID() {
slog.InfoContext(ctx, "Used fallback model", "agent", a.Name(), "primary", model.ID().String(), "used", usedModel.ID().String())
Expand Down
Loading