Skip to content

Commit a2eb2d3

Browse files
committed
feat: add GitHub Copilot provider with per-user token authentication
1 parent efa7ba4 commit a2eb2d3

5 files changed

Lines changed: 554 additions & 0 deletions

File tree

api.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
const (
1818
ProviderAnthropic = config.ProviderAnthropic
1919
ProviderOpenAI = config.ProviderOpenAI
20+
ProviderCopilot = config.ProviderCopilot
2021
)
2122

2223
type (
@@ -35,6 +36,7 @@ type (
3536
AnthropicConfig = config.Anthropic
3637
AWSBedrockConfig = config.AWSBedrock
3738
OpenAIConfig = config.OpenAI
39+
CopilotConfig = config.Copilot
3840
)
3941

4042
func AsActor(ctx context.Context, actorID string, metadata recorder.Metadata) context.Context {
@@ -49,6 +51,10 @@ func NewOpenAIProvider(cfg config.OpenAI) provider.Provider {
4951
return provider.NewOpenAI(cfg)
5052
}
5153

54+
func NewCopilotProvider(cfg config.Copilot) provider.Provider {
55+
return provider.NewCopilot(cfg)
56+
}
57+
5258
func NewMetrics(reg prometheus.Registerer) *metrics.Metrics {
5359
return metrics.NewMetrics(reg)
5460
}

config/config.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import "time"
55
const (
66
ProviderAnthropic = "anthropic"
77
ProviderOpenAI = "openai"
8+
ProviderCopilot = "copilot"
89
)
910

1011
// CircuitBreaker holds configuration for circuit breakers.
@@ -57,4 +58,11 @@ type OpenAI struct {
5758
Key string
5859
APIDumpDir string
5960
CircuitBreaker *CircuitBreaker
61+
ExtraHeaders map[string]string
62+
}
63+
64+
type Copilot struct {
65+
BaseURL string
66+
APIDumpDir string
67+
CircuitBreaker *CircuitBreaker
6068
}

intercept/chatcompletions/base.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ type interceptionBase struct {
3939
func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
4040
opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)}
4141

42+
// Add extra headers if configured.
43+
// Some providers require additional headers that are not added by the SDK.
44+
for key, value := range i.cfg.ExtraHeaders {
45+
opts = append(opts, option.WithHeader(key, value))
46+
}
47+
4248
// Add API dump middleware if configured
4349
if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
4450
opts = append(opts, option.WithMiddleware(mw))

provider/copilot.go

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
package provider
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io"
7+
"net/http"
8+
"os"
9+
"strings"
10+
11+
"github.com/coder/aibridge/config"
12+
"github.com/coder/aibridge/intercept"
13+
"github.com/coder/aibridge/intercept/chatcompletions"
14+
"github.com/coder/aibridge/intercept/responses"
15+
"github.com/coder/aibridge/tracing"
16+
"github.com/google/uuid"
17+
"go.opentelemetry.io/otel/codes"
18+
"go.opentelemetry.io/otel/trace"
19+
)
20+
21+
const (
22+
copilotBaseURL = "https://api.individual.githubcopilot.com"
23+
routeCopilotChatCompletions = "/copilot/chat/completions"
24+
routeCopilotResponses = "/copilot/responses"
25+
)
26+
27+
var copilotOpenErrorResponse = func() []byte {
28+
return []byte(`{"error":{"message":"circuit breaker is open","type":"server_error","code":"service_unavailable"}}`)
29+
}
30+
31+
// Headers that need to be forwarded to Copilot API
32+
var copilotForwardHeaders = []string{
33+
"Editor-Version",
34+
"Copilot-Integration-Id",
35+
}
36+
37+
// Copilot implements the Provider interface for GitHub Copilot.
38+
// Unlike other providers, Copilot uses per-user API keys that are passed through
39+
// the request headers rather than configured statically.
40+
type Copilot struct {
41+
cfg config.Copilot
42+
circuitBreaker *config.CircuitBreaker
43+
}
44+
45+
var _ Provider = &Copilot{}
46+
47+
func NewCopilot(cfg config.Copilot) *Copilot {
48+
if cfg.BaseURL == "" {
49+
cfg.BaseURL = copilotBaseURL
50+
}
51+
if cfg.APIDumpDir == "" {
52+
cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR")
53+
}
54+
if cfg.CircuitBreaker != nil {
55+
cfg.CircuitBreaker.OpenErrorResponse = copilotOpenErrorResponse
56+
}
57+
return &Copilot{
58+
cfg: cfg,
59+
circuitBreaker: cfg.CircuitBreaker,
60+
}
61+
}
62+
63+
func (p *Copilot) Name() string {
64+
return config.ProviderCopilot
65+
}
66+
67+
func (p *Copilot) BaseURL() string {
68+
return p.cfg.BaseURL
69+
}
70+
71+
func (p *Copilot) BridgedRoutes() []string {
72+
return []string{
73+
routeCopilotChatCompletions,
74+
routeCopilotResponses,
75+
}
76+
}
77+
78+
func (p *Copilot) PassthroughRoutes() []string {
79+
return []string{
80+
"/models",
81+
"/models/",
82+
"/agents/",
83+
"/mcp/",
84+
}
85+
}
86+
87+
func (p *Copilot) AuthHeader() string {
88+
return "Authorization"
89+
}
90+
91+
// InjectAuthHeader is a no-op for Copilot.
92+
// Copilot uses per-user tokens passed in the original Authorization header,
93+
// rather than a global key configured at the provider level.
94+
// The original Authorization header flows through untouched from the client.
95+
func (p *Copilot) InjectAuthHeader(_ *http.Header) {}
96+
97+
func (p *Copilot) CircuitBreakerConfig() *config.CircuitBreaker {
98+
return p.circuitBreaker
99+
}
100+
101+
func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) {
102+
_, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor")
103+
defer tracing.EndSpanErr(span, &outErr)
104+
105+
// Extract the per-user Copilot key from the Authorization header.
106+
key := extractBearerToken(r.Header.Get("Authorization"))
107+
if key == "" {
108+
span.SetStatus(codes.Error, "missing authorization")
109+
return nil, fmt.Errorf("missing Copilot authorization: Authorization header not found or invalid")
110+
}
111+
112+
payload, err := io.ReadAll(r.Body)
113+
if err != nil {
114+
return nil, fmt.Errorf("read body: %w", err)
115+
}
116+
117+
id := uuid.New()
118+
119+
// Build config for the interceptor using the per-request key.
120+
// Copilot's API is OpenAI-compatible, so it uses the OpenAI interceptors
121+
// that require a config.OpenAI.
122+
cfg := config.OpenAI{
123+
BaseURL: p.cfg.BaseURL,
124+
Key: key,
125+
APIDumpDir: p.cfg.APIDumpDir,
126+
CircuitBreaker: p.cfg.CircuitBreaker,
127+
ExtraHeaders: extractCopilotHeaders(r),
128+
}
129+
130+
var interceptor intercept.Interceptor
131+
132+
switch r.URL.Path {
133+
case routeCopilotChatCompletions:
134+
var req chatcompletions.ChatCompletionNewParamsWrapper
135+
if err := json.Unmarshal(payload, &req); err != nil {
136+
return nil, fmt.Errorf("unmarshal chat completions request body: %w", err)
137+
}
138+
139+
if req.Stream {
140+
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, tracer)
141+
} else {
142+
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, tracer)
143+
}
144+
145+
case routeCopilotResponses:
146+
var req responses.ResponsesNewParamsWrapper
147+
if err := json.Unmarshal(payload, &req); err != nil {
148+
return nil, fmt.Errorf("unmarshal responses request body: %w", err)
149+
}
150+
151+
if req.Stream {
152+
interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, req.Model, tracer)
153+
} else {
154+
interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, req.Model, tracer)
155+
}
156+
157+
default:
158+
span.SetStatus(codes.Error, "unknown route: "+r.URL.Path)
159+
return nil, UnknownRoute
160+
}
161+
162+
span.SetAttributes(interceptor.TraceAttributes(r)...)
163+
return interceptor, nil
164+
}
165+
166+
// extractBearerToken extracts the token from a "Bearer <token>" authorization header.
167+
func extractBearerToken(auth string) string {
168+
if auth := strings.TrimSpace(auth); auth != "" {
169+
fields := strings.Fields(auth)
170+
if len(fields) == 2 && strings.EqualFold(fields[0], "Bearer") {
171+
return fields[1]
172+
}
173+
}
174+
return ""
175+
}
176+
177+
// extractCopilotHeaders extracts headers required by the Copilot API from the
178+
// incoming request. Copilot requires certain client headers to be forwarded.
179+
func extractCopilotHeaders(r *http.Request) map[string]string {
180+
headers := make(map[string]string)
181+
for _, h := range copilotForwardHeaders {
182+
if v := r.Header.Get(h); v != "" {
183+
headers[h] = v
184+
}
185+
}
186+
return headers
187+
}

0 commit comments

Comments
 (0)