From a70e3d91941837e1bfefe7df0bc10249c3d981a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 23 Jan 2026 13:06:26 +0000 Subject: [PATCH 1/7] fix: consider only last input item when prompt recording for responses API --- intercept/responses/base.go | 76 ++++++++++++++++++----------- intercept/responses/base_test.go | 82 +++++++++++++++++++++++++------- 2 files changed, 113 insertions(+), 45 deletions(-) diff --git a/intercept/responses/base.go b/intercept/responses/base.go index e1b2914..ec43c3a 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -139,13 +139,14 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o return opts } -// lastUserPrompt returns last input message with "user" role -func (i *responsesInterceptionBase) lastUserPrompt() (string, error) { +// lastUserPrompt returns input text with "user" role from last input item or string input value if it is present. +// If no such input was found nil is returned. +func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (*string, error) { if i == nil { - return "", errors.New("cannot get last user prompt: nil struct") + return nil, errors.New("cannot get last user prompt: nil struct") } if i.req == nil { - return "", errors.New("cannot get last user prompt: nil req struct") + return nil, errors.New("cannot get last user prompt: nil request struct") } // 'input' field can be a string or array of objects: @@ -153,7 +154,7 @@ func (i *responsesInterceptionBase) lastUserPrompt() (string, error) { // Check string variant if i.req.Input.OfString.Valid() { - return i.req.Input.OfString.Value, nil + return &i.req.Input.OfString.Value, nil } // Fallback to parsing original bytes since golang SDK doesn't properly decode 'Input' field. @@ -161,41 +162,60 @@ func (i *responsesInterceptionBase) lastUserPrompt() (string, error) { // It is an optional field according to API: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message // example: fixtures/openai/responses/blocking/builtin_tool.txtar inputItems := gjson.GetBytes(i.reqPayload, "input").Array() - for i := len(inputItems) - 1; i >= 0; i-- { - item := inputItems[i] - if item.Get("role").Str == string(constant.ValueOf[constant.User]()) { - var sb strings.Builder - - // content can be a string or array of objects: - // https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content - content := item.Get("content") - if content.Str != "" { - return content.Str, nil - } - for _, c := range content.Array() { - if c.Get("type").Str == "input_text" { - sb.WriteString(c.Get("text").Str) - } - } - if sb.Len() > 0 { - return sb.String(), nil + if len(inputItems) == 0 { + return nil, nil + } + + lastItem := inputItems[len(inputItems)-1] + + // Request was likely not human-initiated. + if lastItem.Get("role").Str != string(constant.ValueOf[constant.User]()) { + return nil, nil + } + + // content can be a string or array of objects: + // https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content + content := lastItem.Get(string(constant.ValueOf[constant.Content]())) + + // non array case, should be string + if !content.IsArray() { + if content.Type == gjson.String { + return &content.Str, nil + } + return nil, fmt.Errorf("unexpected input type: %v", content.Type.String()) + } + + var sb strings.Builder + promptExists := false + for _, c := range content.Array() { + if c.Get(string(constant.ValueOf[constant.Type]())).Str == string(constant.ValueOf[constant.InputText]()) { + text := c.Get(string(constant.ValueOf[constant.Text]())) + if text.Type == gjson.String { + promptExists = true + sb.WriteString(text.Str) + } else { + i.logger.Warn(ctx, fmt.Sprintf("unexpected input array type: %v", text.Type)) } } } - // Request was likely not human-initiated. - return "", nil + if !promptExists { + return nil, nil + } + + prompt := sb.String() + return &prompt, nil } func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string) { - prompt, err := i.lastUserPrompt() + prompt, err := i.lastUserPrompt(ctx) if err != nil { i.logger.Warn(ctx, "failed to get last user prompt", slog.Error(err)) return } // No prompt found: last request was not human-initiated. - if prompt == "" { + if prompt == nil { return } @@ -207,7 +227,7 @@ func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, respon promptUsage := &recorder.PromptUsageRecord{ InterceptionID: i.ID().String(), MsgID: responseID, - Prompt: prompt, + Prompt: *prompt, } if err := i.recorder.RecordPromptUsage(ctx, promptUsage); err != nil { i.logger.Warn(ctx, "failed to record prompt usage", slog.Error(err)) diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index 3bd91d5..a30c532 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -20,22 +20,36 @@ func TestLastUserPrompt(t *testing.T) { tests := []struct { name string reqPayload []byte - expected string + expect string }{ + { + name: "empty_string_input_str", + reqPayload: []byte(`{"input": ""}`), + expect: "", + }, + { + name: "empty_string_input_array_content_str", + reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`), + expect: "", + }, + { + name: "empty_string_input_array_content_array", + reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": ""}] } ] }`), + }, { name: "simple_string_input", reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple), - expected: "tell me a joke", + expect: "tell me a joke", }, { name: "array_single_input_string", reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSingleBuiltinTool), - expected: "Is 3 + 5 a prime number? Use the add function to calculate the sum.", + expect: "Is 3 + 5 a prime number? Use the add function to calculate the sum.", }, { name: "array_multiple_items_content_objects", reqPayload: fixtures.Request(t, fixtures.OaiResponsesStreamingCodex), - expected: "hello", + expect: "hello", }, } @@ -52,51 +66,80 @@ func TestLastUserPrompt(t *testing.T) { reqPayload: tc.reqPayload, } - prompt, err := base.lastUserPrompt() + prompt, err := base.lastUserPrompt(t.Context()) require.NoError(t, err) - require.Equal(t, tc.expected, prompt) + require.NotNil(t, prompt) + require.Equal(t, tc.expect, *prompt) }) } } -func TestLastUserPromptEmptyPrompt(t *testing.T) { +func TestLastUserPromptNil(t *testing.T) { t.Parallel() t.Run("nil_struct", func(t *testing.T) { t.Parallel() var base *responsesInterceptionBase - prompt, err := base.lastUserPrompt() + prompt, err := base.lastUserPrompt(t.Context()) require.Error(t, err) - require.Empty(t, prompt) + require.Nil(t, prompt) require.Contains(t, "cannot get last user prompt: nil struct", err.Error()) }) + t.Run("nil_request", func(t *testing.T) { + t.Parallel() + + base := responsesInterceptionBase{} + prompt, err := base.lastUserPrompt(t.Context()) + require.Error(t, err) + require.Nil(t, prompt) + require.Contains(t, "cannot get last user prompt: nil request struct", err.Error()) + }) + // Other cases where the user prompt might be empty. tests := []struct { name string reqPayload []byte + expectErr string }{ { - name: "empty_input", + name: "non_existing_input", + reqPayload: []byte(`{"model": "gpt-4o"}`), + }, + { + name: "input_empty_array", reqPayload: []byte(`{"model": "gpt-4o", "input": []}`), }, { - name: "no_user_role", - reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "assistant", "content": "hello"}]}`), + name: "input_integer", + reqPayload: []byte(`{"model": "gpt-4o", "input": 123}`), }, { - name: "user_with_empty_content", - reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`), + name: "no_user_role", + reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "assistant", "content": "hello"}]}`), }, { name: "user_with_empty_content_array", reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": []}]}`), }, + { + name: "input_array_integer", + reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": 123}]}`), + expectErr: "unexpected input type", + }, { name: "user_with_non_input_text_content", reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": [{"type": "input_image", "url": "http://example.com/img.png"}]}]}`), }, + { + name: "user_content_not_last", + reqPayload: []byte(`{"model": "gpt-4o", "input": [ {"role": "user", "content":"input"}, {"role": "assistant", "content": "hello"} ]}`), + }, + { + name: "input_array_content_array_integer", + reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": 123}] } ] }`), + }, } for _, tc := range tests { @@ -112,9 +155,14 @@ func TestLastUserPromptEmptyPrompt(t *testing.T) { reqPayload: tc.reqPayload, } - prompt, err := base.lastUserPrompt() - require.NoError(t, err) - require.Empty(t, prompt) + prompt, err := base.lastUserPrompt(t.Context()) + if tc.expectErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectErr) + } else { + require.NoError(t, err) + } + require.Nil(t, prompt) }) } } From a7aa4654c575dfe220360c7f26c94f1b7fc01330 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 23 Jan 2026 13:44:16 +0000 Subject: [PATCH 2/7] review: added test case with mutiple text input items --- intercept/responses/base_test.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index a30c532..3c88861 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -23,19 +23,24 @@ func TestLastUserPrompt(t *testing.T) { expect string }{ { - name: "empty_string_input_str", + name: "input_empty_string", reqPayload: []byte(`{"input": ""}`), expect: "", }, { - name: "empty_string_input_array_content_str", + name: "input_array_content_empty_string", reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`), expect: "", }, { - name: "empty_string_input_array_content_array", + name: "input_array_content_array_empty_string", reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": ""}] } ] }`), }, + { + name: "input_array_content_array_multiple_inputs", + reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": "a"}, {"type": "input_text", "text": "b"}] } ] }`), + expect: "ab", + }, { name: "simple_string_input", reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple), From 0fa97e82a78f38f37a2a784c796b1e606e045789 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 23 Jan 2026 13:46:59 +0000 Subject: [PATCH 3/7] review: fix happy path indent --- intercept/responses/base.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/intercept/responses/base.go b/intercept/responses/base.go index ec43c3a..1422893 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -188,14 +188,17 @@ func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (*string var sb strings.Builder promptExists := false for _, c := range content.Array() { - if c.Get(string(constant.ValueOf[constant.Type]())).Str == string(constant.ValueOf[constant.InputText]()) { - text := c.Get(string(constant.ValueOf[constant.Text]())) - if text.Type == gjson.String { - promptExists = true - sb.WriteString(text.Str) - } else { - i.logger.Warn(ctx, fmt.Sprintf("unexpected input array type: %v", text.Type)) - } + // ignore inputs of not `input_text` type + if c.Get(string(constant.ValueOf[constant.Type]())).Str != string(constant.ValueOf[constant.InputText]()) { + continue + } + + text := c.Get(string(constant.ValueOf[constant.Text]())) + if text.Type == gjson.String { + promptExists = true + sb.WriteString(text.Str) + } else { + i.logger.Warn(ctx, fmt.Sprintf("unexpected input array type: %v", text.Type)) } } From ff4bd28b726e12af49242ffa66f22420702832cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 23 Jan 2026 14:22:20 +0000 Subject: [PATCH 4/7] review 2: new line between text_input items --- intercept/responses/base.go | 4 ++-- intercept/responses/base_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 1422893..6e80173 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -196,7 +196,7 @@ func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (*string text := c.Get(string(constant.ValueOf[constant.Text]())) if text.Type == gjson.String { promptExists = true - sb.WriteString(text.Str) + sb.WriteString(text.Str + "\n") } else { i.logger.Warn(ctx, fmt.Sprintf("unexpected input array type: %v", text.Type)) } @@ -206,7 +206,7 @@ func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (*string return nil, nil } - prompt := sb.String() + prompt := strings.TrimSuffix(sb.String(), "\n") return &prompt, nil } diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index 3c88861..080f81e 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -39,7 +39,7 @@ func TestLastUserPrompt(t *testing.T) { { name: "input_array_content_array_multiple_inputs", reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": "a"}, {"type": "input_text", "text": "b"}] } ] }`), - expect: "ab", + expect: "a\nb", }, { name: "simple_string_input", From 18a5f46925e3a8f0ab963c98f5891d4be382fe6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 23 Jan 2026 15:35:42 +0000 Subject: [PATCH 5/7] fix not recording prompt in streaming + MCP call case (TestResponsesInjectedTool test) --- intercept/responses/base.go | 26 ++++++++++++++++---------- intercept/responses/base_test.go | 29 +++++++++++++++++++---------- intercept/responses/blocking.go | 22 ++++++++-------------- intercept/responses/streaming.go | 16 +++++++++------- responses_integration_test.go | 4 ++-- 5 files changed, 54 insertions(+), 43 deletions(-) diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 6e80173..d726f76 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -37,16 +37,17 @@ const ( ) type responsesInterceptionBase struct { - id uuid.UUID - req *ResponsesNewParamsWrapper - reqPayload []byte - cfg config.OpenAI - model string - recorder recorder.Recorder - mcpProxy mcp.ServerProxier - logger slog.Logger - metrics metrics.Metrics - tracer trace.Tracer + id uuid.UUID + req *ResponsesNewParamsWrapper + reqPayload []byte + promptWasRecorded bool + cfg config.OpenAI + model string + recorder recorder.Recorder + mcpProxy mcp.ServerProxier + logger slog.Logger + metrics metrics.Metrics + tracer trace.Tracer } func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { @@ -211,6 +212,11 @@ func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (*string } func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string) { + if i.promptWasRecorded { + return + } + i.promptWasRecorded = true + prompt, err := i.lastUserPrompt(ctx) if err != nil { i.logger.Warn(ctx, "failed to get last user prompt", slog.Error(err)) diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index 080f81e..c2d786a 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -176,11 +176,12 @@ func TestRecordPrompt(t *testing.T) { t.Parallel() tests := []struct { - name string - reqPayload []byte - responseID string - wantRecorded bool - wantPrompt string + name string + promptWasRecorded bool + reqPayload []byte + responseID string + wantRecorded bool + wantPrompt string }{ { name: "records_prompt_successfully", @@ -189,6 +190,13 @@ func TestRecordPrompt(t *testing.T) { wantRecorded: true, wantPrompt: "tell me a joke", }, + { + name: "skips_record_if_prompt_was_recorded_before", + promptWasRecorded: true, + reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple), + responseID: "resp_123", + wantRecorded: false, + }, { name: "skips_recording_on_empty_response_id", reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple), @@ -214,11 +222,12 @@ func TestRecordPrompt(t *testing.T) { rec := &testutil.MockRecorder{} id := uuid.New() base := &responsesInterceptionBase{ - id: id, - req: req, - reqPayload: tc.reqPayload, - recorder: rec, - logger: slog.Make(), + id: id, + req: req, + reqPayload: tc.reqPayload, + promptWasRecorded: tc.promptWasRecorded, + recorder: rec, + logger: slog.Make(), } base.recordUserPrompt(t.Context(), tc.responseID) diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 9074dc6..1497680 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -26,12 +26,13 @@ type BlockingResponsesInterceptor struct { func NewBlockingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string, tracer trace.Tracer) *BlockingResponsesInterceptor { return &BlockingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ - id: id, - req: req, - reqPayload: reqPayload, - cfg: cfg, - model: model, - tracer: tracer, + id: id, + req: req, + reqPayload: reqPayload, + cfg: cfg, + promptWasRecorded: false, + model: model, + tracer: tracer, }, } } @@ -67,7 +68,6 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * ) shouldLoop := true - recordPromptOnce := true for shouldLoop { srv := i.newResponsesService() respCopy = responseCopier{} @@ -80,13 +80,7 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * break } - // Record prompt usage on first successful response. - if recordPromptOnce { - recordPromptOnce = false - i.recordUserPrompt(ctx, response.ID) - } - - // Record token usage for each inner loop iteration + i.recordUserPrompt(ctx, response.ID) i.recordTokenUsage(ctx, response) // Check if there any injected tools to invoke. diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 5a8d755..27b47b2 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -33,12 +33,13 @@ type StreamingResponsesInterceptor struct { func NewStreamingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string, tracer trace.Tracer) *StreamingResponsesInterceptor { return &StreamingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ - id: id, - req: req, - reqPayload: reqPayload, - cfg: cfg, - model: model, - tracer: tracer, + id: id, + req: req, + reqPayload: reqPayload, + cfg: cfg, + promptWasRecorded: false, + model: model, + tracer: tracer, }, } } @@ -145,6 +146,8 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r } } } + + i.recordUserPrompt(ctx, responseID) streamErr = stream.Err() return nil }() @@ -165,7 +168,6 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r } } - i.recordUserPrompt(ctx, responseID) i.recordNonInjectedToolUsage(ctx, completedResponse) // On innerLoop error custom error has been already sent, diff --git a/responses_integration_test.go b/responses_integration_test.go index 7d6d4cc..a2bec07 100644 --- a/responses_integration_test.go +++ b/responses_integration_test.go @@ -692,9 +692,9 @@ func TestUpstreamError(t *testing.T) { } } -// TestResponsesBlockingInjectedTool tests that injected MCP tool calls trigger the inner agentic loop, +// TestResponsesInjectedTool tests that injected MCP tool calls trigger the inner agentic loop, // invoke the tool via MCP, and send the result back to the model. -func TestResponsesBlockingInjectedTool(t *testing.T) { +func TestResponsesInjectedTool(t *testing.T) { t.Parallel() tests := []struct { From 77f01457e88b42dd9c4edeb98d146696448b231a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Mon, 26 Jan 2026 14:39:31 +0000 Subject: [PATCH 6/7] review: added comment about i.promptWasRecorded --- intercept/responses/base.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/intercept/responses/base.go b/intercept/responses/base.go index d726f76..e198325 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -212,7 +212,11 @@ func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (*string } func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string) { + // User prompt should be recorded only during first inner loop iteration. + // Subsequent inner loop iterations would fail to extract user prompt + // since last input item should be function call result / not contain user prompt. if i.promptWasRecorded { + // Exiting early to avoid confusing log entries. return } i.promptWasRecorded = true From 7cd566269d47d93e52f9ff442005ea3f62d8aa9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Mon, 26 Jan 2026 16:43:20 +0000 Subject: [PATCH 7/7] review: extract prompt in interceptor + record it once after inner loop --- intercept/responses/base.go | 86 ++++++++++++++------------------ intercept/responses/base_test.go | 65 +++++++++++------------- intercept/responses/blocking.go | 34 ++++++++----- intercept/responses/streaming.go | 27 ++++++---- 4 files changed, 103 insertions(+), 109 deletions(-) diff --git a/intercept/responses/base.go b/intercept/responses/base.go index e198325..a0d6068 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -37,17 +37,16 @@ const ( ) type responsesInterceptionBase struct { - id uuid.UUID - req *ResponsesNewParamsWrapper - reqPayload []byte - promptWasRecorded bool - cfg config.OpenAI - model string - recorder recorder.Recorder - mcpProxy mcp.ServerProxier - logger slog.Logger - metrics metrics.Metrics - tracer trace.Tracer + id uuid.UUID + req *ResponsesNewParamsWrapper + reqPayload []byte + cfg config.OpenAI + model string + recorder recorder.Recorder + mcpProxy mcp.ServerProxier + logger slog.Logger + metrics metrics.Metrics + tracer trace.Tracer } func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { @@ -140,14 +139,15 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o return opts } -// lastUserPrompt returns input text with "user" role from last input item or string input value if it is present. -// If no such input was found nil is returned. -func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (*string, error) { +// lastUserPrompt returns input text with "user" role from last input item +// or string input value if it is present + bool indicating if input was found or not. +// If no such input was found empty string + false is returned. +func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (string, bool, error) { if i == nil { - return nil, errors.New("cannot get last user prompt: nil struct") + return "", false, errors.New("cannot get last user prompt: nil struct") } if i.req == nil { - return nil, errors.New("cannot get last user prompt: nil request struct") + return "", false, errors.New("cannot get last user prompt: nil request struct") } // 'input' field can be a string or array of objects: @@ -155,23 +155,31 @@ func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (*string // Check string variant if i.req.Input.OfString.Valid() { - return &i.req.Input.OfString.Value, nil + return i.req.Input.OfString.Value, true, nil } // Fallback to parsing original bytes since golang SDK doesn't properly decode 'Input' field. // If 'type' field of input item is not set it will be omitted from 'Input.OfInputItemList' // It is an optional field according to API: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message // example: fixtures/openai/responses/blocking/builtin_tool.txtar - inputItems := gjson.GetBytes(i.reqPayload, "input").Array() - if len(inputItems) == 0 { - return nil, nil + inputItems := gjson.GetBytes(i.reqPayload, "input") + + if !inputItems.IsArray() { + if inputItems.Type == gjson.Null { + return "", false, nil + } + return "", false, fmt.Errorf("unexpected input type: %v", inputItems.Type.String()) } - lastItem := inputItems[len(inputItems)-1] + inputItemsArr := inputItems.Array() + if len(inputItemsArr) == 0 { + return "", false, nil + } + lastItem := inputItemsArr[len(inputItemsArr)-1] // Request was likely not human-initiated. if lastItem.Get("role").Str != string(constant.ValueOf[constant.User]()) { - return nil, nil + return "", false, nil } // content can be a string or array of objects: @@ -181,9 +189,9 @@ func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (*string // non array case, should be string if !content.IsArray() { if content.Type == gjson.String { - return &content.Str, nil + return content.Str, true, nil } - return nil, fmt.Errorf("unexpected input type: %v", content.Type.String()) + return "", false, fmt.Errorf("unexpected input content type: %v", content.Type.String()) } var sb strings.Builder @@ -199,39 +207,19 @@ func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (*string promptExists = true sb.WriteString(text.Str + "\n") } else { - i.logger.Warn(ctx, fmt.Sprintf("unexpected input array type: %v", text.Type)) + i.logger.Warn(ctx, fmt.Sprintf("unexpected input content array element text type: %v", text.Type)) } } if !promptExists { - return nil, nil + return "", false, nil } prompt := strings.TrimSuffix(sb.String(), "\n") - return &prompt, nil + return prompt, true, nil } -func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string) { - // User prompt should be recorded only during first inner loop iteration. - // Subsequent inner loop iterations would fail to extract user prompt - // since last input item should be function call result / not contain user prompt. - if i.promptWasRecorded { - // Exiting early to avoid confusing log entries. - return - } - i.promptWasRecorded = true - - prompt, err := i.lastUserPrompt(ctx) - if err != nil { - i.logger.Warn(ctx, "failed to get last user prompt", slog.Error(err)) - return - } - - // No prompt found: last request was not human-initiated. - if prompt == nil { - return - } - +func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string, prompt string) { if responseID == "" { i.logger.Warn(ctx, "got empty response ID, skipping prompt recording") return @@ -240,7 +228,7 @@ func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, respon promptUsage := &recorder.PromptUsageRecord{ InterceptionID: i.ID().String(), MsgID: responseID, - Prompt: *prompt, + Prompt: prompt, } if err := i.recorder.RecordPromptUsage(ctx, promptUsage); err != nil { i.logger.Warn(ctx, "failed to record prompt usage", slog.Error(err)) diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index c2d786a..de72010 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -71,24 +71,25 @@ func TestLastUserPrompt(t *testing.T) { reqPayload: tc.reqPayload, } - prompt, err := base.lastUserPrompt(t.Context()) + prompt, promptFound, err := base.lastUserPrompt(t.Context()) require.NoError(t, err) - require.NotNil(t, prompt) - require.Equal(t, tc.expect, *prompt) + require.Equal(t, tc.expect, prompt) + require.True(t, promptFound) }) } } -func TestLastUserPromptNil(t *testing.T) { +func TestLastUserPromptNotFound(t *testing.T) { t.Parallel() t.Run("nil_struct", func(t *testing.T) { t.Parallel() var base *responsesInterceptionBase - prompt, err := base.lastUserPrompt(t.Context()) + prompt, promptFound, err := base.lastUserPrompt(t.Context()) require.Error(t, err) - require.Nil(t, prompt) + require.Empty(t, prompt) + require.False(t, promptFound) require.Contains(t, "cannot get last user prompt: nil struct", err.Error()) }) @@ -96,13 +97,14 @@ func TestLastUserPromptNil(t *testing.T) { t.Parallel() base := responsesInterceptionBase{} - prompt, err := base.lastUserPrompt(t.Context()) + prompt, promptFound, err := base.lastUserPrompt(t.Context()) require.Error(t, err) - require.Nil(t, prompt) + require.Empty(t, prompt) + require.False(t, promptFound) require.Contains(t, "cannot get last user prompt: nil request struct", err.Error()) }) - // Other cases where the user prompt might be empty. + // Cases where the user prompt is not found / wrong format. tests := []struct { name string reqPayload []byte @@ -119,6 +121,7 @@ func TestLastUserPromptNil(t *testing.T) { { name: "input_integer", reqPayload: []byte(`{"model": "gpt-4o", "input": 123}`), + expectErr: "unexpected input type", }, { name: "no_user_role", @@ -131,7 +134,7 @@ func TestLastUserPromptNil(t *testing.T) { { name: "input_array_integer", reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": 123}]}`), - expectErr: "unexpected input type", + expectErr: "unexpected input content type", }, { name: "user_with_non_input_text_content", @@ -160,14 +163,15 @@ func TestLastUserPromptNil(t *testing.T) { reqPayload: tc.reqPayload, } - prompt, err := base.lastUserPrompt(t.Context()) + prompt, promptFound, err := base.lastUserPrompt(t.Context()) if tc.expectErr != "" { require.Error(t, err) require.Contains(t, err.Error(), tc.expectErr) } else { require.NoError(t, err) } - require.Nil(t, prompt) + require.Empty(t, prompt) + require.False(t, promptFound) }) } } @@ -178,59 +182,46 @@ func TestRecordPrompt(t *testing.T) { tests := []struct { name string promptWasRecorded bool - reqPayload []byte + prompt string responseID string wantRecorded bool wantPrompt string }{ { name: "records_prompt_successfully", - reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple), + prompt: "tell me a joke", responseID: "resp_123", wantRecorded: true, wantPrompt: "tell me a joke", }, { - name: "skips_record_if_prompt_was_recorded_before", - promptWasRecorded: true, - reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple), - responseID: "resp_123", - wantRecorded: false, + name: "records_empty_prompt_successfully", + prompt: "", + responseID: "resp_123", + wantRecorded: true, + wantPrompt: "", }, { name: "skips_recording_on_empty_response_id", - reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple), + prompt: "tell me a joke", responseID: "", wantRecorded: false, }, - { - name: "skips_recording_on_lastUserPrompt_error", - reqPayload: []byte(`{"model": "gpt-4o", "input": []}`), - responseID: "resp_123", - wantRecorded: false, - }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() - req := &ResponsesNewParamsWrapper{} - err := req.UnmarshalJSON(tc.reqPayload) - require.NoError(t, err) - rec := &testutil.MockRecorder{} id := uuid.New() base := &responsesInterceptionBase{ - id: id, - req: req, - reqPayload: tc.reqPayload, - promptWasRecorded: tc.promptWasRecorded, - recorder: rec, - logger: slog.Make(), + id: id, + recorder: rec, + logger: slog.Make(), } - base.recordUserPrompt(t.Context(), tc.responseID) + base.recordUserPrompt(t.Context(), tc.responseID, tc.prompt) prompts := rec.RecordedPromptUsages() if tc.wantRecorded { diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 1497680..e895ad5 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -26,13 +26,12 @@ type BlockingResponsesInterceptor struct { func NewBlockingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string, tracer trace.Tracer) *BlockingResponsesInterceptor { return &BlockingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ - id: id, - req: req, - reqPayload: reqPayload, - cfg: cfg, - promptWasRecorded: false, - model: model, - tracer: tracer, + id: id, + req: req, + reqPayload: reqPayload, + cfg: cfg, + model: model, + tracer: tracer, }, } } @@ -61,13 +60,18 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * i.disableParallelToolCalls() var ( - response *responses.Response - err error - upstreamErr error - respCopy responseCopier + response *responses.Response + upstreamErr error + respCopy responseCopier + firstResponseID string ) + prompt, promptFound, err := i.lastUserPrompt(ctx) + if err != nil { + i.logger.Warn(ctx, "failed to get user prompt", slog.Error(err)) + } shouldLoop := true + for shouldLoop { srv := i.newResponsesService() respCopy = responseCopier{} @@ -80,7 +84,10 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * break } - i.recordUserPrompt(ctx, response.ID) + if firstResponseID == "" { + firstResponseID = response.ID + } + i.recordTokenUsage(ctx, response) // Check if there any injected tools to invoke. @@ -92,6 +99,9 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * } } + if promptFound { + i.recordUserPrompt(ctx, firstResponseID, prompt) + } i.recordNonInjectedToolUsage(ctx, response) if upstreamErr != nil && !respCopy.responseReceived.Load() { diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 27b47b2..fcf2efc 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -33,13 +33,12 @@ type StreamingResponsesInterceptor struct { func NewStreamingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string, tracer trace.Tracer) *StreamingResponsesInterceptor { return &StreamingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ - id: id, - req: req, - reqPayload: reqPayload, - cfg: cfg, - promptWasRecorded: false, - model: model, - tracer: tracer, + id: id, + req: req, + reqPayload: reqPayload, + cfg: cfg, + model: model, + tracer: tracer, }, } } @@ -80,11 +79,15 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r }() var respCopy responseCopier - var responseID string + var firstResponseID string var completedResponse *responses.Response var innerLoopErr error var streamErr error + prompt, promptFound, err := i.lastUserPrompt(ctx) + if err != nil { + i.logger.Warn(ctx, "failed to get user prompt", slog.Error(err)) + } shouldLoop := true srv := i.newResponsesService() @@ -123,8 +126,8 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r // Not every event has response.id set (eg: fixtures/openai/responses/streaming/simple.txtar). // First event should be of 'response.created' type and have response.id set. // Set responseID to the first response.id that is set. - if responseID == "" && ev.Response.ID != "" { - responseID = ev.Response.ID + if firstResponseID == "" && ev.Response.ID != "" { + firstResponseID = ev.Response.ID } // Capture the response from the response.completed event. @@ -147,7 +150,6 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r } } - i.recordUserPrompt(ctx, responseID) streamErr = stream.Err() return nil }() @@ -168,6 +170,9 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r } } + if promptFound { + i.recordUserPrompt(ctx, firstResponseID, prompt) + } i.recordNonInjectedToolUsage(ctx, completedResponse) // On innerLoop error custom error has been already sent,