diff --git a/mcp/client.go b/mcp/client.go index c82e189d..6e24c5a3 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -160,6 +160,12 @@ type ClientOptions struct { KeepAlive time.Duration } +// toolContextKeyType is the context key type for passing tool definitions +// from CallTool to the transport layer. +type toolContextKeyType struct{} + +var toolContextKey = toolContextKeyType{} + // bind implements the binder[*ClientSession] interface, so that Clients can // be connected using [connect]. func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession { @@ -318,6 +324,13 @@ type ClientSession struct { // Pending URL elicitations waiting for completion notifications. pendingElicitationsMu sync.Mutex pendingElicitations map[string]chan struct{} + + // toolCacheMu guards toolCache. + toolCacheMu sync.RWMutex + // toolCache stores tool definitions keyed by name. + // It is used to look up x-mcp-header annotations when + // constructing Mcp-Param-* headers for tools/call requests. + toolCache map[string]*Tool } type clientSessionState struct { @@ -363,6 +376,21 @@ func (cs *ClientSession) Wait() error { return cs.conn.Wait() } +func (cs *ClientSession) cacheTools(tools []*Tool) { + cs.toolCacheMu.Lock() + defer cs.toolCacheMu.Unlock() + cs.toolCache = make(map[string]*Tool, len(tools)) + for _, tool := range tools { + cs.toolCache[tool.Name] = tool + } +} + +func (cs *ClientSession) getCachedTool(name string) *Tool { + cs.toolCacheMu.RLock() + defer cs.toolCacheMu.RUnlock() + return cs.toolCache[name] +} + // registerElicitationWaiter registers a waiter for an elicitation complete // notification with the given elicitation ID. It returns two functions: an await // function that waits for the notification or context cancellation, and a cleanup @@ -981,7 +1009,13 @@ func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) // ListTools lists tools that are currently available on the server. func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { - return handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) + result, err := handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) + if err != nil { + return nil, err + } + result.Tools = filterValidTools(cs.client.opts.Logger, result.Tools) + cs.cacheTools(result.Tools) + return result, nil } // CallTool calls the tool with the given parameters. @@ -995,6 +1029,9 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) ( // Avoid sending nil over the wire. params.Arguments = map[string]any{} } + if tool := cs.getCachedTool(params.Name); tool != nil { + ctx = context.WithValue(ctx, toolContextKey, tool) + } return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) } diff --git a/mcp/client_test.go b/mcp/client_test.go index fc37c3eb..609fd501 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -440,6 +440,80 @@ func TestClientCapabilities(t *testing.T) { } } +func TestToolCache(t *testing.T) { + tool1 := &Tool{Name: "tool1", Description: "first"} + tool2 := &Tool{Name: "tool2", Description: "second"} + tool1Updated := &Tool{Name: "tool1", Description: "updated"} + + testCases := []struct { + name string + cacheBatches [][]*Tool + lookup string + want *Tool + }{ + { + name: "empty cache", + lookup: "tool1", + want: nil, + }, + { + name: "single tool found", + cacheBatches: [][]*Tool{{tool1}}, + lookup: "tool1", + want: tool1, + }, + { + name: "unknown tool", + cacheBatches: [][]*Tool{{tool1}}, + lookup: "nonexistent", + want: nil, + }, + { + name: "multiple tools single batch", + cacheBatches: [][]*Tool{{tool1, tool2}}, + lookup: "tool2", + want: tool2, + }, + { + name: "replace clears old entries", + cacheBatches: [][]*Tool{{tool1}, {tool2}}, + lookup: "tool1", + want: nil, + }, + { + name: "replace keeps new entries", + cacheBatches: [][]*Tool{{tool1}, {tool2}}, + lookup: "tool2", + want: tool2, + }, + { + name: "overwrite existing entry", + cacheBatches: [][]*Tool{{tool1}, {tool1Updated}}, + lookup: "tool1", + want: tool1Updated, + }, + { + name: "empty batch no-op", + cacheBatches: [][]*Tool{{}}, + lookup: "tool1", + want: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cs := &ClientSession{} + for _, batch := range tc.cacheBatches { + cs.cacheTools(batch) + } + got := cs.getCachedTool(tc.lookup) + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("getCachedTool(%q) mismatch (-want +got):\n%s", tc.lookup, diff) + } + }) + } +} + func TestClientCapabilitiesOverWire(t *testing.T) { testCases := []struct { name string diff --git a/mcp/server.go b/mcp/server.go index 16d06ca8..d25c7922 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -280,6 +280,9 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { } } } + if err := validateParamHeaderAnnotations(t); err != nil { + panic(fmt.Errorf("AddTool %q: invalid parameter header annotations: %v", t.Name, err)) + } st := &serverTool{tool: t, handler: h} // Assume there was a change, since add replaces existing tools. // (It's possible a tool was replaced with an identical one, but not worth checking.) @@ -753,10 +756,15 @@ func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListTools }) } -func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { +// getServerTool looks up a server tool by name. +func (s *Server) getServerTool(name string) (*serverTool, bool) { s.mu.Lock() - st, ok := s.tools.get(req.Params.Name) - s.mu.Unlock() + defer s.mu.Unlock() + return s.tools.get(name) +} + +func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + st, ok := s.getServerTool(req.Params.Name) if !ok { return nil, &jsonrpc.Error{ Code: jsonrpc.CodeInvalidParams, diff --git a/mcp/streamable.go b/mcp/streamable.go index 7b053737..da135374 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -491,6 +491,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "failed connection", http.StatusInternalServerError) return } + transport.connection.toolLookup = server.getServerTool // Capture the user ID from the token info to enable session hijacking // prevention on subsequent requests. var userID string @@ -669,6 +670,8 @@ type streamableServerConn struct { logger *slog.Logger + toolLookup func(name string) (*serverTool, bool) + incoming chan jsonrpc.Message // messages from the client to the server mu sync.Mutex // guards all fields below @@ -1202,9 +1205,9 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } - // Validate MCP standard headers (Mcp-Method, Mcp-Name) + // Validate MCP standard headers (Mcp-Method, Mcp-Name, Mcp-Param-*) if !isBatch && len(incoming) == 1 { - if err := validateMcpHeaders(req.Header, incoming[0]); err != nil { + if err := validateMcpHeaders(req.Header, incoming[0], c.toolLookup); err != nil { resp := &jsonrpc.Response{ Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()), } @@ -1829,7 +1832,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } // Keep this after the setMCPHeaders call to ensure that the // protocol version header is set. - setStandardHeaders(req.Header, msg) + setStandardHeaders(ctx, req.Header, msg) resp, err := c.client.Do(req) if err != nil { // Any error from client.Do means the request didn't reach the server. diff --git a/mcp/streamable_headers.go b/mcp/streamable_headers.go index ac9a4121..fba1b4d9 100644 --- a/mcp/streamable_headers.go +++ b/mcp/streamable_headers.go @@ -5,10 +5,14 @@ package mcp import ( + "context" + "encoding/base64" "encoding/json" "errors" "fmt" + "log/slog" "net/http" + "strings" internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -20,7 +24,10 @@ const ( lastEventIDHeader = "Last-Event-ID" methodHeader = "Mcp-Method" nameHeader = "Mcp-Name" + paramHeaderPrefix = "Mcp-Param-" minVersionForStandardHeaders = protocolVersion20260630 + base64Prefix = "=?base64?" + base64Suffix = "?=" ) func extractName(method string, params json.RawMessage) (string, bool) { @@ -45,9 +52,79 @@ func extractName(method string, params json.RawMessage) (string, bool) { return "", false } +// headerSchemaProperty captures the fields needed for x-mcp-header processing. +type headerSchemaProperty struct { + Type string `json:"type"` + XMCPHeader json.RawMessage `json:"x-mcp-header,omitempty"` + Properties map[string]headerSchemaProperty `json:"properties,omitempty"` +} + +// unmarshalSchemaProperties normalizes any InputSchema type +// (*jsonschema.Schema, map[string]any, or json.RawMessage) into a common +// representation by marshaling to JSON and unmarshaling only the fields we need. +func unmarshalSchemaProperties(schema any) map[string]headerSchemaProperty { + var s headerSchemaProperty + if err := remarshal(schema, &s); err != nil { + return nil + } + return s.Properties +} + +// extractParamHeaderAnnotations returns a map of parameter name to header name +// for all properties in the tool's InputSchema that have an x-mcp-header +// annotation. +func extractParamHeaderAnnotations(tool *Tool) map[string]string { + props := unmarshalSchemaProperties(tool.InputSchema) + if len(props) == 0 { + return nil + } + result := make(map[string]string) + for propName, prop := range props { + var headerName string + if err := json.Unmarshal(prop.XMCPHeader, &headerName); err != nil || headerName == "" { + continue + } + result[propName] = headerName + } + if len(result) == 0 { + return nil + } + return result +} + +// primitiveToString conversion. +// Returns false in the second return value if the argument is not a primitive value. +func primitiveToString(value any) (string, bool) { + switch v := value.(type) { + case string: + return v, true + case float64: + return fmt.Sprintf("%g", v), true + case bool: + return fmt.Sprintf("%t", v), true + default: + return "", false + } +} + +// unmarshalPrimitive unmarshals a JSON value into a Go primitive +// (string, float64, or bool). Returns nil for non-primitive types. +func unmarshalPrimitive(raw json.RawMessage) any { + var val any + if err := internaljson.Unmarshal(raw, &val); err != nil { + return nil + } + switch val.(type) { + case string, float64, bool: + return val + default: + return nil + } +} + // setStandardHeaders populates standard MCP headers. // It requires the protocol version header to be set. -func setStandardHeaders(header http.Header, msg jsonrpc.Message) { +func setStandardHeaders(ctx context.Context, header http.Header, msg jsonrpc.Message) { if msg == nil { return } @@ -61,10 +138,128 @@ func setStandardHeaders(header http.Header, msg jsonrpc.Message) { if name, ok := extractName(msg.Method, msg.Params); ok { header.Set(nameHeader, name) } + if msg.Method == "tools/call" { + if tool, ok := ctx.Value(toolContextKey).(*Tool); ok && tool != nil { + for k, v := range generateParamHeaders(tool, msg.Params) { + header.Set(k, v) + } + } + } + } +} + +// generateParamHeaders reads x-mcp-header annotations from the tool's InputSchema +// and returns the Mcp-Param-{Name} headers to be set on the HTTP request. +func generateParamHeaders(tool *Tool, params json.RawMessage) map[string]string { + paramHeaders := extractParamHeaderAnnotations(tool) + if len(paramHeaders) == 0 { + return nil + } + + var raw struct { + Arguments map[string]json.RawMessage `json:"arguments"` + } + if err := internaljson.Unmarshal(params, &raw); err != nil || raw.Arguments == nil { + return nil + } + + res := make(map[string]string) + for paramName, headerName := range paramHeaders { + argRaw, ok := raw.Arguments[paramName] + if !ok { + continue + } + if string(argRaw) == "null" { + continue + } + val := unmarshalPrimitive(argRaw) + if val == nil { + continue + } + encoded, ok := encodeHeaderValue(val) + if !ok { + continue + } + res[paramHeaderPrefix+headerName] = encoded + } + return res +} + +// filterValidTools returns only tools that have valid +// x-mcp-header annotations. Invalid tools are logged and excluded. +func filterValidTools(logger *slog.Logger, tools []*Tool) []*Tool { + logger = ensureLogger(logger) + result := make([]*Tool, 0, len(tools)) + for _, tool := range tools { + if err := validateParamHeaderAnnotations(tool); err != nil { + logger.Error("excluding tool from tools/list", "tool", tool.Name, "error", err) + continue + } + result = append(result, tool) + } + return result +} + +// validateParamHeaderAnnotations checks that a tool's x-mcp-header annotations +// are valid. +func validateParamHeaderAnnotations(tool *Tool) error { + props := unmarshalSchemaProperties(tool.InputSchema) + if len(props) == 0 { + return nil + } + + seen := make(map[string]bool) + for propName, prop := range props { + if err := checkForNestedHeaders(prop, propName); err != nil { + return err + } + if prop.XMCPHeader == nil { + continue + } + var headerName string + if err := json.Unmarshal(prop.XMCPHeader, &headerName); err != nil || headerName == "" { + return fmt.Errorf("property %q: x-mcp-header must be a non-empty string", propName) + } + if err := validateHeaderName(headerName); err != nil { + return fmt.Errorf("property %q: %w", propName, err) + } + lower := strings.ToLower(headerName) + if seen[lower] { + return fmt.Errorf("property %q: duplicate x-mcp-header value %q (case-insensitive)", propName, headerName) + } + seen[lower] = true + + if prop.Type != "string" && prop.Type != "number" && prop.Type != "integer" && prop.Type != "boolean" { + return fmt.Errorf("property %q: x-mcp-header can only be applied to primitive types, got %v", propName, prop.Type) + } } + return nil +} + +func checkForNestedHeaders(prop headerSchemaProperty, path string) error { + for propName, nested := range prop.Properties { + if nested.XMCPHeader != nil { + return fmt.Errorf("property %q: x-mcp-header cannot be applied to nested properties", path+"."+propName) + } + if err := checkForNestedHeaders(nested, path+"."+propName); err != nil { + return err + } + } + return nil } -func validateMcpHeaders(header http.Header, msg jsonrpc.Message) error { +// validateHeaderName checks that a header name contains only valid +// ASCII characters (excluding space and ':'). +func validateHeaderName(name string) error { + for _, c := range name { + if c <= 0x20 || c > 0x7E || c == ':' { + return fmt.Errorf("x-mcp-header value %q contains invalid character %q", name, c) + } + } + return nil +} + +func validateMcpHeaders(header http.Header, msg jsonrpc.Message, toolLookup func(string) (*serverTool, bool)) error { protocolVersion := header.Get(protocolVersionHeader) if protocolVersion == "" || protocolVersion < minVersionForStandardHeaders { return nil @@ -80,12 +275,14 @@ func validateMcpHeaders(header http.Header, msg jsonrpc.Message) error { return fmt.Errorf("header mismatch: Mcp-Method header value '%s' does not match body value '%s'", methodInHeader, msg.Method) } + var nameInBody string if msg.Method == "tools/call" || msg.Method == "resources/read" || msg.Method == "prompts/get" { nameInHeader := header.Get(nameHeader) if nameInHeader == "" { return fmt.Errorf("missing required Mcp-Name header for method %q", msg.Method) } - nameInBody, ok := extractName(msg.Method, msg.Params) + var ok bool + nameInBody, ok = extractName(msg.Method, msg.Params) if !ok { return fmt.Errorf("failed to extract name from parameters for method %q", msg.Method) } @@ -93,6 +290,126 @@ func validateMcpHeaders(header http.Header, msg jsonrpc.Message) error { return fmt.Errorf("header mismatch: Mcp-Name header value '%s' does not match body value '%s'", nameInHeader, nameInBody) } } + + if msg.Method == "tools/call" && toolLookup != nil { + if st, ok := toolLookup(nameInBody); ok && st != nil { + if err := validateParamHeaders(header, msg, st.tool); err != nil { + return err + } + } + } + } + return nil +} + +func validateParamHeaders(header http.Header, msg *jsonrpc.Request, tool *Tool) error { + paramHeaders := extractParamHeaderAnnotations(tool) + if len(paramHeaders) == 0 { + return nil + } + + var raw struct { + Arguments map[string]json.RawMessage `json:"arguments"` + } + if err := internaljson.Unmarshal(msg.Params, &raw); err != nil { + return nil + } + + for paramName, headerName := range paramHeaders { + fullHeader := paramHeaderPrefix + headerName + headerVal := header.Get(fullHeader) + argRaw, argExists := raw.Arguments[paramName] + + if !argExists || string(argRaw) == "null" { + if headerVal != "" { + return fmt.Errorf("header mismatch: unexpected %s header for absent or null parameter %q", fullHeader, paramName) + } + continue + } + + if headerVal == "" { + return fmt.Errorf("header mismatch: missing %s header for parameter %q", fullHeader, paramName) + } + + decoded, ok := decodeHeaderValue(headerVal) + if !ok { + return fmt.Errorf("header mismatch: %s header contains invalid Base64 encoding", fullHeader) + } + + bodyVal := unmarshalPrimitive(argRaw) + if bodyVal == nil { + return fmt.Errorf("header mismatch: %s header present but body parameter %q is not a primitive type", fullHeader, paramName) + } + expected, ok := primitiveToString(bodyVal) + if !ok { + return fmt.Errorf("header mismatch: %s header present but body parameter %q is not a primitive type", fullHeader, paramName) + } + + // TODO: String comparison may not work ideally for numbers + if decoded != expected { + return fmt.Errorf("header mismatch: %s header value '%s' does not match body value", fullHeader, headerVal) + } } return nil } + +// encodeHeaderValue converts a parameter value to an HTTP header-safe string +// per the SEP-2243 encoding rules: +// - string: used as-is if safe ASCII, otherwise Base64 encoded +// - number (float64): decimal string representation +// - bool: lowercase "true" or "false" +// +// Values that contain non-ASCII characters, control characters, or +// leading/trailing whitespace are Base64-encoded with the =?base64?...?= wrapper. +// +// The second return value is false if the value is not a supported primitive type. +func encodeHeaderValue(value any) (string, bool) { + s, ok := primitiveToString(value) + if !ok { + return "", false + } + if requiresBase64Encoding(s) { + return encodeBase64(s), true + } + return s, true +} + +// decodeHeaderValue decodes a header value that may be Base64-encoded +// with the =?base64?...?= wrapper. +// +// The second return value is false if the header value is not a valid Base64 encoded value. +func decodeHeaderValue(headerValue string) (string, bool) { + if len(headerValue) == 0 { + return headerValue, true + } + + if strings.HasPrefix(strings.ToLower(headerValue), base64Prefix) && + strings.HasSuffix(headerValue, base64Suffix) { + encoded := headerValue[len(base64Prefix) : len(headerValue)-len(base64Suffix)] + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", false + } + return string(decoded), true + } + return headerValue, true +} + +func requiresBase64Encoding(s string) bool { + if len(s) == 0 { + return false + } + if s[0] == ' ' || s[0] == '\t' || s[len(s)-1] == ' ' || s[len(s)-1] == '\t' { + return true + } + for _, c := range s { + if c < 0x20 || c > 0x7E { + return true + } + } + return false +} + +func encodeBase64(s string) string { + return base64Prefix + base64.StdEncoding.EncodeToString([]byte(s)) + base64Suffix +} diff --git a/mcp/streamable_headers_test.go b/mcp/streamable_headers_test.go index 0abf1b28..88e9f467 100644 --- a/mcp/streamable_headers_test.go +++ b/mcp/streamable_headers_test.go @@ -5,11 +5,14 @@ package mcp import ( + "context" "encoding/json" + "fmt" "net/http" "strings" "testing" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -198,7 +201,7 @@ func TestSetStandardHeaders(t *testing.T) { header.Set(protocolVersionHeader, tt.protocolVersion) } - setStandardHeaders(header, tt.msg) + setStandardHeaders(context.Background(), header, tt.msg) if got := header.Get(methodHeader); got != tt.wantMethodHeader { t.Errorf("MethodHeader = %q, want %q", got, tt.wantMethodHeader) @@ -399,7 +402,7 @@ func TestValidateMcpHeaders(t *testing.T) { header.Set(nameHeader, tt.nameHeader) } - err := validateMcpHeaders(header, tt.msg) + err := validateMcpHeaders(header, tt.msg, nil) if tt.wantErr { if err == nil { t.Fatalf("validateMcpHeaders() = nil, want error containing %q", tt.wantErrContain) @@ -413,3 +416,827 @@ func TestValidateMcpHeaders(t *testing.T) { }) } } + +func TestValidateToolParamHeaders(t *testing.T) { + tests := []struct { + name string + tool *Tool + wantErr bool + wantErrSub string + }{ + { + name: "valid tool with x-mcp-header", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + }, + }, + }, + }, + { + name: "tool with no x-mcp-header annotations", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + }, + }, + }, + }, + { + name: "empty x-mcp-header value", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "non-empty string", + }, + { + name: "x-mcp-header with space", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "My Region", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "invalid character", + }, + { + name: "x-mcp-header with colon", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region:Primary", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "invalid character", + }, + { + name: "x-mcp-header with non-ASCII", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Région", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "invalid character", + }, + { + name: "duplicate header names same case", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "a": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "b": map[string]any{"type": "string", "x-mcp-header": "Region"}, + }, + }, + }, + wantErr: true, + wantErrSub: "duplicate", + }, + { + name: "duplicate header names different case", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "a": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "b": map[string]any{"type": "string", "x-mcp-header": "REGION"}, + }, + }, + }, + wantErr: true, + wantErrSub: "duplicate", + }, + { + name: "x-mcp-header on array type", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "items": map[string]any{ + "type": "array", + "x-mcp-header": "Items", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "primitive types", + }, + { + name: "x-mcp-header on object type", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "nested": map[string]any{ + "type": "object", + "x-mcp-header": "Nested", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "primitive types", + }, + { + name: "x-mcp-header on number type is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{ + "type": "number", + "x-mcp-header": "Count", + }, + }, + }, + }, + }, + { + name: "x-mcp-header on integer type is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{ + "type": "integer", + "x-mcp-header": "Count", + }, + }, + }, + }, + }, + { + name: "x-mcp-header on boolean type is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{ + "type": "boolean", + "x-mcp-header": "Flag", + }, + }, + }, + }, + }, + { + name: "x-mcp-header on nested property inside object", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + }, + }, + }, + }, + }, + wantErr: true, + wantErrSub: "nested", + }, + { + name: "x-mcp-header on deeply nested property", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "outer": map[string]any{ + "type": "object", + "properties": map[string]any{ + "inner": map[string]any{ + "type": "object", + "properties": map[string]any{ + "value": map[string]any{ + "type": "string", + "x-mcp-header": "Value", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: true, + wantErrSub: "nested", + }, + { + name: "object property without nested x-mcp-header is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + }, + }, + }, + "flag": map[string]any{ + "type": "boolean", + "x-mcp-header": "Flag", + }, + }, + }, + }, + }, + { + name: "jsonschema.Schema valid x-mcp-header", + tool: &Tool{ + Name: "test", + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "region": { + Type: "string", + Extra: map[string]any{"x-mcp-header": "Region"}, + }, + }, + }, + }, + }, + { + name: "jsonschema.Schema x-mcp-header on array type", + tool: &Tool{ + Name: "test", + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "items": { + Type: "array", + Extra: map[string]any{"x-mcp-header": "Items"}, + }, + }, + }, + }, + wantErr: true, + wantErrSub: "primitive types", + }, + { + name: "jsonschema.Schema nested x-mcp-header", + tool: &Tool{ + Name: "test", + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "config": { + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "region": { + Type: "string", + Extra: map[string]any{"x-mcp-header": "Region"}, + }, + }, + }, + }, + }, + }, + wantErr: true, + wantErrSub: "nested", + }, + { + name: "json.RawMessage valid x-mcp-header", + tool: &Tool{ + Name: "test", + InputSchema: json.RawMessage(`{"type":"object","properties":{"region":{"type":"string","x-mcp-header":"Region"}}}`), + }, + }, + { + name: "json.RawMessage x-mcp-header on object type", + tool: &Tool{ + Name: "test", + InputSchema: json.RawMessage(`{"type":"object","properties":{"nested":{"type":"object","x-mcp-header":"Nested"}}}`), + }, + wantErr: true, + wantErrSub: "primitive types", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateParamHeaderAnnotations(tt.tool) + if tt.wantErr { + if err == nil { + t.Fatal("validateToolParamHeaders() = nil, want error") + } + if !strings.Contains(err.Error(), tt.wantErrSub) { + t.Errorf("error = %q, want substring %q", err.Error(), tt.wantErrSub) + } + } else if err != nil { + t.Errorf("validateToolParamHeaders() = %v, want nil", err) + } + }) + } +} + +func TestFilterValidTools(t *testing.T) { + valid := &Tool{ + Name: "valid", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + }, + }, + } + invalid := &Tool{ + Name: "invalid", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": ""}, + }, + }, + } + noAnnotation := &Tool{ + Name: "plain", + InputSchema: map[string]any{"type": "object", "properties": map[string]any{"q": map[string]any{"type": "string"}}}, + } + nestedInvalid := &Tool{ + Name: "nested-invalid", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + }, + }, + }, + }, + } + + validJsonSchema := &Tool{ + Name: "valid-jsonschema", + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "region": { + Type: "string", + Extra: map[string]any{"x-mcp-header": "Region"}, + }, + }, + }, + } + invalidJsonSchema := &Tool{ + Name: "invalid-jsonschema", + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "items": { + Type: "array", + Extra: map[string]any{"x-mcp-header": "Items"}, + }, + }, + }, + } + + result := filterValidTools(nil, []*Tool{valid, invalid, noAnnotation, nestedInvalid, validJsonSchema, invalidJsonSchema}) + if len(result) != 3 { + t.Fatalf("filterValidTools returned %d tools, want 3", len(result)) + } + if result[0].Name != "valid" || result[1].Name != "plain" || result[2].Name != "valid-jsonschema" { + t.Errorf("filterValidTools returned [%s, %s, %s], want [valid, plain, valid-jsonschema]", result[0].Name, result[1].Name, result[2].Name) + } +} + +func TestSetStandardHeadersWithParamHeaders(t *testing.T) { + toolSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + "query": map[string]any{ + "type": "string", + }, + "priority": map[string]any{ + "type": "string", + "x-mcp-header": "Priority", + }, + }, + } + tool := &Tool{Name: "execute_sql", InputSchema: toolSchema} + + tests := []struct { + name string + tool *Tool + params any + wantHeaders map[string]string + }{ + { + name: "sets param headers from arguments", + tool: tool, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1", "priority": "high"}, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Region": "us-west1", + "Mcp-Param-Priority": "high", + }, + }, + { + name: "omits header when argument is missing", + tool: tool, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"query": "SELECT 1"}, + }, + wantHeaders: nil, + }, + { + name: "omits header when argument is null", + tool: tool, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": nil, "query": "SELECT 1"}, + }, + wantHeaders: nil, + }, + { + name: "encodes non-ASCII value", + tool: tool, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "日本", "query": "SELECT 1"}, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Region": "=?base64?5pel5pys?=", + }, + }, + { + name: "handles boolean argument", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{"type": "boolean", "x-mcp-header": "Flag"}, + }, + }, + }, + params: &CallToolParams{ + Name: "test", + Arguments: map[string]any{"flag": true}, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Flag": "true", + }, + }, + { + name: "handles number argument", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{"type": "number", "x-mcp-header": "Count"}, + }, + }, + }, + params: &CallToolParams{ + Name: "test", + Arguments: map[string]any{"count": float64(42)}, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Count": "42", + }, + }, + { + name: "no tool in extra does not add param headers", + tool: nil, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1"}, + }, + wantHeaders: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + header := http.Header{} + header.Set(protocolVersionHeader, minVersionForStandardHeaders) + + ctx := context.Background() + if tt.tool != nil { + ctx = context.WithValue(ctx, toolContextKey, tt.tool) + } + + msg := &jsonrpc.Request{ + Method: "tools/call", + Params: mustMarshal(tt.params), + } + + setStandardHeaders(ctx, header, msg) + + if got := header.Get(methodHeader); got != "tools/call" { + t.Errorf("MethodHeader = %q, want %q", got, "tools/call") + } + + for h, want := range tt.wantHeaders { + if got := header.Get(h); got != want { + t.Errorf("%s = %q, want %q", h, got, want) + } + } + + if got := header.Get("Mcp-Param-query"); got != "" { + t.Errorf("non-annotated param got header: Mcp-Param-query = %q", got) + } + }) + } +} + +func TestExtractToolParamHeaders(t *testing.T) { + tests := []struct { + name string + tool *Tool + want map[string]string + }{ + { + name: "extracts x-mcp-header annotations", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "query": map[string]any{"type": "string"}, + "tenant_id": map[string]any{"type": "string", "x-mcp-header": "TenantId"}, + }, + }, + }, + want: map[string]string{"region": "Region", "tenant_id": "TenantId"}, + }, + { + name: "returns nil for tool without properties", + tool: &Tool{Name: "test", InputSchema: map[string]any{"type": "object"}}, + want: nil, + }, + { + name: "returns nil for non-map schema", + tool: &Tool{Name: "test", InputSchema: "not a map"}, + want: nil, + }, + { + name: "returns nil when no annotations", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{"q": map[string]any{"type": "string"}}, + }, + }, + want: nil, + }, + { + name: "jsonschema.Schema with x-mcp-header in Extra", + tool: &Tool{ + Name: "test", + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "region": { + Type: "string", + Extra: map[string]any{"x-mcp-header": "Region"}, + }, + "query": {Type: "string"}, + }, + }, + }, + want: map[string]string{"region": "Region"}, + }, + { + name: "json.RawMessage with x-mcp-header", + tool: &Tool{ + Name: "test", + InputSchema: json.RawMessage(`{"type":"object","properties":{"region":{"type":"string","x-mcp-header":"Region"},"query":{"type":"string"}}}`), + }, + want: map[string]string{"region": "Region"}, + }, + { + name: "jsonschema.Schema without x-mcp-header", + tool: &Tool{ + Name: "test", + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "query": {Type: "string"}, + }, + }, + }, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractParamHeaderAnnotations(tt.tool) + if tt.want == nil { + if got != nil { + t.Errorf("extractToolParamHeaders() = %v, want nil", got) + } + return + } + if len(got) != len(tt.want) { + t.Fatalf("extractToolParamHeaders() returned %d entries, want %d", len(got), len(tt.want)) + } + for k, v := range tt.want { + if got[k] != v { + t.Errorf("extractToolParamHeaders()[%q] = %q, want %q", k, got[k], v) + } + } + }) + } +} + +func TestUnmarshalPrimitive(t *testing.T) { + tests := []struct { + name string + raw string + want any + }{ + {"string", `"hello"`, "hello"}, + {"number", `42`, float64(42)}, + {"float", `3.14`, float64(3.14)}, + {"true", `true`, true}, + {"false", `false`, false}, + {"null", `null`, nil}, + {"array", `[1,2]`, nil}, + {"object", `{"a":1}`, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := unmarshalPrimitive(json.RawMessage(tt.raw)) + if fmt.Sprintf("%v", got) != fmt.Sprintf("%v", tt.want) { + t.Errorf("unmarshalPrimitive(%s) = %v (%T), want %v (%T)", tt.raw, got, got, tt.want, tt.want) + } + }) + } +} + +func TestEncodeHeaderValue(t *testing.T) { + tests := []struct { + name string + value any + want string + wantOK bool + }{ + // Strings + {"plain ASCII", "us-west1", "us-west1", true}, + {"empty string", "", "", true}, + {"string with internal spaces", "us west 1", "us west 1", true}, + {"string with leading space", " us-west1", "=?base64?IHVzLXdlc3Qx?=", true}, + {"string with trailing space", "us-west1 ", "=?base64?dXMtd2VzdDEg?=", true}, + {"string with both spaces", " us-west1 ", "=?base64?IHVzLXdlc3QxIA==?=", true}, + {"non-ASCII", "日本語", "=?base64?5pel5pys6Kqe?=", true}, + {"mixed ASCII and non-ASCII", "Hello, 世界", "=?base64?SGVsbG8sIOS4lueVjA==?=", true}, + {"string with newline", "line1\nline2", "=?base64?bGluZTEKbGluZTI=?=", true}, + {"string with carriage return", "line1\r\nline2", "=?base64?bGluZTENCmxpbmUy?=", true}, + {"string with leading tab", "\tindented", "=?base64?CWluZGVudGVk?=", true}, + + // Numbers + {"integer", float64(42), "42", true}, + {"float", float64(3.14159), "3.14159", true}, + + // Booleans + {"true", true, "true", true}, + {"false", false, "false", true}, + + // Unsupported types + {"nil", nil, "", false}, + {"slice", []string{"a"}, "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := encodeHeaderValue(tt.value) + if ok != tt.wantOK { + t.Fatalf("encodeHeaderValue(%v) ok = %v, want %v", tt.value, ok, tt.wantOK) + } + if got != tt.want { + t.Errorf("encodeHeaderValue(%v) = %q, want %q", tt.value, got, tt.want) + } + }) + } +} + +func TestDecodeHeaderValue(t *testing.T) { + tests := []struct { + name string + input string + want string + wantOK bool + }{ + {"plain value", "us-west1", "us-west1", true}, + {"empty value", "", "", true}, + {"valid base64", "=?base64?SGVsbG8=?=", "Hello", true}, + {"non-ASCII decoded", "=?base64?5pel5pys6Kqe?=", "日本語", true}, + {"leading space decoded", "=?base64?IHVzLXdlc3Qx?=", " us-west1", true}, + {"case-insensitive prefix", "=?BASE64?SGVsbG8=?=", "Hello", true}, + {"invalid base64 chars", "=?base64?SGVs!!!bG8=?=", "", false}, + // Missing prefix or suffix: treated as literal values, not base64 + {"missing prefix", "SGVsbG8=", "SGVsbG8=", true}, + {"missing suffix", "=?base64?SGVsbG8=", "=?base64?SGVsbG8=", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := decodeHeaderValue(tt.input) + if ok != tt.wantOK { + t.Fatalf("decodeHeaderValue(%q) ok = %v, want %v", tt.input, ok, tt.wantOK) + } + if got != tt.want { + t.Errorf("decodeHeaderValue(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestEncodeDecodeRoundTrip(t *testing.T) { + values := []string{ + "us-west1", + "", + " leading", + "trailing ", + "Hello, 世界", + "line1\nline2", + "\ttab", + } + for _, v := range values { + encoded, ok := encodeHeaderValue(v) + if !ok { + t.Fatalf("encodeHeaderValue(%q) failed", v) + } + decoded, ok := decodeHeaderValue(encoded) + if !ok { + t.Fatalf("decodeHeaderValue(%q) failed", encoded) + } + if decoded != v { + t.Errorf("round-trip failed: %q -> %q -> %q", v, encoded, decoded) + } + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 118d05df..77eed1b0 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1915,9 +1915,9 @@ func TestStreamableGET(t *testing.T) { } } -// TestStreamableMcpHeaderValidation tests the server-side Mcp-Method and -// Mcp-Name header validation through the full HTTP handler, as specified -// in SEP-2243. +// TestStreamableMcpHeaderValidation tests the server-side Mcp-Method, +// Mcp-Name, and Mcp-Param header validation through the full HTTP handler, +// as specified in SEP-2243. func TestStreamableMcpHeaderValidation(t *testing.T) { // Temporarily register the future version so the handler accepts it. orig := supportedProtocolVersions @@ -1930,6 +1930,25 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { return &CallToolResult{}, nil }) + server.AddTool( + &Tool{ + Name: "execute_sql", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + "query": map[string]any{ + "type": "string", + }, + }, + }, + }, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) defer handler.closeAll() @@ -2019,6 +2038,50 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"execute_sql"}, + paramHeaderPrefix + "Region": {"us-west1"}, + }, + messages: []jsonrpc.Message{req(7, "tools/call", &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, + })}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(7, &CallToolResult{Content: []Content{}}, nil)}, + }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"execute_sql"}, + paramHeaderPrefix + "Region": {"eu-central1"}, + }, + messages: []jsonrpc.Message{req(8, "tools/call", &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1"}, + })}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "header mismatch", + }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"execute_sql"}, + }, + messages: []jsonrpc.Message{req(9, "tools/call", &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1"}, + })}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "missing", + }, }) } @@ -2170,6 +2233,174 @@ func TestStreamableMcpHeaderVersionGating(t *testing.T) { }) } +// TestStreamableParamHeadersClientSetsHeaders verifies that the client sets +// Mcp-Param-* headers on tool calls when the tool has x-mcp-header annotations. +func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + server.AddTool( + &Tool{ + Name: "execute_sql", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + "query": map[string]any{ + "type": "string", + }, + }, + }, + }, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil + }) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + var capturedHeaders http.Header + customClient := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Header.Get(methodHeader) == "tools/call" { + capturedHeaders = req.Header.Clone() + } + return http.DefaultTransport.RoundTrip(req) + }), + } + + clientTransport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + HTTPClient: customClient, + } + + client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + ctx := context.Background() + session, err := client.Connect(ctx, clientTransport, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) + if err != nil { + t.Fatal(err) + } + defer session.Close() + + // ListTools to populate the tool cache (needed for param headers). + if _, err := session.ListTools(ctx, nil); err != nil { + t.Fatal(err) + } + + _, err = session.CallTool(ctx, &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, + }) + if err != nil { + t.Fatal(err) + } + + if capturedHeaders == nil { + t.Fatal("no tool call headers captured") + } + if got := capturedHeaders.Get(methodHeader); got != "tools/call" { + t.Errorf("Mcp-Method = %q, want %q", got, "tools/call") + } + if got := capturedHeaders.Get(nameHeader); got != "execute_sql" { + t.Errorf("Mcp-Name = %q, want %q", got, "execute_sql") + } + if got := capturedHeaders.Get(paramHeaderPrefix + "Region"); got != "us-west1" { + t.Errorf("Mcp-Param-Region = %q, want %q", got, "us-west1") + } +} + +// TestStreamableFilterValidToolsIntegration verifies that AddTool rejects +// tools with invalid x-mcp-header annotations at registration time, and +// that valid tools are returned by ListTools. +func TestStreamableFilterValidToolsIntegration(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + noop := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + } + + // Valid tool with correct x-mcp-header annotation. + server.AddTool(&Tool{ + Name: "valid-tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + }, + }, + }, noop) + + // Invalid tool: x-mcp-header on an array type should panic. + func() { + defer func() { + r := recover() + if r == nil { + t.Fatal("AddTool with invalid x-mcp-header annotation did not panic") + } + }() + server.AddTool(&Tool{ + Name: "invalid-tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "items": map[string]any{ + "type": "array", + "x-mcp-header": "Items", + }, + }, + }, + }, noop) + }() + + // Tool with no x-mcp-header annotations (always valid). + server.AddTool(&Tool{ + Name: "plain-tool", + InputSchema: &jsonschema.Schema{Type: "object"}, + }, noop) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + ctx := context.Background() + session, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) + if err != nil { + t.Fatal(err) + } + defer session.Close() + + result, err := session.ListTools(ctx, nil) + if err != nil { + t.Fatal(err) + } + + toolNames := make([]string, len(result.Tools)) + for i, tool := range result.Tools { + toolNames[i] = tool.Name + } + sort.Strings(toolNames) + + wantNames := []string{"plain-tool", "valid-tool"} + if !slices.Equal(toolNames, wantNames) { + t.Errorf("ListTools returned %v, want %v", toolNames, wantNames) + } +} + // TestStreamable405AllowHeader verifies RFC 9110 §15.5.6 compliance: // 405 Method Not Allowed responses MUST include an Allow header. func TestStreamable405AllowHeader(t *testing.T) {