Skip to content

Commit 01cd399

Browse files
committed
test: verify upstream errors are relayed to client in streaming chatcompletions
1 parent a127009 commit 01cd399

1 file changed

Lines changed: 127 additions & 0 deletions

File tree

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package chatcompletions
2+
3+
import (
4+
"context"
5+
"io"
6+
"net/http"
7+
"net/http/httptest"
8+
"strconv"
9+
"testing"
10+
11+
"cdr.dev/slog/v3"
12+
"cdr.dev/slog/v3/sloggers/sloghuman"
13+
"github.com/coder/aibridge/config"
14+
"github.com/coder/aibridge/recorder"
15+
"github.com/google/uuid"
16+
"github.com/openai/openai-go/v3"
17+
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/require"
19+
"go.opentelemetry.io/otel"
20+
)
21+
22+
// noopRecorder implements recorder.Recorder for testing.
23+
type noopRecorder struct{}
24+
25+
func (n *noopRecorder) RecordInterception(_ context.Context, _ *recorder.InterceptionRecord) error {
26+
return nil
27+
}
28+
29+
func (n *noopRecorder) RecordInterceptionEnded(_ context.Context, _ *recorder.InterceptionRecordEnded) error {
30+
return nil
31+
}
32+
33+
func (n *noopRecorder) RecordTokenUsage(_ context.Context, _ *recorder.TokenUsageRecord) error {
34+
return nil
35+
}
36+
37+
func (n *noopRecorder) RecordPromptUsage(_ context.Context, _ *recorder.PromptUsageRecord) error {
38+
return nil
39+
}
40+
41+
func (n *noopRecorder) RecordToolUsage(_ context.Context, _ *recorder.ToolUsageRecord) error {
42+
return nil
43+
}
44+
45+
// Test that when the upstream provider returns an error before streaming starts,
46+
// the error status code and body are correctly relayed to the client.
47+
func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
48+
t.Parallel()
49+
50+
tests := []struct {
51+
name string
52+
statusCode int
53+
responseBody string
54+
expectedErrStr string
55+
expectedBody string
56+
}{
57+
{
58+
name: "rate limit error",
59+
statusCode: http.StatusTooManyRequests,
60+
responseBody: `{"error":{"message":"Rate limit exceeded","type":"rate_limit_error","code":"rate_limit_exceeded"}}`,
61+
expectedErrStr: strconv.Itoa(http.StatusTooManyRequests),
62+
expectedBody: "rate_limit",
63+
},
64+
{
65+
name: "internal server error",
66+
statusCode: http.StatusInternalServerError,
67+
responseBody: `{"error":{"message":"Internal server error","type":"server_error","code":"internal_error"}}`,
68+
expectedErrStr: strconv.Itoa(http.StatusInternalServerError),
69+
expectedBody: "server_error",
70+
},
71+
}
72+
73+
for _, tc := range tests {
74+
t.Run(tc.name, func(t *testing.T) {
75+
t.Parallel()
76+
77+
// Setup a mock server that returns an error immediately (before any streaming)
78+
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
79+
w.Header().Set("Content-Type", "application/json")
80+
w.Header().Set("x-should-retry", "false")
81+
w.WriteHeader(tc.statusCode)
82+
_, _ = w.Write([]byte(tc.responseBody))
83+
}))
84+
defer mockServer.Close()
85+
86+
// Create interceptor with mock server URL
87+
cfg := config.OpenAI{
88+
BaseURL: mockServer.URL,
89+
Key: "test-key",
90+
}
91+
92+
req := &ChatCompletionNewParamsWrapper{
93+
ChatCompletionNewParams: openai.ChatCompletionNewParams{
94+
Model: "gpt-4",
95+
Messages: []openai.ChatCompletionMessageParamUnion{
96+
openai.UserMessage("hello"),
97+
},
98+
},
99+
Stream: true,
100+
}
101+
102+
tracer := otel.Tracer("test")
103+
interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, tracer)
104+
105+
logger := slog.Make(sloghuman.Sink(io.Discard))
106+
interceptor.Setup(logger, &noopRecorder{}, nil)
107+
108+
// Create test request
109+
w := httptest.NewRecorder()
110+
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)
111+
112+
// Process the request
113+
err := interceptor.ProcessRequest(w, httpReq)
114+
115+
// Verify error was returned
116+
require.Error(t, err)
117+
assert.Contains(t, err.Error(), tc.expectedErrStr)
118+
119+
// Verify status code was written to response
120+
assert.Equal(t, tc.statusCode, w.Code, "expected status code to be relayed to client")
121+
122+
// Verify error body contains expected error info
123+
body := w.Body.String()
124+
assert.Contains(t, body, tc.expectedBody, "expected error type in response body")
125+
})
126+
}
127+
}

0 commit comments

Comments
 (0)