diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c882ab7..fc0fd7a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ All notable changes to `mcp/sdk` will be documented in this file. 0.5.0 ----- +* **[BC BREAK]** Extract CORS handling from `StreamableHttpTransport` into `CorsMiddleware`. The `$corsHeaders` constructor parameter has been removed — pass a configured `CorsMiddleware` via the `$middleware` array instead (a default is prepended automatically if omitted). Default `Access-Control-Allow-Origin` is no longer set (was `*`). * Add built-in authentication middleware for HTTP transport using OAuth * Add client component for building MCP clients * Add `Builder::setReferenceHandler()` to allow custom `ReferenceHandlerInterface` implementations (e.g. authorization decorators) diff --git a/composer.json b/composer.json index 5d1c5297..e99bfbb4 100644 --- a/composer.json +++ b/composer.json @@ -53,6 +53,9 @@ "symfony/http-client": "^5.4 || ^6.4 || ^7.3 || ^8.0", "symfony/process": "^5.4 || ^6.4 || ^7.3 || ^8.0" }, + "conflict": { + "php-cs-fixer/shim": "3.95.0" + }, "autoload": { "psr-4": { "Mcp\\": "src/" diff --git a/docs/transports.md b/docs/transports.md index a68875d9..5bcf7573 100644 --- a/docs/transports.md +++ b/docs/transports.md @@ -139,45 +139,55 @@ $transport = new StreamableHttpTransport($request, $psr17Factory, $psr17Factory) ### CORS Configuration -The transport sets secure CORS defaults that can be customized or disabled: +CORS is handled by the `CorsMiddleware`, which is automatically prepended to the middleware chain. By default, +no `Access-Control-Allow-Origin` header is set, which effectively blocks cross-origin browser requests. ```php -// Default CORS headers (backward compatible) -$transport = new StreamableHttpTransport($request, $responseFactory, $streamFactory); +use Mcp\Server\Transport\Http\Middleware\CorsMiddleware; +use Mcp\Server\Transport\StreamableHttpTransport; + +// Default: cross-origin requests are blocked (no Access-Control-Allow-Origin header) +$transport = new StreamableHttpTransport($request); -// Restrict to specific origin +// Allow specific origins $transport = new StreamableHttpTransport( $request, - $responseFactory, - $streamFactory, - ['Access-Control-Allow-Origin' => 'https://myapp.com'] + middleware: [ + new CorsMiddleware( + allowedOrigins: ['https://myapp.com', 'https://staging.myapp.com'], + ), + ], ); -// Disable CORS for proxy scenarios +// Allow all origins (e.g. for development) $transport = new StreamableHttpTransport( $request, - $responseFactory, - $streamFactory, - ['Access-Control-Allow-Origin' => ''] + middleware: [new CorsMiddleware(allowedOrigins: ['*'])], ); -// Custom headers with logger +// Full configuration $transport = new StreamableHttpTransport( $request, - $responseFactory, - $streamFactory, - [ - 'Access-Control-Allow-Origin' => 'https://api.example.com', - 'Access-Control-Max-Age' => '86400' + middleware: [ + new CorsMiddleware( + allowedOrigins: ['https://myapp.com'], + allowedMethods: ['GET', 'POST', 'DELETE', 'OPTIONS'], + allowedHeaders: ['Accept', 'Authorization', 'Content-Type', 'Last-Event-ID', 'Mcp-Protocol-Version', 'Mcp-Session-Id'], + exposedHeaders: ['Mcp-Session-Id'], + ), ], - $logger ); ``` -Default CORS headers: -- `Access-Control-Allow-Origin: *` +Default CORS headers (always set unless overridden by middleware): - `Access-Control-Allow-Methods: GET, POST, DELETE, OPTIONS` -- `Access-Control-Allow-Headers: Content-Type, Mcp-Session-Id, Mcp-Protocol-Version, Last-Event-ID, Authorization, Accept` +- `Access-Control-Allow-Headers: Accept, Authorization, Content-Type, Last-Event-ID, Mcp-Protocol-Version, Mcp-Session-Id` +- `Access-Control-Expose-Headers: Mcp-Session-Id` + +If no `CorsMiddleware` is provided, a default instance is automatically prepended — ensuring CORS headers are applied +to all responses, including those from other middleware that short-circuit (e.g. an auth middleware returning `401`). +When you provide your own `CorsMiddleware` in the array, it is used at the position you place it and no default is added. +The transport itself handles `OPTIONS` preflight requests by returning a `204` response. ### PSR-15 Middleware @@ -209,15 +219,13 @@ final class AuthMiddleware implements MiddlewareInterface $transport = new StreamableHttpTransport( $request, - $responseFactory, - $streamFactory, - [], - $logger, - [new AuthMiddleware($responseFactory)], + logger: $logger, + middleware: [new AuthMiddleware($responseFactory)], ); ``` -If middleware returns a response, the transport will still ensure CORS headers are present unless you set them yourself. +If you don't include a `CorsMiddleware` in your middleware array, a default one is automatically prepended, +so CORS headers are applied to all responses even when middleware short-circuits. ### Architecture diff --git a/src/Server/Transport/Http/Middleware/CorsMiddleware.php b/src/Server/Transport/Http/Middleware/CorsMiddleware.php new file mode 100644 index 00000000..299fa37b --- /dev/null +++ b/src/Server/Transport/Http/Middleware/CorsMiddleware.php @@ -0,0 +1,86 @@ + + */ +final class CorsMiddleware implements MiddlewareInterface +{ + /** + * @param list $allowedOrigins Origins to allow (empty = no Access-Control-Allow-Origin header). Use ['*'] to allow all origins. + * @param list $allowedMethods HTTP methods for Access-Control-Allow-Methods + * @param list $allowedHeaders Request headers for Access-Control-Allow-Headers + * @param list $exposedHeaders Response headers for Access-Control-Expose-Headers + */ + public function __construct( + private readonly array $allowedOrigins = [], + private readonly array $allowedMethods = ['GET', 'POST', 'DELETE', 'OPTIONS'], + private readonly array $allowedHeaders = ['Accept', 'Authorization', 'Content-Type', 'Last-Event-ID', 'Mcp-Protocol-Version', 'Mcp-Session-Id'], + private readonly array $exposedHeaders = ['Mcp-Session-Id'], + ) { + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + $response = $handler->handle($request); + + $origin = $request->getHeaderLine('Origin'); + $allowedOrigin = $this->resolveAllowedOrigin($origin); + + if (null !== $allowedOrigin && !$response->hasHeader('Access-Control-Allow-Origin')) { + $response = $response->withHeader('Access-Control-Allow-Origin', $allowedOrigin); + } + + if (!$response->hasHeader('Access-Control-Allow-Methods')) { + $response = $response->withHeader('Access-Control-Allow-Methods', implode(', ', $this->allowedMethods)); + } + + if (!$response->hasHeader('Access-Control-Allow-Headers')) { + $response = $response->withHeader('Access-Control-Allow-Headers', implode(', ', $this->allowedHeaders)); + } + + if ([] !== $this->exposedHeaders && !$response->hasHeader('Access-Control-Expose-Headers')) { + $response = $response->withHeader('Access-Control-Expose-Headers', implode(', ', $this->exposedHeaders)); + } + + return $response; + } + + private function resolveAllowedOrigin(string $origin): ?string + { + if ([] === $this->allowedOrigins) { + return null; + } + + if (\in_array('*', $this->allowedOrigins, true)) { + return '*'; + } + + if ('' !== $origin && \in_array($origin, $this->allowedOrigins, true)) { + return $origin; + } + + return null; + } +} diff --git a/src/Server/Transport/StreamableHttpTransport.php b/src/Server/Transport/StreamableHttpTransport.php index 3c9b2f67..e073bfb8 100644 --- a/src/Server/Transport/StreamableHttpTransport.php +++ b/src/Server/Transport/StreamableHttpTransport.php @@ -14,6 +14,7 @@ use Http\Discovery\Psr17FactoryDiscovery; use Mcp\Exception\InvalidArgumentException; use Mcp\Schema\JsonRpc\Error; +use Mcp\Server\Transport\Http\Middleware\CorsMiddleware; use Mcp\Server\Transport\Http\MiddlewareRequestHandler; use Psr\Http\Message\ResponseFactoryInterface; use Psr\Http\Message\ResponseInterface; @@ -32,36 +33,22 @@ class StreamableHttpTransport extends BaseTransport { private const SESSION_HEADER = 'Mcp-Session-Id'; - private const ALLOWED_HEADER = [ - 'Accept', - 'Authorization', - 'Content-Type', - 'Last-Event-ID', - 'Mcp-Protocol-Version', - self::SESSION_HEADER, - ]; - private ResponseFactoryInterface $responseFactory; private StreamFactoryInterface $streamFactory; private ?string $immediateResponse = null; private ?int $immediateStatusCode = null; - /** @var array */ - private array $corsHeaders; - /** @var list */ private array $middleware = []; /** - * @param array $corsHeaders * @param iterable $middleware */ public function __construct( private ServerRequestInterface $request, ?ResponseFactoryInterface $responseFactory = null, ?StreamFactoryInterface $streamFactory = null, - array $corsHeaders = [], ?LoggerInterface $logger = null, iterable $middleware = [], ) { @@ -70,19 +57,21 @@ public function __construct( $this->responseFactory = $responseFactory ?? Psr17FactoryDiscovery::findResponseFactory(); $this->streamFactory = $streamFactory ?? Psr17FactoryDiscovery::findStreamFactory(); - $this->corsHeaders = array_merge([ - 'Access-Control-Allow-Origin' => '*', - 'Access-Control-Allow-Methods' => 'GET, POST, DELETE, OPTIONS', - 'Access-Control-Allow-Headers' => implode(',', self::ALLOWED_HEADER), - 'Access-Control-Expose-Headers' => self::SESSION_HEADER, - ], $corsHeaders); + $hasCorsMiddleware = false; foreach ($middleware as $m) { if (!$m instanceof MiddlewareInterface) { throw new InvalidArgumentException('Streamable HTTP middleware must implement Psr\\Http\\Server\\MiddlewareInterface.'); } + if ($m instanceof CorsMiddleware) { + $hasCorsMiddleware = true; + } $this->middleware[] = $m; } + + if (!$hasCorsMiddleware) { + array_unshift($this->middleware, new CorsMiddleware()); + } } public function send(string $data, array $context): void @@ -98,7 +87,7 @@ public function listen(): ResponseInterface \Closure::fromCallable([$this, 'handleRequest']), ); - return $this->withCorsHeaders($handler->handle($this->request)); + return $handler->handle($this->request); } protected function handleOptionsRequest(): ResponseInterface @@ -273,17 +262,6 @@ protected function createErrorResponse(Error $jsonRpcError, int $statusCode): Re return $response; } - protected function withCorsHeaders(ResponseInterface $response): ResponseInterface - { - foreach ($this->corsHeaders as $name => $value) { - if (!$response->hasHeader($name)) { - $response = $response->withHeader($name, $value); - } - } - - return $response; - } - private function handleRequest(ServerRequestInterface $request): ResponseInterface { $this->request = $request; diff --git a/tests/Unit/Server/Transport/Http/Middleware/CorsMiddlewareTest.php b/tests/Unit/Server/Transport/Http/Middleware/CorsMiddlewareTest.php new file mode 100644 index 00000000..e5a086f5 --- /dev/null +++ b/tests/Unit/Server/Transport/Http/Middleware/CorsMiddlewareTest.php @@ -0,0 +1,148 @@ +factory = new Psr17Factory(); + $this->handler = new class($this->factory) implements RequestHandlerInterface { + public function __construct(private ResponseFactoryInterface $factory) + { + } + + public function handle(ServerRequestInterface $request): ResponseInterface + { + return $this->factory->createResponse(200); + } + }; + } + + #[TestDox('delegates to handler and adds CORS headers')] + public function testDelegatesAndAddsCorsHeaders(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['*']); + $request = $this->factory->createServerRequest('POST', 'https://example.com'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame(200, $response->getStatusCode()); + $this->assertSame('*', $response->getHeaderLine('Access-Control-Allow-Origin')); + $this->assertSame('GET, POST, DELETE, OPTIONS', $response->getHeaderLine('Access-Control-Allow-Methods')); + } + + #[TestDox('default configuration does not set Access-Control-Allow-Origin')] + public function testDefaultDoesNotSetAllowOrigin(): void + { + $middleware = new CorsMiddleware(); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://evil.com'); + + $response = $middleware->process($request, $this->handler); + + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + $this->assertTrue($response->hasHeader('Access-Control-Allow-Methods')); + $this->assertTrue($response->hasHeader('Access-Control-Allow-Headers')); + $this->assertTrue($response->hasHeader('Access-Control-Expose-Headers')); + } + + #[TestDox('wildcard origin sets Access-Control-Allow-Origin to *')] + public function testWildcardOrigin(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['*']); + $request = $this->factory->createServerRequest('POST', 'https://example.com'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame('*', $response->getHeaderLine('Access-Control-Allow-Origin')); + } + + #[TestDox('matching origin is reflected in Access-Control-Allow-Origin')] + public function testMatchingOriginIsReflected(): void + { + $middleware = new CorsMiddleware( + allowedOrigins: ['https://myapp.com', 'https://staging.myapp.com'], + ); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://myapp.com'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame('https://myapp.com', $response->getHeaderLine('Access-Control-Allow-Origin')); + } + + #[TestDox('non-matching origin does not set Access-Control-Allow-Origin')] + public function testNonMatchingOriginIsNotSet(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['https://myapp.com']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://evil.com'); + + $response = $middleware->process($request, $this->handler); + + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + } + + #[TestDox('pre-existing CORS headers on response are not overwritten')] + public function testPreExistingHeadersNotOverwritten(): void + { + $handler = new class($this->factory) implements RequestHandlerInterface { + public function __construct(private ResponseFactoryInterface $factory) + { + } + + public function handle(ServerRequestInterface $request): ResponseInterface + { + return $this->factory->createResponse(200) + ->withHeader('Access-Control-Allow-Origin', 'https://custom.com'); + } + }; + + $middleware = new CorsMiddleware(allowedOrigins: ['*']); + $request = $this->factory->createServerRequest('POST', 'https://example.com'); + + $response = $middleware->process($request, $handler); + + $this->assertSame('https://custom.com', $response->getHeaderLine('Access-Control-Allow-Origin')); + } + + #[TestDox('custom allowed methods and headers are applied')] + public function testCustomMethodsAndHeaders(): void + { + $middleware = new CorsMiddleware( + allowedOrigins: ['*'], + allowedMethods: ['POST'], + allowedHeaders: ['Content-Type'], + exposedHeaders: [], + ); + $request = $this->factory->createServerRequest('POST', 'https://example.com'); + + $response = $middleware->process($request, $this->handler); + + $this->assertSame('POST', $response->getHeaderLine('Access-Control-Allow-Methods')); + $this->assertSame('Content-Type', $response->getHeaderLine('Access-Control-Allow-Headers')); + $this->assertFalse($response->hasHeader('Access-Control-Expose-Headers')); + } +} diff --git a/tests/Unit/Server/Transport/StreamableHttpTransportTest.php b/tests/Unit/Server/Transport/StreamableHttpTransportTest.php index 7d9cd484..27121a45 100644 --- a/tests/Unit/Server/Transport/StreamableHttpTransportTest.php +++ b/tests/Unit/Server/Transport/StreamableHttpTransportTest.php @@ -11,6 +11,7 @@ namespace Mcp\Tests\Unit\Server\Transport; +use Mcp\Server\Transport\Http\Middleware\CorsMiddleware; use Mcp\Server\Transport\StreamableHttpTransport; use Nyholm\Psr7\Factory\Psr17Factory; use PHPUnit\Framework\Attributes\DataProvider; @@ -26,17 +27,16 @@ final class StreamableHttpTransportTest extends TestCase { public static function corsHeaderProvider(): iterable { - yield 'GET (middleware returns 401)' => ['GET', false, 401]; yield 'POST (middleware returns 401)' => ['POST', false, 401]; yield 'DELETE (middleware returns 401)' => ['DELETE', false, 401]; yield 'OPTIONS (middleware delegates -> transport handles preflight)' => ['OPTIONS', true, 204]; - yield 'GET (middleware delegates -> transport handles preflight)' => ['GET', true, 405]; - yield 'POST (middleware delegates -> transport handles preflight)' => ['POST', true, 202]; - yield 'DELETE (middleware delegates -> transport handles preflight)' => ['DELETE', true, 400]; + yield 'GET (middleware delegates -> transport returns 405)' => ['GET', true, 405]; + yield 'POST (middleware delegates -> transport returns 202)' => ['POST', true, 202]; + yield 'DELETE (middleware delegates -> transport returns 400)' => ['DELETE', true, 400]; } #[DataProvider('corsHeaderProvider')] - #[TestDox('CORS headers are always applied')] + #[TestDox('CORS headers are applied by default CorsMiddleware')] public function testCorsHeader(string $method, bool $middlewareDelegatesToTransport, int $expectedStatusCode): void { $factory = new Psr17Factory(); @@ -64,7 +64,6 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface $request, $factory, $factory, - [], null, [$middleware], ); @@ -72,21 +71,34 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface $response = $transport->listen(); $this->assertSame($expectedStatusCode, $response->getStatusCode(), $response->getBody()->getContents()); - $this->assertTrue($response->hasHeader('Access-Control-Allow-Origin')); $this->assertTrue($response->hasHeader('Access-Control-Allow-Methods')); $this->assertTrue($response->hasHeader('Access-Control-Allow-Headers')); $this->assertTrue($response->hasHeader('Access-Control-Expose-Headers')); + // Default CorsMiddleware has no allowed origins, so no Access-Control-Allow-Origin + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + } + + #[TestDox('CORS headers include Access-Control-Allow-Origin when CorsMiddleware is configured with origins')] + public function testCorsHeaderWithConfiguredOrigins(): void + { + $factory = new Psr17Factory(); + $request = $factory->createServerRequest('POST', 'https://example.com'); + + $transport = new StreamableHttpTransport( + $request, + $factory, + $factory, + null, + [new CorsMiddleware(allowedOrigins: ['*'])], + ); + + $response = $transport->listen(); $this->assertSame('*', $response->getHeaderLine('Access-Control-Allow-Origin')); $this->assertSame('GET, POST, DELETE, OPTIONS', $response->getHeaderLine('Access-Control-Allow-Methods')); - $this->assertSame( - 'Accept,Authorization,Content-Type,Last-Event-ID,Mcp-Protocol-Version,Mcp-Session-Id', - $response->getHeaderLine('Access-Control-Allow-Headers') - ); - $this->assertSame('Mcp-Session-Id', $response->getHeaderLine('Access-Control-Expose-Headers')); } - #[TestDox('transport replaces existing CORS headers on the response')] + #[TestDox('middleware can set CORS headers that CorsMiddleware will not overwrite')] public function testCorsHeadersAreReplacedWhenAlreadyPresent(): void { $factory = new Psr17Factory(); @@ -108,7 +120,6 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface $request, $factory, $factory, - [], null, [$middleware], ); @@ -119,18 +130,13 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface $this->assertSame('https://another.com', $response->getHeaderLine('Access-Control-Allow-Origin')); $this->assertSame('GET, POST, DELETE, OPTIONS', $response->getHeaderLine('Access-Control-Allow-Methods')); - $this->assertSame( - 'Accept,Authorization,Content-Type,Last-Event-ID,Mcp-Protocol-Version,Mcp-Session-Id', - $response->getHeaderLine('Access-Control-Allow-Headers') - ); - $this->assertSame('Mcp-Session-Id', $response->getHeaderLine('Access-Control-Expose-Headers')); } #[TestDox('middleware runs before transport handles the request')] public function testMiddlewareRunsBeforeTransportHandlesRequest(): void { $factory = new Psr17Factory(); - $request = $factory->createServerRequest('OPTIONS', 'https://example.com'); + $request = $factory->createServerRequest('POST', 'https://example.com'); $state = new \stdClass(); $state->called = false; @@ -151,7 +157,6 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface $request, $factory, $factory, - [], null, [$middleware], ); @@ -159,6 +164,6 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface $response = $transport->listen(); $this->assertTrue($state->called); - $this->assertSame(204, $response->getStatusCode()); + $this->assertSame(202, $response->getStatusCode()); } }