From 568b3fc9e55228d4002a01893ea40bc3fcc50082 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Thu, 22 Jan 2026 17:16:30 +0200 Subject: [PATCH 1/2] feat: conditionally add actor headers Signed-off-by: Danny Kopping --- bridge_integration_test.go | 156 +++++++++++++++++++++++++ config/config.go | 50 ++++---- intercept/actor_headers.go | 79 +++++++++++++ intercept/actor_headers_test.go | 55 +++++++++ intercept/chatcompletions/blocking.go | 6 + intercept/chatcompletions/streaming.go | 14 ++- intercept/messages/blocking.go | 5 + intercept/messages/streaming.go | 10 +- intercept/responses/blocking.go | 6 + intercept/responses/streaming.go | 6 + 10 files changed, 359 insertions(+), 28 deletions(-) create mode 100644 intercept/actor_headers.go create mode 100644 intercept/actor_headers_test.go diff --git a/bridge_integration_test.go b/bridge_integration_test.go index dcaf62f..d23a149 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -26,6 +26,7 @@ import ( "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" @@ -1617,6 +1618,161 @@ func TestEnvironmentDoNotLeak(t *testing.T) { } } +func TestActorHeaders(t *testing.T) { + t.Parallel() + + actorUsername := "bob" + + cases := []struct { + name string + createRequest createRequestFunc + createProviderFn func(url, key string, sendHeaders bool) aibridge.Provider + fixture []byte + streaming bool + }{ + { + name: "openai/v1/chat/completions", + createRequest: createOpenAIChatCompletionsReq, + createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { + cfg := openaiCfg(url, key) + cfg.SendActorHeaders = sendHeaders + return provider.NewOpenAI(cfg) + }, + fixture: fixtures.OaiChatSimple, + streaming: true, + }, + { + name: "openai/v1/chat/completions", + createRequest: createOpenAIChatCompletionsReq, + createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { + cfg := openaiCfg(url, key) + cfg.SendActorHeaders = sendHeaders + return provider.NewOpenAI(cfg) + }, + fixture: fixtures.OaiChatSimple, + streaming: false, + }, + { + name: "openai/v1/responses", + createRequest: createOpenAIResponsesReq, + createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { + cfg := openaiCfg(url, key) + cfg.SendActorHeaders = sendHeaders + return provider.NewOpenAI(cfg) + }, + fixture: fixtures.OaiResponsesStreamingSimple, + streaming: true, + }, + { + name: "openai/v1/responses", + createRequest: createOpenAIResponsesReq, + createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { + cfg := openaiCfg(url, key) + cfg.SendActorHeaders = sendHeaders + return provider.NewOpenAI(cfg) + }, + fixture: fixtures.OaiResponsesBlockingSimple, + streaming: false, + }, + { + name: "anthropic/v1/messages", + createRequest: createAnthropicMessagesReq, + createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { + cfg := anthropicCfg(url, key) + cfg.SendActorHeaders = sendHeaders + return provider.NewAnthropic(cfg, nil) + }, + fixture: fixtures.AntSimple, + streaming: true, + }, + { + name: "anthropic/v1/messages", + createRequest: createAnthropicMessagesReq, + createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { + cfg := anthropicCfg(url, key) + cfg.SendActorHeaders = sendHeaders + return provider.NewAnthropic(cfg, nil) + }, + fixture: fixtures.AntSimple, + streaming: false, + }, + } + + for _, tc := range cases { + for _, send := range []bool{true, false} { + t.Run(fmt.Sprintf("%s/streaming=%v/send-headers=%v", tc.name, tc.streaming, send), func(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + + arc := txtar.Parse(tc.fixture) + files := filesMap(arc) + reqBody := files[fixtureRequest] + + // Add the stream param to the request. + newBody, err := setJSON(reqBody, "stream", tc.streaming) + require.NoError(t, err) + reqBody = newBody + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Track headers received by the upstream server. + var receivedHeaders http.Header + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusTeapot) + })) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + srv.Start() + t.Cleanup(srv.Close) + + rec := &testutil.MockRecorder{} + provider := tc.createProviderFn(srv.URL, apiKey, send) + + b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, rec, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + require.NoError(t, err, "failed to create handler") + + mockSrv := httptest.NewUnstartedServer(b) + t.Cleanup(mockSrv.Close) + + metadataKey := "Username" + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + // Attach an actor to the request context. + return aibcontext.AsActor(ctx, userID, recorder.Metadata{ + metadataKey: actorUsername, + }) + } + mockSrv.Start() + + req := tc.createRequest(t, mockSrv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify that the actor headers were only received if intended. + found := make(map[string][]string) + for k, v := range receivedHeaders { + k = strings.ToLower(k) + if intercept.IsActorHeader(k) { + found[k] = v + } + } + + if send { + require.Equal(t, found[strings.ToLower(intercept.ActorIDHeader())], []string{userID}) + require.Equal(t, found[strings.ToLower(intercept.ActorMetadataHeader(metadataKey))], []string{actorUsername}) + } else { + require.Empty(t, found) + } + }) + } + } +} + func calculateTotalInputTokens(in []*recorder.TokenUsageRecord) int64 { var total int64 for _, el := range in { diff --git a/config/config.go b/config/config.go index 3387007..e9cc526 100644 --- a/config/config.go +++ b/config/config.go @@ -7,6 +7,32 @@ const ( ProviderOpenAI = "openai" ) +type Anthropic struct { + BaseURL string + Key string + APIDumpDir string + CircuitBreaker *CircuitBreaker + SendActorHeaders bool +} + +type AWSBedrock struct { + Region string + AccessKey, AccessKeySecret string + Model, SmallFastModel string + // If set, requests will be sent to this URL instead of the default AWS Bedrock endpoint + // (https://bedrock-runtime.{region}.amazonaws.com). + // This is useful for routing requests through a proxy or for testing. + BaseURL string +} + +type OpenAI struct { + BaseURL string + Key string + APIDumpDir string + CircuitBreaker *CircuitBreaker + SendActorHeaders bool +} + // CircuitBreaker holds configuration for circuit breakers. type CircuitBreaker struct { // MaxRequests is the maximum number of requests allowed in half-open state. @@ -34,27 +60,3 @@ func DefaultCircuitBreaker() CircuitBreaker { MaxRequests: 3, } } - -type Anthropic struct { - BaseURL string - Key string - APIDumpDir string - CircuitBreaker *CircuitBreaker -} - -type AWSBedrock struct { - Region string - AccessKey, AccessKeySecret string - Model, SmallFastModel string - // If set, requests will be sent to this URL instead of the default AWS Bedrock endpoint - // (https://bedrock-runtime.{region}.amazonaws.com). - // This is useful for routing requests through a proxy or for testing. - BaseURL string -} - -type OpenAI struct { - BaseURL string - Key string - APIDumpDir string - CircuitBreaker *CircuitBreaker -} diff --git a/intercept/actor_headers.go b/intercept/actor_headers.go new file mode 100644 index 0000000..2d94503 --- /dev/null +++ b/intercept/actor_headers.go @@ -0,0 +1,79 @@ +package intercept + +import ( + "fmt" + "strings" + + ant_option "github.com/anthropics/anthropic-sdk-go/option" + "github.com/coder/aibridge/context" + oai_option "github.com/openai/openai-go/v3/option" +) + +const ( + prefix = "X-AI-Bridge-Actor" +) + +func ActorIDHeader() string { + return fmt.Sprintf("%s-ID", prefix) +} + +func ActorMetadataHeader(name string) string { + return fmt.Sprintf("%s-Metadata-%s", prefix, name) +} + +func IsActorHeader(name string) bool { + return strings.HasPrefix(strings.ToLower(name), strings.ToLower(prefix)) +} + +// ActorHeadersAsOpenAIOpts produces a slice of headers using OpenAI's RequestOption type. +func ActorHeadersAsOpenAIOpts(actor *context.Actor) []oai_option.RequestOption { + var opts []oai_option.RequestOption + + headers := headersFromActor(actor) + if len(headers) == 0 { + return nil + } + + for k, v := range headers { + // [k] will be canonicalized, see [http.Header]'s [Add] method. + opts = append(opts, oai_option.WithHeaderAdd(k, v)) + } + + return opts +} + +// ActorHeadersAsAnthropicOpts produces a slice of headers using Anthropic's RequestOption type. +func ActorHeadersAsAnthropicOpts(actor *context.Actor) []ant_option.RequestOption { + var opts []ant_option.RequestOption + + headers := headersFromActor(actor) + if len(headers) == 0 { + return nil + } + + for k, v := range headers { + // [k] will be canonicalized, see [http.Header]'s [Add] method. + opts = append(opts, ant_option.WithHeaderAdd(k, v)) + } + + return opts +} + +// headersFromActor produces a map of headers from a given [context.Actor]. +func headersFromActor(actor *context.Actor) map[string]string { + if actor == nil { + return nil + } + + headers := make(map[string]string, len(actor.Metadata)+1) + + // Add actor ID. + headers[ActorIDHeader()] = actor.ID + + // Add headers for provided metadata. + for k, v := range actor.Metadata { + headers[ActorMetadataHeader(k)] = fmt.Sprintf("%v", v) + } + + return headers +} diff --git a/intercept/actor_headers_test.go b/intercept/actor_headers_test.go new file mode 100644 index 0000000..e9f80a8 --- /dev/null +++ b/intercept/actor_headers_test.go @@ -0,0 +1,55 @@ +package intercept + +import ( + "testing" + + "github.com/coder/aibridge/context" + "github.com/coder/aibridge/recorder" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestNilActor(t *testing.T) { + t.Parallel() + + require.Nil(t, ActorHeadersAsOpenAIOpts(nil)) + require.Nil(t, ActorHeadersAsAnthropicOpts(nil)) +} + +func TestBasic(t *testing.T) { + t.Parallel() + + actorID := uuid.NewString() + actor := &context.Actor{ + ID: actorID, + } + + // We can't peek inside since these opts require an internal type to apply onto. + // All we can do is check the length. + // See TestActorHeaders for an integration test. + oaiOpts := ActorHeadersAsOpenAIOpts(actor) + require.Len(t, oaiOpts, 1) + antOpts := ActorHeadersAsAnthropicOpts(actor) + require.Len(t, antOpts, 1) +} + +func TestBasicAndMetadata(t *testing.T) { + t.Parallel() + + actorID := uuid.NewString() + actor := &context.Actor{ + ID: actorID, + Metadata: recorder.Metadata{ + "This": "That", + "And": "The other", + }, + } + + // We can't peek inside since these opts require an internal type to apply onto. + // All we can do is check the length. + // See TestActorHeaders for an integration test. + oaiOpts := ActorHeadersAsOpenAIOpts(actor) + require.Len(t, oaiOpts, 1+len(actor.Metadata)) + antOpts := ActorHeadersAsAnthropicOpts(actor) + require.Len(t, antOpts, 1+len(actor.Metadata)) +} diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index 54a4c16..c650ade 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -9,6 +9,8 @@ import ( "time" "github.com/coder/aibridge/config" + aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/eventstream" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -73,8 +75,12 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req for { // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) + var opts []option.RequestOption opts = append(opts, option.WithRequestTimeout(time.Second*600)) + if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { + opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...) + } completion, err = i.newChatCompletion(ctx, svc, opts) if err != nil { diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index 5193148..fcfffc4 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -11,12 +11,15 @@ import ( "time" "github.com/coder/aibridge/config" + aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/eventstream" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" "github.com/google/uuid" "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" "github.com/openai/openai-go/v3/packages/ssestream" "github.com/tidwall/sjson" "go.opentelemetry.io/otel/attribute" @@ -116,7 +119,12 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re ) for { // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) - stream = i.newStream(streamCtx, svc) + var opts []option.RequestOption + if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { + opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...) + } + + stream = i.newStream(streamCtx, svc, opts) processor := newStreamProcessor(streamCtx, i.logger.Named("stream-processor"), i.getInjectedToolByName) var toolCall *openai.FinishedChatCompletionToolCall @@ -349,11 +357,11 @@ func (i *StreamingInterception) encodeForStream(payload []byte) []byte { } // newStream traces svc.NewStreaming(streamCtx, i.req.ChatCompletionNewParams) call -func (i *StreamingInterception) newStream(ctx context.Context, svc openai.ChatCompletionService) *ssestream.Stream[openai.ChatCompletionChunk] { +func (i *StreamingInterception) newStream(ctx context.Context, svc openai.ChatCompletionService, opts []option.RequestOption) *ssestream.Stream[openai.ChatCompletionChunk] { _, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() - return svc.NewStreaming(ctx, i.req.ChatCompletionNewParams) + return svc.NewStreaming(ctx, i.req.ChatCompletionNewParams, opts...) } type streamProcessor struct { diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 9e86c09..f027e49 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -9,6 +9,8 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" "github.com/coder/aibridge/config" + aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/eventstream" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -71,6 +73,9 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req } opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 600)} + if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { + opts = append(opts, intercept.ActorHeadersAsAnthropicOpts(actor)...) + } svc, err := i.newMessagesService(ctx, opts...) if err != nil { diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 83ef086..31d0ea9 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -11,9 +11,12 @@ import ( "time" "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge/config" + aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/eventstream" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -106,7 +109,12 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re streamCtx, streamCancel := context.WithCancelCause(ctx) defer streamCancel(errors.New("deferred")) - svc, err := i.newMessagesService(streamCtx) + var opts []option.RequestOption + if actor := aibcontext.ActorFromContext(ctx); actor != nil && i.cfg.SendActorHeaders { + opts = append(opts, intercept.ActorHeadersAsAnthropicOpts(actor)...) + } + + svc, err := i.newMessagesService(streamCtx, opts...) if err != nil { err = fmt.Errorf("create anthropic client: %w", err) http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index a4b415b..4ec1f63 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -9,6 +9,8 @@ import ( "cdr.dev/slog/v3" "github.com/coder/aibridge/config" + aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" @@ -72,6 +74,10 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * opts := i.requestOptions(&respCopy) opts = append(opts, option.WithRequestTimeout(time.Second*600)) + if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { + opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...) + } + response, upstreamErr = i.newResponse(ctx, srv, opts) if upstreamErr != nil { diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 04d44ca..370c47c 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -9,6 +9,8 @@ import ( "cdr.dev/slog/v3" "github.com/coder/aibridge/config" + aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/eventstream" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -83,6 +85,10 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r srv := i.newResponsesService() opts := i.requestOptions(&respCopy) + if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { + opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...) + } + stream := i.newStream(ctx, srv, opts) defer stream.Close() From 3907941e1bf2a84c475c483bde75e958a2617ee1 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Fri, 23 Jan 2026 11:35:23 +0200 Subject: [PATCH 2/2] chore: explicitly check that response headers were captured Signed-off-by: Danny Kopping --- bridge_integration_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index d23a149..dae4fe8 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1751,6 +1751,7 @@ func TestActorHeaders(t *testing.T) { client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) + require.NotEmpty(t, receivedHeaders) defer resp.Body.Close() // Verify that the actor headers were only received if intended.