From d12d17380d48fecbc672cbea1aae76ca2704cb8f Mon Sep 17 00:00:00 2001 From: Dale Seo <5466341+DaleSeo@users.noreply.github.com> Date: Mon, 15 Jun 2026 22:57:44 -0400 Subject: [PATCH 1/3] feat: add SEP-2243 HTTP standard headers --- crates/rmcp/Cargo.toml | 3 +- crates/rmcp/src/model.rs | 8 +- crates/rmcp/src/transport/common.rs | 3 + .../rmcp/src/transport/common/http_header.rs | 9 + .../rmcp/src/transport/common/mcp_headers.rs | 654 ++++++++++++++++++ .../src/transport/streamable_http_client.rs | 93 ++- .../transport/streamable_http_server/tower.rs | 60 ++ .../test_streamable_http_standard_headers.rs | 147 ++++ 8 files changed, 969 insertions(+), 8 deletions(-) create mode 100644 crates/rmcp/src/transport/common/mcp_headers.rs create mode 100644 crates/rmcp/tests/test_streamable_http_standard_headers.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 638679812..d8997dd73 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -139,12 +139,13 @@ server-side-http = [ "dep:bytes", "dep:sse-stream", "tower", + "base64", ] transport-worker = ["dep:tokio-stream"] # SSE stream parsing utilities (used by streamable HTTP client for SSE-formatted responses) -client-side-sse = ["dep:sse-stream", "dep:http"] +client-side-sse = ["dep:sse-stream", "dep:http", "base64"] # Streamable HTTP client transport-streamable-http-client = ["client-side-sse", "transport-worker"] diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index 5fb7eb9ed..28a9071de 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -159,6 +159,9 @@ impl ProtocolVersion { pub const V_2024_11_05: Self = Self(Cow::Borrowed("2024-11-05")); pub const LATEST: Self = Self::V_2025_11_25; + /// First protocol version that requires SEP-2243 standard HTTP headers. + pub const STANDARD_HEADERS: Self = Self::V_2026_07_28; + /// All protocol versions known to this SDK. pub const KNOWN_VERSIONS: &[Self] = &[ Self::V_2024_11_05, @@ -503,6 +506,7 @@ pub struct JsonRpcNotification { pub struct ErrorCode(pub i32); impl ErrorCode { + pub const HEADER_MISMATCH: Self = Self(-32001); pub const RESOURCE_NOT_FOUND: Self = Self(-32002); pub const INVALID_REQUEST: Self = Self(-32600); pub const METHOD_NOT_FOUND: Self = Self(-32601); @@ -549,7 +553,9 @@ impl ErrorData { pub fn resource_not_found(message: impl Into>, data: Option) -> Self { Self::new(ErrorCode::RESOURCE_NOT_FOUND, message, data) } - + pub fn header_mismatch(message: impl Into>, data: Option) -> Self { + Self::new(ErrorCode::HEADER_MISMATCH, message, data) + } pub fn parse_error(message: impl Into>, data: Option) -> Self { Self::new(ErrorCode::PARSE_ERROR, message, data) } diff --git a/crates/rmcp/src/transport/common.rs b/crates/rmcp/src/transport/common.rs index 3691602b1..8cf00c02f 100644 --- a/crates/rmcp/src/transport/common.rs +++ b/crates/rmcp/src/transport/common.rs @@ -3,6 +3,9 @@ pub mod server_side_http; pub mod http_header; +#[cfg(any(feature = "client-side-sse", feature = "server-side-http"))] +pub mod mcp_headers; + #[cfg(feature = "__reqwest")] mod reqwest; diff --git a/crates/rmcp/src/transport/common/http_header.rs b/crates/rmcp/src/transport/common/http_header.rs index 283f0daa7..6c2abbc8d 100644 --- a/crates/rmcp/src/transport/common/http_header.rs +++ b/crates/rmcp/src/transport/common/http_header.rs @@ -4,6 +4,15 @@ pub const HEADER_MCP_PROTOCOL_VERSION: &str = "MCP-Protocol-Version"; pub const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream"; pub const JSON_MIME_TYPE: &str = "application/json"; +// SEP-2243 standard headers, gated on protocol version >= 2026-07-28. +pub const HEADER_MCP_METHOD: &str = "Mcp-Method"; +pub const HEADER_MCP_NAME: &str = "Mcp-Name"; +pub const HEADER_MCP_PARAM_PREFIX: &str = "Mcp-Param-"; + +/// Sentinel wrapping a Base64-encoded SEP-2243 header value (`=?base64??=`). +pub const BASE64_HEADER_PREFIX: &str = "=?base64?"; +pub const BASE64_HEADER_SUFFIX: &str = "?="; + /// Reserved headers that must not be overridden by user-supplied custom headers. /// `MCP-Protocol-Version` is in this list but is allowed through because the worker /// injects it after initialization. diff --git a/crates/rmcp/src/transport/common/mcp_headers.rs b/crates/rmcp/src/transport/common/mcp_headers.rs new file mode 100644 index 000000000..fd42ba757 --- /dev/null +++ b/crates/rmcp/src/transport/common/mcp_headers.rs @@ -0,0 +1,654 @@ +//! SEP-2243 HTTP header standardization. +//! +//! Builds and validates the `Mcp-Method`, `Mcp-Name`, and `Mcp-Param-*` headers +//! so middle boxes can route Streamable HTTP traffic without parsing the body. +//! All emission/validation is gated by the negotiated protocol version +//! (`>= ProtocolVersion::STANDARD_HEADERS`) at the call sites. + +// Which helpers are reachable depends on the client/server feature combination, +// mirroring `server_side_http`. +#![allow(dead_code)] + +use serde_json::Value; + +use super::http_header::{ + BASE64_HEADER_PREFIX, BASE64_HEADER_SUFFIX, HEADER_MCP_METHOD, HEADER_MCP_NAME, + HEADER_MCP_PARAM_PREFIX, +}; +use crate::model::JsonObject; + +/// Methods whose `Mcp-Name` is sourced from `params.name`. +const NAME_FROM_NAME: &[&str] = &["tools/call", "prompts/get"]; +/// Methods whose `Mcp-Name` is sourced from `params.uri`. +const NAME_FROM_URI: &[&str] = &[ + "resources/read", + "resources/subscribe", + "resources/unsubscribe", +]; + +/// Returns the `Mcp-Name` value for a request, if the method carries one. +fn extract_name(method: &str, params: Option<&Value>) -> Option { + let params = params?; + let key = if NAME_FROM_NAME.contains(&method) { + "name" + } else if NAME_FROM_URI.contains(&method) { + "uri" + } else { + return None; + }; + params.get(key)?.as_str().map(str::to_owned) +} + +/// Converts a JSON primitive to its SEP-2243 string form. Non-primitives yield `None`. +fn primitive_to_string(value: &Value) -> Option { + match value { + Value::String(s) => Some(s.clone()), + Value::Bool(b) => Some(b.to_string()), + Value::Number(n) => Some(n.to_string()), + _ => None, + } +} + +/// True if `value` must be Base64-wrapped to survive as an HTTP header value: +/// leading/trailing space or tab, control/non-ASCII characters, or a value that +/// already looks like the `=?base64?...?=` sentinel. +#[cfg(feature = "client-side-sse")] +fn requires_base64(value: &str) -> bool { + if value.is_empty() { + return false; + } + let bytes = value.as_bytes(); + if matches!(bytes.first(), Some(b' ' | b'\t')) || matches!(bytes.last(), Some(b' ' | b'\t')) { + return true; + } + if value + .chars() + .any(|c| (c as u32) < 0x20 || (c as u32) > 0x7E) + { + return true; + } + value.starts_with(BASE64_HEADER_PREFIX) && value.ends_with(BASE64_HEADER_SUFFIX) +} + +/// RFC 9110 §5.6.2 token character. +#[cfg(feature = "client-side-sse")] +fn is_tchar(c: char) -> bool { + c.is_ascii_alphanumeric() + || matches!( + c, + '!' | '#' + | '$' + | '%' + | '&' + | '\'' + | '*' + | '+' + | '-' + | '.' + | '^' + | '_' + | '`' + | '|' + | '~' + ) +} + +/// Top-level properties carrying an `x-mcp-header` annotation, as `(property, header)` pairs. +fn param_header_annotations(input_schema: &JsonObject) -> Vec<(String, String)> { + let mut out = Vec::new(); + if let Some(Value::Object(props)) = input_schema.get("properties") { + for (prop, schema) in props { + if let Some(Value::String(header)) = schema.get("x-mcp-header") { + if !header.is_empty() { + out.push((prop.clone(), header.clone())); + } + } + } + } + out +} + +/// Validates the `x-mcp-header` annotations in a tool input schema. +/// +/// Annotations must be non-empty RFC 9110 tokens, case-insensitively unique, +/// applied only to top-level primitive (`string`/`integer`/`boolean`) properties. +/// Returns the offending reason on the first violation. +#[cfg(feature = "client-side-sse")] +pub(crate) fn validate_param_header_annotations(input_schema: &JsonObject) -> Result<(), String> { + let Some(Value::Object(props)) = input_schema.get("properties") else { + return Ok(()); + }; + let mut seen = std::collections::HashSet::new(); + for (prop, schema) in props { + reject_nested_annotations(schema, prop)?; + let Some(raw) = schema.get("x-mcp-header") else { + continue; + }; + let Value::String(header) = raw else { + return Err(format!("property `{prop}`: x-mcp-header must be a string")); + }; + if header.is_empty() { + return Err(format!("property `{prop}`: x-mcp-header must not be empty")); + } + if !header.chars().all(is_tchar) { + return Err(format!( + "property `{prop}`: x-mcp-header `{header}` is not a valid HTTP token" + )); + } + if !seen.insert(header.to_ascii_lowercase()) { + return Err(format!( + "property `{prop}`: duplicate x-mcp-header `{header}` (case-insensitive)" + )); + } + match schema.get("type").and_then(Value::as_str) { + Some("string" | "integer" | "boolean") => {} + other => { + return Err(format!( + "property `{prop}`: x-mcp-header requires a primitive type \ + (string/integer/boolean), got {other:?}" + )); + } + } + } + Ok(()) +} + +/// Rejects `x-mcp-header` on nested properties (only top-level promotion is supported). +#[cfg(feature = "client-side-sse")] +fn reject_nested_annotations(schema: &Value, path: &str) -> Result<(), String> { + if let Some(Value::Object(nested)) = schema.get("properties") { + for (key, value) in nested { + if value.get("x-mcp-header").is_some() { + return Err(format!( + "property `{path}.{key}`: x-mcp-header is not supported on nested properties" + )); + } + reject_nested_annotations(value, &format!("{path}.{key}"))?; + } + } + Ok(()) +} + +/// Wraps a value as `=?base64??=` when it cannot travel as a bare header value. +#[cfg(feature = "client-side-sse")] +fn encode_header_value(value: &str) -> String { + use base64::{Engine, prelude::BASE64_STANDARD}; + if requires_base64(value) { + format!( + "{BASE64_HEADER_PREFIX}{}{BASE64_HEADER_SUFFIX}", + BASE64_STANDARD.encode(value) + ) + } else { + value.to_owned() + } +} + +/// Reverses [`encode_header_value`]. Returns `None` if the sentinel wraps invalid Base64/UTF-8. +#[cfg(feature = "server-side-http")] +fn decode_header_value(value: &str) -> Option { + use base64::{Engine, prelude::BASE64_STANDARD}; + match value + .strip_prefix(BASE64_HEADER_PREFIX) + .and_then(|inner| inner.strip_suffix(BASE64_HEADER_SUFFIX)) + { + Some(inner) => { + let bytes = BASE64_STANDARD.decode(inner).ok()?; + String::from_utf8(bytes).ok() + } + None => Some(value.to_owned()), + } +} + +/// Builds the SEP-2243 headers for an outgoing request from its JSON form. +/// +/// `tool_schema` is the cached input schema of the called tool, used to promote +/// annotated `tools/call` arguments to `Mcp-Param-*` headers. +#[cfg(feature = "client-side-sse")] +pub(crate) fn standard_request_headers( + request: &Value, + tool_schema: Option<&JsonObject>, +) -> Vec<(http::HeaderName, http::HeaderValue)> { + use http::{HeaderName, HeaderValue}; + + let mut out = Vec::new(); + let Some(method) = request.get("method").and_then(Value::as_str) else { + return out; + }; + let params = request.get("params"); + + let mut push = |name: &str, value: &str| { + if let (Ok(name), Ok(value)) = ( + HeaderName::from_bytes(name.as_bytes()), + HeaderValue::from_str(value), + ) { + out.push((name, value)); + } + }; + + push(HEADER_MCP_METHOD, method); + if let Some(name) = extract_name(method, params) { + push(HEADER_MCP_NAME, &encode_header_value(&name)); + } + + if method == "tools/call" { + if let (Some(schema), Some(arguments)) = + (tool_schema, params.and_then(|p| p.get("arguments"))) + { + for (prop, header) in param_header_annotations(schema) { + let Some(arg) = arguments.get(&prop) else { + continue; + }; + let Some(encoded) = primitive_to_string(arg).map(|s| encode_header_value(&s)) + else { + continue; + }; + push(&format!("{HEADER_MCP_PARAM_PREFIX}{header}"), &encoded); + } + } + } + out +} + +/// Validates incoming SEP-2243 headers against the request body. +/// +/// Returns `Err(reason)` when a required header is missing or its value does not +/// match the body; the caller maps this to a JSON-RPC `-32001` error (HTTP 400). +#[cfg(feature = "server-side-http")] +pub(crate) fn validate_request_headers( + headers: &http::HeaderMap, + request: &Value, + tool_schema: Option<&JsonObject>, +) -> Result<(), String> { + let Some(method) = request.get("method").and_then(Value::as_str) else { + return Ok(()); + }; + let params = request.get("params"); + + let header_method = header_str(headers, HEADER_MCP_METHOD); + match header_method { + None => return Err("missing required Mcp-Method header".to_owned()), + Some(value) if value != method => { + return Err(format!( + "Mcp-Method header `{value}` does not match body method `{method}`" + )); + } + Some(_) => {} + } + + if let Some(expected) = extract_name(method, params) { + match header_str(headers, HEADER_MCP_NAME) { + None => return Err(format!("missing required Mcp-Name header for `{method}`")), + Some(raw) => { + let decoded = decode_header_value(raw) + .ok_or_else(|| "Mcp-Name header is not valid Base64".to_owned())?; + if decoded != expected { + return Err(format!( + "Mcp-Name header `{decoded}` does not match body value `{expected}`" + )); + } + } + } + } + + if method == "tools/call" { + if let Some(schema) = tool_schema { + let arguments = params.and_then(|p| p.get("arguments")); + for (prop, header) in param_header_annotations(schema) { + let full = format!("{HEADER_MCP_PARAM_PREFIX}{header}"); + let header_value = header_str(headers, &full); + let arg = arguments.and_then(|a| a.get(&prop)); + let body_value = arg.filter(|v| !v.is_null()).and_then(primitive_to_string); + + match (header_value, body_value) { + (None, None) => {} + (Some(_), None) => { + return Err(format!( + "unexpected {full} header for absent or null `{prop}`" + )); + } + (None, Some(_)) => { + return Err(format!("missing {full} header for `{prop}`")); + } + (Some(raw), Some(expected)) => { + let decoded = decode_header_value(raw) + .ok_or_else(|| format!("{full} header is not valid Base64"))?; + if decoded != expected { + return Err(format!( + "{full} header `{decoded}` does not match body value `{expected}`" + )); + } + } + } + } + } + } + Ok(()) +} + +/// Case-insensitive header lookup returning the value as `&str`, if present and valid UTF-8. +#[cfg(feature = "server-side-http")] +fn header_str<'a>(headers: &'a http::HeaderMap, name: &str) -> Option<&'a str> { + headers.get(name).and_then(|value| value.to_str().ok()) +} + +#[cfg(all(test, feature = "client-side-sse", feature = "server-side-http"))] +mod tests { + use std::collections::HashMap; + + use http::{HeaderMap, HeaderName, HeaderValue}; + use serde_json::json; + + use super::*; + + fn schema_with(properties: serde_json::Value) -> JsonObject { + json!({ "type": "object", "properties": properties }) + .as_object() + .unwrap() + .clone() + } + + fn header_map(pairs: &[(&str, &str)]) -> HeaderMap { + let mut map = HeaderMap::new(); + for (name, value) in pairs { + map.insert( + HeaderName::from_bytes(name.as_bytes()).unwrap(), + HeaderValue::from_str(value).unwrap(), + ); + } + map + } + + fn assert_wrapped(value: &str) { + let encoded = encode_header_value(value); + assert!( + encoded.starts_with(BASE64_HEADER_PREFIX), + "expected {value:?} to be Base64-wrapped, got {encoded:?}" + ); + } + + mod encode_header_value { + use super::*; + + #[test] + fn passes_plain_ascii_through() { + assert_eq!(encode_header_value("us-west1"), "us-west1"); + } + + #[test] + fn passes_internal_spaces_through() { + assert_eq!(encode_header_value("a b c"), "a b c"); + } + + #[test] + fn wraps_non_ascii() { + assert_wrapped("café"); + } + + #[test] + fn wraps_leading_whitespace() { + assert_wrapped(" padded"); + } + + #[test] + fn wraps_trailing_whitespace() { + assert_wrapped("trailing "); + } + + #[test] + fn wraps_control_characters() { + assert_wrapped("line1\nline2"); + } + + #[test] + fn wraps_crlf_injection_attempt() { + assert_wrapped("a\r\nEvil: 1"); + } + + #[test] + fn wraps_sentinel_collision() { + assert_wrapped(&format!("{BASE64_HEADER_PREFIX}x{BASE64_HEADER_SUFFIX}")); + } + } + + mod decode_header_value { + use super::*; + + #[test] + fn round_trips_with_encode() { + for value in ["us-west1", "café", " padded ", "line1\nline2", "true", "42"] { + let encoded = encode_header_value(value); + assert_eq!( + decode_header_value(&encoded).as_deref(), + Some(value), + "round-trip failed for {value:?}" + ); + } + } + + #[test] + fn rejects_invalid_base64() { + let bad = format!("{BASE64_HEADER_PREFIX}!!!not-base64!!!{BASE64_HEADER_SUFFIX}"); + assert_eq!(decode_header_value(&bad), None); + } + } + + mod extract_name { + use super::*; + + #[test] + fn from_name_for_tools_call() { + let params = json!({ "name": "my_tool" }); + assert_eq!( + extract_name("tools/call", Some(¶ms)).as_deref(), + Some("my_tool") + ); + } + + #[test] + fn from_name_for_prompts_get() { + let params = json!({ "name": "my_prompt" }); + assert_eq!( + extract_name("prompts/get", Some(¶ms)).as_deref(), + Some("my_prompt") + ); + } + + #[test] + fn from_uri_for_resources_read() { + let params = json!({ "uri": "file:///x" }); + assert_eq!( + extract_name("resources/read", Some(¶ms)).as_deref(), + Some("file:///x") + ); + } + + #[test] + fn none_for_unrelated_method() { + let params = json!({ "name": "my_tool" }); + assert_eq!(extract_name("ping", Some(¶ms)), None); + } + + #[test] + fn none_when_params_absent() { + assert_eq!(extract_name("tools/call", None), None); + } + } + + mod validate_param_header_annotations { + use super::*; + + #[test] + fn accepts_primitive_types() { + let schema = schema_with(json!({ + "region": { "type": "string", "x-mcp-header": "Region" }, + "count": { "type": "integer", "x-mcp-header": "Count" }, + "flag": { "type": "boolean", "x-mcp-header": "Flag" }, + })); + assert!(validate_param_header_annotations(&schema).is_ok()); + } + + #[test] + fn rejects_number_type() { + let schema = schema_with(json!({ "n": { "type": "number", "x-mcp-header": "N" } })); + assert!(validate_param_header_annotations(&schema).is_err()); + } + + #[test] + fn rejects_complex_type() { + let schema = schema_with(json!({ "a": { "type": "array", "x-mcp-header": "A" } })); + assert!(validate_param_header_annotations(&schema).is_err()); + } + + #[test] + fn rejects_empty_header_name() { + let schema = schema_with(json!({ "r": { "type": "string", "x-mcp-header": "" } })); + assert!(validate_param_header_annotations(&schema).is_err()); + } + + #[test] + fn rejects_non_token_header_name() { + let schema = + schema_with(json!({ "r": { "type": "string", "x-mcp-header": "bad:name" } })); + assert!(validate_param_header_annotations(&schema).is_err()); + } + + #[test] + fn rejects_case_insensitive_duplicate() { + let schema = schema_with(json!({ + "a": { "type": "string", "x-mcp-header": "Region" }, + "b": { "type": "string", "x-mcp-header": "region" }, + })); + assert!(validate_param_header_annotations(&schema).is_err()); + } + + #[test] + fn rejects_nested_annotation() { + let schema = schema_with(json!({ + "outer": { + "type": "object", + "properties": { "inner": { "type": "string", "x-mcp-header": "Inner" } } + } + })); + assert!(validate_param_header_annotations(&schema).is_err()); + } + } + + mod standard_request_headers { + use super::*; + + fn tools_call_headers() -> HashMap { + let schema = schema_with(json!({ + "region": { "type": "string", "x-mcp-header": "Region" }, + })); + let request = json!({ + "jsonrpc": "2.0", "id": 1, "method": "tools/call", + "params": { "name": "deploy", "arguments": { "region": "us-west1" } } + }); + super::super::standard_request_headers(&request, Some(&schema)) + .into_iter() + .map(|(name, value)| (name.as_str().to_owned(), value.to_str().unwrap().to_owned())) + .collect() + } + + #[test] + fn sets_method_header() { + assert_eq!( + tools_call_headers().get("mcp-method").map(String::as_str), + Some("tools/call") + ); + } + + #[test] + fn sets_name_header() { + assert_eq!( + tools_call_headers().get("mcp-name").map(String::as_str), + Some("deploy") + ); + } + + #[test] + fn sets_annotated_param_header() { + assert_eq!( + tools_call_headers() + .get("mcp-param-region") + .map(String::as_str), + Some("us-west1") + ); + } + } + + mod validate_request_headers { + use super::*; + + fn tools_call_request() -> Value { + json!({ + "jsonrpc": "2.0", "id": 1, "method": "tools/call", + "params": { "name": "deploy" } + }) + } + + #[test] + fn accepts_matching_method_and_name() { + let headers = header_map(&[("Mcp-Method", "tools/call"), ("Mcp-Name", "deploy")]); + assert!(validate_request_headers(&headers, &tools_call_request(), None).is_ok()); + } + + #[test] + fn rejects_method_mismatch() { + let headers = header_map(&[("Mcp-Method", "tools/list"), ("Mcp-Name", "deploy")]); + assert!(validate_request_headers(&headers, &tools_call_request(), None).is_err()); + } + + #[test] + fn rejects_missing_method() { + let headers = header_map(&[("Mcp-Name", "deploy")]); + assert!(validate_request_headers(&headers, &tools_call_request(), None).is_err()); + } + + #[test] + fn rejects_name_mismatch() { + let headers = header_map(&[("Mcp-Method", "tools/call"), ("Mcp-Name", "other")]); + assert!(validate_request_headers(&headers, &tools_call_request(), None).is_err()); + } + + #[test] + fn rejects_missing_name() { + let headers = header_map(&[("Mcp-Method", "tools/call")]); + assert!(validate_request_headers(&headers, &tools_call_request(), None).is_err()); + } + + #[test] + fn accepts_matching_param() { + let schema = schema_with(json!({ + "region": { "type": "string", "x-mcp-header": "Region" }, + })); + let request = json!({ + "jsonrpc": "2.0", "id": 1, "method": "tools/call", + "params": { "name": "deploy", "arguments": { "region": "us-west1" } } + }); + let headers = header_map(&[ + ("Mcp-Method", "tools/call"), + ("Mcp-Name", "deploy"), + ("Mcp-Param-Region", "us-west1"), + ]); + assert!(validate_request_headers(&headers, &request, Some(&schema)).is_ok()); + } + + #[test] + fn rejects_param_mismatch() { + let schema = schema_with(json!({ + "region": { "type": "string", "x-mcp-header": "Region" }, + })); + let request = json!({ + "jsonrpc": "2.0", "id": 1, "method": "tools/call", + "params": { "name": "deploy", "arguments": { "region": "us-west1" } } + }); + let headers = header_map(&[ + ("Mcp-Method", "tools/call"), + ("Mcp-Name", "deploy"), + ("Mcp-Param-Region", "eu-central1"), + ]); + assert!(validate_request_headers(&headers, &request, Some(&schema)).is_err()); + } + } +} diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index a2c1a7b19..326582e1e 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -12,17 +12,67 @@ use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStre use crate::{ RoleClient, model::{ - ClientJsonRpcMessage, ClientNotification, InitializedNotification, ServerJsonRpcMessage, - ServerResult, + ClientJsonRpcMessage, ClientNotification, InitializedNotification, JsonObject, + ProtocolVersion, ServerJsonRpcMessage, ServerResult, }, transport::{ - common::client_side_sse::SseAutoReconnectStream, + common::{client_side_sse::SseAutoReconnectStream, mcp_headers}, worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport}, }, }; type BoxedSseStream = BoxStream<'static, Result>; +/// Clones `base` and, when the negotiated version requires SEP-2243 headers, adds the +/// `Mcp-Method` / `Mcp-Name` / `Mcp-Param-*` headers derived from the outgoing message. +fn build_request_headers( + base: &HashMap, + message: &ClientJsonRpcMessage, + tool_cache: &HashMap>, + version: &ProtocolVersion, +) -> HashMap { + use serde_json::Value; + + let mut headers = base.clone(); + if *version >= ProtocolVersion::STANDARD_HEADERS { + if let Ok(value) = serde_json::to_value(message) { + let schema = value + .get("method") + .and_then(Value::as_str) + .filter(|method| *method == "tools/call") + .and_then(|_| value.get("params")) + .and_then(|params| params.get("name")) + .and_then(Value::as_str) + .and_then(|name| tool_cache.get(name)) + .map(Arc::as_ref); + for (name, val) in mcp_headers::standard_request_headers(&value, schema) { + headers.insert(name, val); + } + } + } + headers +} + +/// Caches tool input schemas from a `tools/list` response for later `Mcp-Param-*` emission. +fn cache_tools_from_response( + cache: &mut HashMap>, + message: &ServerJsonRpcMessage, +) { + if let ServerJsonRpcMessage::Response(response) = message { + if let ServerResult::ListToolsResult(list) = &response.result { + for tool in &list.tools { + if let Err(reason) = + mcp_headers::validate_param_header_annotations(&tool.input_schema) + { + tracing::warn!(tool = %tool.name, "ignoring x-mcp-header annotations: {reason}"); + continue; + } + cache.insert(tool.name.to_string(), tool.input_schema.clone()); + } + } + } +} + #[derive(Debug)] #[non_exhaustive] pub struct AuthRequiredError { @@ -508,10 +558,14 @@ impl Worker for StreamableHttpClientWorker { // Extract the negotiated protocol version from the init response // and build a custom headers map that includes MCP-Protocol-Version // for all subsequent HTTP requests (per MCP 2025-06-18 spec). + // Negotiated protocol version gates SEP-2243 standard headers; default to a + // pre-SEP version so headers are omitted if the version can't be determined. + let mut negotiated_version = ProtocolVersion::default(); let mut protocol_headers = { let mut headers = config.custom_headers.clone(); if let ServerJsonRpcMessage::Response(response) = &message { if let ServerResult::InitializeResult(init_result) = &response.result { + negotiated_version = init_result.protocol_version.clone(); if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) { // HeaderName::from_static requires lowercase headers.insert(HeaderName::from_static("mcp-protocol-version"), hv); @@ -520,6 +574,9 @@ impl Worker for StreamableHttpClientWorker { } headers }; + // SEP-2243: tool input schemas (name -> schema) cached from tools/list responses, + // used to promote annotated tools/call arguments to Mcp-Param-* headers. + let mut tool_header_cache: HashMap> = HashMap::new(); // Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns) let mut session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo { @@ -533,13 +590,19 @@ impl Worker for StreamableHttpClientWorker { context.send_to_handler(message).await?; let initialized_notification = context.recv_from_handler().await?; // expect a initialized response + let initialized_headers = build_request_headers( + &protocol_headers, + &initialized_notification.message, + &tool_header_cache, + &negotiated_version, + ); self.client .post_message( config.uri.clone(), initialized_notification.message, session_id.clone(), config.auth_header.clone(), - protocol_headers.clone(), + initialized_headers, ) .await .map_err(WorkerQuitReason::fatal_context( @@ -649,6 +712,12 @@ impl Worker for StreamableHttpClientWorker { // Pass a clone to the first attempt so `message` is retained for a // potential re-init retry. `post_message` takes ownership and the // trait cannot be changed, so the clone is unavoidable. + let request_headers = build_request_headers( + &protocol_headers, + &message, + &tool_header_cache, + &negotiated_version, + ); let response = self .client .post_message( @@ -656,7 +725,7 @@ impl Worker for StreamableHttpClientWorker { message.clone(), session_id.clone(), config.auth_header.clone(), - protocol_headers.clone(), + request_headers, ) .await; let send_result = match response { @@ -752,6 +821,12 @@ impl Worker for StreamableHttpClientWorker { }); } + let retry_headers = build_request_headers( + &protocol_headers, + &message, + &tool_header_cache, + &negotiated_version, + ); let retry_response = self .client .post_message( @@ -759,7 +834,7 @@ impl Worker for StreamableHttpClientWorker { message, session_id.clone(), config.auth_header.clone(), - protocol_headers.clone(), + retry_headers, ) .await; match retry_response { @@ -771,6 +846,10 @@ impl Worker for StreamableHttpClientWorker { Ok(()) } Ok(StreamableHttpPostResponse::Json(msg, ..)) => { + cache_tools_from_response( + &mut tool_header_cache, + &msg, + ); context.send_to_handler(msg).await?; Ok(()) } @@ -796,6 +875,7 @@ impl Worker for StreamableHttpClientWorker { Ok(()) } Ok(StreamableHttpPostResponse::Json(message, ..)) => { + cache_tools_from_response(&mut tool_header_cache, &message); context.send_to_handler(message).await?; Ok(()) } @@ -813,6 +893,7 @@ impl Worker for StreamableHttpClientWorker { let _ = responder.send(send_result); } Event::ServerMessage(json_rpc_message) => { + cache_tools_from_response(&mut tool_header_cache, &json_rpc_message); // send the message to the handler if let Err(e) = context.send_to_handler(json_rpc_message).await { break 'main_loop Err(e); diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index cd2f5f1e5..a303cc4f7 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -28,6 +28,7 @@ use crate::{ EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION, HEADER_SESSION_ID, JSON_MIME_TYPE, }, + mcp_headers, server_side_http::{ BoxResponse, ServerSseMessage, accepted_response, expect_json, internal_error_response, sse_stream_response, unexpected_message_response, @@ -259,6 +260,61 @@ fn validate_header_matches_init_body( Ok(()) } +fn header_mismatch_jsonrpc_response( + id: Option, + message: impl Into>, +) -> BoxResponse { + let err = JsonRpcError::new(id, ErrorData::header_mismatch(message, None)); + let body = serde_json::to_vec(&err).expect("serialize JsonRpcError"); + Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .header(http::header::CONTENT_TYPE, JSON_MIME_TYPE) + .body(Full::new(Bytes::from(body)).boxed()) + .expect("valid response") +} + +/// Validates SEP-2243 `Mcp-Method` / `Mcp-Name` headers against the request body. +/// +/// Only enforced when the request declares a protocol version `>= STANDARD_HEADERS`. +/// The `initialize` handshake is exempt: clients emit these headers only after the +/// version has been negotiated. `Mcp-Param-*` validation is not enforced here because +/// the transport layer has no synchronous access to tool input schemas. +#[expect( + clippy::result_large_err, + reason = "BoxResponse is intentionally large; matches other handlers in this file" +)] +fn validate_standard_headers( + headers: &HeaderMap, + message: &ClientJsonRpcMessage, +) -> Result<(), BoxResponse> { + let version_requires_headers = headers + .get(HEADER_MCP_PROTOCOL_VERSION) + .and_then(|value| value.to_str().ok()) + .is_some_and(|version| version >= ProtocolVersion::STANDARD_HEADERS.as_str()); + if !version_requires_headers { + return Ok(()); + } + + let request_id = match message { + ClientJsonRpcMessage::Request(req) => { + if matches!(&req.request, ClientRequest::InitializeRequest(_)) { + return Ok(()); + } + Some(req.id.clone()) + } + ClientJsonRpcMessage::Notification(_) => None, + _ => return Ok(()), + }; + + let Ok(value) = serde_json::to_value(message) else { + return Ok(()); + }; + if let Err(reason) = mcp_headers::validate_request_headers(headers, &value, None) { + return Err(header_mismatch_jsonrpc_response(request_id, reason)); + } + Ok(()) +} + fn forbidden_response(message: impl Into) -> BoxResponse { Response::builder() .status(http::StatusCode::FORBIDDEN) @@ -1082,6 +1138,8 @@ where // Validate MCP-Protocol-Version header (per 2025-06-18 spec) validate_protocol_version_header(&part.headers)?; + // Validate SEP-2243 standard headers against the body + validate_standard_headers(&part.headers, &message)?; // inject request part to extensions match &mut message { @@ -1238,6 +1296,8 @@ where validate_protocol_version_header(&part.headers)?; } } + // Validate SEP-2243 standard headers against the body + validate_standard_headers(&part.headers, &message)?; let service = self .get_service() .map_err(internal_error_response("get service"))?; diff --git a/crates/rmcp/tests/test_streamable_http_standard_headers.rs b/crates/rmcp/tests/test_streamable_http_standard_headers.rs new file mode 100644 index 000000000..b890c8885 --- /dev/null +++ b/crates/rmcp/tests/test_streamable_http_standard_headers.rs @@ -0,0 +1,147 @@ +#![cfg(not(feature = "local"))] +//! SEP-2243 server-side validation of `Mcp-Method` / `Mcp-Name` headers. +use rmcp::transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, +}; +use tokio_util::sync::CancellationToken; + +mod common; +use common::calculator::Calculator; + +const SEP_VERSION: &str = "2026-07-28"; + +fn tools_call_body() -> String { + r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"sum","arguments":{"a":1,"b":2}}}"# + .to_owned() +} + +async fn spawn_server() -> (reqwest::Client, String, CancellationToken) { + let config = StreamableHttpServerConfig::default() + .with_stateful_mode(false) + .with_json_response(true) + .with_sse_keep_alive(None) + .with_cancellation_token(CancellationToken::new()); + let ct = config.cancellation_token.clone(); + let service: StreamableHttpService = + StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), config); + + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = tcp_listener.local_addr().unwrap(); + + tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + let client = reqwest::Client::new(); + (client, format!("http://{addr}/mcp"), ct) +} + +/// POSTs a `tools/call` with the given protocol-version and optional SEP-2243 headers. +async fn post_tools_call( + client: &reqwest::Client, + url: &str, + version: &str, + mcp_method: Option<&str>, + mcp_name: Option<&str>, +) -> reqwest::Response { + let mut req = client + .post(url) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .header("MCP-Protocol-Version", version) + .body(tools_call_body()); + if let Some(method) = mcp_method { + req = req.header("Mcp-Method", method); + } + if let Some(name) = mcp_name { + req = req.header("Mcp-Name", name); + } + req.send().await.expect("send tools/call request") +} + +#[tokio::test] +async fn accepts_matching_standard_headers() -> anyhow::Result<()> { + let (client, url, ct) = spawn_server().await; + + // Matching headers pass validation and reach dispatch. (Stateless mode without a + // prior initialize yields an unrelated -32601, which still proves -32001 was not raised.) + let response = + post_tools_call(&client, &url, SEP_VERSION, Some("tools/call"), Some("sum")).await; + let body: serde_json::Value = response.json().await?; + assert_ne!( + body["error"]["code"], -32001, + "matching headers must not be rejected as a header mismatch, got: {body}" + ); + + ct.cancel(); + Ok(()) +} + +#[tokio::test] +async fn rejects_method_mismatch_with_32001() -> anyhow::Result<()> { + let (client, url, ct) = spawn_server().await; + + let response = + post_tools_call(&client, &url, SEP_VERSION, Some("tools/list"), Some("sum")).await; + assert_eq!(response.status(), 400); + let body: serde_json::Value = response.json().await?; + assert_eq!(body["error"]["code"], -32001); + + ct.cancel(); + Ok(()) +} + +#[tokio::test] +async fn rejects_missing_method_header_with_32001() -> anyhow::Result<()> { + let (client, url, ct) = spawn_server().await; + + let response = post_tools_call(&client, &url, SEP_VERSION, None, Some("sum")).await; + assert_eq!(response.status(), 400); + let body: serde_json::Value = response.json().await?; + assert_eq!(body["error"]["code"], -32001); + + ct.cancel(); + Ok(()) +} + +#[tokio::test] +async fn rejects_name_mismatch_with_32001() -> anyhow::Result<()> { + let (client, url, ct) = spawn_server().await; + + let response = post_tools_call( + &client, + &url, + SEP_VERSION, + Some("tools/call"), + Some("product"), + ) + .await; + assert_eq!(response.status(), 400); + let body: serde_json::Value = response.json().await?; + assert_eq!(body["error"]["code"], -32001); + + ct.cancel(); + Ok(()) +} + +#[tokio::test] +async fn skips_validation_for_pre_sep_version() -> anyhow::Result<()> { + let (client, url, ct) = spawn_server().await; + + // Older version: headers are not enforced even when absent. + let response = post_tools_call(&client, &url, "2025-11-25", None, None).await; + let body: serde_json::Value = response.json().await?; + assert_ne!( + body["error"]["code"], -32001, + "pre-SEP versions must skip header validation, got: {body}" + ); + + ct.cancel(); + Ok(()) +} From fd9d7d9910755a9aa4578fc377a2a3ef2ca973ff Mon Sep 17 00:00:00 2001 From: Dale Seo <5466341+DaleSeo@users.noreply.github.com> Date: Tue, 16 Jun 2026 22:21:43 -0400 Subject: [PATCH 2/3] feat: validate Mcp-Param-* headers on server --- .../transport/streamable_http_server/tower.rs | 59 +++++++-- .../test_streamable_http_standard_headers.rs | 124 +++++++++++++++++- 2 files changed, 169 insertions(+), 14 deletions(-) diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index a303cc4f7..d8423c5b6 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -17,7 +17,8 @@ use crate::{ RoleServer, model::{ ClientJsonRpcMessage, ClientNotification, ClientRequest, ErrorData, GetExtensions, - InitializeRequest, InitializedNotification, JsonRpcError, ProtocolVersion, RequestId, + InitializeRequest, InitializedNotification, JsonObject, JsonRpcError, ProtocolVersion, + RequestId, }, serve_server, service::serve_directly, @@ -273,12 +274,12 @@ fn header_mismatch_jsonrpc_response( .expect("valid response") } -/// Validates SEP-2243 `Mcp-Method` / `Mcp-Name` headers against the request body. +/// Validates SEP-2243 `Mcp-Method` / `Mcp-Name` / `Mcp-Param-*` headers against the body. /// /// Only enforced when the request declares a protocol version `>= STANDARD_HEADERS`. /// The `initialize` handshake is exempt: clients emit these headers only after the -/// version has been negotiated. `Mcp-Param-*` validation is not enforced here because -/// the transport layer has no synchronous access to tool input schemas. +/// version has been negotiated. `tool_schema` supplies the called tool's input schema +/// so annotated `Mcp-Param-*` headers can be checked (no schema => those are skipped). #[expect( clippy::result_large_err, reason = "BoxResponse is intentionally large; matches other handlers in this file" @@ -286,6 +287,7 @@ fn header_mismatch_jsonrpc_response( fn validate_standard_headers( headers: &HeaderMap, message: &ClientJsonRpcMessage, + tool_schema: impl Fn(&str) -> Option>, ) -> Result<(), BoxResponse> { let version_requires_headers = headers .get(HEADER_MCP_PROTOCOL_VERSION) @@ -309,7 +311,16 @@ fn validate_standard_headers( let Ok(value) = serde_json::to_value(message) else { return Ok(()); }; - if let Err(reason) = mcp_headers::validate_request_headers(headers, &value, None) { + // For tools/call, look up the tool schema so Mcp-Param-* headers are validated. + let schema = value + .get("method") + .and_then(|method| method.as_str()) + .filter(|method| *method == "tools/call") + .and_then(|_| value.get("params")) + .and_then(|params| params.get("name")) + .and_then(|name| name.as_str()) + .and_then(tool_schema); + if let Err(reason) = mcp_headers::validate_request_headers(headers, &value, schema.as_deref()) { return Err(header_mismatch_jsonrpc_response(request_id, reason)); } Ok(()) @@ -610,6 +621,10 @@ pub struct StreamableHttpService { pending_restores: Option< Arc>>>>, >, + /// Caches tool input schemas by name for SEP-2243 `Mcp-Param-*` validation. + /// Populated lazily via `get_tool` so the service factory runs at most once + /// per tool name. `None` value means the tool exposes no schema. + tool_schemas: Arc>>>>, } impl Clone for StreamableHttpService { @@ -619,6 +634,7 @@ impl Clone for StreamableHttpService { session_manager: self.session_manager.clone(), service_factory: self.service_factory.clone(), pending_restores: self.pending_restores.clone(), + tool_schemas: self.tool_schemas.clone(), } } } @@ -626,7 +642,7 @@ impl Clone for StreamableHttpService { impl tower_service::Service> for StreamableHttpService where RequestBody: Body + Send + 'static, - S: crate::Service + Send + 'static, + S: crate::ServerHandler + Send + 'static, M: SessionManager, RequestBody::Error: Display, RequestBody::Data: Send + 'static, @@ -680,7 +696,7 @@ impl Drop for PendingRestoreGuard { impl StreamableHttpService where - S: crate::Service + Send + 'static, + S: crate::ServerHandler + Send + 'static, M: SessionManager, { pub fn new( @@ -699,12 +715,33 @@ where session_manager, service_factory: Arc::new(service_factory), pending_restores, + tool_schemas: Arc::new(std::sync::RwLock::new(HashMap::new())), } } fn get_service(&self) -> Result { (self.service_factory)() } + /// Returns the cached input schema for `name`, constructing a service once + /// per name to read its `ServerHandler::get_tool` definition. Used to + /// validate SEP-2243 `Mcp-Param-*` headers against the request body. + fn tool_schema(&self, name: &str) -> Option> { + if let Ok(cache) = self.tool_schemas.read() { + if let Some(schema) = cache.get(name) { + return schema.clone(); + } + } + let schema = self + .get_service() + .ok() + .and_then(|service| service.get_tool(name)) + .map(|tool| tool.input_schema); + if let Ok(mut cache) = self.tool_schemas.write() { + cache.insert(name.to_owned(), schema.clone()); + } + schema + } + /// Spawn a task that runs `serve_server` for the given session, waits for /// it to finish, and then calls `close_session`. /// @@ -719,7 +756,7 @@ where transport: M::Transport, init_done_tx: Option>, ) where - S: crate::Service + Send + 'static, + S: crate::ServerHandler + Send + 'static, M: SessionManager, { tokio::spawn(async move { @@ -762,7 +799,7 @@ where parts: &http::request::Parts, ) -> Result where - S: crate::Service + Send + 'static, + S: crate::ServerHandler + Send + 'static, M: SessionManager, { // Both fields are Some iff a session store is configured. @@ -1139,7 +1176,7 @@ where // Validate MCP-Protocol-Version header (per 2025-06-18 spec) validate_protocol_version_header(&part.headers)?; // Validate SEP-2243 standard headers against the body - validate_standard_headers(&part.headers, &message)?; + validate_standard_headers(&part.headers, &message, |name| self.tool_schema(name))?; // inject request part to extensions match &mut message { @@ -1297,7 +1334,7 @@ where } } // Validate SEP-2243 standard headers against the body - validate_standard_headers(&part.headers, &message)?; + validate_standard_headers(&part.headers, &message, |name| self.tool_schema(name))?; let service = self .get_service() .map_err(internal_error_response("get service"))?; diff --git a/crates/rmcp/tests/test_streamable_http_standard_headers.rs b/crates/rmcp/tests/test_streamable_http_standard_headers.rs index b890c8885..48521dba7 100644 --- a/crates/rmcp/tests/test_streamable_http_standard_headers.rs +++ b/crates/rmcp/tests/test_streamable_http_standard_headers.rs @@ -1,7 +1,13 @@ #![cfg(not(feature = "local"))] -//! SEP-2243 server-side validation of `Mcp-Method` / `Mcp-Name` headers. -use rmcp::transport::streamable_http_server::{ - StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, +//! SEP-2243 server-side validation of `Mcp-Method` / `Mcp-Name` / `Mcp-Param-*` headers. +use std::sync::Arc; + +use rmcp::{ + ServerHandler, + model::{ServerCapabilities, ServerInfo, Tool}, + transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, + }, }; use tokio_util::sync::CancellationToken; @@ -145,3 +151,115 @@ async fn skips_validation_for_pre_sep_version() -> anyhow::Result<()> { ct.cancel(); Ok(()) } + +/// Server exposing one tool whose `region` argument is promoted to `Mcp-Param-Region`. +#[derive(Clone, Default)] +struct ParamHeaderServer; + +impl ServerHandler for ParamHeaderServer { + fn get_info(&self) -> ServerInfo { + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + } + + fn get_tool(&self, name: &str) -> Option { + if name != "deploy" { + return None; + } + let schema = serde_json::json!({ + "type": "object", + "properties": { "region": { "type": "string", "x-mcp-header": "Region" } } + }); + let schema = schema.as_object().expect("object schema").clone(); + Some(Tool::new("deploy", "deploy a thing", Arc::new(schema))) + } +} + +async fn spawn_param_server() -> (reqwest::Client, String, CancellationToken) { + let config = StreamableHttpServerConfig::default() + .with_stateful_mode(false) + .with_json_response(true) + .with_sse_keep_alive(None) + .with_cancellation_token(CancellationToken::new()); + let ct = config.cancellation_token.clone(); + let service: StreamableHttpService = + StreamableHttpService::new(|| Ok(ParamHeaderServer), Default::default(), config); + + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = tcp_listener.local_addr().unwrap(); + tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + (reqwest::Client::new(), format!("http://{addr}/mcp"), ct) +} + +/// POSTs a `deploy` tools/call with a `region` argument and optional `Mcp-Param-Region` header. +async fn post_deploy( + client: &reqwest::Client, + url: &str, + region_arg: &str, + param_region: Option<&str>, +) -> reqwest::Response { + let body = format!( + r#"{{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{{"name":"deploy","arguments":{{"region":"{region_arg}"}}}}}}"# + ); + let mut req = client + .post(url) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .header("MCP-Protocol-Version", SEP_VERSION) + .header("Mcp-Method", "tools/call") + .header("Mcp-Name", "deploy") + .body(body); + if let Some(region) = param_region { + req = req.header("Mcp-Param-Region", region); + } + req.send().await.expect("send deploy request") +} + +#[tokio::test] +async fn accepts_matching_param_header() -> anyhow::Result<()> { + let (client, url, ct) = spawn_param_server().await; + + let response = post_deploy(&client, &url, "us-west1", Some("us-west1")).await; + let body: serde_json::Value = response.json().await?; + assert_ne!( + body["error"]["code"], -32001, + "matching Mcp-Param-* must not be rejected, got: {body}" + ); + + ct.cancel(); + Ok(()) +} + +#[tokio::test] +async fn rejects_param_mismatch_with_32001() -> anyhow::Result<()> { + let (client, url, ct) = spawn_param_server().await; + + let response = post_deploy(&client, &url, "us-west1", Some("eu-central1")).await; + assert_eq!(response.status(), 400); + let body: serde_json::Value = response.json().await?; + assert_eq!(body["error"]["code"], -32001); + + ct.cancel(); + Ok(()) +} + +#[tokio::test] +async fn rejects_missing_param_header_with_32001() -> anyhow::Result<()> { + let (client, url, ct) = spawn_param_server().await; + + // `region` argument is present but the annotated `Mcp-Param-Region` header is absent. + let response = post_deploy(&client, &url, "us-west1", None).await; + assert_eq!(response.status(), 400); + let body: serde_json::Value = response.json().await?; + assert_eq!(body["error"]["code"], -32001); + + ct.cancel(); + Ok(()) +} From 4641ea220a0292bc355f2f0d841915fec6ca9e1e Mon Sep 17 00:00:00 2001 From: Dale Seo <5466341+DaleSeo@users.noreply.github.com> Date: Wed, 17 Jun 2026 11:52:02 -0400 Subject: [PATCH 3/3] fix: emit Mcp-Method on reinit initialized POST --- crates/rmcp/Cargo.toml | 5 + .../src/transport/streamable_http_client.rs | 65 +++--- .../test_streamable_http_standard_headers.rs | 221 ++++++++++-------- 3 files changed, 163 insertions(+), 128 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index d8997dd73..12483b5e9 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -283,6 +283,11 @@ name = "test_streamable_http_protocol_version" required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] path = "tests/test_streamable_http_protocol_version.rs" +[[test]] +name = "test_streamable_http_standard_headers" +required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] +path = "tests/test_streamable_http_standard_headers.rs" + [[test]] name = "test_streamable_http_4xx_error_body" required-features = ["transport-streamable-http-client", "transport-streamable-http-client-reqwest"] diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 326582e1e..661f2f40a 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -73,6 +73,28 @@ fn cache_tools_from_response( } } +/// Derives the negotiated protocol version and base request headers from an +/// initialize response: returns `base` with `MCP-Protocol-Version` injected +/// (per MCP 2025-06-18) plus the version used to gate SEP-2243 headers, +/// defaulting to a pre-SEP version when the response can't be parsed. +fn negotiate_version_headers( + init_response: &ServerJsonRpcMessage, + base: HashMap, +) -> (ProtocolVersion, HashMap) { + let mut version = ProtocolVersion::default(); + let mut headers = base; + if let ServerJsonRpcMessage::Response(response) = init_response { + if let ServerResult::InitializeResult(init_result) = &response.result { + version = init_result.protocol_version.clone(); + // HeaderName::from_static requires lowercase + if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) { + headers.insert(HeaderName::from_static("mcp-protocol-version"), hv); + } + } + } + (version, headers) +} + #[derive(Debug)] #[non_exhaustive] pub struct AuthRequiredError { @@ -455,17 +477,8 @@ impl StreamableHttpClientWorker { let new_session_id: Option> = new_session_id_str.map(|s| Arc::from(s.as_str())); - // Start from custom_headers, then inject the negotiated MCP-Protocol-Version - // so all subsequent requests carry the right version (MCP 2025-06-18 spec). - let mut new_protocol_headers = custom_headers; - if let ServerJsonRpcMessage::Response(response) = &init_msg { - if let ServerResult::InitializeResult(init_result) = &response.result { - if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) { - new_protocol_headers - .insert(HeaderName::from_static("mcp-protocol-version"), hv); - } - } - } + let (negotiated_version, new_protocol_headers) = + negotiate_version_headers(&init_msg, custom_headers); let initialized_notification = ClientJsonRpcMessage::notification( ClientNotification::InitializedNotification(InitializedNotification { @@ -473,13 +486,20 @@ impl StreamableHttpClientWorker { extensions: Default::default(), }), ); + // SEP-2243: notifications carry no Mcp-Param-*, so an empty tool cache suffices. + let initialized_headers = build_request_headers( + &new_protocol_headers, + &initialized_notification, + &HashMap::new(), + &negotiated_version, + ); client .post_message( uri, initialized_notification, new_session_id.clone(), auth_header, - new_protocol_headers.clone(), + initialized_headers, ) .await? .expect_accepted_or_json::()?; @@ -555,25 +575,8 @@ impl Worker for StreamableHttpClientWorker { } None }; - // Extract the negotiated protocol version from the init response - // and build a custom headers map that includes MCP-Protocol-Version - // for all subsequent HTTP requests (per MCP 2025-06-18 spec). - // Negotiated protocol version gates SEP-2243 standard headers; default to a - // pre-SEP version so headers are omitted if the version can't be determined. - let mut negotiated_version = ProtocolVersion::default(); - let mut protocol_headers = { - let mut headers = config.custom_headers.clone(); - if let ServerJsonRpcMessage::Response(response) = &message { - if let ServerResult::InitializeResult(init_result) = &response.result { - negotiated_version = init_result.protocol_version.clone(); - if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) { - // HeaderName::from_static requires lowercase - headers.insert(HeaderName::from_static("mcp-protocol-version"), hv); - } - } - } - headers - }; + let (negotiated_version, mut protocol_headers) = + negotiate_version_headers(&message, config.custom_headers.clone()); // SEP-2243: tool input schemas (name -> schema) cached from tools/list responses, // used to promote annotated tools/call arguments to Mcp-Param-* headers. let mut tool_header_cache: HashMap> = HashMap::new(); diff --git a/crates/rmcp/tests/test_streamable_http_standard_headers.rs b/crates/rmcp/tests/test_streamable_http_standard_headers.rs index 48521dba7..095c35f44 100644 --- a/crates/rmcp/tests/test_streamable_http_standard_headers.rs +++ b/crates/rmcp/tests/test_streamable_http_standard_headers.rs @@ -11,14 +11,28 @@ use rmcp::{ }; use tokio_util::sync::CancellationToken; -mod common; -use common::calculator::Calculator; - const SEP_VERSION: &str = "2026-07-28"; -fn tools_call_body() -> String { - r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"sum","arguments":{"a":1,"b":2}}}"# - .to_owned() +/// Server exposing one tool whose `region` argument is promoted to `Mcp-Param-Region`. +#[derive(Clone, Default)] +struct HeaderValidationServer; + +impl ServerHandler for HeaderValidationServer { + fn get_info(&self) -> ServerInfo { + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + } + + fn get_tool(&self, name: &str) -> Option { + if name != "deploy" { + return None; + } + let schema = serde_json::json!({ + "type": "object", + "properties": { "region": { "type": "string", "x-mcp-header": "Region" } } + }); + let schema = schema.as_object().expect("object schema").clone(); + Some(Tool::new("deploy", "deploy a thing", Arc::new(schema))) + } } async fn spawn_server() -> (reqwest::Client, String, CancellationToken) { @@ -28,13 +42,12 @@ async fn spawn_server() -> (reqwest::Client, String, CancellationToken) { .with_sse_keep_alive(None) .with_cancellation_token(CancellationToken::new()); let ct = config.cancellation_token.clone(); - let service: StreamableHttpService = - StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), config); + let service: StreamableHttpService = + StreamableHttpService::new(|| Ok(HeaderValidationServer), Default::default(), config); let router = axum::Router::new().nest_service("/mcp", service); let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = tcp_listener.local_addr().unwrap(); - tokio::spawn({ let ct = ct.clone(); async move { @@ -43,31 +56,44 @@ async fn spawn_server() -> (reqwest::Client, String, CancellationToken) { .await; } }); - - let client = reqwest::Client::new(); - (client, format!("http://{addr}/mcp"), ct) + (reqwest::Client::new(), format!("http://{addr}/mcp"), ct) } /// POSTs a `tools/call` with the given protocol-version and optional SEP-2243 headers. -async fn post_tools_call( +async fn post_tool_call( client: &reqwest::Client, url: &str, version: &str, + tool_name: &str, + arguments: serde_json::Value, mcp_method: Option<&str>, mcp_name: Option<&str>, + param_region: Option<&str>, ) -> reqwest::Response { + let body = serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": arguments, + } + }); let mut req = client .post(url) .header("Content-Type", "application/json") .header("Accept", "application/json, text/event-stream") .header("MCP-Protocol-Version", version) - .body(tools_call_body()); + .body(body.to_string()); if let Some(method) = mcp_method { req = req.header("Mcp-Method", method); } if let Some(name) = mcp_name { req = req.header("Mcp-Name", name); } + if let Some(region) = param_region { + req = req.header("Mcp-Param-Region", region); + } req.send().await.expect("send tools/call request") } @@ -77,8 +103,17 @@ async fn accepts_matching_standard_headers() -> anyhow::Result<()> { // Matching headers pass validation and reach dispatch. (Stateless mode without a // prior initialize yields an unrelated -32601, which still proves -32001 was not raised.) - let response = - post_tools_call(&client, &url, SEP_VERSION, Some("tools/call"), Some("sum")).await; + let response = post_tool_call( + &client, + &url, + SEP_VERSION, + "sum", + serde_json::json!({ "a": 1, "b": 2 }), + Some("tools/call"), + Some("sum"), + None, + ) + .await; let body: serde_json::Value = response.json().await?; assert_ne!( body["error"]["code"], -32001, @@ -93,8 +128,17 @@ async fn accepts_matching_standard_headers() -> anyhow::Result<()> { async fn rejects_method_mismatch_with_32001() -> anyhow::Result<()> { let (client, url, ct) = spawn_server().await; - let response = - post_tools_call(&client, &url, SEP_VERSION, Some("tools/list"), Some("sum")).await; + let response = post_tool_call( + &client, + &url, + SEP_VERSION, + "sum", + serde_json::json!({ "a": 1, "b": 2 }), + Some("tools/list"), + Some("sum"), + None, + ) + .await; assert_eq!(response.status(), 400); let body: serde_json::Value = response.json().await?; assert_eq!(body["error"]["code"], -32001); @@ -107,7 +151,17 @@ async fn rejects_method_mismatch_with_32001() -> anyhow::Result<()> { async fn rejects_missing_method_header_with_32001() -> anyhow::Result<()> { let (client, url, ct) = spawn_server().await; - let response = post_tools_call(&client, &url, SEP_VERSION, None, Some("sum")).await; + let response = post_tool_call( + &client, + &url, + SEP_VERSION, + "sum", + serde_json::json!({ "a": 1, "b": 2 }), + None, + Some("sum"), + None, + ) + .await; assert_eq!(response.status(), 400); let body: serde_json::Value = response.json().await?; assert_eq!(body["error"]["code"], -32001); @@ -120,12 +174,15 @@ async fn rejects_missing_method_header_with_32001() -> anyhow::Result<()> { async fn rejects_name_mismatch_with_32001() -> anyhow::Result<()> { let (client, url, ct) = spawn_server().await; - let response = post_tools_call( + let response = post_tool_call( &client, &url, SEP_VERSION, + "sum", + serde_json::json!({ "a": 1, "b": 2 }), Some("tools/call"), Some("product"), + None, ) .await; assert_eq!(response.status(), 400); @@ -141,7 +198,17 @@ async fn skips_validation_for_pre_sep_version() -> anyhow::Result<()> { let (client, url, ct) = spawn_server().await; // Older version: headers are not enforced even when absent. - let response = post_tools_call(&client, &url, "2025-11-25", None, None).await; + let response = post_tool_call( + &client, + &url, + "2025-11-25", + "sum", + serde_json::json!({ "a": 1, "b": 2 }), + None, + None, + None, + ) + .await; let body: serde_json::Value = response.json().await?; assert_ne!( body["error"]["code"], -32001, @@ -152,81 +219,21 @@ async fn skips_validation_for_pre_sep_version() -> anyhow::Result<()> { Ok(()) } -/// Server exposing one tool whose `region` argument is promoted to `Mcp-Param-Region`. -#[derive(Clone, Default)] -struct ParamHeaderServer; - -impl ServerHandler for ParamHeaderServer { - fn get_info(&self) -> ServerInfo { - ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) - } - - fn get_tool(&self, name: &str) -> Option { - if name != "deploy" { - return None; - } - let schema = serde_json::json!({ - "type": "object", - "properties": { "region": { "type": "string", "x-mcp-header": "Region" } } - }); - let schema = schema.as_object().expect("object schema").clone(); - Some(Tool::new("deploy", "deploy a thing", Arc::new(schema))) - } -} - -async fn spawn_param_server() -> (reqwest::Client, String, CancellationToken) { - let config = StreamableHttpServerConfig::default() - .with_stateful_mode(false) - .with_json_response(true) - .with_sse_keep_alive(None) - .with_cancellation_token(CancellationToken::new()); - let ct = config.cancellation_token.clone(); - let service: StreamableHttpService = - StreamableHttpService::new(|| Ok(ParamHeaderServer), Default::default(), config); - - let router = axum::Router::new().nest_service("/mcp", service); - let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = tcp_listener.local_addr().unwrap(); - tokio::spawn({ - let ct = ct.clone(); - async move { - let _ = axum::serve(tcp_listener, router) - .with_graceful_shutdown(async move { ct.cancelled_owned().await }) - .await; - } - }); - (reqwest::Client::new(), format!("http://{addr}/mcp"), ct) -} - -/// POSTs a `deploy` tools/call with a `region` argument and optional `Mcp-Param-Region` header. -async fn post_deploy( - client: &reqwest::Client, - url: &str, - region_arg: &str, - param_region: Option<&str>, -) -> reqwest::Response { - let body = format!( - r#"{{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{{"name":"deploy","arguments":{{"region":"{region_arg}"}}}}}}"# - ); - let mut req = client - .post(url) - .header("Content-Type", "application/json") - .header("Accept", "application/json, text/event-stream") - .header("MCP-Protocol-Version", SEP_VERSION) - .header("Mcp-Method", "tools/call") - .header("Mcp-Name", "deploy") - .body(body); - if let Some(region) = param_region { - req = req.header("Mcp-Param-Region", region); - } - req.send().await.expect("send deploy request") -} - #[tokio::test] async fn accepts_matching_param_header() -> anyhow::Result<()> { - let (client, url, ct) = spawn_param_server().await; + let (client, url, ct) = spawn_server().await; - let response = post_deploy(&client, &url, "us-west1", Some("us-west1")).await; + let response = post_tool_call( + &client, + &url, + SEP_VERSION, + "deploy", + serde_json::json!({ "region": "us-west1" }), + Some("tools/call"), + Some("deploy"), + Some("us-west1"), + ) + .await; let body: serde_json::Value = response.json().await?; assert_ne!( body["error"]["code"], -32001, @@ -239,9 +246,19 @@ async fn accepts_matching_param_header() -> anyhow::Result<()> { #[tokio::test] async fn rejects_param_mismatch_with_32001() -> anyhow::Result<()> { - let (client, url, ct) = spawn_param_server().await; + let (client, url, ct) = spawn_server().await; - let response = post_deploy(&client, &url, "us-west1", Some("eu-central1")).await; + let response = post_tool_call( + &client, + &url, + SEP_VERSION, + "deploy", + serde_json::json!({ "region": "us-west1" }), + Some("tools/call"), + Some("deploy"), + Some("eu-central1"), + ) + .await; assert_eq!(response.status(), 400); let body: serde_json::Value = response.json().await?; assert_eq!(body["error"]["code"], -32001); @@ -252,10 +269,20 @@ async fn rejects_param_mismatch_with_32001() -> anyhow::Result<()> { #[tokio::test] async fn rejects_missing_param_header_with_32001() -> anyhow::Result<()> { - let (client, url, ct) = spawn_param_server().await; + let (client, url, ct) = spawn_server().await; // `region` argument is present but the annotated `Mcp-Param-Region` header is absent. - let response = post_deploy(&client, &url, "us-west1", None).await; + let response = post_tool_call( + &client, + &url, + SEP_VERSION, + "deploy", + serde_json::json!({ "region": "us-west1" }), + Some("tools/call"), + Some("deploy"), + None, + ) + .await; assert_eq!(response.status(), 400); let body: serde_json::Value = response.json().await?; assert_eq!(body["error"]["code"], -32001);