Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions intercept/apidump/apidump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) {
req.Header.Set("User-Agent", "test-client")

// Call middleware with a mock next function
_, err = middleware(req, func(r *http.Request) (*http.Response, error) {
resp, err := middleware(req, func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Expand All @@ -62,6 +62,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) {
}, nil
})
require.NoError(t, err)
defer resp.Body.Close()

// Read the request dump file
modelDir := filepath.Join(tmpDir, "openai", "gpt-4")
Expand Down Expand Up @@ -170,7 +171,7 @@ func TestBridgedMiddleware_PreservesRequestBody(t *testing.T) {
require.NoError(t, err)

var capturedBody []byte
_, err = middleware(req, func(r *http.Request) (*http.Response, error) {
resp2, err := middleware(req, func(r *http.Request) (*http.Response, error) {
// Read the body in the next handler to verify it's still available
capturedBody, _ = io.ReadAll(r.Body)
return &http.Response{
Expand All @@ -182,6 +183,7 @@ func TestBridgedMiddleware_PreservesRequestBody(t *testing.T) {
}, nil
})
require.NoError(t, err)
defer resp2.Body.Close()

// Verify the body was preserved for the next handler
require.Equal(t, originalBody, string(capturedBody))
Expand All @@ -202,7 +204,7 @@ func TestBridgedMiddleware_ModelWithSlash(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)

_, err = middleware(req, func(r *http.Request) (*http.Response, error) {
resp3, err := middleware(req, func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Expand All @@ -212,6 +214,7 @@ func TestBridgedMiddleware_ModelWithSlash(t *testing.T) {
}, nil
})
require.NoError(t, err)
defer resp3.Body.Close()

// Verify files are created with sanitized model name
modelDir := filepath.Join(tmpDir, "google", "gemini-1.5-pro")
Expand Down Expand Up @@ -290,7 +293,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) {
req.Header.Set("Proxy-Authorization", "Basic proxy-creds")
req.Header.Set("X-Amz-Security-Token", "aws-security-token")

_, err = middleware(req, func(r *http.Request) (*http.Response, error) {
resp4, err := middleware(req, func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Expand All @@ -300,6 +303,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) {
}, nil
})
require.NoError(t, err)
defer resp4.Body.Close()

modelDir := filepath.Join(tmpDir, "openai", "gpt-4")
reqDumpPath := findDumpFile(t, modelDir, SuffixRequest)
Expand Down Expand Up @@ -358,7 +362,7 @@ func TestPassthroughMiddleware(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/models", nil)
require.NoError(t, err)

resp, err := rt.RoundTrip(req)
resp, err := rt.RoundTrip(req) //nolint:bodyclose // resp is nil on error
require.ErrorIs(t, err, innerErr)
require.Nil(t, resp)
})
Expand Down
1 change: 1 addition & 0 deletions intercept/apidump/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func TestMiddleware_PreservesResponseBody(t *testing.T) {
}, nil
})
require.NoError(t, err)
defer resp.Body.Close()

// Verify the response body is still readable after middleware
capturedBody, err := io.ReadAll(resp.Body)
Expand Down
14 changes: 10 additions & 4 deletions internal/integrationtest/apidump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ func TestAPIDump(t *testing.T) {
withCustomProvider(tc.providerFunc(srv.URL, dumpDir)),
)

resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
_, err := io.ReadAll(resp.Body)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)

// Verify dump files were created.
Expand Down Expand Up @@ -187,6 +189,7 @@ func TestAPIDump(t *testing.T) {
// Parse the dumped HTTP response.
dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil)
require.NoError(t, err)
defer dumpResp.Body.Close()
require.Equal(t, http.StatusOK, dumpResp.StatusCode)
dumpRespBody, err := io.ReadAll(dumpResp.Body)
require.NoError(t, err)
Expand Down Expand Up @@ -256,12 +259,14 @@ func TestAPIDumpPassthrough(t *testing.T) {
withCustomProvider(tc.providerFunc(upstream.URL, dumpDir)),
)

bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil)
resp, err := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil)
require.NoError(t, err)
defer resp.Body.Close()

// Find dump files in the passthrough directory.
passthroughDir := filepath.Join(dumpDir, tc.name, "passthrough")
var reqDumpFile, respDumpFile string
err := filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error {
err = filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
Expand Down Expand Up @@ -299,6 +304,7 @@ func TestAPIDumpPassthrough(t *testing.T) {
require.NoError(t, err)
dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil)
require.NoError(t, err)
defer dumpResp.Body.Close()
require.Equal(t, http.StatusOK, dumpResp.StatusCode)
dumpRespBody, err := io.ReadAll(dumpResp.Body)
require.NoError(t, err)
Expand Down
76 changes: 57 additions & 19 deletions internal/integrationtest/bridge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ func TestAnthropicMessages(t *testing.T) {
// Make API call to aibridge for Anthropic /v1/messages
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

// Response-specific checks.
Expand Down Expand Up @@ -220,7 +222,9 @@ func TestAnthropicMessagesModelThoughts(t *testing.T) {

reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

if tc.streaming {
Expand Down Expand Up @@ -258,7 +262,9 @@ func TestAWSBedrockIntegration(t *testing.T) {
withCustomProvider(provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)),
)

resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool))
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool))
require.NoError(t, err)
defer resp.Body.Close()

