Skip to content

Commit 09d6738

Browse files
committed
chore: lint fixes part 9
1 parent a0dc36f commit 09d6738

30 files changed

Lines changed: 171 additions & 181 deletions

bridge.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
180180

181181
// We execute this before CreateInterceptor since the interceptors
182182
// read the request body and don't reset them.
183-
client := guessClient(r)
184-
sessionID := guessSessionID(client, r)
183+
client := GuessClient(r)
184+
sessionID := GuessSessionID(client, r)
185185

186186
interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer)
187187
if err != nil {

bridge_test.go

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package aibridge
1+
package aibridge_test
22

33
import (
44
"net/http"
@@ -7,16 +7,22 @@ import (
77

88
"github.com/stretchr/testify/assert"
99
"github.com/stretchr/testify/require"
10+
"go.opentelemetry.io/otel"
1011

1112
"cdr.dev/slog/v3/sloggers/slogtest"
13+
"github.com/coder/aibridge"
1214
"github.com/coder/aibridge/config"
1315
"github.com/coder/aibridge/internal/testutil"
1416
"github.com/coder/aibridge/provider"
1517
)
1618

17-
func TestValidateProvider_Names(t *testing.T) {
19+
var bridgeTestTracer = otel.Tracer("bridge_test")
20+
21+
func TestValidateProviders(t *testing.T) {
1822
t.Parallel()
1923

24+
logger := slogtest.Make(t, nil)
25+
2026
tests := []struct {
2127
name string
2228
providers []provider.Provider
@@ -25,94 +31,69 @@ func TestValidateProvider_Names(t *testing.T) {
2531
{
2632
name: "all_supported_providers",
2733
providers: []provider.Provider{
28-
NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
29-
NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
30-
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
31-
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
32-
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
34+
aibridge.NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
35+
aibridge.NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
36+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
37+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
38+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
3339
},
3440
},
3541
{
3642
name: "default_names_and_base_urls",
3743
providers: []provider.Provider{
38-
NewOpenAIProvider(config.OpenAI{}),
39-
NewAnthropicProvider(config.Anthropic{}, nil),
40-
NewCopilotProvider(config.Copilot{}),
44+
aibridge.NewOpenAIProvider(config.OpenAI{}),
45+
aibridge.NewAnthropicProvider(config.Anthropic{}, nil),
46+
aibridge.NewCopilotProvider(config.Copilot{}),
4147
},
4248
},
4349
{
4450
name: "multiple_copilot_instances",
4551
providers: []provider.Provider{
46-
NewCopilotProvider(config.Copilot{}),
47-
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
48-
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
52+
aibridge.NewCopilotProvider(config.Copilot{}),
53+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
54+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
4955
},
5056
},
5157
{
5258
name: "name_with_slashes",
5359
providers: []provider.Provider{
54-
NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}),
60+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}),
5561
},
5662
expectErr: "invalid provider name",
5763
},
5864
{
5965
name: "name_with_spaces",
6066
providers: []provider.Provider{
61-
NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}),
67+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}),
6268
},
6369
expectErr: "invalid provider name",
6470
},
6571
{
6672
name: "name_with_uppercase",
6773
providers: []provider.Provider{
68-
NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}),
74+
aibridge.NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}),
6975
},
7076
expectErr: "invalid provider name",
7177
},
72-
}
73-
74-
for _, tc := range tests {
75-
t.Run(tc.name, func(t *testing.T) {
76-
t.Parallel()
77-
78-
err := validateProviders(tc.providers)
79-
if tc.expectErr != "" {
80-
require.Error(t, err)
81-
assert.Contains(t, err.Error(), tc.expectErr)
82-
} else {
83-
require.NoError(t, err)
84-
}
85-
})
86-
}
87-
}
88-
89-
func TestValidateProvider_DuplicateNames(t *testing.T) {
90-
t.Parallel()
91-
92-
tests := []struct {
93-
name string
94-
providers []provider.Provider
95-
expectErr string
96-
}{
9778
{
9879
name: "unique_names",
9980
providers: []provider.Provider{
100-
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
101-
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
81+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
82+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
10283
},
10384
},
10485
{
10586
name: "duplicate_base_url_different_names",
10687
providers: []provider.Provider{
107-
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
108-
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
88+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
89+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
10990
},
11091
},
11192
{
11293
name: "duplicate_name",
11394
providers: []provider.Provider{
114-
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
115-
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
95+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
96+
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
11697
},
11798
expectErr: "duplicate provider name",
11899
},
@@ -122,7 +103,7 @@ func TestValidateProvider_DuplicateNames(t *testing.T) {
122103
t.Run(tc.name, func(t *testing.T) {
123104
t.Parallel()
124105

125-
err := validateProviders(tc.providers)
106+
_, err := aibridge.NewRequestBridge(t.Context(), tc.providers, nil, nil, logger, nil, bridgeTestTracer)
126107
if tc.expectErr != "" {
127108
require.Error(t, err)
128109
assert.Contains(t, err.Error(), tc.expectErr)
@@ -148,7 +129,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
148129
name: "openAI_no_base_path",
149130
requestPath: "/openai/v1/conversations",
150131
provider: func(baseURL string) provider.Provider {
151-
return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
132+
return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
152133
},
153134
expectPath: "/conversations",
154135
},
@@ -157,15 +138,15 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
157138
baseURLPath: "/v1",
158139
requestPath: "/openai/v1/conversations",
159140
provider: func(baseURL string) provider.Provider {
160-
return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
141+
return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
161142
},
162143
expectPath: "/v1/conversations",
163144
},
164145
{
165146
name: "anthropic_no_base_path",
166147
requestPath: "/anthropic/v1/models",
167148
provider: func(baseURL string) provider.Provider {
168-
return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
149+
return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
169150
},
170151
expectPath: "/v1/models",
171152
},
@@ -174,15 +155,15 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
174155
baseURLPath: "/v1",
175156
requestPath: "/anthropic/v1/models",
176157
provider: func(baseURL string) provider.Provider {
177-
return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
158+
return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
178159
},
179160
expectPath: "/v1/v1/models",
180161
},
181162
{
182163
name: "copilot_no_base_path",
183164
requestPath: "/copilot/models",
184165
provider: func(baseURL string) provider.Provider {
185-
return NewCopilotProvider(config.Copilot{BaseURL: baseURL})
166+
return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL})
186167
},
187168
expectPath: "/models",
188169
},
@@ -191,7 +172,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
191172
baseURLPath: "/v1",
192173
requestPath: "/copilot/models",
193174
provider: func(baseURL string) provider.Provider {
194-
return NewCopilotProvider(config.Copilot{BaseURL: baseURL})
175+
return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL})
195176
},
196177
expectPath: "/v1/models",
197178
},
@@ -210,14 +191,14 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
210191
}))
211192
t.Cleanup(upstream.Close)
212193

