Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion intercept/apidump/apidump.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (d *dumper) dumpResponse(resp *http.Response) error {
// for deterministic output.
// `sensitive` and `overrides` must both supply keys in canoncialized form.
// See [textproto.MIMEHeader].
func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) error {
func (*dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) error {
// Collect all header keys including overrides.
headerKeys := make([]string, 0, len(headers)+len(overrides))
seen := make(map[string]struct{}, len(headers)+len(overrides))
Expand Down
10 changes: 5 additions & 5 deletions intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ func (i *interceptionBase) CorrelatingToolCallID() *string {
return &msg.OfTool.ToolCallID
}

func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
func (i *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
return []attribute.KeyValue{
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, s.id.String()),
attribute.String(tracing.InterceptionID, i.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, s.providerName),
attribute.String(tracing.Model, s.Model()),
attribute.String(tracing.Provider, i.providerName),
attribute.String(tracing.Model, i.Model()),
attribute.Bool(tracing.Streaming, streaming),
}
}
Expand All @@ -118,7 +118,7 @@ func (i *interceptionBase) Model() string {
return i.req.Model
}

func (i *interceptionBase) newErrorResponse(err error) map[string]any {
func (*interceptionBase) newErrorResponse(err error) map[string]any {
return map[string]any{
"error": true,
"message": err.Error(),
Expand Down
10 changes: 5 additions & 5 deletions intercept/chatcompletions/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ func NewBlockingInterceptor(
}}
}

func (s *BlockingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
s.interceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
func (i *BlockingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.interceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
}

func (s *BlockingInterception) Streaming() bool {
func (*BlockingInterception) Streaming() bool {
return false
}

func (s *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return s.interceptionBase.baseTraceAttributes(r, false)
func (i *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return i.interceptionBase.baseTraceAttributes(r, false)
}

func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) {
Expand Down
8 changes: 4 additions & 4 deletions intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ func (i *StreamingInterception) Setup(logger slog.Logger, recorder recorder.Reco
i.interceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy)
}

func (i *StreamingInterception) Streaming() bool {
func (*StreamingInterception) Streaming() bool {
return true
}

func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return s.interceptionBase.baseTraceAttributes(r, true)
func (i *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return i.interceptionBase.baseTraceAttributes(r, true)
}

// ProcessRequest handles a request to /v1/chat/completions.
Expand Down Expand Up @@ -389,7 +389,7 @@ func (i *StreamingInterception) marshalErr(err error) ([]byte, error) {
return i.encodeForStream(data), nil
}

func (i *StreamingInterception) encodeForStream(payload []byte) []byte {
func (*StreamingInterception) encodeForStream(payload []byte) []byte {
var buf bytes.Buffer
buf.WriteString("data: ")
buf.Write(payload)
Expand Down
1 change: 0 additions & 1 deletion intercept/eventstream/eventstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ type EventStream struct {
initiated atomic.Bool
initiateOnce sync.Once

closeOnce sync.Once
shutdownOnce sync.Once
eventsCh chan event

Expand Down
14 changes: 7 additions & 7 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ func (i *interceptionBase) Model() string {
return i.reqPayload.model()
}

func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
func (i *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
return []attribute.KeyValue{
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, s.id.String()),
attribute.String(tracing.InterceptionID, i.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, s.providerName),
attribute.String(tracing.Model, s.Model()),
attribute.String(tracing.Provider, i.providerName),
attribute.String(tracing.Model, i.Model()),
attribute.Bool(tracing.Streaming, streaming),
attribute.Bool(tracing.IsBedrock, s.bedrockCfg != nil),
attribute.Bool(tracing.IsBedrock, i.bedrockCfg != nil),
}
}

Expand Down Expand Up @@ -176,7 +176,7 @@ func (i *interceptionBase) disableParallelToolCalls() {
}

// extractModelThoughts returns any thinking blocks that were returned in the response.
func (i *interceptionBase) extractModelThoughts(msg *anthropic.Message) []*recorder.ModelThoughtRecord {
func (*interceptionBase) extractModelThoughts(msg *anthropic.Message) []*recorder.ModelThoughtRecord {
if msg == nil {
return nil
}
Expand Down Expand Up @@ -264,7 +264,7 @@ func (i *interceptionBase) withBody() option.RequestOption {
return option.WithRequestBody("application/json", []byte(i.reqPayload))
}

func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) {
func (*interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) {
if cfg == nil {
return nil, xerrors.New("nil config given")
}
Expand Down
6 changes: 3 additions & 3 deletions intercept/messages/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,11 +777,11 @@ type mockServerProxier struct {
tools []*mcp.Tool
}

func (m *mockServerProxier) Init(context.Context) error {
func (*mockServerProxier) Init(context.Context) error {
return nil
}

func (m *mockServerProxier) Shutdown(context.Context) error {
func (*mockServerProxier) Shutdown(context.Context) error {
return nil
}

Expand All @@ -798,7 +798,7 @@ func (m *mockServerProxier) GetTool(id string) *mcp.Tool {
return nil
}

func (m *mockServerProxier) CallTool(context.Context, string, any) (*mcpgo.CallToolResult, error) {
func (*mockServerProxier) CallTool(context.Context, string, any) (*mcpgo.CallToolResult, error) {
return nil, nil
}

Expand Down
2 changes: 1 addition & 1 deletion intercept/messages/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (i *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyV
return i.interceptionBase.baseTraceAttributes(r, false)
}

func (s *BlockingInterception) Streaming() bool {
func (*BlockingInterception) Streaming() bool {
return false
}

Expand Down
2 changes: 1 addition & 1 deletion intercept/messages/reqpayload.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) {
return p.resultToRawMessage(tools.Array()), nil
}

func (p MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage {
func (MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage {
// gjson.Result conversion to json.RawMessage is needed because
// gjson.Result does not implement json.Marshaler — would
// serialize its struct fields instead of the raw JSON it represents.
Expand Down
32 changes: 16 additions & 16 deletions intercept/messages/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ func NewStreamingInterceptor(
}}
}

func (s *StreamingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
s.interceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy)
func (i *StreamingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.interceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy)
}

func (s *StreamingInterception) Streaming() bool {
func (*StreamingInterception) Streaming() bool {
return true
}

func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return s.interceptionBase.baseTraceAttributes(r, true)
func (i *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return i.interceptionBase.baseTraceAttributes(r, true)
}

// ProcessRequest handles a request to /v1/messages.
Expand Down Expand Up @@ -534,8 +534,8 @@ newStream:
return interceptionErr
}

func (s *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) {
sj, err := sjson.Set(event.RawJSON(), "message.id", s.ID().String())
func (i *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) {
sj, err := sjson.Set(event.RawJSON(), "message.id", i.ID().String())
if err != nil {
return nil, xerrors.Errorf("marshal event id failed: %w", err)
}
Expand All @@ -545,10 +545,10 @@ func (s *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventU
return nil, xerrors.Errorf("marshal event usage failed: %w", err)
}

return s.encodeForStream([]byte(sj), event.Type), nil
return i.encodeForStream([]byte(sj), event.Type), nil
}

func (s *StreamingInterception) marshal(payload any) ([]byte, error) {
func (i *StreamingInterception) marshal(payload any) ([]byte, error) {
data, err := json.Marshal(payload)
if err != nil {
return nil, xerrors.Errorf("marshal payload: %w", err)
Expand All @@ -564,15 +564,15 @@ func (s *StreamingInterception) marshal(payload any) ([]byte, error) {
return nil, xerrors.Errorf("could not determine type from payload %q", data)
}

return s.encodeForStream(data, eventType), nil
return i.encodeForStream(data, eventType), nil
}

// https://docs.anthropic.com/en/docs/build-with-claude/streaming#basic-streaming-request
func (s *StreamingInterception) pingPayload() []byte {
return s.encodeForStream([]byte(`{"type": "ping"}`), "ping")
func (i *StreamingInterception) pingPayload() []byte {
return i.encodeForStream([]byte(`{"type": "ping"}`), "ping")
}

func (s *StreamingInterception) encodeForStream(payload []byte, typ string) []byte {
func (*StreamingInterception) encodeForStream(payload []byte, typ string) []byte {
var buf bytes.Buffer
buf.WriteString("event: ")
buf.WriteString(typ)
Expand All @@ -584,9 +584,9 @@ func (s *StreamingInterception) encodeForStream(payload []byte, typ string) []by
}

// newStream traces svc.NewStreaming() call.
func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
_, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
func (i *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer span.End()

return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, s.withBody())
return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, i.withBody())
}
2 changes: 1 addition & 1 deletion intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, respon
// extractModelThoughts extracts model thoughts from response output items.
// It captures both reasoning summary items and commentary messages (message
// output items with "phase": "commentary") as model thoughts.
func (i *responsesInterceptionBase) extractModelThoughts(response *responses.Response) []*recorder.ModelThoughtRecord {
func (*responsesInterceptionBase) extractModelThoughts(response *responses.Response) []*recorder.ModelThoughtRecord {
if response == nil {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion intercept/responses/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (i *BlockingResponsesInterceptor) Setup(logger slog.Logger, recorder record
i.responsesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
}

func (i *BlockingResponsesInterceptor) Streaming() bool {
func (*BlockingResponsesInterceptor) Streaming() bool {
return false
}

Expand Down
2 changes: 1 addition & 1 deletion intercept/responses/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (i *StreamingResponsesInterceptor) Setup(logger slog.Logger, recorder recor
i.responsesInterceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy)
}

func (i *StreamingResponsesInterceptor) Streaming() bool {
func (*StreamingResponsesInterceptor) Streaming() bool {
return true
}

Expand Down
2 changes: 1 addition & 1 deletion internal/integrationtest/mockmcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) {
tool := mcplib.NewTool(name,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
s.AddTool(tool, func(_ context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
acc.addCall(request.Params.Name, request.Params.Arguments)
if errMsg, ok := acc.getToolError(request.Params.Name); ok {
return nil, xerrors.New(errMsg)
Expand Down
12 changes: 6 additions & 6 deletions internal/testutil/mock_recorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ type MockRecorder struct {
interceptionsEnd map[string]*recorder.InterceptionRecordEnded
}

func (m *MockRecorder) RecordInterception(ctx context.Context, req *recorder.InterceptionRecord) error {
func (m *MockRecorder) RecordInterception(_ context.Context, req *recorder.InterceptionRecord) error {
m.mu.Lock()
defer m.mu.Unlock()
m.interceptions = append(m.interceptions, req)
return nil
}

func (m *MockRecorder) RecordInterceptionEnded(ctx context.Context, req *recorder.InterceptionRecordEnded) error {
func (m *MockRecorder) RecordInterceptionEnded(_ context.Context, req *recorder.InterceptionRecordEnded) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.interceptionsEnd == nil {
Expand All @@ -46,28 +46,28 @@ func (m *MockRecorder) RecordInterceptionEnded(ctx context.Context, req *recorde
return nil
}

func (m *MockRecorder) RecordPromptUsage(ctx context.Context, req *recorder.PromptUsageRecord) error {
func (m *MockRecorder) RecordPromptUsage(_ context.Context, req *recorder.PromptUsageRecord) error {
m.mu.Lock()
defer m.mu.Unlock()
m.userPrompts = append(m.userPrompts, req)
return nil
}

func (m *MockRecorder) RecordTokenUsage(ctx context.Context, req *recorder.TokenUsageRecord) error {
func (m *MockRecorder) RecordTokenUsage(_ context.Context, req *recorder.TokenUsageRecord) error {
m.mu.Lock()
defer m.mu.Unlock()
m.tokenUsages = append(m.tokenUsages, req)
return nil
}

func (m *MockRecorder) RecordToolUsage(ctx context.Context, req *recorder.ToolUsageRecord) error {
func (m *MockRecorder) RecordToolUsage(_ context.Context, req *recorder.ToolUsageRecord) error {
m.mu.Lock()
defer m.mu.Unlock()
m.toolUsages = append(m.toolUsages, req)
return nil
}

func (m *MockRecorder) RecordModelThought(ctx context.Context, req *recorder.ModelThoughtRecord) error {
func (m *MockRecorder) RecordModelThought(_ context.Context, req *recorder.ModelThoughtRecord) error {
m.mu.Lock()
defer m.mu.Unlock()
m.modelThoughts = append(m.modelThoughts, req)
Expand Down
20 changes: 10 additions & 10 deletions internal/testutil/mockprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ type MockProvider struct {
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
}

func (m *MockProvider) Type() string { return m.NameStr }
func (m *MockProvider) Name() string { return m.NameStr }
func (m *MockProvider) BaseURL() string { return m.URL }
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) }
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
func (m *MockProvider) AuthHeader() string { return "Authorization" }
func (m *MockProvider) InjectAuthHeader(h *http.Header) {}
func (m *MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil }
func (m *MockProvider) APIDumpDir() string { return "" }
func (m *MockProvider) Type() string { return m.NameStr }
func (m *MockProvider) Name() string { return m.NameStr }
func (m *MockProvider) BaseURL() string { return m.URL }
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) }
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
func (*MockProvider) AuthHeader() string { return "Authorization" }
func (*MockProvider) InjectAuthHeader(_ *http.Header) {}
func (*MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil }
func (*MockProvider) APIDumpDir() string { return "" }
func (m *MockProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) {
if m.InterceptorFunc != nil {
return m.InterceptorFunc(w, r, tracer)
Expand Down
2 changes: 1 addition & 1 deletion mcp/proxy_streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (_ map[strin
return out, nil
}

func (p *StreamableHTTPServerProxy) Shutdown(ctx context.Context) error {
func (p *StreamableHTTPServerProxy) Shutdown(_ context.Context) error {
if p.client == nil {
return nil
}
Expand Down
Loading
Loading