diff --git a/intercept/responses/base.go b/intercept/responses/base.go index e1b2914..a0d6068 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -139,13 +139,15 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o return opts } -// lastUserPrompt returns last input message with "user" role -func (i *responsesInterceptionBase) lastUserPrompt() (string, error) { +// lastUserPrompt returns input text with "user" role from last input item +// or string input value if it is present + bool indicating if input was found or not. +// If no such input was found empty string + false is returned. +func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (string, bool, error) { if i == nil { - return "", errors.New("cannot get last user prompt: nil struct") + return "", false, errors.New("cannot get last user prompt: nil struct") } if i.req == nil { - return "", errors.New("cannot get last user prompt: nil req struct") + return "", false, errors.New("cannot get last user prompt: nil request struct") } // 'input' field can be a string or array of objects: @@ -153,52 +155,71 @@ func (i *responsesInterceptionBase) lastUserPrompt() (string, error) { // Check string variant if i.req.Input.OfString.Valid() { - return i.req.Input.OfString.Value, nil + return i.req.Input.OfString.Value, true, nil } // Fallback to parsing original bytes since golang SDK doesn't properly decode 'Input' field. // If 'type' field of input item is not set it will be omitted from 'Input.OfInputItemList' // It is an optional field according to API: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message // example: fixtures/openai/responses/blocking/builtin_tool.txtar - inputItems := gjson.GetBytes(i.reqPayload, "input").Array() - for i := len(inputItems) - 1; i >= 0; i-- { - item := inputItems[i] - if item.Get("role").Str == string(constant.ValueOf[constant.User]()) { - var sb strings.Builder - - // content can be a string or array of objects: - // https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content - content := item.Get("content") - if content.Str != "" { - return content.Str, nil - } - for _, c := range content.Array() { - if c.Get("type").Str == "input_text" { - sb.WriteString(c.Get("text").Str) - } - } - if sb.Len() > 0 { - return sb.String(), nil - } + inputItems := gjson.GetBytes(i.reqPayload, "input") + + if !inputItems.IsArray() { + if inputItems.Type == gjson.Null { + return "", false, nil } + return "", false, fmt.Errorf("unexpected input type: %v", inputItems.Type.String()) + } + + inputItemsArr := inputItems.Array() + if len(inputItemsArr) == 0 { + return "", false, nil } + lastItem := inputItemsArr[len(inputItemsArr)-1] // Request was likely not human-initiated. - return "", nil -} + if lastItem.Get("role").Str != string(constant.ValueOf[constant.User]()) { + return "", false, nil + } -func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string) { - prompt, err := i.lastUserPrompt() - if err != nil { - i.logger.Warn(ctx, "failed to get last user prompt", slog.Error(err)) - return + // content can be a string or array of objects: + // https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content + content := lastItem.Get(string(constant.ValueOf[constant.Content]())) + + // non array case, should be string + if !content.IsArray() { + if content.Type == gjson.String { + return content.Str, true, nil + } + return "", false, fmt.Errorf("unexpected input content type: %v", content.Type.String()) } - // No prompt found: last request was not human-initiated. - if prompt == "" { - return + var sb strings.Builder + promptExists := false + for _, c := range content.Array() { + // ignore inputs of not `input_text` type + if c.Get(string(constant.ValueOf[constant.Type]())).Str != string(constant.ValueOf[constant.InputText]()) { + continue + } + + text := c.Get(string(constant.ValueOf[constant.Text]())) + if text.Type == gjson.String { + promptExists = true + sb.WriteString(text.Str + "\n") + } else { + i.logger.Warn(ctx, fmt.Sprintf("unexpected input content array element text type: %v", text.Type)) + } } + if !promptExists { + return "", false, nil + } + + prompt := strings.TrimSuffix(sb.String(), "\n") + return prompt, true, nil +} + +func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string, prompt string) { if responseID == "" { i.logger.Warn(ctx, "got empty response ID, skipping prompt recording") return diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index 3bd91d5..de72010 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -20,22 +20,41 @@ func TestLastUserPrompt(t *testing.T) { tests := []struct { name string reqPayload []byte - expected string + expect string }{ + { + name: "input_empty_string", + reqPayload: []byte(`{"input": ""}`), + expect: "", + }, + { + name: "input_array_content_empty_string", + reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`), + expect: "", + }, + { + name: "input_array_content_array_empty_string", + reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": ""}] } ] }`), + }, + { + name: "input_array_content_array_multiple_inputs", + reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": "a"}, {"type": "input_text", "text": "b"}] } ] }`), + expect: "a\nb", + }, { name: "simple_string_input", reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple), - expected: "tell me a joke", + expect: "tell me a joke", }, { name: "array_single_input_string", reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSingleBuiltinTool), - expected: "Is 3 + 5 a prime number? Use the add function to calculate the sum.", + expect: "Is 3 + 5 a prime number? Use the add function to calculate the sum.", }, { name: "array_multiple_items_content_objects", reqPayload: fixtures.Request(t, fixtures.OaiResponsesStreamingCodex), - expected: "hello", + expect: "hello", }, } @@ -52,51 +71,83 @@ func TestLastUserPrompt(t *testing.T) { reqPayload: tc.reqPayload, } - prompt, err := base.lastUserPrompt() + prompt, promptFound, err := base.lastUserPrompt(t.Context()) require.NoError(t, err) - require.Equal(t, tc.expected, prompt) + require.Equal(t, tc.expect, prompt) + require.True(t, promptFound) }) } } -func TestLastUserPromptEmptyPrompt(t *testing.T) { +func TestLastUserPromptNotFound(t *testing.T) { t.Parallel() t.Run("nil_struct", func(t *testing.T) { t.Parallel() var base *responsesInterceptionBase - prompt, err := base.lastUserPrompt() + prompt, promptFound, err := base.lastUserPrompt(t.Context()) require.Error(t, err) require.Empty(t, prompt) + require.False(t, promptFound) require.Contains(t, "cannot get last user prompt: nil struct", err.Error()) }) - // Other cases where the user prompt might be empty. + t.Run("nil_request", func(t *testing.T) { + t.Parallel() + + base := responsesInterceptionBase{} + prompt, promptFound, err := base.lastUserPrompt(t.Context()) + require.Error(t, err) + require.Empty(t, prompt) + require.False(t, promptFound) + require.Contains(t, "cannot get last user prompt: nil request struct", err.Error()) + }) + + // Cases where the user prompt is not found / wrong format. tests := []struct { name string reqPayload []byte + expectErr string }{ { - name: "empty_input", + name: "non_existing_input", + reqPayload: []byte(`{"model": "gpt-4o"}`), + }, + { + name: "input_empty_array", reqPayload: []byte(`{"model": "gpt-4o", "input": []}`), }, { - name: "no_user_role", - reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "assistant", "content": "hello"}]}`), + name: "input_integer", + reqPayload: []byte(`{"model": "gpt-4o", "input": 123}`), + expectErr: "unexpected input type", }, { - name: "user_with_empty_content", - reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`), + name: "no_user_role", + reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "assistant", "content": "hello"}]}`), }, { name: "user_with_empty_content_array", reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": []}]}`), }, + { + name: "input_array_integer", + reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": 123}]}`), + expectErr: "unexpected input content type", + }, { name: "user_with_non_input_text_content", reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": [{"type": "input_image", "url": "http://example.com/img.png"}]}]}`), }, + { + name: "user_content_not_last", + reqPayload: []byte(`{"model": "gpt-4o", "input": [ {"role": "user", "content":"input"}, {"role": "assistant", "content": "hello"} ]}`), + }, + { + name: "input_array_content_array_integer", + reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": 123}] } ] }`), + }, } for _, tc := range tests { @@ -112,9 +163,15 @@ func TestLastUserPromptEmptyPrompt(t *testing.T) { reqPayload: tc.reqPayload, } - prompt, err := base.lastUserPrompt() - require.NoError(t, err) + prompt, promptFound, err := base.lastUserPrompt(t.Context()) + if tc.expectErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectErr) + } else { + require.NoError(t, err) + } require.Empty(t, prompt) + require.False(t, promptFound) }) } } @@ -123,29 +180,31 @@ func TestRecordPrompt(t *testing.T) { t.Parallel() tests := []struct { - name string - reqPayload []byte - responseID string - wantRecorded bool - wantPrompt string + name string + promptWasRecorded bool + prompt string + responseID string + wantRecorded bool + wantPrompt string }{ { name: "records_prompt_successfully", - reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple), + prompt: "tell me a joke", responseID: "resp_123", wantRecorded: true, wantPrompt: "tell me a joke", }, { - name: "skips_recording_on_empty_response_id", - reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple), - responseID: "", - wantRecorded: false, + name: "records_empty_prompt_successfully", + prompt: "", + responseID: "resp_123", + wantRecorded: true, + wantPrompt: "", }, { - name: "skips_recording_on_lastUserPrompt_error", - reqPayload: []byte(`{"model": "gpt-4o", "input": []}`), - responseID: "resp_123", + name: "skips_recording_on_empty_response_id", + prompt: "tell me a joke", + responseID: "", wantRecorded: false, }, } @@ -154,21 +213,15 @@ func TestRecordPrompt(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - req := &ResponsesNewParamsWrapper{} - err := req.UnmarshalJSON(tc.reqPayload) - require.NoError(t, err) - rec := &testutil.MockRecorder{} id := uuid.New() base := &responsesInterceptionBase{ - id: id, - req: req, - reqPayload: tc.reqPayload, - recorder: rec, - logger: slog.Make(), + id: id, + recorder: rec, + logger: slog.Make(), } - base.recordUserPrompt(t.Context(), tc.responseID) + base.recordUserPrompt(t.Context(), tc.responseID, tc.prompt) prompts := rec.RecordedPromptUsages() if tc.wantRecorded { diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 9074dc6..e895ad5 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -60,14 +60,18 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * i.disableParallelToolCalls() var ( - response *responses.Response - err error - upstreamErr error - respCopy responseCopier + response *responses.Response + upstreamErr error + respCopy responseCopier + firstResponseID string ) + prompt, promptFound, err := i.lastUserPrompt(ctx) + if err != nil { + i.logger.Warn(ctx, "failed to get user prompt", slog.Error(err)) + } shouldLoop := true - recordPromptOnce := true + for shouldLoop { srv := i.newResponsesService() respCopy = responseCopier{} @@ -80,13 +84,10 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * break } - // Record prompt usage on first successful response. - if recordPromptOnce { - recordPromptOnce = false - i.recordUserPrompt(ctx, response.ID) + if firstResponseID == "" { + firstResponseID = response.ID } - // Record token usage for each inner loop iteration i.recordTokenUsage(ctx, response) // Check if there any injected tools to invoke. @@ -98,6 +99,9 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * } } + if promptFound { + i.recordUserPrompt(ctx, firstResponseID, prompt) + } i.recordNonInjectedToolUsage(ctx, response) if upstreamErr != nil && !respCopy.responseReceived.Load() { diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 5a8d755..fcf2efc 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -79,11 +79,15 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r }() var respCopy responseCopier - var responseID string + var firstResponseID string var completedResponse *responses.Response var innerLoopErr error var streamErr error + prompt, promptFound, err := i.lastUserPrompt(ctx) + if err != nil { + i.logger.Warn(ctx, "failed to get user prompt", slog.Error(err)) + } shouldLoop := true srv := i.newResponsesService() @@ -122,8 +126,8 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r // Not every event has response.id set (eg: fixtures/openai/responses/streaming/simple.txtar). // First event should be of 'response.created' type and have response.id set. // Set responseID to the first response.id that is set. - if responseID == "" && ev.Response.ID != "" { - responseID = ev.Response.ID + if firstResponseID == "" && ev.Response.ID != "" { + firstResponseID = ev.Response.ID } // Capture the response from the response.completed event. @@ -145,6 +149,7 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r } } } + streamErr = stream.Err() return nil }() @@ -165,7 +170,9 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r } } - i.recordUserPrompt(ctx, responseID) + if promptFound { + i.recordUserPrompt(ctx, firstResponseID, prompt) + } i.recordNonInjectedToolUsage(ctx, completedResponse) // On innerLoop error custom error has been already sent, diff --git a/responses_integration_test.go b/responses_integration_test.go index 7d6d4cc..a2bec07 100644 --- a/responses_integration_test.go +++ b/responses_integration_test.go @@ -692,9 +692,9 @@ func TestUpstreamError(t *testing.T) { } } -// TestResponsesBlockingInjectedTool tests that injected MCP tool calls trigger the inner agentic loop, +// TestResponsesInjectedTool tests that injected MCP tool calls trigger the inner agentic loop, // invoke the tool via MCP, and send the result back to the model. -func TestResponsesBlockingInjectedTool(t *testing.T) { +func TestResponsesInjectedTool(t *testing.T) { t.Parallel() tests := []struct {