213-
recorder := testutil.MockRecorder{}
194+
rec := testutil.MockRecorder{}
214195
prov := tc.provider(upstream.URL + tc.baseURLPath)
215-
bridge, err := NewRequestBridge(t.Context(), []provider.Provider{prov}, &recorder, nil, logger, nil, testTracer)
196+
bridge, err := aibridge.NewRequestBridge(t.Context(), []provider.Provider{prov}, &rec, nil, logger, nil, bridgeTestTracer)
216197
require.NoError(t, err)
217198

218199
req := httptest.NewRequest("", tc.requestPath, nil)
219200
resp := httptest.NewRecorder()
220-
bridge.mux.ServeHTTP(resp, req)
201+
bridge.ServeHTTP(resp, req)
221202

222203
assert.Equal(t, http.StatusOK, resp.Code)
223204
assert.Contains(t, resp.Body.String(), upstreamRespBody)

circuitbreaker/circuitbreaker_test.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package circuitbreaker
1+
package circuitbreaker_test
22

33
import (
44
"errors"
@@ -11,6 +11,7 @@ import (
1111
"github.com/sony/gobreaker/v2"
1212
"github.com/stretchr/testify/assert"
1313

14+
"github.com/coder/aibridge/circuitbreaker"
1415
"github.com/coder/aibridge/config"
1516
)
1617

@@ -20,7 +21,7 @@ func TestExecute_PerModelIsolation(t *testing.T) {
2021
sonnetCalls := atomic.Int32{}
2122
haikuCalls := atomic.Int32{}
2223

23-
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
24+
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
2425
FailureThreshold: 1,
2526
Interval: time.Minute,
2627
Timeout: time.Minute,
@@ -48,7 +49,7 @@ func TestExecute_PerModelIsolation(t *testing.T) {
4849
rw.WriteHeader(http.StatusOK)
4950
return nil
5051
})
51-
assert.True(t, errors.Is(err, ErrCircuitOpen))
52+
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
5253
assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call
5354
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
5455

@@ -69,7 +70,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) {
6970
messagesCalls := atomic.Int32{}
7071
completionsCalls := atomic.Int32{}
7172

72-
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
73+
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
7374
FailureThreshold: 1,
7475
Interval: time.Minute,
7576
Timeout: time.Minute,
@@ -95,7 +96,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) {
9596
rw.WriteHeader(http.StatusOK)
9697
return nil
9798
})
98-
assert.True(t, errors.Is(err, ErrCircuitOpen))
99+
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
99100
assert.Equal(t, int32(1), messagesCalls.Load()) // No new call
100101
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
101102

