Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 56 additions & 35 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,66 +139,87 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o
return opts
}

// lastUserPrompt returns last input message with "user" role
func (i *responsesInterceptionBase) lastUserPrompt() (string, error) {
// lastUserPrompt returns input text with "user" role from last input item
// or string input value if it is present + bool indicating if input was found or not.
// If no such input was found empty string + false is returned.
func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (string, bool, error) {
if i == nil {
return "", errors.New("cannot get last user prompt: nil struct")
return "", false, errors.New("cannot get last user prompt: nil struct")
}
if i.req == nil {
return "", errors.New("cannot get last user prompt: nil req struct")
return "", false, errors.New("cannot get last user prompt: nil request struct")
}

// 'input' field can be a string or array of objects:
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input

// Check string variant
if i.req.Input.OfString.Valid() {
return i.req.Input.OfString.Value, nil
return i.req.Input.OfString.Value, true, nil
}

// Fallback to parsing original bytes since golang SDK doesn't properly decode 'Input' field.
// If 'type' field of input item is not set it will be omitted from 'Input.OfInputItemList'
// It is an optional field according to API: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message
// example: fixtures/openai/responses/blocking/builtin_tool.txtar
inputItems := gjson.GetBytes(i.reqPayload, "input").Array()
for i := len(inputItems) - 1; i >= 0; i-- {
item := inputItems[i]
if item.Get("role").Str == string(constant.ValueOf[constant.User]()) {
var sb strings.Builder

// content can be a string or array of objects:
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content
content := item.Get("content")
if content.Str != "" {
return content.Str, nil
}
for _, c := range content.Array() {
if c.Get("type").Str == "input_text" {
sb.WriteString(c.Get("text").Str)
}
}
if sb.Len() > 0 {
return sb.String(), nil
}
inputItems := gjson.GetBytes(i.reqPayload, "input")

if !inputItems.IsArray() {
if inputItems.Type == gjson.Null {
return "", false, nil
}
return "", false, fmt.Errorf("unexpected input type: %v", inputItems.Type.String())
}

inputItemsArr := inputItems.Array()
if len(inputItemsArr) == 0 {
return "", false, nil
}
lastItem := inputItemsArr[len(inputItemsArr)-1]

// Request was likely not human-initiated.
return "", nil
}
if lastItem.Get("role").Str != string(constant.ValueOf[constant.User]()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would hide the expression like this behind a descriptive function

return "", false, nil
}

func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string) {
prompt, err := i.lastUserPrompt()
if err != nil {
i.logger.Warn(ctx, "failed to get last user prompt", slog.Error(err))
return
// content can be a string or array of objects:
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content
content := lastItem.Get(string(constant.ValueOf[constant.Content]()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a helper function? I see .Get(string(constant.ValueOf[constant.Content]() is repeated a few times


// non array case, should be string
if !content.IsArray() {
if content.Type == gjson.String {
return content.Str, true, nil
}
return "", false, fmt.Errorf("unexpected input content type: %v", content.Type.String())
}

// No prompt found: last request was not human-initiated.
if prompt == "" {
return
var sb strings.Builder
promptExists := false
for _, c := range content.Array() {
// ignore inputs of not `input_text` type
if c.Get(string(constant.ValueOf[constant.Type]())).Str != string(constant.ValueOf[constant.InputText]()) {
continue
}

text := c.Get(string(constant.ValueOf[constant.Text]()))
if text.Type == gjson.String {
promptExists = true
sb.WriteString(text.Str + "\n")
} else {
i.logger.Warn(ctx, fmt.Sprintf("unexpected input content array element text type: %v", text.Type))
}
}

if !promptExists {
return "", false, nil
}

prompt := strings.TrimSuffix(sb.String(), "\n")
return prompt, true, nil
}

func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string, prompt string) {
if responseID == "" {
i.logger.Warn(ctx, "got empty response ID, skipping prompt recording")
return
Expand Down
131 changes: 92 additions & 39 deletions intercept/responses/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,41 @@ func TestLastUserPrompt(t *testing.T) {
tests := []struct {
name string
reqPayload []byte
expected string
expect string
}{
{
name: "input_empty_string",
reqPayload: []byte(`{"input": ""}`),
expect: "",
},
{
name: "input_array_content_empty_string",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`),
expect: "",
},
{
name: "input_array_content_array_empty_string",
reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": ""}] } ] }`),
},
{
name: "input_array_content_array_multiple_inputs",
reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": "a"}, {"type": "input_text", "text": "b"}] } ] }`),
expect: "a\nb",
},
{
name: "simple_string_input",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple),
expected: "tell me a joke",
expect: "tell me a joke",
},
{
name: "array_single_input_string",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSingleBuiltinTool),
expected: "Is 3 + 5 a prime number? Use the add function to calculate the sum.",
expect: "Is 3 + 5 a prime number? Use the add function to calculate the sum.",
},
{
name: "array_multiple_items_content_objects",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesStreamingCodex),
expected: "hello",
expect: "hello",
},
}

Expand All @@ -52,51 +71,83 @@ func TestLastUserPrompt(t *testing.T) {
reqPayload: tc.reqPayload,
}

prompt, err := base.lastUserPrompt()
prompt, promptFound, err := base.lastUserPrompt(t.Context())
require.NoError(t, err)
require.Equal(t, tc.expected, prompt)
require.Equal(t, tc.expect, prompt)
require.True(t, promptFound)
})
}
}

func TestLastUserPromptEmptyPrompt(t *testing.T) {
func TestLastUserPromptNotFound(t *testing.T) {
t.Parallel()

t.Run("nil_struct", func(t *testing.T) {
t.Parallel()

var base *responsesInterceptionBase
prompt, err := base.lastUserPrompt()
prompt, promptFound, err := base.lastUserPrompt(t.Context())
require.Error(t, err)
require.Empty(t, prompt)
require.False(t, promptFound)
require.Contains(t, "cannot get last user prompt: nil struct", err.Error())
})

// Other cases where the user prompt might be empty.
t.Run("nil_request", func(t *testing.T) {
t.Parallel()

base := responsesInterceptionBase{}
prompt, promptFound, err := base.lastUserPrompt(t.Context())
require.Error(t, err)
require.Empty(t, prompt)
require.False(t, promptFound)
require.Contains(t, "cannot get last user prompt: nil request struct", err.Error())
})

// Cases where the user prompt is not found / wrong format.
tests := []struct {
name string
reqPayload []byte
expectErr string
}{
{
name: "empty_input",
name: "non_existing_input",
reqPayload: []byte(`{"model": "gpt-4o"}`),
},
{
name: "input_empty_array",
reqPayload: []byte(`{"model": "gpt-4o", "input": []}`),
},
{
name: "no_user_role",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "assistant", "content": "hello"}]}`),
name: "input_integer",
reqPayload: []byte(`{"model": "gpt-4o", "input": 123}`),
expectErr: "unexpected input type",
},
{
name: "user_with_empty_content",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`),
name: "no_user_role",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "assistant", "content": "hello"}]}`),
},
{
name: "user_with_empty_content_array",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": []}]}`),
},
{
name: "input_array_integer",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": 123}]}`),
expectErr: "unexpected input content type",
},
{
name: "user_with_non_input_text_content",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": [{"type": "input_image", "url": "http://example.com/img.png"}]}]}`),
},
{
name: "user_content_not_last",
reqPayload: []byte(`{"model": "gpt-4o", "input": [ {"role": "user", "content":"input"}, {"role": "assistant", "content": "hello"} ]}`),
},
{
name: "input_array_content_array_integer",
reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": 123}] } ] }`),
},
}

