diff --git a/README.md b/README.md index a45d333..cd2ac5d 100644 --- a/README.md +++ b/README.md @@ -823,8 +823,7 @@ This enables servers to leverage the client's LLM capabilities without needing d **Using Sampling in Tools:** Tools that accept a `server_context:` parameter can call `create_sampling_message` on it. -The request is automatically routed to the correct client session. -Set `server.server_context = server` so that `server_context.create_sampling_message` delegates to the server: +The request is automatically routed to the correct client session: ```ruby class SummarizeTool < MCP::Tool @@ -852,7 +851,6 @@ class SummarizeTool < MCP::Tool end server = MCP::Server.new(name: "my_server", tools: [SummarizeTool]) -server.server_context = server ``` **Parameters:** @@ -873,86 +871,8 @@ Optional: - `tools:` (Array) - Tools available to the LLM (requires `sampling.tools` capability) - `tool_choice:` (Hash) - Tool selection mode (e.g., `{ mode: "auto" }`) -**Direct Usage:** - -`Server#create_sampling_message` can also be called directly outside of tools: - -```ruby -result = server.create_sampling_message( - messages: [ - { role: "user", content: { type: "text", text: "What is the capital of France?" } } - ], - max_tokens: 100, - system_prompt: "You are a helpful assistant.", - temperature: 0.7 -) -``` - -Result contains the LLM response: - -```ruby -{ - role: "assistant", - content: { type: "text", text: "The capital of France is Paris." }, - model: "claude-3-sonnet-20240307", - stopReason: "endTurn" -} -``` - -For multi-client transports (e.g., `StreamableHTTPTransport`), use `server_context.create_sampling_message` inside tools -to route the request to the correct client session. - -**Tool Use in Sampling:** - -When tools are provided in a sampling request, the LLM can call them during generation. -The server must handle tool calls and continue the conversation with tool results: - -```ruby -result = server.create_sampling_message( - messages: [ - { role: "user", content: { type: "text", text: "What's the weather in Paris?" } } - ], - max_tokens: 1000, - tools: [ - { - name: "get_weather", - description: "Get weather for a city", - inputSchema: { - type: "object", - properties: { city: { type: "string" } }, - required: ["city"] - } - } - ], - tool_choice: { mode: "auto" } -) - -if result[:stopReason] == "toolUse" - tool_results = result[:content].map do |tool_use| - weather_data = get_weather(tool_use[:input][:city]) - - { - type: "tool_result", - toolUseId: tool_use[:id], - content: [{ type: "text", text: weather_data.to_json }] - } - end - - final_result = server.create_sampling_message( - messages: [ - { role: "user", content: { type: "text", text: "What's the weather in Paris?" } }, - { role: "assistant", content: result[:content] }, - { role: "user", content: tool_results } - ], - max_tokens: 1000, - tools: [...] - ) -end -``` - **Error Handling:** -- Raises `RuntimeError` if transport is not set - Raises `RuntimeError` if client does not support `sampling` capability - Raises `RuntimeError` if `tools` are used but client lacks `sampling.tools` capability - Raises `StandardError` if client returns an error response diff --git a/lib/mcp/server.rb b/lib/mcp/server.rb index d085b8b..0fb4876 100644 --- a/lib/mcp/server.rb +++ b/lib/mcp/server.rb @@ -206,44 +206,6 @@ def notify_log_message(data:, level:, logger: nil) report_exception(e, { notification: "log_message" }) end - # Sends a `sampling/createMessage` request to the client. - # For single-client transports (e.g., `StdioTransport`). For multi-client transports - # (e.g., `StreamableHTTPTransport`), use `ServerSession#create_sampling_message` instead - # to ensure the request is routed to the correct client. - def create_sampling_message( - messages:, - max_tokens:, - system_prompt: nil, - model_preferences: nil, - include_context: nil, - temperature: nil, - stop_sequences: nil, - metadata: nil, - tools: nil, - tool_choice: nil, - related_request_id: nil - ) - unless @transport - raise "Cannot send sampling request without a transport." - end - - params = build_sampling_params( - @client_capabilities, - messages: messages, - max_tokens: max_tokens, - system_prompt: system_prompt, - model_preferences: model_preferences, - include_context: include_context, - temperature: temperature, - stop_sequences: stop_sequences, - metadata: metadata, - tools: tools, - tool_choice: tool_choice, - ) - - @transport.send_request(Methods::SAMPLING_CREATE_MESSAGE, params) - end - # Sets a custom handler for `resources/read` requests. # The block receives the parsed request params and should return resource # contents. The return value is set as the `contents` field of the response. diff --git a/test/mcp/server_sampling_test.rb b/test/mcp/server_sampling_test.rb index 1526f5f..ffb3539 100644 --- a/test/mcp/server_sampling_test.rb +++ b/test/mcp/server_sampling_test.rb @@ -42,21 +42,12 @@ def close; end @mock_transport = MockTransport.new(@server) - # Simulate client initialization with sampling capability. - @server.handle({ - jsonrpc: "2.0", - method: "initialize", - id: 1, - params: { - protocolVersion: "2025-11-25", - capabilities: { sampling: {} }, - clientInfo: { name: "test-client", version: "1.0" }, - }, - }) + @session = ServerSession.new(server: @server, transport: @mock_transport) + @session.store_client_info(client: { name: "test-client" }, capabilities: { sampling: {} }) end test "create_sampling_message sends request with required params" do - result = @server.create_sampling_message( + result = @session.create_sampling_message( messages: [{ role: "user", content: { type: "text", text: "Hello" } }], max_tokens: 100, ) @@ -72,7 +63,7 @@ def close; end end test "create_sampling_message sends all optional params" do - @server.create_sampling_message( + @session.create_sampling_message( messages: [{ role: "user", content: { type: "text", text: "Hello" } }], max_tokens: 100, system_prompt: "You are helpful", @@ -94,46 +85,11 @@ def close; end assert_equal({ key: "value" }, params[:metadata]) end - test "create_sampling_message raises error when transport is not set" do - server_without_transport = Server.new(name: "test", version: "1.0") - - # Initialize with sampling capability but no transport. - server_without_transport.handle({ - jsonrpc: "2.0", - method: "initialize", - id: 1, - params: { - protocolVersion: "2025-11-25", - capabilities: { sampling: {} }, - clientInfo: { name: "test-client", version: "1.0" }, - }, - }) - - error = assert_raises(RuntimeError) do - server_without_transport.create_sampling_message( - messages: [{ role: "user", content: { type: "text", text: "Hello" } }], - max_tokens: 100, - ) - end - - assert_equal("Cannot send sampling request without a transport.", error.message) - end - test "create_sampling_message raises error when client does not support sampling" do - # Re-initialize without sampling capability. - @server.handle({ - jsonrpc: "2.0", - method: "initialize", - id: 2, - params: { - protocolVersion: "2025-11-25", - capabilities: {}, - clientInfo: { name: "test-client", version: "1.0" }, - }, - }) + @session.store_client_info(client: { name: "test-client" }, capabilities: {}) error = assert_raises(RuntimeError) do - @server.create_sampling_message( + @session.create_sampling_message( messages: [{ role: "user", content: { type: "text", text: "Hello" } }], max_tokens: 100, ) @@ -144,7 +100,7 @@ def close; end test "create_sampling_message raises error when tools used but client lacks sampling.tools" do error = assert_raises(RuntimeError) do - @server.create_sampling_message( + @session.create_sampling_message( messages: [{ role: "user", content: { type: "text", text: "Hello" } }], max_tokens: 100, tools: [{ name: "test_tool", inputSchema: { type: "object" } }], @@ -156,7 +112,7 @@ def close; end test "create_sampling_message raises error when tool_choice used alone but client lacks sampling.tools" do error = assert_raises(RuntimeError) do - @server.create_sampling_message( + @session.create_sampling_message( messages: [{ role: "user", content: { type: "text", text: "Hello" } }], max_tokens: 100, tool_choice: { mode: "auto" }, @@ -167,19 +123,9 @@ def close; end end test "create_sampling_message allows tools when client has sampling.tools capability" do - # Re-initialize with sampling.tools capability. - @server.handle({ - jsonrpc: "2.0", - method: "initialize", - id: 3, - params: { - protocolVersion: "2025-11-25", - capabilities: { sampling: { tools: {} } }, - clientInfo: { name: "test-client", version: "1.0" }, - }, - }) + @session.store_client_info(client: { name: "test-client" }, capabilities: { sampling: { tools: {} } }) - result = @server.create_sampling_message( + result = @session.create_sampling_message( messages: [{ role: "user", content: { type: "text", text: "Hello" } }], max_tokens: 100, tools: [{ name: "test_tool", inputSchema: { type: "object" } }], @@ -194,56 +140,6 @@ def close; end assert_equal "Response from LLM", result[:content][:text] end - test "init with sampling capability allows create_sampling_message" do - server = Server.new(name: "test", version: "1.0") - # Assigns server.transport via Transport#initialize, which create_sampling_message requires. - MockTransport.new(server) - - server.handle({ - jsonrpc: "2.0", - method: "initialize", - id: 1, - params: { - protocolVersion: "2025-11-25", - capabilities: { sampling: { tools: {} } }, - clientInfo: { name: "test-client", version: "1.0" }, - }, - }) - - result = server.create_sampling_message( - messages: [{ role: "user", content: { type: "text", text: "Hello" } }], - max_tokens: 100, - tools: [{ name: "t", inputSchema: { type: "object" } }], - ) - - assert_equal "assistant", result[:role] - end - - test "init without capabilities rejects create_sampling_message" do - server = Server.new(name: "test", version: "1.0") - # Assigns server.transport via Transport#initialize, which create_sampling_message requires. - MockTransport.new(server) - - server.handle({ - jsonrpc: "2.0", - method: "initialize", - id: 1, - params: { - protocolVersion: "2025-11-25", - clientInfo: { name: "test-client", version: "1.0" }, - }, - }) - - error = assert_raises(RuntimeError) do - server.create_sampling_message( - messages: [{ role: "user", content: { type: "text", text: "Hello" } }], - max_tokens: 100, - ) - end - - assert_equal("Client does not support sampling.", error.message) - end - test "create_sampling_message uses per-session capabilities via ServerSession" do transport = MCP::Server::Transports::StreamableHTTPTransport.new(@server) @@ -276,11 +172,23 @@ def close; end test "ServerSession#client_capabilities falls back to server global capabilities" do transport = MCP::Server::Transports::StreamableHTTPTransport.new(@server) + # Initialize server with sampling capability. + @server.handle({ + jsonrpc: "2.0", + method: "initialize", + id: 1, + params: { + protocolVersion: "2025-11-25", + capabilities: { sampling: {} }, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }) + # Session without capabilities stored falls back to @server.client_capabilities. session = ServerSession.new(server: @server, transport: transport, session_id: "s3") transport.instance_variable_get(:@sessions)["s3"] = { stream: nil, server_session: session } - # Server was initialized with sampling capability in setup, so fallback should pass validation. + # Server was initialized with sampling capability, so fallback should pass validation. error = assert_raises(RuntimeError) do session.create_sampling_message( messages: [{ role: "user", content: { type: "text", text: "Hello" } }], @@ -292,7 +200,7 @@ def close; end test "session init does not overwrite server global client_capabilities" do server = Server.new(name: "test", version: "1.0") - # Assigns server.transport via Transport#initialize, which create_sampling_message requires. + # Assigns server.transport via Transport#initialize. MockTransport.new(server) # Non-session init sets global capabilities. @@ -333,59 +241,8 @@ def close; end assert_equal({}, session.client_capabilities) end - test "Server#create_sampling_message does not see session-scoped capabilities from HTTP init" do - server = Server.new(name: "test", version: "1.0") - transport = MCP::Server::Transports::StreamableHTTPTransport.new(server) - - # HTTP init stores capabilities on the session, not on the server. - session = ServerSession.new(server: server, transport: transport, session_id: "s1") - server.handle( - { - jsonrpc: "2.0", - method: "initialize", - id: 1, - params: { - protocolVersion: "2025-11-25", - capabilities: { sampling: {} }, - clientInfo: { name: "http-client", version: "1.0" }, - }, - }, - session: session, - ) - - # Server-level API should not see session-scoped capabilities. - error = assert_raises(RuntimeError) do - server.create_sampling_message( - messages: [{ role: "user", content: { type: "text", text: "Hello" } }], - max_tokens: 100, - ) - end - assert_equal("Client does not support sampling.", error.message) - - # Session-scoped API should work (fails at transport level, not capability). - transport.instance_variable_get(:@sessions)["s1"] = { stream: nil, server_session: session } - error = assert_raises(RuntimeError) do - session.create_sampling_message( - messages: [{ role: "user", content: { type: "text", text: "Hello" } }], - max_tokens: 100, - ) - end - assert_equal("No active stream for sampling/createMessage request.", error.message) - end - - test "Server#create_sampling_message accepts related_request_id without error" do - @server.create_sampling_message( - messages: [{ role: "user", content: { type: "text", text: "Hello" } }], - max_tokens: 100, - related_request_id: "req-1", - ) - - request = @mock_transport.requests.first - assert_equal "sampling/createMessage", request[:method] - end - test "create_sampling_message omits nil optional params" do - @server.create_sampling_message( + @session.create_sampling_message( messages: [{ role: "user", content: { type: "text", text: "Hello" } }], max_tokens: 100, system_prompt: nil,