Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC

// We execute this before CreateInterceptor since the interceptors
// read the request body and don't reset them.
client := guessClient(r)
sessionID := guessSessionID(client, r)
client := GuessClient(r)
sessionID := GuessSessionID(client, r)

interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer)
if err != nil {
Expand Down
95 changes: 38 additions & 57 deletions bridge_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package aibridge
package aibridge_test

import (
"net/http"
Expand All @@ -7,16 +7,22 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"

"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/aibridge"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/internal/testutil"
"github.com/coder/aibridge/provider"
)

func TestValidateProvider_Names(t *testing.T) {
var bridgeTestTracer = otel.Tracer("bridge_test")

func TestValidateProviders(t *testing.T) {
t.Parallel()

logger := slogtest.Make(t, nil)

tests := []struct {
name string
providers []provider.Provider
Expand All @@ -25,94 +31,69 @@ func TestValidateProvider_Names(t *testing.T) {
{
name: "all_supported_providers",
providers: []provider.Provider{
NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
aibridge.NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
aibridge.NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
},
},
{
name: "default_names_and_base_urls",
providers: []provider.Provider{
NewOpenAIProvider(config.OpenAI{}),
NewAnthropicProvider(config.Anthropic{}, nil),
NewCopilotProvider(config.Copilot{}),
aibridge.NewOpenAIProvider(config.OpenAI{}),
aibridge.NewAnthropicProvider(config.Anthropic{}, nil),
aibridge.NewCopilotProvider(config.Copilot{}),
},
},
{
name: "multiple_copilot_instances",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{}),
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
},
},
{
name: "name_with_slashes",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "invalid provider name",
},
{
name: "name_with_spaces",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "invalid provider name",
},
{
name: "name_with_uppercase",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "invalid provider name",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

err := validateProviders(tc.providers)
if tc.expectErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectErr)
} else {
require.NoError(t, err)
}
})
}
}

func TestValidateProvider_DuplicateNames(t *testing.T) {
t.Parallel()

tests := []struct {
name string
providers []provider.Provider
expectErr string
}{
{
name: "unique_names",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
},
},
{
name: "duplicate_base_url_different_names",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
},
},
{
name: "duplicate_name",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "duplicate provider name",
},
Expand All @@ -122,7 +103,7 @@ func TestValidateProvider_DuplicateNames(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

err := validateProviders(tc.providers)
_, err := aibridge.NewRequestBridge(t.Context(), tc.providers, nil, nil, logger, nil, bridgeTestTracer)
if tc.expectErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectErr)
Expand All @@ -148,7 +129,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
name: "openAI_no_base_path",
requestPath: "/openai/v1/conversations",
provider: func(baseURL string) provider.Provider {
return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
},
expectPath: "/conversations",
},
Expand All @@ -157,15 +138,15 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
baseURLPath: "/v1",
requestPath: "/openai/v1/conversations",
provider: func(baseURL string) provider.Provider {
return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
},
expectPath: "/v1/conversations",
},
{
name: "anthropic_no_base_path",
requestPath: "/anthropic/v1/models",
provider: func(baseURL string) provider.Provider {
return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
},
expectPath: "/v1/models",
},
Expand All @@ -174,15 +155,15 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
baseURLPath: "/v1",
requestPath: "/anthropic/v1/models",
provider: func(baseURL string) provider.Provider {
return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
},
expectPath: "/v1/v1/models",
},
{
name: "copilot_no_base_path",
requestPath: "/copilot/models",
provider: func(baseURL string) provider.Provider {
return NewCopilotProvider(config.Copilot{BaseURL: baseURL})
return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL})
},
expectPath: "/models",
},
Expand All @@ -191,7 +172,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
baseURLPath: "/v1",
requestPath: "/copilot/models",
provider: func(baseURL string) provider.Provider {
return NewCopilotProvider(config.Copilot{BaseURL: baseURL})
return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL})
},
expectPath: "/v1/models",
},
Expand All @@ -210,14 +191,14 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
}))
t.Cleanup(upstream.Close)

