Skip to content

Commit 3744945

Browse files
committed
feat: add GitHub Copilot provider with per-user token authentication
1 parent b4479ce commit 3744945

5 files changed

Lines changed: 445 additions & 0 deletions

File tree

api.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
const (
1818
ProviderAnthropic = config.ProviderAnthropic
1919
ProviderOpenAI = config.ProviderOpenAI
20+
ProviderCopilot = config.ProviderCopilot
2021
)
2122

2223
type (
@@ -35,6 +36,7 @@ type (
3536
AnthropicConfig = config.Anthropic
3637
AWSBedrockConfig = config.AWSBedrock
3738
OpenAIConfig = config.OpenAI
39+
CopilotConfig = config.Copilot
3840
)
3941

4042
func AsActor(ctx context.Context, actorID string, metadata recorder.Metadata) context.Context {
@@ -49,6 +51,12 @@ func NewOpenAIProvider(cfg config.OpenAI) provider.Provider {
4951
return provider.NewOpenAI(cfg)
5052
}
5153

54+
func NewCopilotProvider(cfg config.Copilot) provider.Provider {
55+
return provider.NewCopilot(provider.CopilotConfig{
56+
CircuitBreaker: cfg.CircuitBreaker,
57+
})
58+
}
59+
5260
func NewMetrics(reg prometheus.Registerer) *metrics.Metrics {
5361
return metrics.NewMetrics(reg)
5462
}

config/config.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import "time"
55
const (
66
ProviderAnthropic = "anthropic"
77
ProviderOpenAI = "openai"
8+
ProviderCopilot = "copilot"
89
)
910

1011
// CircuitBreaker holds configuration for circuit breakers.
@@ -57,4 +58,9 @@ type OpenAI struct {
5758
Key string
5859
APIDumpDir string
5960
CircuitBreaker *CircuitBreaker
61+
ExtraHeaders map[string]string
62+
}
63+
64+
type Copilot struct {
65+
CircuitBreaker *CircuitBreaker
6066
}

intercept/chatcompletions/base.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ type interceptionBase struct {
3939
func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
4040
opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)}
4141

42+
// Add extra headers if configured
43+
for key, value := range i.cfg.ExtraHeaders {
44+
opts = append(opts, option.WithHeader(key, value))
45+
}
46+
4247
// Add API dump middleware if configured
4348
if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
4449
opts = append(opts, option.WithMiddleware(mw))

provider/copilot.go

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
package provider
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io"
7+
"net/http"
8+
"os"
9+
"strings"
10+
11+
"github.com/coder/aibridge/config"
12+
"github.com/coder/aibridge/intercept"
13+
"github.com/coder/aibridge/intercept/chatcompletions"
14+
"github.com/coder/aibridge/intercept/responses"
15+
"github.com/coder/aibridge/tracing"
16+
"github.com/google/uuid"
17+
"go.opentelemetry.io/otel/codes"
18+
"go.opentelemetry.io/otel/trace"
19+
)
20+
21+
const (
22+
copilotBaseURL = "https://api.individual.githubcopilot.com"
23+
routeCopilotChatCompletions = "/copilot/chat/completions"
24+
routeCopilotResponses = "/copilot/responses"
25+
)
26+
27+
var copilotOpenErrorResponse = func() []byte {
28+
return []byte(`{"error":{"message":"circuit breaker is open","type":"server_error","code":"service_unavailable"}}`)
29+
}
30+
31+
// CopilotConfig configures the Copilot provider.
32+
type CopilotConfig struct {
33+
CircuitBreaker *config.CircuitBreaker
34+
}
35+
36+
// Copilot implements the Provider interface for GitHub Copilot.
37+
// Unlike other providers, Copilot uses per-user API keys that are passed through
38+
// the request headers rather than configured statically.
39+
type Copilot struct {
40+
circuitBreaker *config.CircuitBreaker
41+
}
42+
43+
var _ Provider = &Copilot{}
44+
45+
func NewCopilot(cfg CopilotConfig) *Copilot {
46+
if cfg.CircuitBreaker != nil {
47+
cfg.CircuitBreaker.OpenErrorResponse = copilotOpenErrorResponse
48+
}
49+
return &Copilot{
50+
circuitBreaker: cfg.CircuitBreaker,
51+
}
52+
}
53+
54+
func (p *Copilot) Name() string {
55+
return config.ProviderCopilot
56+
}
57+
58+
func (p *Copilot) BaseURL() string {
59+
return copilotBaseURL
60+
}
61+
62+
func (p *Copilot) BridgedRoutes() []string {
63+
return []string{
64+
routeCopilotChatCompletions,
65+
routeCopilotResponses,
66+
}
67+
}
68+
69+
func (p *Copilot) PassthroughRoutes() []string {
70+
return []string{
71+
"/models",
72+
"/models/",
73+
"/agents/",
74+
"/mcp/",
75+
}
76+
}
77+
78+
func (p *Copilot) AuthHeader() string {
79+
return "Authorization"
80+
}
81+
82+
// InjectAuthHeader is a no-op for Copilot.
83+
// Copilot uses per-user tokens passed in the original Authorization header,
84+
// rather than a global key configured at the provider level.
85+
// The original Authorization header flows through untouched from the client.
86+
func (p *Copilot) InjectAuthHeader(_ *http.Header) {}
87+
88+
func (p *Copilot) CircuitBreakerConfig() *config.CircuitBreaker {
89+
return p.circuitBreaker
90+
}
91+
92+
func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) {
93+
_, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor")
94+
defer tracing.EndSpanErr(span, &outErr)
95+
96+
// Extract the per-user Copilot key from the Authorization header.
97+
key := extractBearerToken(r.Header.Get("Authorization"))
98+
if key == "" {
99+
span.SetStatus(codes.Error, "missing authorization")
100+
return nil, fmt.Errorf("missing Copilot authorization: Authorization header not found or invalid")
101+
}
102+
103+
payload, err := io.ReadAll(r.Body)
104+
if err != nil {
105+
return nil, fmt.Errorf("read body: %w", err)
106+
}
107+
108+
id := uuid.New()
109+
110+
// Capture headers that need to be forwarded to Copilot API
111+
extraHeaders := make(map[string]string)
112+
if editorVersion := r.Header.Get("Editor-Version"); editorVersion != "" {
113+
extraHeaders["Editor-Version"] = editorVersion
114+
}
115+
if copilotIntegrationID := r.Header.Get("Copilot-Integration-Id"); copilotIntegrationID != "" {
116+
extraHeaders["Copilot-Integration-Id"] = copilotIntegrationID
117+
}
118+
119+
// Build config for the interceptor using the per-request key.
120+
// Copilot's API is OpenAI-compatible, so it uses the OpenAI interceptors
121+
// that require a config.OpenAI.
122+
cfg := config.OpenAI{
123+
BaseURL: copilotBaseURL,
124+
Key: key,
125+
APIDumpDir: os.Getenv("BRIDGE_DUMP_DIR"),
126+
CircuitBreaker: p.circuitBreaker,
127+
ExtraHeaders: extraHeaders,
128+
}
129+
130+
var interceptor intercept.Interceptor
131+
132+
switch r.URL.Path {
133+
case routeCopilotChatCompletions:
134+
var req chatcompletions.ChatCompletionNewParamsWrapper
135+
if err := json.Unmarshal(payload, &req); err != nil {
136+
return nil, fmt.Errorf("unmarshal chat completions request body: %w", err)
137+
}
138+
139+
if req.Stream {
140+
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, tracer)
141+
} else {
142+
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, tracer)
143+
}
144+
145+
case routeCopilotResponses:
146+
var req responses.ResponsesNewParamsWrapper
147+
if err := json.Unmarshal(payload, &req); err != nil {
148+
return nil, fmt.Errorf("unmarshal responses request body: %w", err)
149+
}
150+
151+
if req.Stream {
152+
interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, req.Model, tracer)
153+
} else {
154+
interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, req.Model, tracer)
155+
}
156+
157+
default:
158+
span.SetStatus(codes.Error, "unknown route: "+r.URL.Path)
159+
return nil, UnknownRoute
160+
}
161+
162+
span.SetAttributes(interceptor.TraceAttributes(r)...)
163+
return interceptor, nil
164+
}
165+
166+
// extractBearerToken extracts the token from a "Bearer <token>" authorization header.
167+
func extractBearerToken(auth string) string {
168+
if auth := strings.TrimSpace(auth); auth != "" {
169+
fields := strings.Fields(auth)
170+
if len(fields) == 2 && strings.EqualFold(fields[0], "Bearer") {
171+
return fields[1]
172+
}
173+
}
174+
return ""
175+
}

0 commit comments

Comments
 (0)