diff --git a/api.go b/api.go index 897c401..acc789e 100644 --- a/api.go +++ b/api.go @@ -17,6 +17,7 @@ import ( const ( ProviderAnthropic = config.ProviderAnthropic ProviderOpenAI = config.ProviderOpenAI + ProviderCopilot = config.ProviderCopilot ) type ( @@ -35,6 +36,7 @@ type ( AnthropicConfig = config.Anthropic AWSBedrockConfig = config.AWSBedrock OpenAIConfig = config.OpenAI + CopilotConfig = config.Copilot ) func AsActor(ctx context.Context, actorID string, metadata recorder.Metadata) context.Context { @@ -49,6 +51,10 @@ func NewOpenAIProvider(cfg config.OpenAI) provider.Provider { return provider.NewOpenAI(cfg) } +func NewCopilotProvider(cfg config.Copilot) provider.Provider { + return provider.NewCopilot(cfg) +} + func NewMetrics(reg prometheus.Registerer) *metrics.Metrics { return metrics.NewMetrics(reg) } diff --git a/config/config.go b/config/config.go index e9cc526..370a68b 100644 --- a/config/config.go +++ b/config/config.go @@ -5,6 +5,7 @@ import "time" const ( ProviderAnthropic = "anthropic" ProviderOpenAI = "openai" + ProviderCopilot = "copilot" ) type Anthropic struct { @@ -31,6 +32,7 @@ type OpenAI struct { APIDumpDir string CircuitBreaker *CircuitBreaker SendActorHeaders bool + ExtraHeaders map[string]string } // CircuitBreaker holds configuration for circuit breakers. @@ -60,3 +62,9 @@ func DefaultCircuitBreaker() CircuitBreaker { MaxRequests: 3, } } + +type Copilot struct { + BaseURL string + APIDumpDir string + CircuitBreaker *CircuitBreaker +} diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 02e7f7b..ac7476b 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -39,6 +39,12 @@ type interceptionBase struct { func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService { opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)} + // Add extra headers if configured. + // Some providers require additional headers that are not added by the SDK. + for key, value := range i.cfg.ExtraHeaders { + opts = append(opts, option.WithHeader(key, value)) + } + // Add API dump middleware if configured if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) diff --git a/provider/copilot.go b/provider/copilot.go new file mode 100644 index 0000000..dae8f54 --- /dev/null +++ b/provider/copilot.go @@ -0,0 +1,192 @@ +package provider + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/intercept" + "github.com/coder/aibridge/intercept/chatcompletions" + "github.com/coder/aibridge/intercept/responses" + "github.com/coder/aibridge/tracing" + "github.com/google/uuid" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +const ( + copilotBaseURL = "https://api.individual.githubcopilot.com" + + // Copilot exposes an OpenAI-compatible API, including for Anthropic models. + routeCopilotChatCompletions = "/copilot/chat/completions" + routeCopilotResponses = "/copilot/responses" +) + +var copilotOpenErrorResponse = func() []byte { + return []byte(`{"error":{"message":"circuit breaker is open","type":"server_error","code":"service_unavailable"}}`) +} + +// Headers that need to be forwarded to Copilot API. +// These were determined through manual testing as there is no reference +// of the headers in the official documentation. +// LiteLLM uses the same headers: +// https://docs.litellm.ai/docs/providers/github_copilot +var copilotForwardHeaders = []string{ + "Editor-Version", + "Copilot-Integration-Id", +} + +// Copilot implements the Provider interface for GitHub Copilot. +// Unlike other providers, Copilot uses per-user API keys that are passed through +// the request headers rather than configured statically. +type Copilot struct { + cfg config.Copilot + circuitBreaker *config.CircuitBreaker +} + +var _ Provider = &Copilot{} + +func NewCopilot(cfg config.Copilot) *Copilot { + if cfg.BaseURL == "" { + cfg.BaseURL = copilotBaseURL + } + if cfg.APIDumpDir == "" { + cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR") + } + if cfg.CircuitBreaker != nil { + cfg.CircuitBreaker.OpenErrorResponse = copilotOpenErrorResponse + } + return &Copilot{ + cfg: cfg, + circuitBreaker: cfg.CircuitBreaker, + } +} + +func (p *Copilot) Name() string { + return config.ProviderCopilot +} + +func (p *Copilot) BaseURL() string { + return p.cfg.BaseURL +} + +func (p *Copilot) BridgedRoutes() []string { + return []string{ + routeCopilotChatCompletions, + routeCopilotResponses, + } +} + +func (p *Copilot) PassthroughRoutes() []string { + return []string{ + "/models", + "/models/", + "/agents/", + "/mcp/", + } +} + +func (p *Copilot) AuthHeader() string { + return "Authorization" +} + +// InjectAuthHeader is a no-op for Copilot. +// Copilot uses per-user tokens passed in the original Authorization header, +// rather than a global key configured at the provider level. +// The original Authorization header flows through untouched from the client. +func (p *Copilot) InjectAuthHeader(_ *http.Header) {} + +func (p *Copilot) CircuitBreakerConfig() *config.CircuitBreaker { + return p.circuitBreaker +} + +func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) { + _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") + defer tracing.EndSpanErr(span, &outErr) + + // Extract the per-user Copilot key from the Authorization header. + key := extractBearerToken(r.Header.Get("Authorization")) + if key == "" { + span.SetStatus(codes.Error, "missing authorization") + return nil, fmt.Errorf("missing Copilot authorization: Authorization header not found or invalid") + } + + id := uuid.New() + + // Build config for the interceptor using the per-request key. + // Copilot's API is OpenAI-compatible, so it uses the OpenAI interceptors + // that require a config.OpenAI. + cfg := config.OpenAI{ + BaseURL: p.cfg.BaseURL, + Key: key, + APIDumpDir: p.cfg.APIDumpDir, + CircuitBreaker: p.cfg.CircuitBreaker, + ExtraHeaders: extractCopilotHeaders(r), + } + + var interceptor intercept.Interceptor + + switch r.URL.Path { + case routeCopilotChatCompletions: + var req chatcompletions.ChatCompletionNewParamsWrapper + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, fmt.Errorf("unmarshal chat completions request body: %w", err) + } + + if req.Stream { + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, tracer) + } else { + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, tracer) + } + + case routeCopilotResponses: + payload, err := io.ReadAll(r.Body) + if err != nil { + return nil, fmt.Errorf("read body: %w", err) + } + var req responses.ResponsesNewParamsWrapper + if err := json.Unmarshal(payload, &req); err != nil { + return nil, fmt.Errorf("unmarshal responses request body: %w", err) + } + + if req.Stream { + interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, req.Model, tracer) + } else { + interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, req.Model, tracer) + } + + default: + span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) + return nil, UnknownRoute + } + + span.SetAttributes(interceptor.TraceAttributes(r)...) + return interceptor, nil +} + +// extractBearerToken extracts the token from a "Bearer " authorization header. +func extractBearerToken(auth string) string { + if auth := strings.TrimSpace(auth); auth != "" { + fields := strings.Fields(auth) + if len(fields) == 2 && strings.EqualFold(fields[0], "Bearer") { + return fields[1] + } + } + return "" +} + +// extractCopilotHeaders extracts headers required by the Copilot API from the +// incoming request. Copilot requires certain client headers to be forwarded. +func extractCopilotHeaders(r *http.Request) map[string]string { + headers := make(map[string]string, len(copilotForwardHeaders)) + for _, h := range copilotForwardHeaders { + if v := r.Header.Get(h); v != "" { + headers[h] = v + } + } + return headers +} diff --git a/provider/copilot_test.go b/provider/copilot_test.go new file mode 100644 index 0000000..119c114 --- /dev/null +++ b/provider/copilot_test.go @@ -0,0 +1,347 @@ +package provider + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "cdr.dev/slog/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/internal/testutil" +) + +var testTracer = otel.Tracer("copilot_test") + +func TestCopilot_InjectAuthHeader(t *testing.T) { + t.Parallel() + + // Copilot uses per-user key passed in the Authorization header, + // so InjectAuthHeader should not modify any headers. + provider := NewCopilot(config.Copilot{}) + + t.Run("ExistingHeaders_Unchanged", func(t *testing.T) { + t.Parallel() + + headers := http.Header{} + headers.Set("Authorization", "Bearer user-token") + headers.Set("X-Custom-Header", "custom-value") + + provider.InjectAuthHeader(&headers) + + assert.Equal(t, "Bearer user-token", headers.Get("Authorization"), + "Authorization header should remain unchanged") + assert.Equal(t, "custom-value", headers.Get("X-Custom-Header"), + "other headers should remain unchanged") + }) + + t.Run("EmptyHeaders_NoneAdded", func(t *testing.T) { + t.Parallel() + + headers := http.Header{} + + provider.InjectAuthHeader(&headers) + + assert.Empty(t, headers, "no headers should be added") + }) +} + +func TestCopilot_CreateInterceptor(t *testing.T) { + t.Parallel() + + provider := NewCopilot(config.Copilot{}) + + t.Run("MissingAuthorizationHeader", func(t *testing.T) { + t.Parallel() + + body := `{"model": "gpt-4.1", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.Error(t, err) + require.Nil(t, interceptor) + assert.Contains(t, err.Error(), "missing Copilot authorization: Authorization header not found or invalid") + }) + + t.Run("InvalidAuthorizationFormat", func(t *testing.T) { + t.Parallel() + + body := `{"model": "claude-haiku-4.5", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "InvalidFormat") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.Error(t, err) + require.Nil(t, interceptor) + assert.Contains(t, err.Error(), "missing Copilot authorization: Authorization header not found or invalid") + }) + + t.Run("ChatCompletions_NonStreamingRequest_BlockingInterceptor", func(t *testing.T) { + t.Parallel() + + body := `{"model": "claude-haiku-4.5", "messages": [{"role": "user", "content": "hello"}], "stream": false}` + req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.NoError(t, err) + require.NotNil(t, interceptor) + assert.False(t, interceptor.Streaming()) + }) + + t.Run("ChatCompletions_StreamingRequest_StreamingInterceptor", func(t *testing.T) { + t.Parallel() + + body := `{"model": "gpt-4.1", "messages": [{"role": "user", "content": "hello"}], "stream": true}` + req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.NoError(t, err) + require.NotNil(t, interceptor) + assert.True(t, interceptor.Streaming()) + }) + + t.Run("ChatCompletions_InvalidRequestBody", func(t *testing.T) { + t.Parallel() + + body := `invalid json` + req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.Error(t, err) + require.Nil(t, interceptor) + assert.Contains(t, err.Error(), "unmarshal chat completions request body") + }) + + t.Run("ChatCompletions_ForwardsHeadersToUpstream", func(t *testing.T) { + t.Parallel() + + var receivedHeaders http.Header + + // Mock upstream that captures headers + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`)) + })) + t.Cleanup(mockUpstream.Close) + + // Create provider with mock upstream URL + provider := NewCopilot(config.Copilot{ + BaseURL: mockUpstream.URL, + }) + + body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": false}` + req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("Editor-Version", "vscode/1.85.0") + req.Header.Set("Copilot-Integration-Id", "test-integration") + req.Header.Set("X-Custom-Header", "should-not-forward") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + require.NoError(t, err) + require.NotNil(t, interceptor) + + // Setup and process request + logger := slog.Make() + interceptor.Setup(logger, &testutil.MockRecorder{}, nil) + + processReq := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, nil) + err = interceptor.ProcessRequest(w, processReq) + require.NoError(t, err) + + // Verify headers were forwarded + assert.Equal(t, "vscode/1.85.0", receivedHeaders.Get("Editor-Version")) + assert.Equal(t, "test-integration", receivedHeaders.Get("Copilot-Integration-Id")) + + // Verify non-Copilot headers are not forwarded + assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Copilot headers should not be forwarded") + }) + + t.Run("Responses_NonStreamingRequest_BlockingInterceptor", func(t *testing.T) { + t.Parallel() + + body := `{"model": "gpt-5-mini", "input": "hello", "stream": false}` + req := httptest.NewRequest(http.MethodPost, routeCopilotResponses, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.NoError(t, err) + require.NotNil(t, interceptor) + assert.False(t, interceptor.Streaming()) + }) + + t.Run("Responses_StreamingRequest_StreamingInterceptor", func(t *testing.T) { + t.Parallel() + + body := `{"model": "gpt-5-mini", "input": "hello", "stream": true}` + req := httptest.NewRequest(http.MethodPost, routeCopilotResponses, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.NoError(t, err) + require.NotNil(t, interceptor) + assert.True(t, interceptor.Streaming()) + }) + + t.Run("Responses_InvalidRequestBody", func(t *testing.T) { + t.Parallel() + + body := `invalid json` + req := httptest.NewRequest(http.MethodPost, routeCopilotResponses, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.Error(t, err) + require.Nil(t, interceptor) + assert.Contains(t, err.Error(), "unmarshal responses request body") + }) + + t.Run("UnknownRoute", func(t *testing.T) { + t.Parallel() + + body := `{"model": "gpt-4.1", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/copilot/unknown/route", bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.ErrorIs(t, err, UnknownRoute) + require.Nil(t, interceptor) + }) +} + +func Test_extractBearerToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "Empty", + input: "", + expected: "", + }, + { + name: "Whitespace", + input: " ", + expected: "", + }, + { + name: "InvalidFormat", + input: "some-token", + expected: "", + }, + { + name: "BearerOnly", + input: "Bearer", + expected: "", + }, + { + name: "Valid", + input: "Bearer my-secret-token", + expected: "my-secret-token", + }, + { + name: "BearerMixedCase", + input: "BeArEr my-secret-token", + expected: "my-secret-token", + }, + { + name: "LeadingWhitespace", + input: " Bearer my-secret-token", + expected: "my-secret-token", + }, + { + name: "TrailingWhitespace", + input: "Bearer my-secret-token ", + expected: "my-secret-token", + }, + { + name: "TooManyParts", + input: "Bearer token extra", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := extractBearerToken(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestExtractCopilotHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + headers map[string]string + expected map[string]string + }{ + { + name: "all headers present", + headers: map[string]string{"Editor-Version": "vscode/1.85.0", "Copilot-Integration-Id": "some-id"}, + expected: map[string]string{"Editor-Version": "vscode/1.85.0", "Copilot-Integration-Id": "some-id"}, + }, + { + name: "some headers present", + headers: map[string]string{"Editor-Version": "vscode/1.85.0"}, + expected: map[string]string{"Editor-Version": "vscode/1.85.0"}, + }, + { + name: "no headers", + headers: map[string]string{}, + expected: map[string]string{}, + }, + { + name: "ignores other headers", + headers: map[string]string{"Editor-Version": "vscode/1.85.0", "Authorization": "Bearer token"}, + expected: map[string]string{"Editor-Version": "vscode/1.85.0"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodPost, "/", nil) + for header, value := range tc.headers { + req.Header.Set(header, value) + } + + result := extractCopilotHeaders(req) + assert.Equal(t, tc.expected, result) + }) + } +}