diff --git a/bridge.go b/bridge.go index 41f9d547..27bb3645 100644 --- a/bridge.go +++ b/bridge.go @@ -180,8 +180,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC // We execute this before CreateInterceptor since the interceptors // read the request body and don't reset them. - client := guessClient(r) - sessionID := guessSessionID(client, r) + client := GuessClient(r) + sessionID := GuessSessionID(client, r) interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer) if err != nil { diff --git a/bridge_test.go b/bridge_test.go index f83fd0b0..c6427077 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -1,4 +1,4 @@ -package aibridge +package aibridge_test import ( "net/http" @@ -7,16 +7,22 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/provider" ) -func TestValidateProvider_Names(t *testing.T) { +var bridgeTestTracer = otel.Tracer("bridge_test") + +func TestValidateProviders(t *testing.T) { t.Parallel() + logger := slogtest.Make(t, nil) + tests := []struct { name string providers []provider.Provider @@ -25,94 +31,69 @@ func TestValidateProvider_Names(t *testing.T) { { name: "all_supported_providers", providers: []provider.Provider{ - NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}), - NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil), - NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), - NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), - NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}), + aibridge.NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}), + aibridge.NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}), }, }, { name: "default_names_and_base_urls", providers: []provider.Provider{ - NewOpenAIProvider(config.OpenAI{}), - NewAnthropicProvider(config.Anthropic{}, nil), - NewCopilotProvider(config.Copilot{}), + aibridge.NewOpenAIProvider(config.OpenAI{}), + aibridge.NewAnthropicProvider(config.Anthropic{}, nil), + aibridge.NewCopilotProvider(config.Copilot{}), }, }, { name: "multiple_copilot_instances", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{}), - NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), - NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}), }, }, { name: "name_with_slashes", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}), }, expectErr: "invalid provider name", }, { name: "name_with_spaces", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}), }, expectErr: "invalid provider name", }, { name: "name_with_uppercase", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}), }, expectErr: "invalid provider name", }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - err := validateProviders(tc.providers) - if tc.expectErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectErr) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestValidateProvider_DuplicateNames(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - providers []provider.Provider - expectErr string - }{ { name: "unique_names", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), - NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), }, }, { name: "duplicate_base_url_different_names", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), - NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}), }, }, { name: "duplicate_name", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), - NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}), }, expectErr: "duplicate provider name", }, @@ -122,7 +103,7 @@ func TestValidateProvider_DuplicateNames(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - err := validateProviders(tc.providers) + _, err := aibridge.NewRequestBridge(t.Context(), tc.providers, nil, nil, logger, nil, bridgeTestTracer) if tc.expectErr != "" { require.Error(t, err) assert.Contains(t, err.Error(), tc.expectErr) @@ -148,7 +129,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { name: "openAI_no_base_path", requestPath: "/openai/v1/conversations", provider: func(baseURL string) provider.Provider { - return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL}) + return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL}) }, expectPath: "/conversations", }, @@ -157,7 +138,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { baseURLPath: "/v1", requestPath: "/openai/v1/conversations", provider: func(baseURL string) provider.Provider { - return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL}) + return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL}) }, expectPath: "/v1/conversations", }, @@ -165,7 +146,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { name: "anthropic_no_base_path", requestPath: "/anthropic/v1/models", provider: func(baseURL string) provider.Provider { - return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil) + return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil) }, expectPath: "/v1/models", }, @@ -174,7 +155,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { baseURLPath: "/v1", requestPath: "/anthropic/v1/models", provider: func(baseURL string) provider.Provider { - return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil) + return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil) }, expectPath: "/v1/v1/models", }, @@ -182,7 +163,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { name: "copilot_no_base_path", requestPath: "/copilot/models", provider: func(baseURL string) provider.Provider { - return NewCopilotProvider(config.Copilot{BaseURL: baseURL}) + return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL}) }, expectPath: "/models", }, @@ -191,7 +172,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { baseURLPath: "/v1", requestPath: "/copilot/models", provider: func(baseURL string) provider.Provider { - return NewCopilotProvider(config.Copilot{BaseURL: baseURL}) + return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL}) }, expectPath: "/v1/models", }, @@ -210,14 +191,14 @@ func TestPassthroughRoutesForProviders(t *testing.T) { })) t.Cleanup(upstream.Close) - recorder := testutil.MockRecorder{} + rec := testutil.MockRecorder{} prov := tc.provider(upstream.URL + tc.baseURLPath) - bridge, err := NewRequestBridge(t.Context(), []provider.Provider{prov}, &recorder, nil, logger, nil, testTracer) + bridge, err := aibridge.NewRequestBridge(t.Context(), []provider.Provider{prov}, &rec, nil, logger, nil, bridgeTestTracer) require.NoError(t, err) req := httptest.NewRequest("", tc.requestPath, nil) resp := httptest.NewRecorder() - bridge.mux.ServeHTTP(resp, req) + bridge.ServeHTTP(resp, req) assert.Equal(t, http.StatusOK, resp.Code) assert.Contains(t, resp.Body.String(), upstreamRespBody) diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go index ab744cbb..84fb98ae 100644 --- a/circuitbreaker/circuitbreaker_test.go +++ b/circuitbreaker/circuitbreaker_test.go @@ -1,4 +1,4 @@ -package circuitbreaker +package circuitbreaker_test import ( "errors" @@ -11,6 +11,7 @@ import ( "github.com/sony/gobreaker/v2" "github.com/stretchr/testify/assert" + "github.com/coder/aibridge/circuitbreaker" "github.com/coder/aibridge/config" ) @@ -20,7 +21,7 @@ func TestExecute_PerModelIsolation(t *testing.T) { sonnetCalls := atomic.Int32{} haikuCalls := atomic.Int32{} - cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{ FailureThreshold: 1, Interval: time.Minute, Timeout: time.Minute, @@ -48,7 +49,7 @@ func TestExecute_PerModelIsolation(t *testing.T) { rw.WriteHeader(http.StatusOK) return nil }) - assert.True(t, errors.Is(err, ErrCircuitOpen)) + assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen)) assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call assert.Equal(t, http.StatusServiceUnavailable, w.Code) @@ -69,7 +70,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { messagesCalls := atomic.Int32{} completionsCalls := atomic.Int32{} - cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{ FailureThreshold: 1, Interval: time.Minute, Timeout: time.Minute, @@ -95,7 +96,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { rw.WriteHeader(http.StatusOK) return nil }) - assert.True(t, errors.Is(err, ErrCircuitOpen)) + assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen)) assert.Equal(t, int32(1), messagesCalls.Load()) // No new call assert.Equal(t, http.StatusServiceUnavailable, w.Code) @@ -116,7 +117,7 @@ func TestExecute_CustomIsFailure(t *testing.T) { var calls atomic.Int32 // Custom IsFailure that treats 502 as failure - cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{ FailureThreshold: 1, Interval: time.Minute, Timeout: time.Minute, @@ -143,7 +144,7 @@ func TestExecute_CustomIsFailure(t *testing.T) { rw.WriteHeader(http.StatusOK) return nil }) - assert.True(t, errors.Is(err, ErrCircuitOpen)) + assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen)) assert.Equal(t, int32(1), calls.Load()) // No new call assert.Equal(t, http.StatusServiceUnavailable, w.Code) } @@ -158,7 +159,7 @@ func TestExecute_OnStateChange(t *testing.T) { to gobreaker.State } - cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{ FailureThreshold: 1, Interval: time.Minute, Timeout: time.Minute, @@ -209,14 +210,14 @@ func TestDefaultIsFailure(t *testing.T) { } for _, tt := range tests { - assert.Equal(t, tt.isFailure, DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode) + assert.Equal(t, tt.isFailure, circuitbreaker.DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode) } } func TestStateToGaugeValue(t *testing.T) { t.Parallel() - assert.Equal(t, float64(0), StateToGaugeValue(gobreaker.StateClosed)) - assert.Equal(t, float64(0.5), StateToGaugeValue(gobreaker.StateHalfOpen)) - assert.Equal(t, float64(1), StateToGaugeValue(gobreaker.StateOpen)) + assert.Equal(t, float64(0), circuitbreaker.StateToGaugeValue(gobreaker.StateClosed)) + assert.Equal(t, float64(0.5), circuitbreaker.StateToGaugeValue(gobreaker.StateHalfOpen)) + assert.Equal(t, float64(1), circuitbreaker.StateToGaugeValue(gobreaker.StateOpen)) } diff --git a/client.go b/client.go index a5c84f84..3e9e277b 100644 --- a/client.go +++ b/client.go @@ -24,10 +24,10 @@ const ( ClientUnknown Client = "Unknown" ) -// guessClient attempts to guess the client application from the request headers. +// GuessClient attempts to guess the client application from the request headers. // Not all clients set proper user agent headers, so this is a best-effort approach. // Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101. -func guessClient(r *http.Request) Client { +func GuessClient(r *http.Request) Client { userAgent := strings.ToLower(r.UserAgent()) originator := r.Header.Get("originator") diff --git a/client_test.go b/client_test.go index a33f8459..923e5c0e 100644 --- a/client_test.go +++ b/client_test.go @@ -1,10 +1,12 @@ -package aibridge +package aibridge_test import ( "net/http" "testing" "github.com/stretchr/testify/require" + + "github.com/coder/aibridge" ) func TestGuessClient(t *testing.T) { @@ -14,93 +16,93 @@ func TestGuessClient(t *testing.T) { name string userAgent string headers map[string]string - wantClient Client + wantClient aibridge.Client }{ { name: "mux", userAgent: "mux/0.19.0-next.2.gcceff159 ai-sdk/openai/3.0.36 ai-sdk/provider-utils/4.0.15 runtime/node.js/22", - wantClient: ClientMux, + wantClient: aibridge.ClientMux, }, { name: "claude_code", userAgent: "claude-cli/2.0.67 (external, cli)", - wantClient: ClientClaudeCode, + wantClient: aibridge.ClientClaudeCode, }, { name: "codex_cli", userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64) ghostty/1.3.0-main_250877ef", - wantClient: ClientCodex, + wantClient: aibridge.ClientCodex, }, { name: "zed", userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)", - wantClient: ClientZed, + wantClient: aibridge.ClientZed, }, { name: "github_copilot_vsc", userAgent: "GitHubCopilotChat/0.37.2026011603", - wantClient: ClientCopilotVSC, + wantClient: aibridge.ClientCopilotVSC, }, { name: "github_copilot_cli", userAgent: "copilot/0.0.403 (client/cli linux v24.11.1)", - wantClient: ClientCopilotCLI, + wantClient: aibridge.ClientCopilotCLI, }, { name: "kilo_code_user_agent", userAgent: "kilo-code/5.1.0 (darwin 25.2.0; arm64) node/22.21.1", - wantClient: ClientKilo, + wantClient: aibridge.ClientKilo, }, { name: "kilo_code_originator", headers: map[string]string{"Originator": "kilo-code"}, - wantClient: ClientKilo, + wantClient: aibridge.ClientKilo, }, { name: "roo_code_user_agent", userAgent: "roo-code/3.45.0 (darwin 25.2.0; arm64) node/22.21.1", - wantClient: ClientRoo, + wantClient: aibridge.ClientRoo, }, { name: "roo_code_originator", headers: map[string]string{"Originator": "roo-code"}, - wantClient: ClientRoo, + wantClient: aibridge.ClientRoo, }, { name: "coder_agents", userAgent: "coder-agents/v2.24.0 (linux/amd64)", - wantClient: ClientCoderAgents, + wantClient: aibridge.ClientCoderAgents, }, { name: "coder_agents_dev", userAgent: "coder-agents/v0.0.0-devel (darwin/arm64)", - wantClient: ClientCoderAgents, + wantClient: aibridge.ClientCoderAgents, }, { name: "charm_crush", userAgent: "Charm Crush/0.1.11", - wantClient: ClientCrush, + wantClient: aibridge.ClientCrush, }, { name: "cursor_x_cursor_client_version", userAgent: "connect-es/1.6.1", headers: map[string]string{"X-Cursor-client-version": "0.50.0"}, - wantClient: ClientCursor, + wantClient: aibridge.ClientCursor, }, { name: "cursor_x_cursor_some_other_header", headers: map[string]string{"x-cursor-client-version": "abc123"}, - wantClient: ClientCursor, + wantClient: aibridge.ClientCursor, }, { name: "unknown_client", userAgent: "ccclaude-cli/calude-with-wrong-prefix", - wantClient: ClientUnknown, + wantClient: aibridge.ClientUnknown, }, { name: "empty_user_agent", userAgent: "", - wantClient: ClientUnknown, + wantClient: aibridge.ClientUnknown, }, } @@ -116,7 +118,7 @@ func TestGuessClient(t *testing.T) { req.Header.Set(key, value) } - got := guessClient(req) + got := aibridge.GuessClient(req) require.Equal(t, tt.wantClient, got) }) } diff --git a/context/context_test.go b/context/context_test.go index 22e4d567..e9ba8b07 100644 --- a/context/context_test.go +++ b/context/context_test.go @@ -1,4 +1,4 @@ -package context +package context_test import ( "context" @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/recorder" ) @@ -17,10 +18,10 @@ func TestAsActor(t *testing.T) { metadata := recorder.Metadata{"key": "value"} // When: storing an actor in the context - ctx := AsActor(context.Background(), "actor-123", metadata) + ctx := aibcontext.AsActor(context.Background(), "actor-123", metadata) // Then: the actor should be retrievable with correct ID and metadata - actor := ActorFromContext(ctx) + actor := aibcontext.ActorFromContext(ctx) require.NotNil(t, actor) assert.Equal(t, "actor-123", actor.ID) assert.Equal(t, "value", actor.Metadata["key"]) @@ -33,10 +34,10 @@ func TestActorFromContext(t *testing.T) { t.Parallel() // Given: a context with an actor - ctx := AsActor(context.Background(), "test-id", recorder.Metadata{}) + ctx := aibcontext.AsActor(context.Background(), "test-id", recorder.Metadata{}) // When: extracting the actor from context - actor := ActorFromContext(ctx) + actor := aibcontext.ActorFromContext(ctx) // Then: the actor should be returned with correct ID require.NotNil(t, actor) @@ -50,7 +51,7 @@ func TestActorFromContext(t *testing.T) { ctx := context.Background() // When: extracting the actor from context - actor := ActorFromContext(ctx) + actor := aibcontext.ActorFromContext(ctx) // Then: nil should be returned assert.Nil(t, actor) @@ -64,10 +65,10 @@ func TestActorIDFromContext(t *testing.T) { t.Parallel() // Given: a context with an actor - ctx := AsActor(context.Background(), "test-actor-id", recorder.Metadata{}) + ctx := aibcontext.AsActor(context.Background(), "test-actor-id", recorder.Metadata{}) // When: extracting the actor ID from context - got := ActorIDFromContext(ctx) + got := aibcontext.ActorIDFromContext(ctx) // Then: the actor ID should be returned assert.Equal(t, "test-actor-id", got) @@ -80,7 +81,7 @@ func TestActorIDFromContext(t *testing.T) { ctx := context.Background() // When: extracting the actor ID from context - got := ActorIDFromContext(ctx) + got := aibcontext.ActorIDFromContext(ctx) // Then: an empty string should be returned assert.Empty(t, got) diff --git a/intercept/actor_headers_test.go b/intercept/actor_headers_test.go index f38a7315..080f09b8 100644 --- a/intercept/actor_headers_test.go +++ b/intercept/actor_headers_test.go @@ -1,4 +1,4 @@ -package intercept +package intercept_test import ( "testing" @@ -7,14 +7,15 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/recorder" ) func TestNilActor(t *testing.T) { t.Parallel() - require.Nil(t, ActorHeadersAsOpenAIOpts(nil)) - require.Nil(t, ActorHeadersAsAnthropicOpts(nil)) + require.Nil(t, intercept.ActorHeadersAsOpenAIOpts(nil)) + require.Nil(t, intercept.ActorHeadersAsAnthropicOpts(nil)) } func TestBasic(t *testing.T) { @@ -28,9 +29,9 @@ func TestBasic(t *testing.T) { // We can't peek inside since these opts require an internal type to apply onto. // All we can do is check the length. // See TestActorHeaders for an integration test. - oaiOpts := ActorHeadersAsOpenAIOpts(actor) + oaiOpts := intercept.ActorHeadersAsOpenAIOpts(actor) require.Len(t, oaiOpts, 1) - antOpts := ActorHeadersAsAnthropicOpts(actor) + antOpts := intercept.ActorHeadersAsAnthropicOpts(actor) require.Len(t, antOpts, 1) } @@ -49,8 +50,8 @@ func TestBasicAndMetadata(t *testing.T) { // We can't peek inside since these opts require an internal type to apply onto. // All we can do is check the length. // See TestActorHeaders for an integration test. - oaiOpts := ActorHeadersAsOpenAIOpts(actor) + oaiOpts := intercept.ActorHeadersAsOpenAIOpts(actor) require.Len(t, oaiOpts, 1+len(actor.Metadata)) - antOpts := ActorHeadersAsAnthropicOpts(actor) + antOpts := intercept.ActorHeadersAsAnthropicOpts(actor) require.Len(t, antOpts, 1+len(actor.Metadata)) } diff --git a/intercept/apidump/apidump_test.go b/intercept/apidump/apidump_test.go index 043ce15c..1aaf56c0 100644 --- a/intercept/apidump/apidump_test.go +++ b/intercept/apidump/apidump_test.go @@ -1,4 +1,4 @@ -package apidump +package apidump //nolint:testpackage // tests unexported internals import ( "bytes" diff --git a/intercept/apidump/headers_test.go b/intercept/apidump/headers_test.go index dc2a0216..1c5bb697 100644 --- a/intercept/apidump/headers_test.go +++ b/intercept/apidump/headers_test.go @@ -1,4 +1,4 @@ -package apidump +package apidump //nolint:testpackage // tests unexported internals import ( "bytes" diff --git a/intercept/apidump/streaming_test.go b/intercept/apidump/streaming_test.go index 87223df6..7bdac2a9 100644 --- a/intercept/apidump/streaming_test.go +++ b/intercept/apidump/streaming_test.go @@ -1,4 +1,4 @@ -package apidump +package apidump //nolint:testpackage // shares test helpers with apidump_test.go import ( "bytes" diff --git a/intercept/chatcompletions/base_test.go b/intercept/chatcompletions/base_test.go index 5f83f5c3..7094c2c4 100644 --- a/intercept/chatcompletions/base_test.go +++ b/intercept/chatcompletions/base_test.go @@ -1,4 +1,4 @@ -package chatcompletions +package chatcompletions //nolint:testpackage // tests unexported internals import ( "testing" diff --git a/intercept/chatcompletions/paramswrap_test.go b/intercept/chatcompletions/paramswrap_test.go index 7397e220..1e7c61f3 100644 --- a/intercept/chatcompletions/paramswrap_test.go +++ b/intercept/chatcompletions/paramswrap_test.go @@ -1,4 +1,4 @@ -package chatcompletions +package chatcompletions //nolint:testpackage // tests unexported internals import ( "fmt" diff --git a/intercept/chatcompletions/streaming_test.go b/intercept/chatcompletions/streaming_test.go index 52d5baa5..88d50461 100644 --- a/intercept/chatcompletions/streaming_test.go +++ b/intercept/chatcompletions/streaming_test.go @@ -1,4 +1,4 @@ -package chatcompletions +package chatcompletions_test import ( "net/http" @@ -16,6 +16,7 @@ import ( "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge/config" "github.com/coder/aibridge/intercept" + "github.com/coder/aibridge/intercept/chatcompletions" "github.com/coder/aibridge/internal/testutil" ) @@ -73,7 +74,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { Key: "test-key", } - req := &ChatCompletionNewParamsWrapper{ + req := &chatcompletions.ChatCompletionNewParamsWrapper{ ChatCompletionNewParams: openai.ChatCompletionNewParams{ Model: "gpt-4", Messages: []openai.ChatCompletionMessageParamUnion{ @@ -88,7 +89,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil) tracer := otel.Tracer("test") - interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer, intercept.CredentialInfo{}) + interceptor := chatcompletions.NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer, intercept.CredentialInfo{}) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) interceptor.Setup(logger, &testutil.MockRecorder{}, nil) diff --git a/intercept/client_headers_test.go b/intercept/client_headers_test.go index ecd2f018..918ce64d 100644 --- a/intercept/client_headers_test.go +++ b/intercept/client_headers_test.go @@ -1,4 +1,4 @@ -package intercept +package intercept_test import ( "net/http" @@ -6,6 +6,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/coder/aibridge/intercept" ) func TestPrepareClientHeaders(t *testing.T) { @@ -14,7 +16,7 @@ func TestPrepareClientHeaders(t *testing.T) { t.Run("nil input returns empty header", func(t *testing.T) { t.Parallel() - result := PrepareClientHeaders(nil) + result := intercept.PrepareClientHeaders(nil) require.Empty(t, result) }) @@ -29,7 +31,7 @@ func TestPrepareClientHeaders(t *testing.T) { "X-Custom": {"preserved"}, } - result := PrepareClientHeaders(input) + result := intercept.PrepareClientHeaders(input) assert.Empty(t, result.Get("Connection")) assert.Empty(t, result.Get("Keep-Alive")) @@ -48,7 +50,7 @@ func TestPrepareClientHeaders(t *testing.T) { "X-Custom": {"preserved"}, } - result := PrepareClientHeaders(input) + result := intercept.PrepareClientHeaders(input) assert.Empty(t, result.Get("Host")) assert.Empty(t, result.Get("Accept-Encoding")) @@ -65,7 +67,7 @@ func TestPrepareClientHeaders(t *testing.T) { "X-Custom": {"preserved"}, } - result := PrepareClientHeaders(input) + result := intercept.PrepareClientHeaders(input) assert.Empty(t, result.Get("Authorization")) assert.Empty(t, result.Get("X-Api-Key")) @@ -79,7 +81,7 @@ func TestPrepareClientHeaders(t *testing.T) { "X-Custom": {"value-1", "value-2"}, } - result := PrepareClientHeaders(input) + result := intercept.PrepareClientHeaders(input) require.Equal(t, []string{"value-1", "value-2"}, result["X-Custom"]) }) @@ -93,7 +95,7 @@ func TestPrepareClientHeaders(t *testing.T) { } originalCopy := input.Clone() - _ = PrepareClientHeaders(input) + _ = intercept.PrepareClientHeaders(input) require.Equal(t, originalCopy, input) }) @@ -113,7 +115,7 @@ func TestBuildUpstreamHeaders(t *testing.T) { "User-Agent": {"claude-code/1.0"}, } - result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") assert.Equal(t, "Bearer sk-provider-key", result.Get("Authorization")) assert.Equal(t, "claude-code/1.0", result.Get("User-Agent")) @@ -131,7 +133,7 @@ func TestBuildUpstreamHeaders(t *testing.T) { "Anthropic-Beta": {"prompt-caching-2024-07-31"}, } - result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "X-Api-Key") + result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "X-Api-Key") assert.Equal(t, "sk-ant-provider-key", result.Get("X-Api-Key")) assert.Empty(t, result.Get("Authorization")) @@ -151,7 +153,7 @@ func TestBuildUpstreamHeaders(t *testing.T) { "User-Agent": {"claude-code/1.0"}, } - result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") assert.Equal(t, "Bearer sk-key", result.Get("Authorization")) assert.Equal(t, "user-123", result.Get("X-Ai-Bridge-Actor-Id")) @@ -174,7 +176,7 @@ func TestBuildUpstreamHeaders(t *testing.T) { "User-Agent": {"claude-code/1.0"}, } - result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") assert.Empty(t, result.Get("Connection")) assert.Empty(t, result.Get("Host")) @@ -192,7 +194,7 @@ func TestBuildUpstreamHeaders(t *testing.T) { "User-Agent": {"claude-code/1.0"}, } - result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") assert.Empty(t, result.Get("Authorization")) assert.Equal(t, "claude-code/1.0", result.Get("User-Agent")) @@ -211,7 +213,7 @@ func TestBuildUpstreamHeaders(t *testing.T) { sdkCopy := sdkHeader.Clone() clientCopy := clientHeaders.Clone() - _ = BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + _ = intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") require.Equal(t, sdkCopy, sdkHeader) require.Equal(t, clientCopy, clientHeaders) diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index ae1ee5be..ff0a20d6 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -1,4 +1,4 @@ -package messages +package messages //nolint:testpackage // tests unexported internals import ( "context" diff --git a/intercept/messages/reqpayload_test.go b/intercept/messages/reqpayload_test.go index d1b062f4..a5de61f8 100644 --- a/intercept/messages/reqpayload_test.go +++ b/intercept/messages/reqpayload_test.go @@ -1,4 +1,4 @@ -package messages +package messages //nolint:testpackage // tests unexported internals import ( "testing" diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index cf02738f..ea5c87b5 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -1,4 +1,4 @@ -package responses +package responses //nolint:testpackage // tests unexported internals import ( "net/http" diff --git a/intercept/responses/reqpayload_test.go b/intercept/responses/reqpayload_test.go index df99954f..83115f08 100644 --- a/intercept/responses/reqpayload_test.go +++ b/intercept/responses/reqpayload_test.go @@ -1,4 +1,4 @@ -package responses +package responses //nolint:testpackage // tests unexported internals import ( "encoding/json" diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index e05f9d0d..f2be96f3 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -1,4 +1,4 @@ -package integrationtest +package integrationtest //nolint:testpackage // tests unexported internals import ( "bufio" diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 15623b84..d0fdff16 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -1,4 +1,4 @@ -package integrationtest +package integrationtest //nolint:testpackage // tests unexported internals import ( "bytes" diff --git a/internal/integrationtest/circuit_breaker_test.go b/internal/integrationtest/circuit_breaker_test.go index 3ad039be..ec619409 100644 --- a/internal/integrationtest/circuit_breaker_test.go +++ b/internal/integrationtest/circuit_breaker_test.go @@ -1,4 +1,4 @@ -package integrationtest +package integrationtest //nolint:testpackage // tests unexported internals import ( "fmt" diff --git a/internal/integrationtest/metrics_test.go b/internal/integrationtest/metrics_test.go index f774b046..e941d94b 100644 --- a/internal/integrationtest/metrics_test.go +++ b/internal/integrationtest/metrics_test.go @@ -1,4 +1,4 @@ -package integrationtest +package integrationtest //nolint:testpackage // tests unexported internals import ( "bytes" diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 61c885d7..3213e2ff 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -1,4 +1,4 @@ -package integrationtest +package integrationtest //nolint:testpackage // tests unexported internals import ( "context" diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index 164c880e..dc86815f 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -1,4 +1,4 @@ -package integrationtest +package integrationtest //nolint:testpackage // tests unexported internals import ( "context" diff --git a/passthrough_test.go b/passthrough_test.go index 8f219c7c..85600d65 100644 --- a/passthrough_test.go +++ b/passthrough_test.go @@ -1,4 +1,4 @@ -package aibridge +package aibridge //nolint:testpackage // tests unexported newPassthroughRouter import ( "net/http" diff --git a/provider/anthropic_test.go b/provider/anthropic_test.go index c59fd7dd..bc14296b 100644 --- a/provider/anthropic_test.go +++ b/provider/anthropic_test.go @@ -1,4 +1,4 @@ -package provider +package provider //nolint:testpackage // tests unexported internals import ( "bytes" diff --git a/provider/copilot_test.go b/provider/copilot_test.go index b45da10e..c5d5316f 100644 --- a/provider/copilot_test.go +++ b/provider/copilot_test.go @@ -1,4 +1,4 @@ -package provider +package provider //nolint:testpackage // tests unexported internals import ( "bytes" diff --git a/provider/openai_test.go b/provider/openai_test.go index 80e5097e..f53a2830 100644 --- a/provider/openai_test.go +++ b/provider/openai_test.go @@ -1,4 +1,4 @@ -package provider +package provider //nolint:testpackage // tests unexported internals import ( "bytes" diff --git a/session.go b/session.go index e89ad784..760ea7cf 100644 --- a/session.go +++ b/session.go @@ -14,10 +14,10 @@ import ( var claudeCodePattern = regexp.MustCompile(`_session_(.+)$`) // Legacy format: save compilation on each call. -// guessSessionID attempts to retrieve a session ID which may have been sent by +// GuessSessionID attempts to retrieve a session ID which may have been sent by // the client. We only attempt to retrieve sessions using methods recognized for // the given client. -func guessSessionID(client Client, r *http.Request) *string { +func GuessSessionID(client Client, r *http.Request) *string { switch client { case ClientClaudeCode: // Prefer the dedicated header (added in Claude Code v2.1.86+). diff --git a/session_test.go b/session_test.go index 7cce9e40..244b5cc0 100644 --- a/session_test.go +++ b/session_test.go @@ -1,4 +1,4 @@ -package aibridge +package aibridge_test import ( "io" @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/coder/aibridge" "github.com/coder/aibridge/utils" ) @@ -16,7 +17,7 @@ func TestGuessSessionID(t *testing.T) { cases := []struct { name string - client Client + client aibridge.Client body string headers map[string]string sessionID *string @@ -24,177 +25,177 @@ func TestGuessSessionID(t *testing.T) { // Claude Code. { name: "claude_code_header_takes_precedence", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, headers: map[string]string{"X-Claude-Code-Session-Id": "header-session-id"}, body: `{"metadata":{"user_id":"user_abc123_account_456_session_body-session-id"}}`, sessionID: utils.PtrTo("header-session-id"), }, { name: "claude_code_header_only", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, headers: map[string]string{"X-Claude-Code-Session-Id": "aabb-ccdd"}, body: `{"model":"claude-3"}`, sessionID: utils.PtrTo("aabb-ccdd"), }, { name: "claude_code_empty_header_falls_back_to_body", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, headers: map[string]string{"X-Claude-Code-Session-Id": ""}, body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`, sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"), }, { name: "claude_code_whitespace_header_falls_back_to_body", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, headers: map[string]string{"X-Claude-Code-Session-Id": " "}, body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`, sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"), }, { name: "claude_code_with_valid_session", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`, sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"), }, { name: "claude_code_with_valid_session_new_format", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"metadata":{"user_id":"{\"device_id\":\"45aa15c8c244ea2582f8144dde91a50ec3815851f6f648abef4ee15b173cc927\",\"account_uuid\":\"\",\"session_id\":\"54c1eb09-bc4c-4d2f-98eb-6d2ab2d5e2fe\"}"}}`, sessionID: utils.PtrTo("54c1eb09-bc4c-4d2f-98eb-6d2ab2d5e2fe"), }, { name: "claude_code_new_format_empty_session_id", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\",\"session_id\":\"\"}"}}`, }, { name: "claude_code_new_format_no_session_id_field", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\"}"}}`, }, { name: "claude_code_missing_metadata", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"model":"claude-3"}`, }, { name: "claude_code_missing_user_id", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"metadata":{}}`, }, { name: "claude_code_user_id_without_session", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"metadata":{"user_id":"user_abc123_account_456"}}`, }, { name: "claude_code_empty_body", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: ``, }, { name: "claude_code_invalid_json", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `not json at all`, }, // Codex. { name: "codex_with_session_header", - client: ClientCodex, + client: aibridge.ClientCodex, headers: map[string]string{"session_id": "codex-session-123"}, sessionID: utils.PtrTo("codex-session-123"), }, { name: "codex_with_whitespace_in_header", - client: ClientCodex, + client: aibridge.ClientCodex, headers: map[string]string{"session_id": " codex-session-123 "}, sessionID: utils.PtrTo("codex-session-123"), }, { name: "codex_without_session_header", - client: ClientCodex, + client: aibridge.ClientCodex, }, // Other clients shouldn't use others' logic. { name: "unknown_client_returns_empty", - client: ClientUnknown, + client: aibridge.ClientUnknown, body: `{"metadata":{"user_id":"user_abc_account_456_session_some-id"}}`, }, { name: "zed_returns_empty", - client: ClientZed, + client: aibridge.ClientZed, headers: map[string]string{"session_id": "zed-session"}, body: `{"metadata":{"user_id":"user_abc_account_456_session_some-id"}}`, }, // Mux. { name: "mux_with_workspace_header", - client: ClientMux, + client: aibridge.ClientMux, headers: map[string]string{"X-Mux-Workspace-Id": "ws-abc-123"}, sessionID: utils.PtrTo("ws-abc-123"), }, { name: "mux_without_workspace_header", - client: ClientMux, + client: aibridge.ClientMux, }, // Copilot VS Code. { name: "copilot_vsc_with_interaction_id", - client: ClientCopilotVSC, + client: aibridge.ClientCopilotVSC, headers: map[string]string{"x-interaction-id": "interaction-xyz"}, sessionID: utils.PtrTo("interaction-xyz"), }, { name: "copilot_vsc_without_interaction_id", - client: ClientCopilotVSC, + client: aibridge.ClientCopilotVSC, }, // Copilot CLI. { name: "copilot_cli_with_session_header", - client: ClientCopilotCLI, + client: aibridge.ClientCopilotCLI, headers: map[string]string{"X-Client-Session-Id": "cli-sess-456"}, sessionID: utils.PtrTo("cli-sess-456"), }, { name: "copilot_cli_without_session_header", - client: ClientCopilotCLI, + client: aibridge.ClientCopilotCLI, }, // Kilo. { name: "kilo_with_task_id", - client: ClientKilo, + client: aibridge.ClientKilo, headers: map[string]string{"X-KILOCODE-TASKID": "task-789"}, sessionID: utils.PtrTo("task-789"), }, { name: "kilo_without_task_id", - client: ClientKilo, + client: aibridge.ClientKilo, }, // Coder Agents. { name: "coder_agents_with_chat_id", - client: ClientCoderAgents, + client: aibridge.ClientCoderAgents, headers: map[string]string{"X-Coder-Chat-Id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890"}, sessionID: utils.PtrTo("a1b2c3d4-e5f6-7890-abcd-ef1234567890"), }, { name: "coder_agents_without_chat_id", - client: ClientCoderAgents, + client: aibridge.ClientCoderAgents, }, // Roo. { name: "roo_returns_empty", - client: ClientRoo, + client: aibridge.ClientRoo, }, // Cursor. { name: "cursor_returns_empty", - client: ClientCursor, + client: aibridge.ClientCursor, }, // Other cases. { name: "empty session ID value", - client: ClientKilo, + client: aibridge.ClientKilo, headers: map[string]string{"X-KILOCODE-TASKID": " "}, sessionID: nil, }, @@ -212,7 +213,7 @@ func TestGuessSessionID(t *testing.T) { req.Header.Set(key, value) } - got := guessSessionID(tc.client, req) + got := aibridge.GuessSessionID(tc.client, req) require.Equal(t, tc.sessionID, got) // Verify the body was restored and can be read again. @@ -229,7 +230,7 @@ func TestUnreadableBody(t *testing.T) { req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://localhost", &errReader{}) require.NoError(t, err) - got := guessSessionID(ClientClaudeCode, req) + got := aibridge.GuessSessionID(aibridge.ClientClaudeCode, req) require.Nil(t, got) }