From e9bac3b0395ac718dc0baaa35ebd17d4e7a14e8a Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Fri, 23 Jan 2026 17:10:29 +0000 Subject: [PATCH 1/8] fix: handle chatcompletions streaming tool calls with no text preamble and non-zero indices --- bridge_integration_test.go | 108 ++++++++++++++++++ fixtures/fixtures.go | 6 + .../streaming_injected_tool_no_preamble.txtar | 70 ++++++++++++ ...treaming_injected_tool_nonzero_index.txtar | 69 +++++++++++ intercept/chatcompletions/streaming.go | 47 ++++++-- intercept/eventstream/eventstream.go | 9 ++ 6 files changed, 298 insertions(+), 11 deletions(-) create mode 100644 fixtures/openai/chatcompletions/streaming_injected_tool_no_preamble.txtar create mode 100644 fixtures/openai/chatcompletions/streaming_injected_tool_nonzero_index.txtar diff --git a/bridge_integration_test.go b/bridge_integration_test.go index dcaf62f..c8f5e86 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -437,6 +437,114 @@ func TestOpenAIChatCompletions(t *testing.T) { }) } }) + + t.Run("streaming injected tool call edge cases", func(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + fixture []byte + expectedArgs map[string]any + }{ + { + name: "tool call no preamble", + fixture: fixtures.OaiChatStreamingInjectedToolNoPreamble, + expectedArgs: map[string]any{"owner": "me"}, + }, + { + name: "tool call with non-zero index", + fixture: fixtures.OaiChatStreamingInjectedToolNonzeroIndex, + expectedArgs: nil, // No arguments in this fixture + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + arc := txtar.Parse(tc.fixture) + t.Logf("%s: %s", t.Name(), arc.Comment) + + files := filesMap(arc) + require.Len(t, files, 3) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureStreamingResponse) + require.Contains(t, files, fixtureStreamingToolResponse) + + reqBody := files[fixtureRequest] + + // Add the stream param to the request. + newBody, err := setJSON(reqBody, "stream", true) + require.NoError(t, err) + reqBody = newBody + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Setup mock server with response mutator for multi-turn interaction. + srv := newMockServer(ctx, t, files, func(reqCount uint32, resp []byte) []byte { + if reqCount == 1 { + // First request gets the tool call response + return resp + } + // Second request gets final response + return files[fixtureStreamingToolResponse] + }) + t.Cleanup(srv.Close) + + recorderClient := &testutil.MockRecorder{} + + // Setup MCP proxies with the tool from the fixture + mcpProxiers, mcpCalls := setupMCPServerProxiesForTest(t, testTracer) + mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) + require.NoError(t, mcpMgr.Init(ctx)) + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(srv.URL, apiKey))} + b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcpMgr, logger, nil, testTracer) + require.NoError(t, err) + + mockSrv := httptest.NewUnstartedServer(b) + t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibcontext.AsActor(ctx, userID, nil) + } + mockSrv.Start() + + req := createOpenAIChatCompletionsReq(t, mockSrv.URL, reqBody) + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Consume the full response body to ensure the interception completes + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + + // Verify the MCP tool was actually invoked + invocations := mcpCalls.getCallsByTool(mockToolName) + require.Len(t, invocations, 1, "expected MCP tool to be invoked") + + // Verify tool was invoked with the expected args (if specified) + if tc.expectedArgs != nil { + expected, err := json.Marshal(tc.expectedArgs) + require.NoError(t, err) + actual, err := json.Marshal(invocations[0]) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + } + + // Verify tool usage was recorded + toolUsages := recorderClient.RecordedToolUsages() + require.Len(t, toolUsages, 1) + assert.Equal(t, mockToolName, toolUsages[0].Tool) + + recorderClient.VerifyAllInterceptionsEnded(t) + }) + } + }) } func TestSimple(t *testing.T) { diff --git a/fixtures/fixtures.go b/fixtures/fixtures.go index 243b506..2370b4e 100644 --- a/fixtures/fixtures.go +++ b/fixtures/fixtures.go @@ -45,6 +45,12 @@ var ( //go:embed openai/chatcompletions/non_stream_error.txtar OaiChatNonStreamError []byte + + //go:embed openai/chatcompletions/streaming_injected_tool_no_preamble.txtar + OaiChatStreamingInjectedToolNoPreamble []byte + + //go:embed openai/chatcompletions/streaming_injected_tool_nonzero_index.txtar + OaiChatStreamingInjectedToolNonzeroIndex []byte ) var ( diff --git a/fixtures/openai/chatcompletions/streaming_injected_tool_no_preamble.txtar b/fixtures/openai/chatcompletions/streaming_injected_tool_no_preamble.txtar new file mode 100644 index 0000000..d31b0f8 --- /dev/null +++ b/fixtures/openai/chatcompletions/streaming_injected_tool_no_preamble.txtar @@ -0,0 +1,70 @@ +Streaming response where the provider returns an injected tool call as the first chunk with no text preamble. +This test ensures tool invocation continues even when no chunks are relayed to the client. + +-- request -- +{ + "messages": [ + { + "content": "2026-01-22T18:35:17.612Z\n\nlist all my coder workspaces", + "role": "user" + } + ], + "model": "claude-haiku-4.5", + "n": 1, + "temperature": 1, + "parallel_tool_calls": false, + "stream_options": { + "include_usage": true + }, + "stream": true +} + +-- streaming -- +data: {"choices":[{"index":0,"delta":{"content":null,"tool_calls":[{"function":{"name":"bmcp_coder_coder_list_workspaces"},"id":"toolu_vrtx_01CvBi1d4qpKTG2PCuc9wDbZ","index":0,"type":"function"}]}}],"created":1769106921,"id":"msg_vrtx_01UoiRJwj3JXcwNYAh3z7ARs","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":null,"tool_calls":[{"function":{"arguments":""},"index":0}]}}],"created":1769106921,"id":"msg_vrtx_01UoiRJwj3JXcwNYAh3z7ARs","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":null,"tool_calls":[{"function":{"arguments":"{\"own"},"index":0}]}}],"created":1769106921,"id":"msg_vrtx_01UoiRJwj3JXcwNYAh3z7ARs","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":null,"tool_calls":[{"function":{"arguments":"er\": \"me\"}"},"index":0}]}}],"created":1769106921,"id":"msg_vrtx_01UoiRJwj3JXcwNYAh3z7ARs","model":"claude-haiku-4.5"} + +data: {"choices":[{"finish_reason":"tool_calls","index":0,"delta":{"content":null}}],"created":1769106921,"id":"msg_vrtx_01UoiRJwj3JXcwNYAh3z7ARs","usage":{"completion_tokens":65,"prompt_tokens":25716,"prompt_tokens_details":{"cached_tokens":20470},"total_tokens":25781},"model":"claude-haiku-4.5"} + +data: [DONE] + +-- streaming/tool-call -- +data: {"choices":[{"index":0,"delta":{"content":"You","role":"assistant"}}],"created":1769198061,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" have one","role":"assistant"}}],"created":1769198061,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" Coder workspace:","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"\n\n**test-scf** (","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"ID: a174a2e5","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"-5050-445d-89","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"ff-dd720e5b442","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"e)\n- Template: docker","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"\n- Template Version","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" ID","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":": ad1b5ab1-","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"fc18-4792-84f","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"7-797787607d30","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"\n- Status","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":": Up","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" to date","role":"assistant"}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} + +data: {"choices":[{"finish_reason":"stop","index":0,"delta":{"content":null}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","usage":{"completion_tokens":85,"prompt_tokens":25989,"prompt_tokens_details":{"cached_tokens":0},"total_tokens":26074},"model":"claude-haiku-4.5"} + +data: [DONE] diff --git a/fixtures/openai/chatcompletions/streaming_injected_tool_nonzero_index.txtar b/fixtures/openai/chatcompletions/streaming_injected_tool_nonzero_index.txtar new file mode 100644 index 0000000..6ffce77 --- /dev/null +++ b/fixtures/openai/chatcompletions/streaming_injected_tool_nonzero_index.txtar @@ -0,0 +1,69 @@ +Streaming response where the provider returns text content followed by an injected tool call at index 1 (instead of index 0). +This can happen when the provider incorrectly continues indexing from a previous response. +This tests that nil entries are removed from the tool calls array caused by non-zero starting indices. + +-- request -- +{ + "messages": [ + { + "content": "2026-01-23T20:22:43.781Z\n\nI want you to do to this in order:\n1) create a file in my current directory with name \"test.txt\"\n2) list all my coder workspaces", + "role": "user" + } + ], + "model": "claude-haiku-4.5", + "n": 1, + "temperature": 1, + "parallel_tool_calls": false, + "stream_options": { + "include_usage": true + }, + "stream": true +} + +-- streaming -- +data: {"choices":[{"index":0,"delta":{"content":"Now","role":"assistant"}}],"created":1769199774,"id":"msg_vrtx_01Fiieb5Z3kqJf9a3FwvLkky","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" listing","role":"assistant"}}],"created":1769199774,"id":"msg_vrtx_01Fiieb5Z3kqJf9a3FwvLkky","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" your","role":"assistant"}}],"created":1769199774,"id":"msg_vrtx_01Fiieb5Z3kqJf9a3FwvLkky","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" C","role":"assistant"}}],"created":1769199774,"id":"msg_vrtx_01Fiieb5Z3kqJf9a3FwvLkky","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"oder workspaces:","role":"assistant"}}],"created":1769199774,"id":"msg_vrtx_01Fiieb5Z3kqJf9a3FwvLkky","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":null,"tool_calls":[{"function":{"name":"bmcp_coder_coder_list_workspaces"},"id":"toolu_vrtx_01DbFqUgk6aAtJ4nDBqzFWDF","index":1,"type":"function"}]}}],"created":1769199774,"id":"msg_vrtx_01Fiieb5Z3kqJf9a3FwvLkky","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":null,"tool_calls":[{"function":{"arguments":""},"index":1}]}}],"created":1769199774,"id":"msg_vrtx_01Fiieb5Z3kqJf9a3FwvLkky","model":"claude-haiku-4.5"} + +data: {"choices":[{"finish_reason":"tool_calls","index":0,"delta":{"content":null}}],"created":1769199774,"id":"msg_vrtx_01Fiieb5Z3kqJf9a3FwvLkky","usage":{"completion_tokens":58,"prompt_tokens":25939,"prompt_tokens_details":{"cached_tokens":25429},"total_tokens":25997},"model":"claude-haiku-4.5"} + +data: [DONE] + +-- streaming/tool-call -- +data: {"choices":[{"index":0,"delta":{"content":"Done","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"! I create","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"d `","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"test.txt` in","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" your current directory.","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" You","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" have","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" 1","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" ","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":"Coder workspace:\n\n-","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" **test-scf** (docker","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} + +data: {"choices":[{"index":0,"delta":{"content":" template)","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} + +data: {"choices":[{"finish_reason":"stop","index":0,"delta":{"content":null}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","usage":{"completion_tokens":39,"prompt_tokens":26166,"prompt_tokens_details":{"cached_tokens":25934},"total_tokens":26205},"model":"claude-haiku-4.5"} + +data: [DONE] diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index 5193148..f16afc5 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net/http" + "slices" "strings" "time" @@ -148,16 +149,24 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re } } - // Builtin tools are not intercepted. - if toolCall != nil && i.getInjectedToolByName(toolCall.Name) == nil { - _ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{ - InterceptionID: i.ID().String(), - MsgID: processor.getMsgID(), - Tool: toolCall.Name, - Args: i.unmarshalArgs(toolCall.Arguments), - Injected: false, - }) - toolCall = nil + if toolCall != nil { + // Builtin tools are not intercepted. + if i.getInjectedToolByName(toolCall.Name) == nil { + _ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{ + InterceptionID: i.ID().String(), + MsgID: processor.getMsgID(), + Tool: toolCall.Name, + Args: i.unmarshalArgs(toolCall.Arguments), + Injected: false, + }) + toolCall = nil + } else { + // Injected tools mark the stream as initiated so we continue to tool invocation. + // When the provider responds with a tool call as the first chunk (no text + // preamble), no chunks are relayed to the client. Marking as initiated + // ensures we continue to tool invocation instead of returning early. + events.MarkInitiated() + } } if prompt != nil { @@ -239,7 +248,13 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re // Invoke the injected tool, and use the tool result to make a subsequent request to the upstream. // Append the completion from this stream as context. - i.req.Messages = append(i.req.Messages, processor.getLastCompletion().ToParam()) + // Some providers may return tool calls with non-zero starting indices, + // resulting in nil entries in the array that must be removed. + completion := processor.getLastCompletion() + if completion != nil { + compactToolCalls(completion) + i.req.Messages = append(i.req.Messages, completion.ToParam()) + } id := toolCall.ID args := i.unmarshalArgs(toolCall.Arguments) @@ -486,3 +501,13 @@ func (s *streamProcessor) getLastUsage() openai.CompletionUsage { func (s *streamProcessor) getCumulativeUsage() openai.CompletionUsage { return s.cumulativeUsage } + +// compactToolCalls removes nil/empty tool call entries (without an ID). +func compactToolCalls(msg *openai.ChatCompletionMessage) { + if msg == nil || len(msg.ToolCalls) == 0 { + return + } + msg.ToolCalls = slices.DeleteFunc(msg.ToolCalls, func(tc openai.ChatCompletionMessageToolCallUnion) bool { + return tc.ID == "" + }) +} diff --git a/intercept/eventstream/eventstream.go b/intercept/eventstream/eventstream.go index 361dc21..7762caf 100644 --- a/intercept/eventstream/eventstream.go +++ b/intercept/eventstream/eventstream.go @@ -199,6 +199,15 @@ func (s *EventStream) IsStreaming() bool { return s.initiated.Load() || len(s.eventsCh) > 0 } +// MarkInitiated marks the stream as initiated, even if no events have been +// sent to the client yet. A stream is considered initiated when processing +// injected tool calls that don't relay chunks to the client. +func (s *EventStream) MarkInitiated() { + s.initiateOnce.Do(func() { + s.initiated.Store(true) + }) +} + // IsConnError checks if an error is related to client disconnection or context cancellation. func IsConnError(err error) bool { if err == nil { From 734ddda07ed4e721baf1edda20d88fc667b3af92 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Mon, 26 Jan 2026 12:09:29 +0000 Subject: [PATCH 2/8] chore: add newlines at fixture files --- .../chatcompletions/streaming_injected_tool_no_preamble.txtar | 3 +++ .../streaming_injected_tool_nonzero_index.txtar | 3 +++ 2 files changed, 6 insertions(+) diff --git a/fixtures/openai/chatcompletions/streaming_injected_tool_no_preamble.txtar b/fixtures/openai/chatcompletions/streaming_injected_tool_no_preamble.txtar index d31b0f8..f39097c 100644 --- a/fixtures/openai/chatcompletions/streaming_injected_tool_no_preamble.txtar +++ b/fixtures/openai/chatcompletions/streaming_injected_tool_no_preamble.txtar @@ -32,6 +32,7 @@ data: {"choices":[{"finish_reason":"tool_calls","index":0,"delta":{"content":nul data: [DONE] + -- streaming/tool-call -- data: {"choices":[{"index":0,"delta":{"content":"You","role":"assistant"}}],"created":1769198061,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","model":"claude-haiku-4.5"} @@ -68,3 +69,5 @@ data: {"choices":[{"index":0,"delta":{"content":" to date","role":"assistant"}}] data: {"choices":[{"finish_reason":"stop","index":0,"delta":{"content":null}}],"created":1769198062,"id":"msg_vrtx_015B1npskreQgEjMrfsdjH1m","usage":{"completion_tokens":85,"prompt_tokens":25989,"prompt_tokens_details":{"cached_tokens":0},"total_tokens":26074},"model":"claude-haiku-4.5"} data: [DONE] + + diff --git a/fixtures/openai/chatcompletions/streaming_injected_tool_nonzero_index.txtar b/fixtures/openai/chatcompletions/streaming_injected_tool_nonzero_index.txtar index 6ffce77..384d1ee 100644 --- a/fixtures/openai/chatcompletions/streaming_injected_tool_nonzero_index.txtar +++ b/fixtures/openai/chatcompletions/streaming_injected_tool_nonzero_index.txtar @@ -39,6 +39,7 @@ data: {"choices":[{"finish_reason":"tool_calls","index":0,"delta":{"content":nul data: [DONE] + -- streaming/tool-call -- data: {"choices":[{"index":0,"delta":{"content":"Done","role":"assistant"}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","model":"claude-haiku-4.5"} @@ -67,3 +68,5 @@ data: {"choices":[{"index":0,"delta":{"content":" template)","role":"assistant"} data: {"choices":[{"finish_reason":"stop","index":0,"delta":{"content":null}}],"created":1769199776,"id":"msg_vrtx_01RVxamMyw1DBtpoENDpmnQK","usage":{"completion_tokens":39,"prompt_tokens":26166,"prompt_tokens_details":{"cached_tokens":25934},"total_tokens":26205},"model":"claude-haiku-4.5"} data: [DONE] + + From 61ea53a82e37544058510a7d75525161577edce7 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Mon, 26 Jan 2026 18:56:22 +0000 Subject: [PATCH 3/8] fix: initiate SSE stream synchronously to prevent race condition and client timeout during tool invocation --- bridge_integration_test.go | 5 ++ intercept/chatcompletions/streaming.go | 14 +++-- intercept/eventstream/eventstream.go | 81 +++++++++++++------------- 3 files changed, 53 insertions(+), 47 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index c8f5e86..ee54bfd 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -518,6 +518,11 @@ func TestOpenAIChatCompletions(t *testing.T) { require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) + // Verify SSE headers are sent correctly + require.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) + require.Equal(t, "no-cache", resp.Header.Get("Cache-Control")) + require.Equal(t, "keep-alive", resp.Header.Get("Connection")) + // Consume the full response body to ensure the interception completes _, err = io.ReadAll(resp.Body) require.NoError(t, err) diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index f16afc5..3c65ad7 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -161,11 +161,15 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re }) toolCall = nil } else { - // Injected tools mark the stream as initiated so we continue to tool invocation. - // When the provider responds with a tool call as the first chunk (no text - // preamble), no chunks are relayed to the client. Marking as initiated - // ensures we continue to tool invocation instead of returning early. - events.MarkInitiated() + // When the provider responds with only tool calls (no text content), + // no chunks are relayed to the client, so the stream is not yet + // initiated. Initiate it here so the SSE headers are sent and the + // ping ticker is started, preventing client timeout during tool invocation. + // Only initiate if no stream error, if there's an error, we'll return + // an HTTP error response instead of starting an SSE stream. + if stream.Err() == nil { + events.MarkInitiated(w) + } } } diff --git a/intercept/eventstream/eventstream.go b/intercept/eventstream/eventstream.go index 7762caf..6d454f4 100644 --- a/intercept/eventstream/eventstream.go +++ b/intercept/eventstream/eventstream.go @@ -37,10 +37,18 @@ type EventStream struct { // doneCh is closed when the start loop exits. doneCh chan struct{} + + // tick sends periodic pings to keep the connection alive. + tick *time.Ticker } // NewEventStream creates a new SSE stream, with an optional payload which is used to send pings every [pingInterval]. func NewEventStream(ctx context.Context, logger slog.Logger, pingPayload []byte) *EventStream { + // Send periodic pings to keep connections alive. + // The upstream provider may also send their own pings, but we can't rely on this. + tick := time.NewTicker(time.Nanosecond) + tick.Stop() // Ticker will start after stream initiation. + return &EventStream{ ctx: ctx, logger: logger, @@ -49,9 +57,35 @@ func NewEventStream(ctx context.Context, logger slog.Logger, pingPayload []byte) eventsCh: make(chan event, 128), // Small buffer to unblock senders; once full, senders will block. doneCh: make(chan struct{}), + tick: tick, } } +// MarkInitiated initiates the SSE stream by sending headers and starting the +// ping ticker. This is safe to call multiple times as only the first call has +// any effect. +func (s *EventStream) MarkInitiated(w http.ResponseWriter) { + s.initiateOnce.Do(func() { + s.initiated.Store(true) + s.logger.Debug(s.ctx, "stream initiated") + + // Send headers for Server-Sent Event stream. + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + // Send initial flush to ensure connection is established. + if err := flush(w); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Start ping ticker. + s.tick.Reset(pingInterval) + }) +} + // Start handles sending Server-Sent Event to the client. func (s *EventStream) Start(w http.ResponseWriter, r *http.Request) { // Signal completion on exit so senders don't block indefinitely after closure. @@ -59,11 +93,7 @@ func (s *EventStream) Start(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - // Send periodic pings to keep connections alive. - // The upstream provider may also send their own pings, but we can't rely on this. - tick := time.NewTicker(time.Nanosecond) - tick.Stop() // Ticker will start after stream initiation. - defer tick.Stop() + defer s.tick.Stop() for { var ( @@ -83,33 +113,9 @@ func (s *EventStream) Start(w http.ResponseWriter, r *http.Request) { return } - // Initiate the stream once the first event is received. - s.initiateOnce.Do(func() { - s.initiated.Store(true) - s.logger.Debug(ctx, "stream initiated") - - // Send headers for Server-Sent Event stream. - // - // We only send these once an event is processed because an error can occur in the upstream - // request prior to the stream starting, in which case the SSE headers are inappropriate to - // send to the client. - // - // See use of IsStreaming(). - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") - - // Send initial flush to ensure connection is established. - if err := flush(w); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Start ping ticker. - tick.Reset(pingInterval) - }) - case <-tick.C: + // Initiate the stream on first event (if not already initiated). + s.MarkInitiated(w) + case <-s.tick.C: ev = s.pingPayload if ev == nil { continue @@ -132,7 +138,7 @@ func (s *EventStream) Start(w http.ResponseWriter, r *http.Request) { // Reset the timer once we've flushed some data to the stream, since it's already fresh. // No need to ping in that case. - tick.Reset(pingInterval) + s.tick.Reset(pingInterval) } } @@ -199,15 +205,6 @@ func (s *EventStream) IsStreaming() bool { return s.initiated.Load() || len(s.eventsCh) > 0 } -// MarkInitiated marks the stream as initiated, even if no events have been -// sent to the client yet. A stream is considered initiated when processing -// injected tool calls that don't relay chunks to the client. -func (s *EventStream) MarkInitiated() { - s.initiateOnce.Do(func() { - s.initiated.Store(true) - }) -} - // IsConnError checks if an error is related to client disconnection or context cancellation. func IsConnError(err error) bool { if err == nil { From 8e62b0652471a8dabd3790817c0198a9fe2253f2 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 27 Jan 2026 10:15:39 +0000 Subject: [PATCH 4/8] chore: address comments --- intercept/chatcompletions/streaming.go | 2 +- intercept/eventstream/eventstream.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index 3c65ad7..74f7ac3 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -168,7 +168,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re // Only initiate if no stream error, if there's an error, we'll return // an HTTP error response instead of starting an SSE stream. if stream.Err() == nil { - events.MarkInitiated(w) + events.InitiateStream(w) } } } diff --git a/intercept/eventstream/eventstream.go b/intercept/eventstream/eventstream.go index 6d454f4..b3ee96a 100644 --- a/intercept/eventstream/eventstream.go +++ b/intercept/eventstream/eventstream.go @@ -61,10 +61,10 @@ func NewEventStream(ctx context.Context, logger slog.Logger, pingPayload []byte) } } -// MarkInitiated initiates the SSE stream by sending headers and starting the +// InitiateStream initiates the SSE stream by sending headers and starting the // ping ticker. This is safe to call multiple times as only the first call has // any effect. -func (s *EventStream) MarkInitiated(w http.ResponseWriter) { +func (s *EventStream) InitiateStream(w http.ResponseWriter) { s.initiateOnce.Do(func() { s.initiated.Store(true) s.logger.Debug(s.ctx, "stream initiated") @@ -114,7 +114,7 @@ func (s *EventStream) Start(w http.ResponseWriter, r *http.Request) { } // Initiate the stream on first event (if not already initiated). - s.MarkInitiated(w) + s.InitiateStream(w) case <-s.tick.C: ev = s.pingPayload if ev == nil { From 84c49e658d07ea3091618d51cdd6a156b292cf8d Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Wed, 21 Jan 2026 13:35:38 +0000 Subject: [PATCH 5/8] feat: add GitHub Copilot provider with per-user token authentication --- api.go | 6 + config/config.go | 8 + intercept/chatcompletions/base.go | 6 + provider/copilot.go | 187 ++++++++++++++++ provider/copilot_test.go | 347 ++++++++++++++++++++++++++++++ 5 files changed, 554 insertions(+) create mode 100644 provider/copilot.go create mode 100644 provider/copilot_test.go diff --git a/api.go b/api.go index 897c401..acc789e 100644 --- a/api.go +++ b/api.go @@ -17,6 +17,7 @@ import ( const ( ProviderAnthropic = config.ProviderAnthropic ProviderOpenAI = config.ProviderOpenAI + ProviderCopilot = config.ProviderCopilot ) type ( @@ -35,6 +36,7 @@ type ( AnthropicConfig = config.Anthropic AWSBedrockConfig = config.AWSBedrock OpenAIConfig = config.OpenAI + CopilotConfig = config.Copilot ) func AsActor(ctx context.Context, actorID string, metadata recorder.Metadata) context.Context { @@ -49,6 +51,10 @@ func NewOpenAIProvider(cfg config.OpenAI) provider.Provider { return provider.NewOpenAI(cfg) } +func NewCopilotProvider(cfg config.Copilot) provider.Provider { + return provider.NewCopilot(cfg) +} + func NewMetrics(reg prometheus.Registerer) *metrics.Metrics { return metrics.NewMetrics(reg) } diff --git a/config/config.go b/config/config.go index 3387007..27fc0c2 100644 --- a/config/config.go +++ b/config/config.go @@ -5,6 +5,7 @@ import "time" const ( ProviderAnthropic = "anthropic" ProviderOpenAI = "openai" + ProviderCopilot = "copilot" ) // CircuitBreaker holds configuration for circuit breakers. @@ -57,4 +58,11 @@ type OpenAI struct { Key string APIDumpDir string CircuitBreaker *CircuitBreaker + ExtraHeaders map[string]string +} + +type Copilot struct { + BaseURL string + APIDumpDir string + CircuitBreaker *CircuitBreaker } diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 02e7f7b..ac7476b 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -39,6 +39,12 @@ type interceptionBase struct { func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService { opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)} + // Add extra headers if configured. + // Some providers require additional headers that are not added by the SDK. + for key, value := range i.cfg.ExtraHeaders { + opts = append(opts, option.WithHeader(key, value)) + } + // Add API dump middleware if configured if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) diff --git a/provider/copilot.go b/provider/copilot.go new file mode 100644 index 0000000..e34f6f6 --- /dev/null +++ b/provider/copilot.go @@ -0,0 +1,187 @@ +package provider + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/intercept" + "github.com/coder/aibridge/intercept/chatcompletions" + "github.com/coder/aibridge/intercept/responses" + "github.com/coder/aibridge/tracing" + "github.com/google/uuid" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +const ( + copilotBaseURL = "https://api.individual.githubcopilot.com" + routeCopilotChatCompletions = "/copilot/chat/completions" + routeCopilotResponses = "/copilot/responses" +) + +var copilotOpenErrorResponse = func() []byte { + return []byte(`{"error":{"message":"circuit breaker is open","type":"server_error","code":"service_unavailable"}}`) +} + +// Headers that need to be forwarded to Copilot API +var copilotForwardHeaders = []string{ + "Editor-Version", + "Copilot-Integration-Id", +} + +// Copilot implements the Provider interface for GitHub Copilot. +// Unlike other providers, Copilot uses per-user API keys that are passed through +// the request headers rather than configured statically. +type Copilot struct { + cfg config.Copilot + circuitBreaker *config.CircuitBreaker +} + +var _ Provider = &Copilot{} + +func NewCopilot(cfg config.Copilot) *Copilot { + if cfg.BaseURL == "" { + cfg.BaseURL = copilotBaseURL + } + if cfg.APIDumpDir == "" { + cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR") + } + if cfg.CircuitBreaker != nil { + cfg.CircuitBreaker.OpenErrorResponse = copilotOpenErrorResponse + } + return &Copilot{ + cfg: cfg, + circuitBreaker: cfg.CircuitBreaker, + } +} + +func (p *Copilot) Name() string { + return config.ProviderCopilot +} + +func (p *Copilot) BaseURL() string { + return p.cfg.BaseURL +} + +func (p *Copilot) BridgedRoutes() []string { + return []string{ + routeCopilotChatCompletions, + routeCopilotResponses, + } +} + +func (p *Copilot) PassthroughRoutes() []string { + return []string{ + "/models", + "/models/", + "/agents/", + "/mcp/", + } +} + +func (p *Copilot) AuthHeader() string { + return "Authorization" +} + +// InjectAuthHeader is a no-op for Copilot. +// Copilot uses per-user tokens passed in the original Authorization header, +// rather than a global key configured at the provider level. +// The original Authorization header flows through untouched from the client. +func (p *Copilot) InjectAuthHeader(_ *http.Header) {} + +func (p *Copilot) CircuitBreakerConfig() *config.CircuitBreaker { + return p.circuitBreaker +} + +func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) { + _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") + defer tracing.EndSpanErr(span, &outErr) + + // Extract the per-user Copilot key from the Authorization header. + key := extractBearerToken(r.Header.Get("Authorization")) + if key == "" { + span.SetStatus(codes.Error, "missing authorization") + return nil, fmt.Errorf("missing Copilot authorization: Authorization header not found or invalid") + } + + payload, err := io.ReadAll(r.Body) + if err != nil { + return nil, fmt.Errorf("read body: %w", err) + } + + id := uuid.New() + + // Build config for the interceptor using the per-request key. + // Copilot's API is OpenAI-compatible, so it uses the OpenAI interceptors + // that require a config.OpenAI. + cfg := config.OpenAI{ + BaseURL: p.cfg.BaseURL, + Key: key, + APIDumpDir: p.cfg.APIDumpDir, + CircuitBreaker: p.cfg.CircuitBreaker, + ExtraHeaders: extractCopilotHeaders(r), + } + + var interceptor intercept.Interceptor + + switch r.URL.Path { + case routeCopilotChatCompletions: + var req chatcompletions.ChatCompletionNewParamsWrapper + if err := json.Unmarshal(payload, &req); err != nil { + return nil, fmt.Errorf("unmarshal chat completions request body: %w", err) + } + + if req.Stream { + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, tracer) + } else { + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, tracer) + } + + case routeCopilotResponses: + var req responses.ResponsesNewParamsWrapper + if err := json.Unmarshal(payload, &req); err != nil { + return nil, fmt.Errorf("unmarshal responses request body: %w", err) + } + + if req.Stream { + interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, req.Model, tracer) + } else { + interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, req.Model, tracer) + } + + default: + span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) + return nil, UnknownRoute + } + + span.SetAttributes(interceptor.TraceAttributes(r)...) + return interceptor, nil +} + +// extractBearerToken extracts the token from a "Bearer " authorization header. +func extractBearerToken(auth string) string { + if auth := strings.TrimSpace(auth); auth != "" { + fields := strings.Fields(auth) + if len(fields) == 2 && strings.EqualFold(fields[0], "Bearer") { + return fields[1] + } + } + return "" +} + +// extractCopilotHeaders extracts headers required by the Copilot API from the +// incoming request. Copilot requires certain client headers to be forwarded. +func extractCopilotHeaders(r *http.Request) map[string]string { + headers := make(map[string]string) + for _, h := range copilotForwardHeaders { + if v := r.Header.Get(h); v != "" { + headers[h] = v + } + } + return headers +} diff --git a/provider/copilot_test.go b/provider/copilot_test.go new file mode 100644 index 0000000..119c114 --- /dev/null +++ b/provider/copilot_test.go @@ -0,0 +1,347 @@ +package provider + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "cdr.dev/slog/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/internal/testutil" +) + +var testTracer = otel.Tracer("copilot_test") + +func TestCopilot_InjectAuthHeader(t *testing.T) { + t.Parallel() + + // Copilot uses per-user key passed in the Authorization header, + // so InjectAuthHeader should not modify any headers. + provider := NewCopilot(config.Copilot{}) + + t.Run("ExistingHeaders_Unchanged", func(t *testing.T) { + t.Parallel() + + headers := http.Header{} + headers.Set("Authorization", "Bearer user-token") + headers.Set("X-Custom-Header", "custom-value") + + provider.InjectAuthHeader(&headers) + + assert.Equal(t, "Bearer user-token", headers.Get("Authorization"), + "Authorization header should remain unchanged") + assert.Equal(t, "custom-value", headers.Get("X-Custom-Header"), + "other headers should remain unchanged") + }) + + t.Run("EmptyHeaders_NoneAdded", func(t *testing.T) { + t.Parallel() + + headers := http.Header{} + + provider.InjectAuthHeader(&headers) + + assert.Empty(t, headers, "no headers should be added") + }) +} + +func TestCopilot_CreateInterceptor(t *testing.T) { + t.Parallel() + + provider := NewCopilot(config.Copilot{}) + + t.Run("MissingAuthorizationHeader", func(t *testing.T) { + t.Parallel() + + body := `{"model": "gpt-4.1", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.Error(t, err) + require.Nil(t, interceptor) + assert.Contains(t, err.Error(), "missing Copilot authorization: Authorization header not found or invalid") + }) + + t.Run("InvalidAuthorizationFormat", func(t *testing.T) { + t.Parallel() + + body := `{"model": "claude-haiku-4.5", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "InvalidFormat") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.Error(t, err) + require.Nil(t, interceptor) + assert.Contains(t, err.Error(), "missing Copilot authorization: Authorization header not found or invalid") + }) + + t.Run("ChatCompletions_NonStreamingRequest_BlockingInterceptor", func(t *testing.T) { + t.Parallel() + + body := `{"model": "claude-haiku-4.5", "messages": [{"role": "user", "content": "hello"}], "stream": false}` + req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.NoError(t, err) + require.NotNil(t, interceptor) + assert.False(t, interceptor.Streaming()) + }) + + t.Run("ChatCompletions_StreamingRequest_StreamingInterceptor", func(t *testing.T) { + t.Parallel() + + body := `{"model": "gpt-4.1", "messages": [{"role": "user", "content": "hello"}], "stream": true}` + req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.NoError(t, err) + require.NotNil(t, interceptor) + assert.True(t, interceptor.Streaming()) + }) + + t.Run("ChatCompletions_InvalidRequestBody", func(t *testing.T) { + t.Parallel() + + body := `invalid json` + req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.Error(t, err) + require.Nil(t, interceptor) + assert.Contains(t, err.Error(), "unmarshal chat completions request body") + }) + + t.Run("ChatCompletions_ForwardsHeadersToUpstream", func(t *testing.T) { + t.Parallel() + + var receivedHeaders http.Header + + // Mock upstream that captures headers + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`)) + })) + t.Cleanup(mockUpstream.Close) + + // Create provider with mock upstream URL + provider := NewCopilot(config.Copilot{ + BaseURL: mockUpstream.URL, + }) + + body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": false}` + req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("Editor-Version", "vscode/1.85.0") + req.Header.Set("Copilot-Integration-Id", "test-integration") + req.Header.Set("X-Custom-Header", "should-not-forward") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + require.NoError(t, err) + require.NotNil(t, interceptor) + + // Setup and process request + logger := slog.Make() + interceptor.Setup(logger, &testutil.MockRecorder{}, nil) + + processReq := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, nil) + err = interceptor.ProcessRequest(w, processReq) + require.NoError(t, err) + + // Verify headers were forwarded + assert.Equal(t, "vscode/1.85.0", receivedHeaders.Get("Editor-Version")) + assert.Equal(t, "test-integration", receivedHeaders.Get("Copilot-Integration-Id")) + + // Verify non-Copilot headers are not forwarded + assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Copilot headers should not be forwarded") + }) + + t.Run("Responses_NonStreamingRequest_BlockingInterceptor", func(t *testing.T) { + t.Parallel() + + body := `{"model": "gpt-5-mini", "input": "hello", "stream": false}` + req := httptest.NewRequest(http.MethodPost, routeCopilotResponses, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.NoError(t, err) + require.NotNil(t, interceptor) + assert.False(t, interceptor.Streaming()) + }) + + t.Run("Responses_StreamingRequest_StreamingInterceptor", func(t *testing.T) { + t.Parallel() + + body := `{"model": "gpt-5-mini", "input": "hello", "stream": true}` + req := httptest.NewRequest(http.MethodPost, routeCopilotResponses, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.NoError(t, err) + require.NotNil(t, interceptor) + assert.True(t, interceptor.Streaming()) + }) + + t.Run("Responses_InvalidRequestBody", func(t *testing.T) { + t.Parallel() + + body := `invalid json` + req := httptest.NewRequest(http.MethodPost, routeCopilotResponses, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.Error(t, err) + require.Nil(t, interceptor) + assert.Contains(t, err.Error(), "unmarshal responses request body") + }) + + t.Run("UnknownRoute", func(t *testing.T) { + t.Parallel() + + body := `{"model": "gpt-4.1", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/copilot/unknown/route", bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.ErrorIs(t, err, UnknownRoute) + require.Nil(t, interceptor) + }) +} + +func Test_extractBearerToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "Empty", + input: "", + expected: "", + }, + { + name: "Whitespace", + input: " ", + expected: "", + }, + { + name: "InvalidFormat", + input: "some-token", + expected: "", + }, + { + name: "BearerOnly", + input: "Bearer", + expected: "", + }, + { + name: "Valid", + input: "Bearer my-secret-token", + expected: "my-secret-token", + }, + { + name: "BearerMixedCase", + input: "BeArEr my-secret-token", + expected: "my-secret-token", + }, + { + name: "LeadingWhitespace", + input: " Bearer my-secret-token", + expected: "my-secret-token", + }, + { + name: "TrailingWhitespace", + input: "Bearer my-secret-token ", + expected: "my-secret-token", + }, + { + name: "TooManyParts", + input: "Bearer token extra", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := extractBearerToken(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestExtractCopilotHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + headers map[string]string + expected map[string]string + }{ + { + name: "all headers present", + headers: map[string]string{"Editor-Version": "vscode/1.85.0", "Copilot-Integration-Id": "some-id"}, + expected: map[string]string{"Editor-Version": "vscode/1.85.0", "Copilot-Integration-Id": "some-id"}, + }, + { + name: "some headers present", + headers: map[string]string{"Editor-Version": "vscode/1.85.0"}, + expected: map[string]string{"Editor-Version": "vscode/1.85.0"}, + }, + { + name: "no headers", + headers: map[string]string{}, + expected: map[string]string{}, + }, + { + name: "ignores other headers", + headers: map[string]string{"Editor-Version": "vscode/1.85.0", "Authorization": "Bearer token"}, + expected: map[string]string{"Editor-Version": "vscode/1.85.0"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodPost, "/", nil) + for header, value := range tc.headers { + req.Header.Set(header, value) + } + + result := extractCopilotHeaders(req) + assert.Equal(t, tc.expected, result) + }) + } +} From 8fec1c29151fc203a9558052ba7c3ef9845f399c Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Mon, 26 Jan 2026 19:44:47 +0000 Subject: [PATCH 6/8] chore: address comments --- provider/copilot.go | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/provider/copilot.go b/provider/copilot.go index e34f6f6..5276931 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -19,7 +19,9 @@ import ( ) const ( - copilotBaseURL = "https://api.individual.githubcopilot.com" + copilotBaseURL = "https://api.individual.githubcopilot.com" + + // Copilot exposes an OpenAI-compatible API, including for Anthropic models. routeCopilotChatCompletions = "/copilot/chat/completions" routeCopilotResponses = "/copilot/responses" ) @@ -109,11 +111,6 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac return nil, fmt.Errorf("missing Copilot authorization: Authorization header not found or invalid") } - payload, err := io.ReadAll(r.Body) - if err != nil { - return nil, fmt.Errorf("read body: %w", err) - } - id := uuid.New() // Build config for the interceptor using the per-request key. @@ -132,7 +129,7 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac switch r.URL.Path { case routeCopilotChatCompletions: var req chatcompletions.ChatCompletionNewParamsWrapper - if err := json.Unmarshal(payload, &req); err != nil { + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { return nil, fmt.Errorf("unmarshal chat completions request body: %w", err) } @@ -143,6 +140,10 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } case routeCopilotResponses: + payload, err := io.ReadAll(r.Body) + if err != nil { + return nil, fmt.Errorf("read body: %w", err) + } var req responses.ResponsesNewParamsWrapper if err := json.Unmarshal(payload, &req); err != nil { return nil, fmt.Errorf("unmarshal responses request body: %w", err) @@ -177,7 +178,7 @@ func extractBearerToken(auth string) string { // extractCopilotHeaders extracts headers required by the Copilot API from the // incoming request. Copilot requires certain client headers to be forwarded. func extractCopilotHeaders(r *http.Request) map[string]string { - headers := make(map[string]string) + headers := make(map[string]string, len(copilotForwardHeaders)) for _, h := range copilotForwardHeaders { if v := r.Header.Get(h); v != "" { headers[h] = v From ecd5e5085f3419d580ec81b64df3c5b8993357b8 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 27 Jan 2026 10:19:39 +0000 Subject: [PATCH 7/8] chore: address comments --- provider/copilot.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/provider/copilot.go b/provider/copilot.go index 5276931..dae8f54 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -30,7 +30,11 @@ var copilotOpenErrorResponse = func() []byte { return []byte(`{"error":{"message":"circuit breaker is open","type":"server_error","code":"service_unavailable"}}`) } -// Headers that need to be forwarded to Copilot API +// Headers that need to be forwarded to Copilot API. +// These were determined through manual testing as there is no reference +// of the headers in the official documentation. +// LiteLLM uses the same headers: +// https://docs.litellm.ai/docs/providers/github_copilot var copilotForwardHeaders = []string{ "Editor-Version", "Copilot-Integration-Id", From 2ef645897162b528adcb04c83538019d0d57f9e4 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 27 Jan 2026 10:42:14 +0000 Subject: [PATCH 8/8] chore: fix fmt --- config/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config.go b/config/config.go index 385bb61..370a68b 100644 --- a/config/config.go +++ b/config/config.go @@ -32,7 +32,7 @@ type OpenAI struct { APIDumpDir string CircuitBreaker *CircuitBreaker SendActorHeaders bool - ExtraHeaders map[string]string + ExtraHeaders map[string]string } // CircuitBreaker holds configuration for circuit breakers.