Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/rmcp/src/transport/common/http_header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub const JSON_MIME_TYPE: &str = "application/json";
/// Reserved headers that must not be overridden by user-supplied custom headers.
/// `MCP-Protocol-Version` is in this list but is allowed through because the worker
/// injects it after initialization.
#[allow(dead_code)]
pub(crate) const RESERVED_HEADERS: &[&str] = &[
"accept",
HEADER_SESSION_ID,
Expand Down Expand Up @@ -36,6 +37,7 @@ pub(crate) fn validate_custom_header(name: &http::HeaderName) -> Result<(), Stri

/// Extracts the `scope=` parameter from a `WWW-Authenticate` header value.
/// Handles both quoted (`scope="files:read files:write"`) and unquoted (`scope=read:data`) forms.
#[allow(dead_code)]
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If extract_scope_from_header is only needed under specific transport/client features, prefer feature-gating it similarly rather than using #[allow(dead_code)]. This keeps unused APIs from accumulating silently across feature combinations.

Suggested change
#[allow(dead_code)]
#[cfg(feature = "client-side-sse")]

Copilot uses AI. Check for mistakes.
pub(crate) fn extract_scope_from_header(header: &str) -> Option<String> {
let header_lowercase = header.to_ascii_lowercase();
let scope_key = "scope=";
Expand Down
111 changes: 109 additions & 2 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};

use bytes::Bytes;
use futures::{StreamExt, future::BoxFuture};
use http::{Method, Request, Response, header::ALLOW};
use http::{HeaderMap, Method, Request, Response, header::ALLOW};
use http_body::Body;
use http_body_util::{BodyExt, Full, combinators::BoxBody};
use tokio_stream::wrappers::ReceiverStream;
Expand All @@ -29,8 +29,8 @@ use crate::{
},
};

#[derive(Debug, Clone)]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct StreamableHttpServerConfig {
/// The ping message duration for SSE connections.
pub sse_keep_alive: Option<Duration>,
Expand All @@ -49,6 +49,16 @@ pub struct StreamableHttpServerConfig {
/// When this token is cancelled, all active sessions are terminated and
/// the server stops accepting new requests.
pub cancellation_token: CancellationToken,
/// Allowed hostnames or `host:port` authorities for inbound `Host` validation.
///
/// By default, Streamable HTTP servers only accept loopback hosts to
/// prevent DNS rebinding attacks against locally running servers. Public
/// deployments should override this list with their own hostnames.
/// examples:
/// allowed_hosts = ["localhost", "127.0.0.1", "0.0.0.0"]
/// or with ports:
/// allowed_hosts = ["example.com", "example.com:8080"]
pub allowed_hosts: Vec<String>,
Comment on lines +53 to +61
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the new pub allowed_hosts field to StreamableHttpServerConfig is a semver-breaking change for downstream users constructing the config via struct literals (they will now fail to compile). If this is intended, it should be paired with a major version bump / explicit release note; otherwise consider making the config #[non_exhaustive] and/or moving toward a constructor/builder pattern to avoid future breakage.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is especially important given that PR #715 was recently merged to prepare for 1.0 stable release. StreamableHttpServerConfig may have been missed in that effort. Adding #[non_exhaustive] here would be consistent with that direction and prevent this class of breakage going forward.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this security update should be introduced in 1.0 version , and I will add the #[non_exhaustive].

}

impl Default for StreamableHttpServerConfig {
Expand All @@ -59,11 +69,24 @@ impl Default for StreamableHttpServerConfig {
stateful_mode: true,
json_response: false,
cancellation_token: CancellationToken::new(),
allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()],
}
}
}

impl StreamableHttpServerConfig {
pub fn with_allowed_hosts(
mut self,
allowed_hosts: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.allowed_hosts = allowed_hosts.into_iter().map(Into::into).collect();
self
}
/// Disable allowed hosts. This will allow requests with any `Host` header, which is NOT recommended for public deployments.
pub fn disable_allowed_hosts(mut self) -> Self {
self.allowed_hosts.clear();
self
}
pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self {
self.sse_keep_alive = duration;
self
Expand Down Expand Up @@ -130,6 +153,87 @@ fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), Box
Ok(())
}

fn forbidden_response(message: impl Into<String>) -> BoxResponse {
Response::builder()
.status(http::StatusCode::FORBIDDEN)
.body(Full::new(Bytes::from(message.into())).boxed())
.expect("valid response")
}

fn normalize_host(host: &str) -> String {
host.trim_matches('[')
.trim_matches(']')
.to_ascii_lowercase()
}

#[derive(Debug, Clone, PartialEq, Eq)]
struct NormalizedAuthority {
host: String,
port: Option<u16>,
}

fn normalize_authority(host: &str, port: Option<u16>) -> NormalizedAuthority {
NormalizedAuthority {
host: normalize_host(host),
port,
}
}

fn parse_allowed_authority(allowed: &str) -> Option<NormalizedAuthority> {
let allowed = allowed.trim();
if allowed.is_empty() {
return None;
}

if let Ok(authority) = http::uri::Authority::try_from(allowed) {
return Some(normalize_authority(authority.host(), authority.port_u16()));
}

Some(normalize_authority(allowed, None))
}

fn host_is_allowed(host: &NormalizedAuthority, allowed_hosts: &[String]) -> bool {
if allowed_hosts.is_empty() {
// If the allowed hosts list is empty, allow all hosts (not recommended).
return true;
}
allowed_hosts
.iter()
.filter_map(|allowed| parse_allowed_authority(allowed))
.any(|allowed| {
allowed.host == host.host
&& match allowed.port {
Some(port) => host.port == Some(port),
None => true,
}
})
}

fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> {
let Some(host) = headers.get(http::header::HOST) else {
return Err(forbidden_response("Forbidden:missing_host header"));
};

let host = host
.to_str()
.map_err(|_| forbidden_response("Forbidden: Invalid Host header encoding"))?;
let authority = http::uri::Authority::try_from(host)
.map_err(|_| forbidden_response("Forbidden: Invalid Host header"))?;
Comment on lines +212 to +221
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing Host on an HTTP/1.1 request is a malformed request; returning 403 here is misleading and the error string has a formatting typo ("Forbidden:missing_host header"). Consider using 400 Bad Request for missing/invalid Host (reserving 403 for syntactically valid but disallowed hosts) and standardizing the message formatting.

Suggested change
fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> {
let Some(host) = headers.get(http::header::HOST) else {
return Err(forbidden_response("Forbidden:missing_host header"));
};
let host = host
.to_str()
.map_err(|_| forbidden_response("Forbidden: Invalid Host header encoding"))?;
let authority = http::uri::Authority::try_from(host)
.map_err(|_| forbidden_response("Forbidden: Invalid Host header"))?;
fn bad_request_response(message: &str) -> BoxResponse {
let body = Full::from(message.to_string()).boxed();
http::Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(body)
.expect("failed to build bad request response")
}
fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> {
let Some(host) = headers.get(http::header::HOST) else {
return Err(bad_request_response("Bad Request: missing Host header"));
};
let host = host
.to_str()
.map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
let authority = http::uri::Authority::try_from(host)
.map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;

Copilot uses AI. Check for mistakes.
Comment on lines +217 to +221
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_str() / Authority::try_from() failures indicate an invalid Host header value rather than an authorization failure. Returning a 400 Bad Request for these parse errors would better match HTTP semantics; 403 can remain for “Host not allowed”. Also consider keeping the error messages consistent with other responses in this file (spacing/capitalization).

Copilot uses AI. Check for mistakes.
Ok(normalize_authority(authority.host(), authority.port_u16()))
}

fn validate_dns_rebinding_headers(
headers: &HeaderMap,
config: &StreamableHttpServerConfig,
) -> Result<(), BoxResponse> {
let host = parse_host_header(headers)?;
if !host_is_allowed(&host, &config.allowed_hosts) {
return Err(forbidden_response("Forbidden: Host header is not allowed"));
}

Ok(())
}

/// # Streamable HTTP server
///
/// An HTTP service that implements the
Expand Down Expand Up @@ -279,6 +383,9 @@ where
B: Body + Send + 'static,
B::Error: Display,
{
if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) {
return response;
}
let method = request.method().clone();
let allowed_methods = match self.config.stateful_mode {
true => "GET, POST, DELETE",
Expand Down
158 changes: 158 additions & 0 deletions crates/rmcp/tests/test_custom_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();

Expand All @@ -785,6 +786,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("mcp-session-id", &session_id)
.header("mcp-protocol-version", "2025-03-26")
.body(Full::new(Bytes::from(initialized_body.to_string())))
Expand All @@ -802,6 +804,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("mcp-session-id", &session_id)
.header("mcp-protocol-version", "2025-03-26")
.body(Full::new(Bytes::from(valid_body.to_string())))
Expand All @@ -823,6 +826,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("mcp-session-id", &session_id)
.header("mcp-protocol-version", "9999-01-01")
.body(Full::new(Bytes::from(invalid_body.to_string())))
Expand All @@ -844,6 +848,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("mcp-session-id", &session_id)
.body(Full::new(Bytes::from(no_version_body.to_string())))
.unwrap();
Expand All @@ -870,3 +875,156 @@ fn test_protocol_version_utilities() {
assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_03_26));
assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_06_18));
}

/// Integration test: Verify server validates only the Host header for DNS rebinding protection
#[tokio::test]
#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))]
async fn test_server_validates_host_header_for_dns_rebinding_protection() {
use std::sync::Arc;

use bytes::Bytes;
use http::{Method, Request, header::CONTENT_TYPE};
use http_body_util::Full;
use rmcp::{
handler::server::ServerHandler,
model::{ServerCapabilities, ServerInfo},
transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
};
use serde_json::json;

#[derive(Clone)]
struct TestHandler;

impl ServerHandler for TestHandler {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().build())
}
}

let service = StreamableHttpService::new(
|| Ok(TestHandler),
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default(),
);

let init_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
});

let allowed_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("Origin", "http://localhost:8080")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();

let response = service.handle(allowed_request).await;
assert_eq!(response.status(), http::StatusCode::OK);

let bad_host_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "attacker.example")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();

let response = service.handle(bad_host_request).await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);

let ignored_origin_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("Origin", "http://attacker.example")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();

let response = service.handle(ignored_origin_request).await;
assert_eq!(response.status(), http::StatusCode::OK);
}

/// Integration test: Verify server can enforce an allowed Host port when configured
#[tokio::test]
#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))]
async fn test_server_validates_host_header_port_for_dns_rebinding_protection() {
use std::sync::Arc;

use bytes::Bytes;
use http::{Method, Request, header::CONTENT_TYPE};
use http_body_util::Full;
use rmcp::{
handler::server::ServerHandler,
model::{ServerCapabilities, ServerInfo},
transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
};
use serde_json::json;

#[derive(Clone)]
struct TestHandler;

impl ServerHandler for TestHandler {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().build())
}
}

let service = StreamableHttpService::new(
|| Ok(TestHandler),
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default().with_allowed_hosts(["localhost:8080"]),
);

let init_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
});

let allowed_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();

let response = service.handle(allowed_request).await;
assert_eq!(response.status(), http::StatusCode::OK);

let wrong_port_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:3000")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();

let response = service.handle(wrong_port_request).await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
}
Loading