Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ package aibridge
import (
"context"

"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"

"cdr.dev/slog/v3"
"github.com/coder/aibridge/config"
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/metrics"
"github.com/coder/aibridge/provider"
"github.com/coder/aibridge/recorder"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
)

// Const + Type + function aliases for backwards compatibility.
Expand Down
24 changes: 13 additions & 11 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ import (
"sync/atomic"
"time"

"github.com/hashicorp/go-multierror"
"github.com/sony/gobreaker/v2"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"

"cdr.dev/slog/v3"
"github.com/coder/aibridge/circuitbreaker"
aibcontext "github.com/coder/aibridge/context"
Expand All @@ -19,10 +25,6 @@ import (
"github.com/coder/aibridge/provider"
"github.com/coder/aibridge/recorder"
"github.com/coder/aibridge/tracing"
"github.com/hashicorp/go-multierror"
"github.com/sony/gobreaker/v2"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)

const (
Expand Down Expand Up @@ -67,10 +69,10 @@ func validateProviders(providers []provider.Provider) error {
for _, prov := range providers {
name := prov.Name()
if !validProviderName.MatchString(name) {
return fmt.Errorf("invalid provider name %q: must contain only lowercase alphanumeric characters and hyphens", name)
return xerrors.Errorf("invalid provider name %q: must contain only lowercase alphanumeric characters and hyphens", name)
}
if names[name] {
return fmt.Errorf("duplicate provider name: %q", name)
return xerrors.Errorf("duplicate provider name: %q", name)
}
names[name] = true
}
Expand Down Expand Up @@ -125,7 +127,7 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re
slog.F("prefix", prov.RoutePrefix()),
slog.F("path", path),
)
return nil, fmt.Errorf("failed to configure provider '%v': failed to join bridged path: %w", providerName, err)
return nil, xerrors.Errorf("failed to configure provider '%v': failed to join bridged path: %w", providerName, err)
}
mux.Handle(route, handler)
}
Expand All @@ -144,7 +146,7 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re
slog.F("prefix", prov.RoutePrefix()),
slog.F("path", path),
)
return nil, fmt.Errorf("failed to configure provider '%v': failed to join passed through path: %w", providerName, err)
return nil, xerrors.Errorf("failed to configure provider '%v': failed to join passed through path: %w", providerName, err)
}
mux.Handle(route, http.StripPrefix(prov.RoutePrefix(), ftr))
}
Expand Down Expand Up @@ -325,7 +327,7 @@ func (b *RequestBridge) Shutdown(ctx context.Context) error {
select {
case <-ctx.Done():
// Cancel all inflight requests, if any are still running.
b.logger.Debug(ctx, "shutdown context canceled; cancelling inflight requests", slog.Error(ctx.Err()))
b.logger.Debug(ctx, "shutdown context canceled; canceling inflight requests", slog.Error(ctx.Err()))
b.inflightCancel()
<-done
err = ctx.Err()
Expand All @@ -347,8 +349,8 @@ func (b *RequestBridge) InflightRequests() int32 {
return b.inflightReqs.Load()
}

// mergeContexts merges two contexts together, so that if either is cancelled
// the returned context is cancelled. The context values will only be used from
// mergeContexts merges two contexts together, so that if either is canceled
// the returned context is canceled. The context values will only be used from
// the first context.
func mergeContexts(base, other context.Context) context.Context {
ctx, cancel := context.WithCancel(base)
Expand Down
5 changes: 3 additions & 2 deletions bridge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import (
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/internal/testutil"
"github.com/coder/aibridge/provider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestValidateProvider_Names(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion buildinfo/buildinfo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package buildinfo_test
import (
"testing"

"github.com/coder/aibridge/buildinfo"
"github.com/stretchr/testify/assert"

"github.com/coder/aibridge/buildinfo"
)

func TestBuildInfo(t *testing.T) {
Expand Down
8 changes: 5 additions & 3 deletions circuitbreaker/circuitbreaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ import (
"sync"
"time"

"github.com/sony/gobreaker/v2"
"golang.org/x/xerrors"

"github.com/coder/aibridge/config"
"github.com/coder/aibridge/metrics"
"github.com/sony/gobreaker/v2"
)

// ErrCircuitOpen is returned by Execute when the circuit breaker is open
// and the request was rejected without calling the handler.
var ErrCircuitOpen = errors.New("circuit breaker is open")
var ErrCircuitOpen = xerrors.New("circuit breaker is open")

// DefaultIsFailure returns true for standard HTTP status codes that typically
// indicate upstream overload.
Expand Down Expand Up @@ -153,7 +155,7 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons
_, err := cb.Execute(func() (struct{}, error) {
handlerErr = handler(sw)
if p.isFailure(sw.statusCode) {
return struct{}{}, fmt.Errorf("upstream error: %d", sw.statusCode)
return struct{}{}, xerrors.Errorf("upstream error: %d", sw.statusCode)
}
return struct{}{}, nil
})
Expand Down
3 changes: 2 additions & 1 deletion circuitbreaker/circuitbreaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ import (
"testing"
"time"

"github.com/coder/aibridge/config"
"github.com/sony/gobreaker/v2"
"github.com/stretchr/testify/assert"

"github.com/coder/aibridge/config"
)

func TestExecute_PerModelIsolation(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ require (
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b
golang.org/x/sync v0.16.0
golang.org/x/tools v0.36.0
golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9
)

// AI-related libs.
Expand Down Expand Up @@ -86,7 +87,6 @@ require (
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/term v0.34.0 // indirect
golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
3 changes: 2 additions & 1 deletion intercept/actor_headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import (
"strings"

ant_option "github.com/anthropics/anthropic-sdk-go/option"
"github.com/coder/aibridge/context"
oai_option "github.com/openai/openai-go/v3/option"

"github.com/coder/aibridge/context"
)

const (
Expand Down
5 changes: 3 additions & 2 deletions intercept/actor_headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package intercept
import (
"testing"

"github.com/coder/aibridge/context"
"github.com/coder/aibridge/recorder"
"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/coder/aibridge/context"
"github.com/coder/aibridge/recorder"
)

func TestNilActor(t *testing.T) {
Expand Down
27 changes: 14 additions & 13 deletions intercept/apidump/apidump.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ import (
"slices"
"strings"

"cdr.dev/slog/v3"
"github.com/google/uuid"
"github.com/tidwall/pretty"
"golang.org/x/xerrors"

"cdr.dev/slog/v3"
"github.com/coder/aibridge/utils"
"github.com/coder/quartz"
"github.com/google/uuid"
"github.com/tidwall/pretty"
)

const (
Expand Down Expand Up @@ -72,7 +73,7 @@ type dumper struct {
func (d *dumper) dumpRequest(req *http.Request) error {
dumpPath := d.dumpPath + SuffixRequest
if err := os.MkdirAll(filepath.Dir(dumpPath), 0o755); err != nil {
return fmt.Errorf("create dump dir: %w", err)
return xerrors.Errorf("create dump dir: %w", err)
}

// Read and restore body
Expand All @@ -81,7 +82,7 @@ func (d *dumper) dumpRequest(req *http.Request) error {
var err error
bodyBytes, err = io.ReadAll(req.Body)
if err != nil {
return fmt.Errorf("read request body: %w", err)
return xerrors.Errorf("read request body: %w", err)
}
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
Expand All @@ -92,18 +93,18 @@ func (d *dumper) dumpRequest(req *http.Request) error {
var buf bytes.Buffer
_, err := fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto)
if err != nil {
return fmt.Errorf("write request uri: %w", err)
return xerrors.Errorf("write request uri: %w", err)
}
err = d.writeRedactedHeaders(&buf, req.Header, sensitiveRequestHeaders, map[string]string{
"Content-Length": fmt.Sprintf("%d", len(prettyBody)),
})
if err != nil {
return fmt.Errorf("write request headers: %w", err)
return xerrors.Errorf("write request headers: %w", err)
}

_, err = fmt.Fprintf(&buf, "\r\n")
if err != nil {
return fmt.Errorf("write request header terminator: %w", err)
return xerrors.Errorf("write request header terminator: %w", err)
}
buf.Write(prettyBody)
buf.WriteByte('\n')
Expand All @@ -118,15 +119,15 @@ func (d *dumper) dumpResponse(resp *http.Response) error {
var headerBuf bytes.Buffer
_, err := fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status)
if err != nil {
return fmt.Errorf("write response status: %w", err)
return xerrors.Errorf("write response status: %w", err)
}
err = d.writeRedactedHeaders(&headerBuf, resp.Header, sensitiveResponseHeaders, nil)
if err != nil {
return fmt.Errorf("write response headers: %w", err)
return xerrors.Errorf("write response headers: %w", err)
}
_, err = fmt.Fprintf(&headerBuf, "\r\n")
if err != nil {
return fmt.Errorf("write response header terminator: %w", err)
return xerrors.Errorf("write response header terminator: %w", err)
}

// Wrap the response body to capture it as it streams
Expand Down Expand Up @@ -176,7 +177,7 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv
if override, ok := overrides[key]; ok {
_, err := fmt.Fprintf(w, "%s: %s\r\n", key, override)
if err != nil {
return fmt.Errorf("write response header override: %w", err)
return xerrors.Errorf("write response header override: %w", err)
}
}
continue
Expand All @@ -191,7 +192,7 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv
}
_, err := fmt.Fprintf(w, "%s: %s\r\n", key, value)
if err != nil {
return fmt.Errorf("write response headers: %w", err)
return xerrors.Errorf("write response headers: %w", err)
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions intercept/apidump/apidump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ import (
"strings"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/quartz"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)

// findDumpFile finds a dump file matching the pattern in the given directory.
Expand Down
3 changes: 2 additions & 1 deletion intercept/apidump/headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import (

"cdr.dev/slog/v3"

"github.com/coder/quartz"
"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/coder/quartz"
)

func TestSensitiveHeaderLists(t *testing.T) {
Expand Down
9 changes: 5 additions & 4 deletions intercept/apidump/streaming.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package apidump

import (
"fmt"
"io"
"os"
"path/filepath"
"sync"

"golang.org/x/xerrors"
)

// streamingBodyDumper wraps an io.ReadCloser and writes all data to a dump file
Expand All @@ -24,18 +25,18 @@ type streamingBodyDumper struct {
func (s *streamingBodyDumper) init() {
s.once.Do(func() {
if err := os.MkdirAll(filepath.Dir(s.dumpPath), 0o755); err != nil {
s.initErr = fmt.Errorf("create dump dir: %w", err)
s.initErr = xerrors.Errorf("create dump dir: %w", err)
return
}
f, err := os.Create(s.dumpPath)
if err != nil {
s.initErr = fmt.Errorf("create dump file: %w", err)
s.initErr = xerrors.Errorf("create dump file: %w", err)
return
}
s.file = f
// Write headers first.
if _, err := s.file.Write(s.headerData); err != nil {
s.initErr = fmt.Errorf("write headers: %w", err)
s.initErr = xerrors.Errorf("write headers: %w", err)
s.file.Close()
s.file = nil
}
Expand Down
5 changes: 3 additions & 2 deletions intercept/apidump/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import (
"strings"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/quartz"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)

func TestMiddleware_StreamingResponse(t *testing.T) {
Expand Down
13 changes: 7 additions & 6 deletions intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ import (
"net/http"
"strings"

"github.com/google/uuid"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/shared"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

"github.com/coder/aibridge/config"
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/intercept"
Expand All @@ -15,12 +22,6 @@ import (
"github.com/coder/aibridge/recorder"
"github.com/coder/aibridge/tracing"
"github.com/coder/quartz"
"github.com/google/uuid"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/shared"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

"cdr.dev/slog/v3"
)
Expand Down
Loading
Loading