diff --git a/pkg/transport/proxy/transparent/response_processor.go b/pkg/transport/proxy/transparent/response_processor.go index a6e0ca765f..e7fd14f844 100644 --- a/pkg/transport/proxy/transparent/response_processor.go +++ b/pkg/transport/proxy/transparent/response_processor.go @@ -6,11 +6,29 @@ package transparent import ( + "bytes" + "encoding/json" + "fmt" + "io" + "math" + "mime" "net/http" + "strings" "github.com/stacklok/toolhive/pkg/transport/types" ) +// maxJSONRPCResponseBytes caps how much of an upstream JSON-RPC response the proxy +// will buffer for structural validation. Matches existing streamable-HTTP body +// limits elsewhere in the codebase (pkg/vmcp/client, pkg/vmcp/session/internal/backend). +const maxJSONRPCResponseBytes = 100 << 20 // 100 MiB + +// JSON-RPC error code returned to clients when the proxy rejects a malformed +// upstream response. -32000..-32099 is the implementation-defined server-error +// range in the JSON-RPC 2.0 spec; -32603 is reserved for internal JSON-RPC +// implementation errors and is not appropriate for a policy-level rejection. +const jsonRPCInvalidUpstreamCode = -32000 + // ResponseProcessor defines the interface for processing and modifying HTTP responses // based on transport-specific requirements. type ResponseProcessor interface { @@ -22,12 +40,38 @@ type ResponseProcessor interface { ShouldProcess(resp *http.Response) bool } -// NoOpResponseProcessor is a processor that does nothing. -// Used for transports that don't require response processing (e.g., streamable-http). +// NoOpResponseProcessor is the default processor for non-SSE transports. +// It validates JSON-RPC responses for streamable HTTP and otherwise leaves responses unchanged. type NoOpResponseProcessor struct{} -// ProcessResponse is a no-op implementation. -func (*NoOpResponseProcessor) ProcessResponse(_ *http.Response) error { +// ProcessResponse validates JSON-RPC responses when applicable. +func (*NoOpResponseProcessor) ProcessResponse(resp *http.Response) error { + if !shouldValidateJSONRPCResponse(resp) { + return nil + } + + // Read one byte past the cap so we can detect oversize without allocating beyond it. + body, err := io.ReadAll(io.LimitReader(resp.Body, maxJSONRPCResponseBytes+1)) + if err != nil { + return fmt.Errorf("failed to read upstream response body: %w", err) + } + _ = resp.Body.Close() + + if len(body) > maxJSONRPCResponseBytes { + writeInvalidUpstreamJSONRPCResponse(resp, fmt.Errorf( + "upstream JSON-RPC response exceeds maximum allowed size of %d bytes", maxJSONRPCResponseBytes)) + return nil + } + + if err := validateJSONRPCResponse(body); err != nil { + writeInvalidUpstreamJSONRPCResponse(resp, err) + return nil + } + + // The reverse proxy still needs a readable body after validation. + resp.Body = io.NopCloser(bytes.NewReader(body)) + resp.ContentLength = int64(len(body)) + resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(body))) return nil } @@ -36,6 +80,154 @@ func (*NoOpResponseProcessor) ShouldProcess(_ *http.Response) bool { return false } +func shouldValidateJSONRPCResponse(resp *http.Response) bool { + if resp == nil || resp.Body == nil || resp.Request == nil { + return false + } + if resp.Request.Method != http.MethodPost || resp.StatusCode != http.StatusOK { + return false + } + if !hasIdentityContentEncoding(resp.Header.Get("Content-Encoding")) { + // Content-Encoding semantics (RFC 9110): media-type rules apply after decoding. + // Validating a still-encoded body would mis-classify legitimate gzip JSON-RPC + // frames as invalid. Skip rather than introduce decompression here. + return false + } + if !requestLooksLikeMCP(resp.Request) { + // Narrow validation to traffic that carries an MCP streamable-HTTP signal, + // so non-MCP application/json POSTs flowing through the catch-all are not + // rewritten. Backward-compat clients omitting MCP-Protocol-Version on the + // initial initialize will pass through unchanged. + return false + } + contentType := strings.ToLower(resp.Header.Get("Content-Type")) + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return false + } + return mediaType == "application/json" || mediaType == "application/json-rpc" +} + +func hasIdentityContentEncoding(value string) bool { + v := strings.TrimSpace(strings.ToLower(value)) + return v == "" || v == "identity" +} + +func requestLooksLikeMCP(req *http.Request) bool { + if req == nil { + return false + } + return req.Header.Get("MCP-Protocol-Version") != "" || req.Header.Get("Mcp-Session-Id") != "" +} + +func validateJSONRPCResponse(body []byte) error { + var payload any + dec := json.NewDecoder(bytes.NewReader(body)) + if err := dec.Decode(&payload); err != nil { + return fmt.Errorf("invalid JSON body: %w", err) + } + if dec.More() { + return fmt.Errorf("JSON-RPC response must contain a single JSON value") + } + if err := dec.Decode(&struct{}{}); err != io.EOF { + return fmt.Errorf("JSON-RPC response must contain a single JSON value") + } + + switch value := payload.(type) { + case map[string]any: + return validateJSONRPCResponseObject(value) + case []any: + if len(value) == 0 { + return fmt.Errorf("JSON-RPC batch response must not be empty") + } + for i, item := range value { + obj, ok := item.(map[string]any) + if !ok { + return fmt.Errorf("JSON-RPC batch item %d must be an object", i) + } + if err := validateJSONRPCResponseObject(obj); err != nil { + return fmt.Errorf("JSON-RPC batch item %d is invalid: %w", i, err) + } + } + return nil + default: + return fmt.Errorf("JSON-RPC response must be an object or array") + } +} + +func validateJSONRPCResponseObject(obj map[string]any) error { + if obj["jsonrpc"] != "2.0" { + return fmt.Errorf(`JSON-RPC response must include "jsonrpc":"2.0"`) + } + + if _, ok := obj["id"]; !ok { + return fmt.Errorf("JSON-RPC response must include id") + } + if !isValidJSONRPCID(obj["id"]) { + return fmt.Errorf("JSON-RPC response id must be string, number, or null") + } + + _, hasResult := obj["result"] + _, hasError := obj["error"] + if hasResult == hasError { + return fmt.Errorf("JSON-RPC response must include exactly one of result or error") + } + if hasError { + if errObj, ok := obj["error"].(map[string]any); !ok || !isValidJSONRPCError(errObj) { + return fmt.Errorf("JSON-RPC error response must include error.code and error.message") + } + } + + return nil +} + +func isValidJSONRPCID(id any) bool { + switch id.(type) { + case nil, string, float64: + return true + default: + return false + } +} + +func isValidJSONRPCError(errObj map[string]any) bool { + code, codeOK := errObj["code"].(float64) + if !codeOK || math.Trunc(code) != code { + // JSON-RPC 2.0 requires error.code to be an integer. + return false + } + _, messageOK := errObj["message"].(string) + return messageOK +} + +func writeInvalidUpstreamJSONRPCResponse(resp *http.Response, validationErr error) { + body, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "error": map[string]any{ + "code": jsonRPCInvalidUpstreamCode, + "message": "Invalid upstream JSON-RPC response", + "data": validationErr.Error(), + }, + "id": nil, + }) + if err != nil { + body = []byte(`{"jsonrpc":"2.0","error":{"code":-32000,"message":"Invalid upstream JSON-RPC response"},"id":null}`) + } + + resp.StatusCode = http.StatusBadGateway + resp.Status = fmt.Sprintf("%d %s", http.StatusBadGateway, http.StatusText(http.StatusBadGateway)) + resp.Body = io.NopCloser(bytes.NewReader(body)) + resp.ContentLength = int64(len(body)) + + // Replace headers wholesale so upstream session/cookie/cache metadata is not + // smuggled into the proxy-generated error. Only carry the fields needed to + // describe this synthetic body. + resp.Header = http.Header{} + resp.Header.Set("Content-Type", "application/json") + resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(body))) + resp.Trailer = nil +} + // createResponseProcessor is a factory function that creates the appropriate // response processor based on transport type. func createResponseProcessor( diff --git a/pkg/transport/proxy/transparent/response_processor_test.go b/pkg/transport/proxy/transparent/response_processor_test.go new file mode 100644 index 0000000000..fd53560fb7 --- /dev/null +++ b/pkg/transport/proxy/transparent/response_processor_test.go @@ -0,0 +1,391 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package transparent + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNoOpResponseProcessorValidatesJSONRPCResponses(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + body string + wantStatus int + wantBody string + }{ + { + name: "valid result response passes through", + body: `{"jsonrpc":"2.0","id":1,"result":{"ok":true}}`, + wantStatus: http.StatusOK, + wantBody: `{"jsonrpc":"2.0","id":1,"result":{"ok":true}}`, + }, + { + name: "valid error response passes through", + body: `{"jsonrpc":"2.0","id":"abc","error":{"code":-32601,"message":"Method not found"}}`, + wantStatus: http.StatusOK, + wantBody: `{"jsonrpc":"2.0","id":"abc","error":{"code":-32601,"message":"Method not found"}}`, + }, + { + name: "valid batch response passes through", + body: `[{"jsonrpc":"2.0","id":1,"result":{}},{"jsonrpc":"2.0","id":"two","result":{}}]`, + wantStatus: http.StatusOK, + wantBody: `[{"jsonrpc":"2.0","id":1,"result":{}},{"jsonrpc":"2.0","id":"two","result":{}}]`, + }, + { + name: "valid null result response passes through", + body: `{"jsonrpc":"2.0","id":1,"result":null}`, + wantStatus: http.StatusOK, + wantBody: `{"jsonrpc":"2.0","id":1,"result":null}`, + }, + { + name: "missing jsonrpc is rejected", + body: `{"id":1,"result":{"ok":true}}`, + wantStatus: http.StatusBadGateway, + wantBody: `"Invalid upstream JSON-RPC response"`, + }, + { + name: "invalid id type is rejected", + body: `{"jsonrpc":"2.0","id":{"nested":true},"result":{}}`, + wantStatus: http.StatusBadGateway, + wantBody: `"JSON-RPC response id must be string, number, or null"`, + }, + { + name: "non-object body is rejected", + body: `"not an object"`, + wantStatus: http.StatusBadGateway, + wantBody: `"JSON-RPC response must be an object or array"`, + }, + { + name: "result and error together are rejected", + body: `{"jsonrpc":"2.0","id":1,"result":{},"error":{"code":-32603,"message":"boom"}}`, + wantStatus: http.StatusBadGateway, + wantBody: `"JSON-RPC response must include exactly one of result or error"`, + }, + { + name: "trailing JSON value is rejected", + body: `{"jsonrpc":"2.0","id":1,"result":{}} {"jsonrpc":"2.0","id":2,"result":{}}`, + wantStatus: http.StatusBadGateway, + wantBody: `"JSON-RPC response must contain a single JSON value"`, + }, + { + name: "fractional error code is rejected", + body: `{"jsonrpc":"2.0","id":1,"error":{"code":1.5,"message":"nope"}}`, + wantStatus: http.StatusBadGateway, + wantBody: `"JSON-RPC error response must include error.code and error.message"`, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resp := jsonResponse(tt.body) + if tt.wantStatus == http.StatusBadGateway { + // These sensitive headers must not survive a rewrite. Content-Encoding + // is covered separately by TestNoOpResponseProcessorSkipsCompressedResponses; + // setting it here would route through the pass-through gate instead. + resp.Header.Set("Mcp-Session-Id", "upstream-session-leak") + resp.Header.Set("Set-Cookie", "leak=1") + resp.Header.Set("Etag", "\"upstream-etag\"") + resp.Header.Set("Cache-Control", "private, max-age=60") + } + err := (&NoOpResponseProcessor{}).ProcessResponse(resp) + require.NoError(t, err) + + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, tt.wantStatus, resp.StatusCode) + assert.Contains(t, string(gotBody), tt.wantBody) + assert.Equal(t, int64(len(gotBody)), resp.ContentLength) + assert.Equal(t, len(gotBody), int(resp.ContentLength)) + if tt.wantStatus == http.StatusBadGateway { + // Wholesale header replacement: only Content-Type and Content-Length remain. + assert.Empty(t, resp.Header.Get("Mcp-Session-Id")) + assert.Empty(t, resp.Header.Get("Set-Cookie")) + assert.Empty(t, resp.Header.Get("Etag")) + assert.Empty(t, resp.Header.Get("Cache-Control")) + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + assert.Nil(t, resp.Trailer) + } + }) + } +} + +func TestNoOpResponseProcessorAcceptsJSONContentTypeParameters(t *testing.T) { + t.Parallel() + + resp := jsonResponse(`{"jsonrpc":"2.0","id":1,"result":{}}`) + resp.Header.Set("Content-Type", "application/json; charset=utf-8") + + err := (&NoOpResponseProcessor{}).ProcessResponse(resp) + require.NoError(t, err) + + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, `{"jsonrpc":"2.0","id":1,"result":{}}`, string(gotBody)) +} + +func TestNoOpResponseProcessorSkipsNonJSONRPCResponses(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + status int + contentType string + body string + }{ + { + name: "non-post response", + method: http.MethodGet, + status: http.StatusOK, + contentType: "application/json", + body: `{"resource":"https://example.com"}`, + }, + { + name: "non-200 response", + method: http.MethodPost, + status: http.StatusAccepted, + contentType: "application/json", + body: ``, + }, + { + name: "non-json response", + method: http.MethodPost, + status: http.StatusOK, + contentType: "text/plain", + body: `not json`, + }, + { + name: "post response with event stream", + method: http.MethodPost, + status: http.StatusOK, + contentType: "text/event-stream", + body: "event: message\ndata: {}\n\n", + }, + { + name: "content type containing application/json is not enough", + method: http.MethodPost, + status: http.StatusOK, + contentType: "application/jsonsomethingelse", + body: `not json`, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := mcpRequest(tt.method) + resp := &http.Response{ + StatusCode: tt.status, + Status: http.StatusText(tt.status), + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(tt.body)), + ContentLength: int64(len(tt.body)), + Request: req, + } + resp.Header.Set("Content-Type", tt.contentType) + + err := (&NoOpResponseProcessor{}).ProcessResponse(resp) + require.NoError(t, err) + + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, tt.status, resp.StatusCode) + assert.Equal(t, tt.body, string(gotBody)) + }) + } +} + +// TestNoOpResponseProcessorSkipsCompressedResponses verifies that responses +// carrying a non-identity Content-Encoding are passed through unchanged. +// Decoding here would either reject legitimate compressed JSON-RPC frames or +// open a decompression-bomb amplification path. +func TestNoOpResponseProcessorSkipsCompressedResponses(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + contentEncoding string + body string + }{ + { + name: "gzip valid json is left alone", + contentEncoding: "gzip", + body: gzipBytes(t, `{"jsonrpc":"2.0","id":1,"result":{}}`), + }, + { + name: "gzip malformed body is left alone (no false reject)", + contentEncoding: "gzip", + body: "not really gzip, but encoding header is set", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resp := jsonResponse(tt.body) + resp.Header.Set("Content-Encoding", tt.contentEncoding) + + err := (&NoOpResponseProcessor{}).ProcessResponse(resp) + require.NoError(t, err) + + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, tt.body, string(gotBody)) + }) + } +} + +// TestNoOpResponseProcessorValidatesUnderIdentityEncoding proves that an +// explicit Content-Encoding: identity does not bypass validation: a malformed +// JSON-RPC body must still produce a 502 rewrite. +func TestNoOpResponseProcessorValidatesUnderIdentityEncoding(t *testing.T) { + t.Parallel() + + resp := jsonResponse(`{"id":1,"result":{"ok":true}}`) // missing jsonrpc → invalid + resp.Header.Set("Content-Encoding", "identity") + + require.NoError(t, (&NoOpResponseProcessor{}).ProcessResponse(resp)) + + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) + assert.Contains(t, string(gotBody), `"Invalid upstream JSON-RPC response"`) +} + +// TestNoOpResponseProcessorRequiresMCPSignal narrows validation to traffic that +// carries an MCP streamable-HTTP signal on the request. application/json POST +// 200 responses from non-MCP traffic flowing through the catch-all proxy must +// not be rewritten. +func TestNoOpResponseProcessorRequiresMCPSignal(t *testing.T) { + t.Parallel() + + body := `{"id":1,"result":{"ok":true}}` // missing jsonrpc — would be rejected if validated + + tests := []struct { + name string + headers map[string]string + validate bool + }{ + { + name: "no MCP headers — pass through", + headers: nil, + validate: false, + }, + { + name: "MCP-Protocol-Version header — validated", + headers: map[string]string{"MCP-Protocol-Version": "2025-06-18"}, + validate: true, + }, + { + name: "Mcp-Session-Id header — validated", + headers: map[string]string{"Mcp-Session-Id": "session-abc"}, + validate: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodPost, "http://example.com/mcp", nil) + require.NoError(t, err) + for k, v := range tt.headers { + req.Header.Set(k, v) + } + resp := &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + ContentLength: int64(len(body)), + Request: req, + } + resp.Header.Set("Content-Type", "application/json") + + require.NoError(t, (&NoOpResponseProcessor{}).ProcessResponse(resp)) + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + if tt.validate { + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) + assert.Contains(t, string(gotBody), `"Invalid upstream JSON-RPC response"`) + } else { + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, body, string(gotBody)) + } + }) + } +} + +// TestNoOpResponseProcessorRejectsOversizeResponse verifies the bounded read. +// The proxy is a security boundary; an unbounded io.ReadAll on attacker- +// controlled upstream bodies would amplify a malicious server into a memory +// DoS against the proxy. +func TestNoOpResponseProcessorRejectsOversizeResponse(t *testing.T) { + t.Parallel() + + // Produce a body strictly larger than the cap. Content does not need to be + // valid JSON-RPC — the size check fires before validation. + oversize := strings.Repeat("a", maxJSONRPCResponseBytes+1) + resp := jsonResponse(oversize) + + require.NoError(t, (&NoOpResponseProcessor{}).ProcessResponse(resp)) + + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) + assert.Contains(t, string(gotBody), fmt.Sprintf("exceeds maximum allowed size of %d bytes", maxJSONRPCResponseBytes)) +} + +func jsonResponse(body string) *http.Response { + resp := &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + ContentLength: int64(len(body)), + Request: mcpRequest(http.MethodPost), + } + resp.Header.Set("Content-Type", "application/json") + return resp +} + +func mcpRequest(method string) *http.Request { + req, _ := http.NewRequest(method, "http://example.com/mcp", nil) + req.Header.Set("MCP-Protocol-Version", "2025-06-18") + return req +} + +func gzipBytes(t *testing.T, payload string) string { + t.Helper() + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + _, err := gw.Write([]byte(payload)) + require.NoError(t, err) + require.NoError(t, gw.Close()) + return buf.String() +}