Skip to content

Commit c1d483a

Browse files
committed
feat: add GitHub Copilot provider with per-user token authentication
1 parent b4479ce commit c1d483a

5 files changed

Lines changed: 509 additions & 0 deletions

File tree

api.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
const (
1818
ProviderAnthropic = config.ProviderAnthropic
1919
ProviderOpenAI = config.ProviderOpenAI
20+
ProviderCopilot = config.ProviderCopilot
2021
)
2122

2223
type (
@@ -35,6 +36,7 @@ type (
3536
AnthropicConfig = config.Anthropic
3637
AWSBedrockConfig = config.AWSBedrock
3738
OpenAIConfig = config.OpenAI
39+
CopilotConfig = config.Copilot
3840
)
3941

4042
func AsActor(ctx context.Context, actorID string, metadata recorder.Metadata) context.Context {
@@ -49,6 +51,12 @@ func NewOpenAIProvider(cfg config.OpenAI) provider.Provider {
4951
return provider.NewOpenAI(cfg)
5052
}
5153

54+
func NewCopilotProvider(cfg config.Copilot) provider.Provider {
55+
return provider.NewCopilot(provider.CopilotConfig{
56+
CircuitBreaker: cfg.CircuitBreaker,
57+
})
58+
}
59+
5260
func NewMetrics(reg prometheus.Registerer) *metrics.Metrics {
5361
return metrics.NewMetrics(reg)
5462
}

config/config.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import "time"
55
const (
66
ProviderAnthropic = "anthropic"
77
ProviderOpenAI = "openai"
8+
ProviderCopilot = "copilot"
89
)
910

1011
// CircuitBreaker holds configuration for circuit breakers.
@@ -57,4 +58,10 @@ type OpenAI struct {
5758
Key string
5859
APIDumpDir string
5960
CircuitBreaker *CircuitBreaker
61+
ExtraHeaders map[string]string
62+
}
63+
64+
type Copilot struct {
65+
APIDumpDir string
66+
CircuitBreaker *CircuitBreaker
6067
}

intercept/chatcompletions/base.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ type interceptionBase struct {
3939
func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
4040
opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)}
4141

42+
// Add extra headers if configured.
43+
// Some providers require additional headers that are not added by the SDK.
44+
for key, value := range i.cfg.ExtraHeaders {
45+
opts = append(opts, option.WithHeader(key, value))
46+
}
47+
4248
// Add API dump middleware if configured
4349
if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
4450
opts = append(opts, option.WithMiddleware(mw))

provider/copilot.go

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package provider
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io"
7+
"net/http"
8+
"os"
9+
"strings"
10+
11+
"github.com/coder/aibridge/config"
12+
"github.com/coder/aibridge/intercept"
13+
"github.com/coder/aibridge/intercept/chatcompletions"
14+
"github.com/coder/aibridge/intercept/responses"
15+
"github.com/coder/aibridge/tracing"
16+
"github.com/google/uuid"
17+
"go.opentelemetry.io/otel/codes"
18+
"go.opentelemetry.io/otel/trace"
19+
)
20+
21+
const (
22+
copilotBaseURL = "https://api.individual.githubcopilot.com"
23+
routeCopilotChatCompletions = "/copilot/chat/completions"
24+
routeCopilotResponses = "/copilot/responses"
25+
)
26+
27+
var copilotOpenErrorResponse = func() []byte {
28+
return []byte(`{"error":{"message":"circuit breaker is open","type":"server_error","code":"service_unavailable"}}`)
29+
}
30+
31+
// Headers that need to be forwarded to Copilot API
32+
var copilotForwardHeaders = []string{
33+
"Editor-Version",
34+
"Copilot-Integration-Id",
35+
}
36+
37+
// Copilot implements the Provider interface for GitHub Copilot.
38+
// Unlike other providers, Copilot uses per-user API keys that are passed through
39+
// the request headers rather than configured statically.
40+
type Copilot struct {
41+
cfg config.Copilot
42+
circuitBreaker *config.CircuitBreaker
43+
}
44+
45+
var _ Provider = &Copilot{}
46+
47+
func NewCopilot(cfg config.Copilot) *Copilot {
48+
if cfg.APIDumpDir == "" {
49+
cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR")
50+
}
51+
if cfg.CircuitBreaker != nil {
52+
cfg.CircuitBreaker.OpenErrorResponse = copilotOpenErrorResponse
53+
}
54+
return &Copilot{
55+
cfg: cfg,
56+
circuitBreaker: cfg.CircuitBreaker,
57+
}
58+
}
59+
60+
func (p *Copilot) Name() string {
61+
return config.ProviderCopilot
62+
}
63+
64+
func (p *Copilot) BaseURL() string {
65+
return copilotBaseURL
66+
}
67+
68+
func (p *Copilot) BridgedRoutes() []string {
69+
return []string{
70+
routeCopilotChatCompletions,
71+
routeCopilotResponses,
72+
}
73+
}
74+
75+
func (p *Copilot) PassthroughRoutes() []string {
76+
return []string{
77+
"/models",
78+
"/models/",
79+
"/agents/",
80+
"/mcp/",
81+
}
82+
}
83+
84+
func (p *Copilot) AuthHeader() string {
85+
return "Authorization"
86+
}
87+
88+
// InjectAuthHeader is a no-op for Copilot.
89+
// Copilot uses per-user tokens passed in the original Authorization header,
90+
// rather than a global key configured at the provider level.
91+
// The original Authorization header flows through untouched from the client.
92+
func (p *Copilot) InjectAuthHeader(_ *http.Header) {}
93+
94+
func (p *Copilot) CircuitBreakerConfig() *config.CircuitBreaker {
95+
return p.circuitBreaker
96+
}
97+
98+
func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) {
99+
_, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor")
100+
defer tracing.EndSpanErr(span, &outErr)
101+
102+
// Extract the per-user Copilot key from the Authorization header.
103+
key := extractBearerToken(r.Header.Get("Authorization"))
104+
if key == "" {
105+
span.SetStatus(codes.Error, "missing authorization")
106+
return nil, fmt.Errorf("missing Copilot authorization: Authorization header not found or invalid")
107+
}
108+
109+
payload, err := io.ReadAll(r.Body)
110+
if err != nil {
111+
return nil, fmt.Errorf("read body: %w", err)
112+
}
113+
114+
id := uuid.New()
115+
116+
// Capture headers that need to be forwarded to Copilot API
117+
extraHeaders := make(map[string]string)
118+
if editorVersion := r.Header.Get("Editor-Version"); editorVersion != "" {
119+
extraHeaders["Editor-Version"] = editorVersion
120+
}
121+
if copilotIntegrationID := r.Header.Get("Copilot-Integration-Id"); copilotIntegrationID != "" {
122+
extraHeaders["Copilot-Integration-Id"] = copilotIntegrationID
123+
}
124+
125+
// Build config for the interceptor using the per-request key.
126+
// Copilot's API is OpenAI-compatible, so it uses the OpenAI interceptors
127+
// that require a config.OpenAI.
128+
cfg := config.OpenAI{
129+
BaseURL: copilotBaseURL,
130+
Key: key,
131+
APIDumpDir: p.cfg.APIDumpDir,
132+
CircuitBreaker: p.cfg.CircuitBreaker,
133+
ExtraHeaders: extractCopilotHeaders(r),
134+
}
135+
136+
var interceptor intercept.Interceptor
137+
138+
switch r.URL.Path {
139+
case routeCopilotChatCompletions:
140+
var req chatcompletions.ChatCompletionNewParamsWrapper
141+
if err := json.Unmarshal(payload, &req); err != nil {
142+
return nil, fmt.Errorf("unmarshal chat completions request body: %w", err)
143+
}
144+
145+
if req.Stream {
146+
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, tracer)
147+
} else {
148+
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, tracer)
149+
}
150+
151+
case routeCopilotResponses:
152+
var req responses.ResponsesNewParamsWrapper
153+
if err := json.Unmarshal(payload, &req); err != nil {
154+
return nil, fmt.Errorf("unmarshal responses request body: %w", err)
155+
}
156+
157+
if req.Stream {
158+
interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, req.Model, tracer)
159+
} else {
160+
interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, req.Model, tracer)
161+
}
162+
163+
default:
164+
span.SetStatus(codes.Error, "unknown route: "+r.URL.Path)
165+
return nil, UnknownRoute
166+
}
167+
168+
span.SetAttributes(interceptor.TraceAttributes(r)...)
169+
return interceptor, nil
170+
}
171+
172+
// extractBearerToken extracts the token from a "Bearer <token>" authorization header.
173+
func extractBearerToken(auth string) string {
174+
if auth := strings.TrimSpace(auth); auth != "" {
175+
fields := strings.Fields(auth)
176+
if len(fields) == 2 && strings.EqualFold(fields[0], "Bearer") {
177+
return fields[1]
178+
}
179+
}
180+
return ""
181+
}
182+
183+
func extractCopilotHeaders(r *http.Request) map[string]string {
184+
headers := make(map[string]string)
185+
for _, h := range copilotForwardHeaders {
186+
if v := r.Header.Get(h); v != "" {
187+
headers[h] = v
188+
}
189+
}
190+
return headers
191+
}

0 commit comments

Comments
 (0)