require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
Expand Down Expand Up @@ -296,7 +302,9 @@ func TestAWSBedrockIntegration(t *testing.T) {
// We override the AWS Bedrock client to route requests through our mock server.
reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
require.NoError(t, err)
defer resp.Body.Close()

// For streaming responses, consume the body to allow the stream to complete.
if streaming {
Expand Down Expand Up @@ -419,9 +427,11 @@ func TestAWSBedrockIntegration(t *testing.T) {
require.NoError(t, err)

// Send with Anthropic-Beta header containing flags that should be filtered.
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, http.Header{
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, http.Header{
"Anthropic-Beta": {"interleaved-thinking-2025-05-14,effort-2025-11-24,context-management-2025-06-27,prompt-caching-scope-2026-01-05"},
})
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
Expand Down Expand Up @@ -502,7 +512,9 @@ func TestOpenAIChatCompletions(t *testing.T) {
// Make API call to aibridge for OpenAI /v1/chat/completions
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

// Response-specific checks.
Expand Down Expand Up @@ -583,7 +595,9 @@ func TestOpenAIChatCompletions(t *testing.T) {
// Add the stream param to the request.
reqBody, err := sjson.SetBytes(fix.Request(), "stream", true)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

// Verify SSE headers are sent correctly
Expand Down Expand Up @@ -767,7 +781,9 @@ func TestSimple(t *testing.T) {
// When: calling the "API server" with the fixture's request body.
reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}})
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}})
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

// Then: I expect the upstream request to have the correct path.
Expand Down Expand Up @@ -875,11 +891,13 @@ func TestSessionIDTracking(t *testing.T) {
require.NoError(t, err)
}

resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

// Drain the body to let the stream complete.
_, err := io.ReadAll(resp.Body)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)

interceptions := bridgeServer.Recorder.RecordedInterceptions()
Expand Down Expand Up @@ -951,7 +969,9 @@ func TestFallthrough(t *testing.T) {
upstream := newMockUpstream(t.Context(), t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL+tc.basePath)

resp := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil)
resp, err := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil)
require.NoError(t, err)
defer resp.Body.Close()

require.Equal(t, http.StatusOK, resp.StatusCode)

Expand Down Expand Up @@ -984,6 +1004,7 @@ func TestAnthropicInjectedTools(t *testing.T) {

// Build the requirements & make the assertions which are common to all providers.
bridgeServer, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, defaultTracer, pathAnthropicMessages, anthropicToolResultValidator(t))
defer resp.Body.Close()

// Ensure expected tool was invoked with expected input.
toolUsages := bridgeServer.Recorder.RecordedToolUsages()
Expand Down Expand Up @@ -1067,6 +1088,7 @@ func TestOpenAIInjectedTools(t *testing.T) {

// Build the requirements & make the assertions which are common to all providers.
bridgeServer, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, pathOpenAIChatCompletions, openaiChatToolResultValidator(t))
defer resp.Body.Close()

// Ensure expected tool was invoked with expected input.
toolUsages := bridgeServer.Recorder.RecordedToolUsages()
Expand Down Expand Up @@ -1290,7 +1312,9 @@ func TestErrorHandling(t *testing.T) {
reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming)
require.NoError(t, err)

resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
require.NoError(t, err)
defer resp.Body.Close()

tc.responseHandlerFn(resp)
bridgeServer.Recorder.VerifyAllInterceptionsEnded(t)
Expand Down Expand Up @@ -1357,7 +1381,9 @@ func TestErrorHandling(t *testing.T) {

bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)

resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
require.NoError(t, err)
defer resp.Body.Close()

tc.responseHandlerFn(resp)
bridgeServer.Recorder.VerifyAllInterceptionsEnded(t)
Expand Down Expand Up @@ -1416,7 +1442,9 @@ func TestStableRequestEncoding(t *testing.T) {

// Make multiple requests and verify they all have identical payloads.
for range count {
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
}

Expand Down Expand Up @@ -1679,7 +1707,9 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) {
reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice)
require.NoError(t, err)

resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

// Verify tool_choice in the upstream request.
Expand Down Expand Up @@ -1842,7 +1872,9 @@ func TestChatCompletionsParallelToolCallsDisabled(t *testing.T) {
reqBody, err = sjson.SetBytes(reqBody, "stream", streaming)
require.NoError(t, err)

resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)

Expand Down Expand Up @@ -1886,7 +1918,9 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) {
reqBody, err = sjson.SetBytes(reqBody, "stream", streaming)
require.NoError(t, err)

resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
Expand Down Expand Up @@ -1949,7 +1983,9 @@ func TestEnvironmentDoNotLeak(t *testing.T) {

bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)

resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

// Verify that environment values did not leak.
Expand Down Expand Up @@ -2063,7 +2099,9 @@ func TestActorHeaders(t *testing.T) {
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)

resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
// Drain the body so streaming responses complete without
// a "connection reset" error in the mock upstream.
_, err = io.ReadAll(resp.Body)
Expand Down
Loading
Loading