|
| 1 | +<?php |
| 2 | + |
| 3 | +/* |
| 4 | + * This file is part of the official PHP MCP SDK. |
| 5 | + * |
| 6 | + * A collaboration between Symfony and the PHP Foundation. |
| 7 | + * |
| 8 | + * For the full copyright and license information, please view the LICENSE |
| 9 | + * file that was distributed with this source code. |
| 10 | + */ |
| 11 | + |
| 12 | +namespace Mcp\Server\Transport\Http\Middleware; |
| 13 | + |
| 14 | +use Http\Discovery\Psr17FactoryDiscovery; |
| 15 | +use Mcp\Schema\JsonRpc\Error; |
| 16 | +use Psr\Http\Message\ResponseFactoryInterface; |
| 17 | +use Psr\Http\Message\ResponseInterface; |
| 18 | +use Psr\Http\Message\ServerRequestInterface; |
| 19 | +use Psr\Http\Message\StreamFactoryInterface; |
| 20 | +use Psr\Http\Server\MiddlewareInterface; |
| 21 | +use Psr\Http\Server\RequestHandlerInterface; |
| 22 | + |
| 23 | +/** |
| 24 | + * Protects against DNS rebinding attacks by validating Origin and Host headers. |
| 25 | + * |
| 26 | + * When an Origin header is present, it is validated against the allowed hostnames. |
| 27 | + * Otherwise, the Host header is validated instead. |
| 28 | + * By default, only localhost variants (localhost, 127.0.0.1, [::1], ::1) are allowed. |
| 29 | + * |
| 30 | + * @see https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise |
| 31 | + * @see https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#security-warning |
| 32 | + */ |
| 33 | +final class DnsRebindingProtectionMiddleware implements MiddlewareInterface |
| 34 | +{ |
| 35 | + private ResponseFactoryInterface $responseFactory; |
| 36 | + private StreamFactoryInterface $streamFactory; |
| 37 | + |
| 38 | + /** @var list<string> */ |
| 39 | + private readonly array $allowedHosts; |
| 40 | + |
| 41 | + /** |
| 42 | + * @param string[] $allowedHosts Allowed hostnames (without port). Defaults to localhost variants. |
| 43 | + * @param ResponseFactoryInterface|null $responseFactory PSR-17 response factory |
| 44 | + * @param StreamFactoryInterface|null $streamFactory PSR-17 stream factory |
| 45 | + */ |
| 46 | + public function __construct( |
| 47 | + array $allowedHosts = ['localhost', '127.0.0.1', '[::1]', '::1'], |
| 48 | + ?ResponseFactoryInterface $responseFactory = null, |
| 49 | + ?StreamFactoryInterface $streamFactory = null, |
| 50 | + ) { |
| 51 | + $this->allowedHosts = array_values(array_map('strtolower', $allowedHosts)); |
| 52 | + $this->responseFactory = $responseFactory ?? Psr17FactoryDiscovery::findResponseFactory(); |
| 53 | + $this->streamFactory = $streamFactory ?? Psr17FactoryDiscovery::findStreamFactory(); |
| 54 | + } |
| 55 | + |
| 56 | + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface |
| 57 | + { |
| 58 | + $origin = $request->getHeaderLine('Origin'); |
| 59 | + if ('' !== $origin) { |
| 60 | + if (!$this->isAllowedOrigin($origin)) { |
| 61 | + return $this->createForbiddenResponse('Forbidden: Invalid Origin header.'); |
| 62 | + } |
| 63 | + |
| 64 | + return $handler->handle($request); |
| 65 | + } |
| 66 | + |
| 67 | + $host = $request->getHeaderLine('Host'); |
| 68 | + if ('' !== $host && !$this->isAllowedHost($host)) { |
| 69 | + return $this->createForbiddenResponse('Forbidden: Invalid Host header.'); |
| 70 | + } |
| 71 | + |
| 72 | + return $handler->handle($request); |
| 73 | + } |
| 74 | + |
| 75 | + private function isAllowedOrigin(string $origin): bool |
| 76 | + { |
| 77 | + $parsed = parse_url($origin); |
| 78 | + if (false === $parsed || !isset($parsed['host'])) { |
| 79 | + return false; |
| 80 | + } |
| 81 | + |
| 82 | + return \in_array(strtolower($parsed['host']), $this->allowedHosts, true); |
| 83 | + } |
| 84 | + |
| 85 | + /** |
| 86 | + * Validates the Host header value (host or host:port) against the allowed list. |
| 87 | + */ |
| 88 | + private function isAllowedHost(string $host): bool |
| 89 | + { |
| 90 | + // IPv6 host with port: [::1]:8080 |
| 91 | + if (str_starts_with($host, '[')) { |
| 92 | + $closingBracket = strpos($host, ']'); |
| 93 | + if (false === $closingBracket) { |
| 94 | + return false; |
| 95 | + } |
| 96 | + $hostname = substr($host, 0, $closingBracket + 1); |
| 97 | + } else { |
| 98 | + // Strip port if present (host:port) |
| 99 | + $hostname = explode(':', $host, 2)[0]; |
| 100 | + } |
| 101 | + |
| 102 | + return \in_array(strtolower($hostname), $this->allowedHosts, true); |
| 103 | + } |
| 104 | + |
| 105 | + private function createForbiddenResponse(string $message): ResponseInterface |
| 106 | + { |
| 107 | + $body = json_encode(Error::forInvalidRequest($message), \JSON_THROW_ON_ERROR); |
| 108 | + |
| 109 | + return $this->responseFactory |
| 110 | + ->createResponse(403) |
| 111 | + ->withHeader('Content-Type', 'application/json') |
| 112 | + ->withBody($this->streamFactory->createStream($body)); |
| 113 | + } |
| 114 | +} |
0 commit comments