for _, tc := range tests {
Expand All @@ -112,9 +163,15 @@ func TestLastUserPromptEmptyPrompt(t *testing.T) {
reqPayload: tc.reqPayload,
}

prompt, err := base.lastUserPrompt()
require.NoError(t, err)
prompt, promptFound, err := base.lastUserPrompt(t.Context())
if tc.expectErr != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tc.expectErr)
} else {
require.NoError(t, err)
}
require.Empty(t, prompt)
require.False(t, promptFound)
})
}
}
Expand All @@ -123,29 +180,31 @@ func TestRecordPrompt(t *testing.T) {
t.Parallel()

tests := []struct {
name string
reqPayload []byte
responseID string
wantRecorded bool
wantPrompt string
name string
promptWasRecorded bool
prompt string
responseID string
wantRecorded bool
wantPrompt string
}{
{
name: "records_prompt_successfully",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple),
prompt: "tell me a joke",
responseID: "resp_123",
wantRecorded: true,
wantPrompt: "tell me a joke",
},
{
name: "skips_recording_on_empty_response_id",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple),
responseID: "",
wantRecorded: false,
name: "records_empty_prompt_successfully",
prompt: "",
responseID: "resp_123",
wantRecorded: true,
wantPrompt: "",
},
{
name: "skips_recording_on_lastUserPrompt_error",
reqPayload: []byte(`{"model": "gpt-4o", "input": []}`),
responseID: "resp_123",
name: "skips_recording_on_empty_response_id",
prompt: "tell me a joke",
responseID: "",
wantRecorded: false,
},
}
Expand All @@ -154,21 +213,15 @@ func TestRecordPrompt(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

req := &ResponsesNewParamsWrapper{}
err := req.UnmarshalJSON(tc.reqPayload)
require.NoError(t, err)

rec := &testutil.MockRecorder{}
id := uuid.New()
base := &responsesInterceptionBase{
id: id,
req: req,
reqPayload: tc.reqPayload,
recorder: rec,
logger: slog.Make(),
id: id,
recorder: rec,
logger: slog.Make(),
}

base.recordUserPrompt(t.Context(), tc.responseID)
base.recordUserPrompt(t.Context(), tc.responseID, tc.prompt)

prompts := rec.RecordedPromptUsages()
if tc.wantRecorded {
Expand Down
Loading