@@ -116,7 +117,7 @@ func TestExecute_CustomIsFailure(t *testing.T) {
116117
var calls atomic.Int32
117118

118119
// Custom IsFailure that treats 502 as failure
119-
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
120+
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
120121
FailureThreshold: 1,
121122
Interval: time.Minute,
122123
Timeout: time.Minute,
@@ -143,7 +144,7 @@ func TestExecute_CustomIsFailure(t *testing.T) {
143144
rw.WriteHeader(http.StatusOK)
144145
return nil
145146
})
146-
assert.True(t, errors.Is(err, ErrCircuitOpen))
147+
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
147148
assert.Equal(t, int32(1), calls.Load()) // No new call
148149
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
149150
}
@@ -158,7 +159,7 @@ func TestExecute_OnStateChange(t *testing.T) {
158159
to gobreaker.State
159160
}
160161

161-
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
162+
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
162163
FailureThreshold: 1,
163164
Interval: time.Minute,
164165
Timeout: time.Minute,
@@ -209,14 +210,14 @@ func TestDefaultIsFailure(t *testing.T) {
209210
}
210211

211212
for _, tt := range tests {
212-
assert.Equal(t, tt.isFailure, DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode)
213+
assert.Equal(t, tt.isFailure, circuitbreaker.DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode)
213214
}
214215
}
215216

216217
func TestStateToGaugeValue(t *testing.T) {
217218
t.Parallel()
218219

219-
assert.Equal(t, float64(0), StateToGaugeValue(gobreaker.StateClosed))
220-
assert.Equal(t, float64(0.5), StateToGaugeValue(gobreaker.StateHalfOpen))
221-
assert.Equal(t, float64(1), StateToGaugeValue(gobreaker.StateOpen))
220+
assert.Equal(t, float64(0), circuitbreaker.StateToGaugeValue(gobreaker.StateClosed))
221+
assert.Equal(t, float64(0.5), circuitbreaker.StateToGaugeValue(gobreaker.StateHalfOpen))
222+
assert.Equal(t, float64(1), circuitbreaker.StateToGaugeValue(gobreaker.StateOpen))
222223
}

client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ const (
2424
ClientUnknown Client = "Unknown"
2525
)
2626

27-
// guessClient attempts to guess the client application from the request headers.
27+
// GuessClient attempts to guess the client application from the request headers.
2828
// Not all clients set proper user agent headers, so this is a best-effort approach.
2929
// Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101.
30-
func guessClient(r *http.Request) Client {
30+
func GuessClient(r *http.Request) Client {
3131
userAgent := strings.ToLower(r.UserAgent())
3232
originator := r.Header.Get("originator")
3333

0 commit comments

Comments
 (0)