recorder := testutil.MockRecorder{}
rec := testutil.MockRecorder{}
prov := tc.provider(upstream.URL + tc.baseURLPath)
bridge, err := NewRequestBridge(t.Context(), []provider.Provider{prov}, &recorder, nil, logger, nil, testTracer)
bridge, err := aibridge.NewRequestBridge(t.Context(), []provider.Provider{prov}, &rec, nil, logger, nil, bridgeTestTracer)
require.NoError(t, err)

req := httptest.NewRequest("", tc.requestPath, nil)
resp := httptest.NewRecorder()
bridge.mux.ServeHTTP(resp, req)
bridge.ServeHTTP(resp, req)

assert.Equal(t, http.StatusOK, resp.Code)
assert.Contains(t, resp.Body.String(), upstreamRespBody)
Expand Down
25 changes: 13 additions & 12 deletions circuitbreaker/circuitbreaker_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package circuitbreaker
package circuitbreaker_test

import (
"errors"
Expand All @@ -11,6 +11,7 @@ import (
"github.com/sony/gobreaker/v2"
"github.com/stretchr/testify/assert"

"github.com/coder/aibridge/circuitbreaker"
"github.com/coder/aibridge/config"
)

Expand All @@ -20,7 +21,7 @@ func TestExecute_PerModelIsolation(t *testing.T) {
sonnetCalls := atomic.Int32{}
haikuCalls := atomic.Int32{}

cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
Expand Down Expand Up @@ -48,7 +49,7 @@ func TestExecute_PerModelIsolation(t *testing.T) {
rw.WriteHeader(http.StatusOK)
return nil
})
assert.True(t, errors.Is(err, ErrCircuitOpen))
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call
assert.Equal(t, http.StatusServiceUnavailable, w.Code)

Expand All @@ -69,7 +70,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) {
messagesCalls := atomic.Int32{}
completionsCalls := atomic.Int32{}

cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
Expand All @@ -95,7 +96,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) {
rw.WriteHeader(http.StatusOK)
return nil
})
assert.True(t, errors.Is(err, ErrCircuitOpen))
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
assert.Equal(t, int32(1), messagesCalls.Load()) // No new call
assert.Equal(t, http.StatusServiceUnavailable, w.Code)

Expand All @@ -116,7 +117,7 @@ func TestExecute_CustomIsFailure(t *testing.T) {
var calls atomic.Int32

// Custom IsFailure that treats 502 as failure
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
Expand All @@ -143,7 +144,7 @@ func TestExecute_CustomIsFailure(t *testing.T) {
rw.WriteHeader(http.StatusOK)
return nil
})
assert.True(t, errors.Is(err, ErrCircuitOpen))
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
assert.Equal(t, int32(1), calls.Load()) // No new call
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
}
Expand All @@ -158,7 +159,7 @@ func TestExecute_OnStateChange(t *testing.T) {
to gobreaker.State
}

cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
Expand Down Expand Up @@ -209,14 +210,14 @@ func TestDefaultIsFailure(t *testing.T) {
}

for _, tt := range tests {
assert.Equal(t, tt.isFailure, DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode)
assert.Equal(t, tt.isFailure, circuitbreaker.DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode)
}
}

func TestStateToGaugeValue(t *testing.T) {
t.Parallel()

assert.Equal(t, float64(0), StateToGaugeValue(gobreaker.StateClosed))
assert.Equal(t, float64(0.5), StateToGaugeValue(gobreaker.StateHalfOpen))
assert.Equal(t, float64(1), StateToGaugeValue(gobreaker.StateOpen))
assert.Equal(t, float64(0), circuitbreaker.StateToGaugeValue(gobreaker.StateClosed))
assert.Equal(t, float64(0.5), circuitbreaker.StateToGaugeValue(gobreaker.StateHalfOpen))
assert.Equal(t, float64(1), circuitbreaker.StateToGaugeValue(gobreaker.StateOpen))
}
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ const (
ClientUnknown Client = "Unknown"
)

// guessClient attempts to guess the client application from the request headers.
// GuessClient attempts to guess the client application from the request headers.
// Not all clients set proper user agent headers, so this is a best-effort approach.
// Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101.
func guessClient(r *http.Request) Client {
func GuessClient(r *http.Request) Client {
userAgent := strings.ToLower(r.UserAgent())
originator := r.Header.Get("originator")

Expand Down
Loading
Loading