diff --git a/.changeset/clever-nights-enjoy.md b/.changeset/clever-nights-enjoy.md new file mode 100644 index 000000000..dd770caef --- /dev/null +++ b/.changeset/clever-nights-enjoy.md @@ -0,0 +1,5 @@ +--- +"@sei-js/mcp-server": patch +--- + +Block wallet mode on HTTP transports to prevent CORS-based attacks diff --git a/packages/mcp-server/src/server/args.ts b/packages/mcp-server/src/server/args.ts index d84e07eb6..6074f662f 100644 --- a/packages/mcp-server/src/server/args.ts +++ b/packages/mcp-server/src/server/args.ts @@ -97,7 +97,7 @@ Examples: Streamable HTTP transport with custom path: $ SERVER_TRANSPORT=streamable-http SERVER_PORT=8080 SERVER_PATH=/api/mcp npx ${packageInfo.name} - With wallet enabled: + With wallet enabled (STDIO transport only): $ WALLET_MODE=private-key PRIVATE_KEY=your_private_key_here npx ${packageInfo.name} Environment Variables: @@ -110,6 +110,10 @@ Environment Variables: MAINNET_RPC_URL Custom RPC URL for Sei mainnet (optional) TESTNET_RPC_URL Custom RPC URL for Sei testnet (optional) DEVNET_RPC_URL Custom RPC URL for Sei devnet (optional) + +Security Note: + Wallet mode is only supported with stdio transport. HTTP transports block + wallet mode to prevent cross-origin attacks from malicious websites. `); program.parse(); diff --git a/packages/mcp-server/src/server/transport/factory.ts b/packages/mcp-server/src/server/transport/factory.ts index b9ca1b2ae..ca1841757 100644 --- a/packages/mcp-server/src/server/transport/factory.ts +++ b/packages/mcp-server/src/server/transport/factory.ts @@ -9,14 +9,12 @@ export const createTransport = (config: TransportConfig): McpTransport => { return new StdioTransport(); case 'streamable-http': - return new StreamableHttpTransport(config.port, config.host, config.path); + return new StreamableHttpTransport(config.port, config.host, config.path, config.walletMode); case 'http-sse': - return new HttpSseTransport(config.port, config.host, config.path); + return new HttpSseTransport(config.port, config.host, config.path, config.walletMode); default: throw new Error(`Unsupported transport mode: ${config.mode}`); } }; - - diff --git a/packages/mcp-server/src/server/transport/http-sse.ts b/packages/mcp-server/src/server/transport/http-sse.ts index 4cdc5b2d3..31c78c5ff 100644 --- a/packages/mcp-server/src/server/transport/http-sse.ts +++ b/packages/mcp-server/src/server/transport/http-sse.ts @@ -1,9 +1,9 @@ import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; -import cors from 'cors'; import express, { type Request, type Response } from 'express'; import type { Server } from 'node:http'; -import type { McpTransport } from './types.js'; +import type { McpTransport, WalletMode } from './types.js'; +import { createCorsMiddleware, validateSecurityConfig } from './security.js'; export class HttpSseTransport implements McpTransport { readonly mode = 'http-sse' as const; @@ -11,12 +11,15 @@ export class HttpSseTransport implements McpTransport { private httpServer: Server | null = null; private connections = new Map(); private mcpServer: McpServer | null = null; + private walletMode: WalletMode; constructor( private port: number, private host: string, - private path: string + private path: string, + walletMode: WalletMode = 'disabled' ) { + this.walletMode = walletMode; this.app = express(); this.setupMiddleware(); this.setupRoutes(); @@ -24,16 +27,9 @@ export class HttpSseTransport implements McpTransport { private setupMiddleware() { this.app.use(express.json()); - this.app.use( - cors({ - origin: '*', - methods: ['GET', 'POST', 'OPTIONS'], - allowedHeaders: ['Content-Type', 'Authorization'], - credentials: true, - exposedHeaders: ['Content-Type', 'Access-Control-Allow-Origin'] - }) - ); - this.app.options('*', cors()); + + // Secure CORS - no cross-origin allowed by default + this.app.use(createCorsMiddleware()); } private setupRoutes() { @@ -82,6 +78,9 @@ export class HttpSseTransport implements McpTransport { } async start(server: McpServer): Promise { + // Block wallet mode on HTTP transports + validateSecurityConfig(this.mode, this.walletMode); + this.mcpServer = server; return new Promise((resolve, reject) => { this.httpServer = this.app.listen(this.port, this.host, () => { diff --git a/packages/mcp-server/src/server/transport/security.ts b/packages/mcp-server/src/server/transport/security.ts new file mode 100644 index 000000000..9c3355344 --- /dev/null +++ b/packages/mcp-server/src/server/transport/security.ts @@ -0,0 +1,45 @@ +import type { Request, Response, NextFunction, RequestHandler } from 'express'; +import type { TransportMode, WalletMode } from './types.js'; + +/** + * Creates CORS middleware with secure defaults. + * By default, no CORS headers are set (same-origin only). + */ +export function createCorsMiddleware(): RequestHandler { + return (req: Request, res: Response, next: NextFunction) => { + // Handle preflight - reject cross-origin by default + if (req.method === 'OPTIONS') { + return res.sendStatus(204); + } + next(); + }; +} + +/** + * Validates that wallet mode is not used with HTTP transports + * Exits the process if unsafe configuration detected + */ +export function validateSecurityConfig( + transportMode: TransportMode, + walletMode: WalletMode +): void { + const isHttpTransport = transportMode === 'streamable-http' || transportMode === 'http-sse'; + const isWalletEnabled = walletMode !== 'disabled'; + + if (isHttpTransport && isWalletEnabled) { + console.error(''); + console.error('╔════════════════════════════════════════════════════════════════╗'); + console.error('║ SECURITY ERROR ║'); + console.error('╠════════════════════════════════════════════════════════════════╣'); + console.error('║ Wallet mode cannot be used with HTTP transports! ║'); + console.error('║ ║'); + console.error('║ HTTP transports expose the server to cross-origin requests, ║'); + console.error('║ allowing malicious websites to steal funds from your wallet. ║'); + console.error('║ ║'); + console.error('║ Use stdio transport instead (default, works with Claude): ║'); + console.error('║ $ WALLET_MODE=private-key PRIVATE_KEY=... npx @sei-js/mcp-server'); + console.error('╚════════════════════════════════════════════════════════════════╝'); + console.error(''); + process.exit(1); + } +} diff --git a/packages/mcp-server/src/server/transport/streamable-http.ts b/packages/mcp-server/src/server/transport/streamable-http.ts index 43b254c73..2e362a822 100644 --- a/packages/mcp-server/src/server/transport/streamable-http.ts +++ b/packages/mcp-server/src/server/transport/streamable-http.ts @@ -2,38 +2,42 @@ import express, { type Request, type Response } from 'express'; import type { Server } from 'node:http'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; -import type { McpTransport, TransportMode } from './types.js'; -import {getServer} from '../server.js'; +import type { McpTransport, TransportMode, WalletMode } from './types.js'; +import { createCorsMiddleware, validateSecurityConfig } from './security.js'; +import { getServer } from '../server.js'; export class StreamableHttpTransport implements McpTransport { public readonly mode: TransportMode = 'streamable-http'; private port: number; private host: string; private path: string; + private walletMode: WalletMode; private app?: express.Express; private server?: Server; - constructor(port = 8080, host = 'localhost', path = '/mcp') { + constructor(port = 8080, host = 'localhost', path = '/mcp', walletMode: WalletMode = 'disabled') { this.port = port; this.host = host; this.path = path; + this.walletMode = walletMode; } // Note: server parameter ignored for now as this is a stateless server // TODO: allow creating both stateless and stateful remote MCP servers async start(_server: McpServer): Promise { + // Block wallet mode on HTTP transports + validateSecurityConfig(this.mode, this.walletMode); + this.app = express(); this.app.use(express.json()); - this.app.use((req, res, next) => { - res.header('Access-Control-Allow-Origin', '*'); - res.header('Access-Control-Allow-Methods', 'POST, OPTIONS'); - res.header('Access-Control-Allow-Headers', 'Content-Type'); - if (req.method === 'OPTIONS') { - return res.sendStatus(200); - } - next(); - }); + + // Secure CORS - no cross-origin allowed by default + this.app.use(createCorsMiddleware()); + // Health check endpoint + this.app.get('/health', (_req: Request, res: Response) => { + res.json({ status: 'ok', timestamp: new Date().toISOString() }); + }); this.app.post(this.path, async (req: Request, res: Response) => { try { diff --git a/packages/mcp-server/src/server/transport/types.ts b/packages/mcp-server/src/server/transport/types.ts index 7b3c7acf9..111f4194d 100644 --- a/packages/mcp-server/src/server/transport/types.ts +++ b/packages/mcp-server/src/server/transport/types.ts @@ -1,6 +1,8 @@ import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import type { WalletMode } from '../../core/config.js'; export type TransportMode = 'stdio' | 'streamable-http' | 'http-sse'; + export interface McpTransport { start(server: McpServer): Promise; stop(): Promise; @@ -9,8 +11,11 @@ export interface McpTransport { export interface TransportConfig { mode: TransportMode; - walletMode: 'disabled' | 'private-key'; + walletMode: WalletMode; port: number; // Required for HTTP-based transports host: string; // Required for HTTP-based transports path: string; // Required for HTTP-based transports } + +// Re-export WalletMode for convenience +export type { WalletMode }; diff --git a/packages/mcp-server/src/tests/server/transport/factory.test.ts b/packages/mcp-server/src/tests/server/transport/factory.test.ts index 5c598791b..db788acff 100644 --- a/packages/mcp-server/src/tests/server/transport/factory.test.ts +++ b/packages/mcp-server/src/tests/server/transport/factory.test.ts @@ -70,7 +70,7 @@ describe('Transport Factory', () => { const transport = createTransport(config); - expect(StreamableHttpTransport).toHaveBeenCalledWith(8080, '0.0.0.0', '/api/mcp'); + expect(StreamableHttpTransport).toHaveBeenCalledWith(8080, '0.0.0.0', '/api/mcp', 'private-key'); expect(transport).toBe(mockStreamableInstance); }); @@ -88,7 +88,7 @@ describe('Transport Factory', () => { const transport = createTransport(config); - expect(HttpSseTransport).toHaveBeenCalledWith(9000, '127.0.0.1', '/sse'); + expect(HttpSseTransport).toHaveBeenCalledWith(9000, '127.0.0.1', '/sse', 'disabled'); expect(transport).toBe(mockSseInstance); }); @@ -123,7 +123,7 @@ describe('Transport Factory', () => { const transport = createTransport(config); - expect(StreamableHttpTransport).toHaveBeenCalledWith(params.port, params.host, params.path); + expect(StreamableHttpTransport).toHaveBeenCalledWith(params.port, params.host, params.path, 'disabled'); expect(transport).toBe(mockInstance); jest.clearAllMocks(); @@ -145,7 +145,7 @@ describe('Transport Factory', () => { const transport1 = createTransport(config1); - expect(HttpSseTransport).toHaveBeenCalledWith(1, '::1', '/'); + expect(HttpSseTransport).toHaveBeenCalledWith(1, '::1', '/', 'private-key'); expect(transport1).toBe(mockInstance1); jest.clearAllMocks(); @@ -164,7 +164,7 @@ describe('Transport Factory', () => { const transport2 = createTransport(config2); - expect(StreamableHttpTransport).toHaveBeenCalledWith(65535, '0.0.0.0', '/very/long/path/to/test/edge/cases'); + expect(StreamableHttpTransport).toHaveBeenCalledWith(65535, '0.0.0.0', '/very/long/path/to/test/edge/cases', 'disabled'); expect(transport2).toBe(mockInstance2); }); }); diff --git a/packages/mcp-server/src/tests/server/transport/http-sse.test.ts b/packages/mcp-server/src/tests/server/transport/http-sse.test.ts index 2a4349974..243e22bca 100644 --- a/packages/mcp-server/src/tests/server/transport/http-sse.test.ts +++ b/packages/mcp-server/src/tests/server/transport/http-sse.test.ts @@ -16,7 +16,10 @@ jest.mock('express', () => { return express; }); -jest.mock('cors', () => jest.fn(() => 'cors-middleware')); +jest.mock('../../../server/transport/security.js', () => ({ + createCorsMiddleware: jest.fn(() => 'cors-middleware'), + validateSecurityConfig: jest.fn() +})); jest.mock('@modelcontextprotocol/sdk/server/sse.js', () => ({ SSEServerTransport: jest.fn() @@ -27,7 +30,8 @@ describe('HttpSseTransport', () => { let mockExpress: jest.MockedFunction; let mockApp: any; let mockServer: any; - let mockCors: jest.MockedFunction; + let mockCreateCorsMiddleware: jest.MockedFunction; + let mockValidateSecurityConfig: jest.MockedFunction; let mockSSEServerTransport: jest.MockedFunction; let mockTransport: any; let mockMcpServer: any; @@ -38,11 +42,12 @@ describe('HttpSseTransport', () => { // Import mocked modules const expressModule = await import('express'); - const corsModule = await import('cors'); + const securityModule = await import('../../../server/transport/security.js'); const { SSEServerTransport } = await import('@modelcontextprotocol/sdk/server/sse.js'); mockExpress = expressModule.default as jest.MockedFunction; - mockCors = corsModule.default as jest.MockedFunction; + mockCreateCorsMiddleware = securityModule.createCorsMiddleware as jest.MockedFunction; + mockValidateSecurityConfig = securityModule.validateSecurityConfig as jest.MockedFunction; mockSSEServerTransport = SSEServerTransport as jest.MockedFunction; // Setup mock objects @@ -70,7 +75,7 @@ describe('HttpSseTransport', () => { // Configure mocks mockExpress.mockReturnValue(mockApp); mockExpress.json = jest.fn().mockReturnValue('json-middleware'); - mockCors.mockReturnValue('cors-middleware'); + mockCreateCorsMiddleware.mockReturnValue('cors-middleware'); mockSSEServerTransport.mockImplementation(() => mockTransport); // Import the class after mocks are set up @@ -96,14 +101,8 @@ describe('HttpSseTransport', () => { expect(mockExpress).toHaveBeenCalled(); expect(mockApp.use).toHaveBeenCalledWith('json-middleware'); - expect(mockCors).toHaveBeenCalledWith({ - origin: '*', - methods: ['GET', 'POST', 'OPTIONS'], - allowedHeaders: ['Content-Type', 'Authorization'], - credentials: true, - exposedHeaders: ['Content-Type', 'Access-Control-Allow-Origin'] - }); - expect(mockApp.options).toHaveBeenCalledWith('*', 'cors-middleware'); + expect(mockCreateCorsMiddleware).toHaveBeenCalled(); + expect(mockApp.use).toHaveBeenCalledWith('cors-middleware'); expect(mockApp.get).toHaveBeenCalledWith('/health', expect.any(Function)); expect(mockApp.get).toHaveBeenCalledWith('/sse', expect.any(Function)); expect(mockApp.post).toHaveBeenCalledWith('/sse/message', expect.any(Function)); diff --git a/packages/mcp-server/src/tests/server/transport/security.test.ts b/packages/mcp-server/src/tests/server/transport/security.test.ts new file mode 100644 index 000000000..24c4cbbfc --- /dev/null +++ b/packages/mcp-server/src/tests/server/transport/security.test.ts @@ -0,0 +1,175 @@ +import { jest } from '@jest/globals'; +import type { Request, Response, NextFunction } from 'express'; + +describe('Security Module', () => { + let createCorsMiddleware: typeof import('../../../server/transport/security.js').createCorsMiddleware; + let validateSecurityConfig: typeof import('../../../server/transport/security.js').validateSecurityConfig; + let consoleErrorSpy: jest.SpyInstance; + let processExitSpy: jest.SpyInstance; + + beforeEach(async () => { + jest.clearAllMocks(); + + // Spy on console.error + consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation(); + + // Spy on process.exit to prevent actual exit + processExitSpy = jest.spyOn(process, 'exit').mockImplementation((code?: number | string | null | undefined) => { + throw new Error(`process.exit called with code ${code}`); + }); + + // Import the module + const securityModule = await import('../../../server/transport/security.js'); + createCorsMiddleware = securityModule.createCorsMiddleware; + validateSecurityConfig = securityModule.validateSecurityConfig; + }); + + afterEach(() => { + consoleErrorSpy.mockRestore(); + processExitSpy.mockRestore(); + }); + + describe('createCorsMiddleware', () => { + it('should return a middleware function', () => { + const middleware = createCorsMiddleware(); + expect(typeof middleware).toBe('function'); + }); + + it('should return 204 for OPTIONS preflight requests', () => { + const middleware = createCorsMiddleware(); + + const mockReq = { method: 'OPTIONS' } as Request; + const mockRes = { + sendStatus: jest.fn().mockReturnThis() + } as unknown as Response; + const mockNext = jest.fn() as NextFunction; + + middleware(mockReq, mockRes, mockNext); + + expect(mockRes.sendStatus).toHaveBeenCalledWith(204); + expect(mockNext).not.toHaveBeenCalled(); + }); + + it('should call next() for non-OPTIONS requests', () => { + const middleware = createCorsMiddleware(); + + const mockReq = { method: 'POST' } as Request; + const mockRes = { + sendStatus: jest.fn() + } as unknown as Response; + const mockNext = jest.fn() as NextFunction; + + middleware(mockReq, mockRes, mockNext); + + expect(mockRes.sendStatus).not.toHaveBeenCalled(); + expect(mockNext).toHaveBeenCalled(); + }); + + it('should call next() for GET requests', () => { + const middleware = createCorsMiddleware(); + + const mockReq = { method: 'GET' } as Request; + const mockRes = {} as Response; + const mockNext = jest.fn() as NextFunction; + + middleware(mockReq, mockRes, mockNext); + + expect(mockNext).toHaveBeenCalled(); + }); + }); + + describe('validateSecurityConfig', () => { + describe('safe configurations', () => { + it('should allow stdio transport with wallet enabled', () => { + expect(() => { + validateSecurityConfig('stdio', 'private-key'); + }).not.toThrow(); + + expect(processExitSpy).not.toHaveBeenCalled(); + }); + + it('should allow streamable-http transport with wallet disabled', () => { + expect(() => { + validateSecurityConfig('streamable-http', 'disabled'); + }).not.toThrow(); + + expect(processExitSpy).not.toHaveBeenCalled(); + }); + + it('should allow http-sse transport with wallet disabled', () => { + expect(() => { + validateSecurityConfig('http-sse', 'disabled'); + }).not.toThrow(); + + expect(processExitSpy).not.toHaveBeenCalled(); + }); + + it('should allow stdio transport with wallet disabled', () => { + expect(() => { + validateSecurityConfig('stdio', 'disabled'); + }).not.toThrow(); + + expect(processExitSpy).not.toHaveBeenCalled(); + }); + }); + + describe('unsafe configurations', () => { + it('should exit with code 1 for streamable-http with wallet enabled', () => { + expect(() => { + validateSecurityConfig('streamable-http', 'private-key'); + }).toThrow('process.exit called with code 1'); + + expect(processExitSpy).toHaveBeenCalledWith(1); + expect(consoleErrorSpy).toHaveBeenCalled(); + }); + + it('should exit with code 1 for http-sse with wallet enabled', () => { + expect(() => { + validateSecurityConfig('http-sse', 'private-key'); + }).toThrow('process.exit called with code 1'); + + expect(processExitSpy).toHaveBeenCalledWith(1); + expect(consoleErrorSpy).toHaveBeenCalled(); + }); + + it('should log security error message for unsafe config', () => { + expect(() => { + validateSecurityConfig('streamable-http', 'private-key'); + }).toThrow(); + + // Verify error messages were logged + expect(consoleErrorSpy).toHaveBeenCalledWith(expect.stringContaining('SECURITY ERROR')); + expect(consoleErrorSpy).toHaveBeenCalledWith(expect.stringContaining('Wallet mode cannot be used with HTTP transports')); + }); + }); + + describe('wallet mode variations', () => { + it('should block private-key wallet mode on streamable-http', () => { + expect(() => { + validateSecurityConfig('streamable-http', 'private-key'); + }).toThrow('process.exit called with code 1'); + + expect(processExitSpy).toHaveBeenCalledWith(1); + }); + + it('should block private-key wallet mode on http-sse', () => { + expect(() => { + validateSecurityConfig('http-sse', 'private-key'); + }).toThrow('process.exit called with code 1'); + + expect(processExitSpy).toHaveBeenCalledWith(1); + }); + + it('should allow disabled wallet mode on all transports', () => { + expect(() => { + validateSecurityConfig('stdio', 'disabled'); + validateSecurityConfig('streamable-http', 'disabled'); + validateSecurityConfig('http-sse', 'disabled'); + }).not.toThrow(); + + expect(processExitSpy).not.toHaveBeenCalled(); + }); + }); + }); +}); + diff --git a/packages/mcp-server/src/tests/server/transport/streamable-http.test.ts b/packages/mcp-server/src/tests/server/transport/streamable-http.test.ts index 8ddf9418e..a4ebe4062 100644 --- a/packages/mcp-server/src/tests/server/transport/streamable-http.test.ts +++ b/packages/mcp-server/src/tests/server/transport/streamable-http.test.ts @@ -6,6 +6,7 @@ import type { Server } from 'node:http'; jest.mock('express', () => { const mockApp = { use: jest.fn(), + get: jest.fn(), post: jest.fn(), listen: jest.fn() }; @@ -14,6 +15,11 @@ jest.mock('express', () => { return express; }); +jest.mock('../../../server/transport/security.js', () => ({ + createCorsMiddleware: jest.fn(() => 'cors-middleware'), + validateSecurityConfig: jest.fn() +})); + jest.mock('@modelcontextprotocol/sdk/server/streamableHttp.js', () => ({ StreamableHTTPServerTransport: jest.fn() })); @@ -47,6 +53,7 @@ describe('StreamableHttpTransport', () => { // Setup mock objects mockApp = { use: jest.fn(), + get: jest.fn(), post: jest.fn(), listen: jest.fn() }; @@ -63,9 +70,14 @@ describe('StreamableHttpTransport', () => { close: jest.fn() }; + // Import security mock + const securityModule = await import('../../../server/transport/security.js'); + const mockCreateCorsMiddleware = securityModule.createCorsMiddleware as jest.MockedFunction; + mockCreateCorsMiddleware.mockReturnValue('cors-middleware'); + // Setup default mocks mockExpress.mockReturnValue(mockApp); - mockExpress.json = jest.fn(); + mockExpress.json = jest.fn().mockReturnValue('json-middleware'); mockApp.listen.mockReturnValue(mockServer); mockGetServer.mockResolvedValue(mockMcpServer); (StreamableHTTPServerTransport as jest.Mock).mockImplementation(() => mockStreamableTransport); @@ -106,10 +118,10 @@ describe('StreamableHttpTransport', () => { // Verify express setup expect(mockExpress).toHaveBeenCalled(); expect(mockExpress.json).toHaveBeenCalled(); - expect(mockApp.use).toHaveBeenCalledWith(mockExpress.json()); + expect(mockApp.use).toHaveBeenCalledWith('json-middleware'); // Verify CORS middleware setup - expect(mockApp.use).toHaveBeenCalledWith(expect.any(Function)); + expect(mockApp.use).toHaveBeenCalledWith('cors-middleware'); // Verify POST route setup expect(mockApp.post).toHaveBeenCalledWith('/mcp', expect.any(Function)); @@ -121,49 +133,51 @@ describe('StreamableHttpTransport', () => { expect(mockServer.on).toHaveBeenCalledWith('error', expect.any(Function)); }); - it('should handle CORS preflight requests', () => { + it('should setup CORS middleware', async () => { const transport = new StreamableHttpTransport(); - transport.start(); + await transport.start({ mock: 'server' }); - // Get the middleware function - const middlewareFunction = mockApp.use.mock.calls[1][0]; + // Verify CORS middleware was set up + expect(mockApp.use).toHaveBeenCalledWith('cors-middleware'); + }); - // Mock request and response for OPTIONS request - const mockReq = { method: 'OPTIONS' } as any as Request; - const mockRes = { - header: jest.fn(), - sendStatus: jest.fn() - } as any as Response; - mockRes.sendStatus.mockReturnValue(mockRes); - const mockNext = jest.fn(); - - // Call the middleware - middlewareFunction(mockReq, mockRes, mockNext); - - // Verify CORS headers are set - expect(mockRes.header).toHaveBeenCalledWith('Access-Control-Allow-Origin', '*'); - expect(mockRes.header).toHaveBeenCalledWith('Access-Control-Allow-Methods', 'POST, OPTIONS'); - expect(mockRes.header).toHaveBeenCalledWith('Access-Control-Allow-Headers', 'Content-Type'); - expect(mockRes.sendStatus).toHaveBeenCalledWith(200); - expect(mockNext).not.toHaveBeenCalled(); + it('should call validateSecurityConfig on start', async () => { + const securityModule = await import('../../../server/transport/security.js'); + const mockValidateSecurityConfig = securityModule.validateSecurityConfig as jest.MockedFunction; + + const transport = new StreamableHttpTransport(8080, 'localhost', '/mcp', 'disabled'); + await transport.start({ mock: 'server' }); + + expect(mockValidateSecurityConfig).toHaveBeenCalledWith('streamable-http', 'disabled'); }); - it('should call next for non-OPTIONS requests', async () => { + it('should setup health endpoint', async () => { const transport = new StreamableHttpTransport(); await transport.start({ mock: 'server' }); - // Get the CORS middleware function - const corsMiddleware = mockApp.use.mock.calls.find(call => - typeof call[0] === 'function' && call[0].length === 3 - )?.[0]; + // Verify health endpoint was set up + expect(mockApp.get).toHaveBeenCalledWith('/health', expect.any(Function)); + }); - const mockReq = { method: 'POST' } as Request; - const mockRes = { header: jest.fn() } as any as Response; - const mockNext = jest.fn(); + it('should return health status from health endpoint', async () => { + const transport = new StreamableHttpTransport(); + await transport.start({ mock: 'server' }); - corsMiddleware(mockReq, mockRes, mockNext); + // Get the health endpoint handler + const healthHandler = mockApp.get.mock.calls.find( + (call: any[]) => call[0] === '/health' + )?.[1]; + expect(healthHandler).toBeDefined(); - expect(mockNext).toHaveBeenCalled(); + const mockReq = {}; + const mockRes = { json: jest.fn() }; + + healthHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalledWith({ + status: 'ok', + timestamp: expect.any(String) + }); }); it('should handle successful MCP requests', async () => {