From 97a847e6346456b54d251c3e4918fc89670c2f4f Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Thu, 19 Mar 2026 12:19:57 +0000 Subject: [PATCH] mcp: MRTR Tool API PoC --- mcp/client.go | 20 +++++++- mcp/mcp_test.go | 2 + mcp/protocol.go | 59 +++++++++++++++++++++++ mcp/server.go | 117 ++++++++++++++++++++++++++++----------------- mcp/server_test.go | 10 ++-- mcp/tool.go | 10 +++- 6 files changed, 167 insertions(+), 51 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 74900b1c..f1238f88 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -915,8 +915,26 @@ var clientMethodInfos = map[string]methodInfo{ notificationElicitationComplete: newClientMethodInfo(clientMethod((*Client).callElicitationCompleteHandler), notification|missingParamsOK), } +var clientSendingMethodInfos = map[string]methodInfo{} + +func init() { + // clientSendingMethodInfos is identical to serverMethodInfos, but tools/call + // produces a *CallToolResult, honoring the client's v1 signature, rather than + // the server's internal v2 *RoundTripCallToolResult. + // This is a temporary workaround to make the tests pass, given that we're not updating the client. + for k, v := range serverMethodInfos { + if k == methodCallTool { + clone := v + clone.newResult = func() Result { return &CallToolResult{} } + clientSendingMethodInfos[k] = clone + } else { + clientSendingMethodInfos[k] = v + } + } +} + func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { - return serverMethodInfos + return clientSendingMethodInfos } func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 86c05502..ca521553 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -2200,6 +2200,8 @@ func TestToolErrorMiddleware(t *testing.T) { if err == nil { if ctr, ok := res.(*CallToolResult); ok { middleErr = ctr.GetError() + } else if rctr, ok := res.(*RoundTripCallToolResult); ok && rctr.Complete != nil { + middleErr = rctr.Complete.GetError() } } return res, err diff --git a/mcp/protocol.go b/mcp/protocol.go index d0717404..d4f3406b 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -45,6 +45,11 @@ type CallToolParams struct { // Arguments holds the tool arguments. It can hold any value that can be // marshaled to JSON. Arguments any `json:"arguments,omitempty"` + + // MRTR: InputResponses map from server-issued ID to client-fulfilled states. + InputResponses map[string]any `json:"inputResponses,omitempty"` + // MRTR: RequestState is an opaque token from an earlier IncompleteResult. + RequestState string `json:"requestState,omitempty"` } // CallToolParamsRaw is passed to tool handlers on the server. Its arguments @@ -60,6 +65,11 @@ type CallToolParamsRaw struct { // is the responsibility of the tool handler to unmarshal and validate the // Arguments (see [AddTool]). Arguments json.RawMessage `json:"arguments,omitempty"` + + // MRTR: InputResponses map from server-issued ID to client-fulfilled states. + InputResponses map[string]any `json:"inputResponses,omitempty"` + // MRTR: RequestState is an opaque token from an earlier IncompleteResult. + RequestState string `json:"requestState,omitempty"` } // A CallToolResult is the server's response to a tool call. @@ -148,6 +158,55 @@ func (x *CallToolResult) UnmarshalJSON(data []byte) error { return nil } +// IncompleteResult corresponds to a CallToolResult structurally +// when ResultType == "incomplete". +type IncompleteResult struct { + Meta `json:"_meta,omitempty"` + ResultType string `json:"result_type"` // Expected to be "incomplete" + InputRequests map[string]any `json:"inputRequests,omitempty"` + RequestState string `json:"requestState,omitempty"` +} + +// RoundTripCallToolResult wraps a standard CallToolResult or an IncompleteResult. +// Handlers populate exactly one of the fields. +// This is not the final choice to represent the union type, but some type will be needed for that. +type RoundTripCallToolResult struct { + Complete *CallToolResult + Incomplete *IncompleteResult +} + +func (r *RoundTripCallToolResult) isResult() {} + +func (r *RoundTripCallToolResult) GetMeta() map[string]any { + if r.Incomplete != nil { + return r.Incomplete.GetMeta() + } + if r.Complete != nil { + return r.Complete.GetMeta() + } + return nil +} + +func (r *RoundTripCallToolResult) SetMeta(m map[string]any) { + if r.Incomplete != nil { + r.Incomplete.SetMeta(m) + } else if r.Complete != nil { + r.Complete.SetMeta(m) + } +} + +// MarshalJSON safely marshals whichever result branch is taken. +func (r *RoundTripCallToolResult) MarshalJSON() ([]byte, error) { + if r.Incomplete != nil { + r.Incomplete.ResultType = "incomplete" + return json.Marshal(r.Incomplete) + } + if r.Complete == nil { + return json.Marshal(&CallToolResult{}) + } + return json.Marshal(r.Complete) +} + func (x *CallToolParams) isParams() {} func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) } func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) } diff --git a/mcp/server.go b/mcp/server.go index e3c03e27..bba20479 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -235,7 +235,7 @@ func (s *Server) RemovePrompts(names ...string) { // // Most users should use the top-level function [AddTool], which handles all these // responsibilities. -func (s *Server) AddTool(t *Tool, h ToolHandler) { +func (s *Server) AddRoundTripTool(t *Tool, h RoundTripToolHandler) { if err := validateToolName(t.Name); err != nil { s.opts.Logger.Error(fmt.Sprintf("AddTool: invalid tool name %q: %v", t.Name, err)) } @@ -282,7 +282,20 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { s.changeAndNotify(notificationToolListChanged, func() bool { s.tools.add(st); return true }) } -func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], cache *SchemaCache) (*Tool, ToolHandler, error) { +// AddTool adds a [Tool] to the server using the v1 ToolHandler API. +// It wraps the provided ToolHandler into a RoundTripToolHandler and calls AddRoundTripTool. +func (s *Server) AddTool(t *Tool, h ToolHandler) { + adapted := func(ctx context.Context, req *CallToolRequest) (*RoundTripCallToolResult, error) { + res, err := h(ctx, req) + if err != nil { + return nil, err + } + return &RoundTripCallToolResult{Complete: res}, nil + } + s.AddRoundTripTool(t, adapted) +} + +func roundTripToolForErr[In, Out any](t *Tool, h RoundTripToolHandlerFor[In, Out], cache *SchemaCache) (*Tool, RoundTripToolHandler, error) { tt := *t // Special handling for an "any" input: treat as an empty object. @@ -312,7 +325,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], cache *SchemaCa } } - th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + th := func(ctx context.Context, req *CallToolRequest) (*RoundTripCallToolResult, error) { var input json.RawMessage if req.Params.Arguments != nil { input = req.Params.Arguments @@ -347,46 +360,48 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], cache *SchemaCa // For regular errors, embed them in the tool result as per MCP spec var errRes CallToolResult errRes.SetError(err) - return &errRes, nil + return &RoundTripCallToolResult{Complete: &errRes}, nil } if res == nil { - res = &CallToolResult{} + res = &RoundTripCallToolResult{Complete: &CallToolResult{}} } // Marshal the output and put the RawMessage in the StructuredContent field. - var outval any = out - if elemZero != nil { - // Avoid typed nil, which will serialize as JSON null. - // Instead, use the zero value of the unpointered type. - var z Out - if any(out) == any(z) { // zero is only non-nil if Out is a pointer type - outval = elemZero + if res.Complete != nil { + var outval any = out + if elemZero != nil { + // Avoid typed nil, which will serialize as JSON null. + // Instead, use the zero value of the unpointered type. + var z Out + if any(out) == any(z) { // zero is only non-nil if Out is a pointer type + outval = elemZero + } } - } - if outval != nil { - outbytes, err := json.Marshal(outval) - if err != nil { - return nil, fmt.Errorf("marshaling output: %w", err) - } - outJSON := json.RawMessage(outbytes) - // Validate the output JSON, and apply defaults. - // - // We validate against the JSON, rather than the output value, as - // some types may have custom JSON marshalling (issue #447). - outJSON, err = applySchema(outJSON, outputResolved) - if err != nil { - return nil, fmt.Errorf("validating tool output: %w", err) - } - res.StructuredContent = outJSON // avoid a second marshal over the wire - - // If the Content field isn't being used, return the serialized JSON in a - // TextContent block, as the spec suggests: - // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content. - if res.Content == nil { - res.Content = []Content{&TextContent{ - Text: string(outJSON), - }} + if outval != nil { + outbytes, err := json.Marshal(outval) + if err != nil { + return nil, fmt.Errorf("marshaling output: %w", err) + } + outJSON := json.RawMessage(outbytes) + // Validate the output JSON, and apply defaults. + // + // We validate against the JSON, rather than the output value, as + // some types may have custom JSON marshalling (issue #447). + outJSON, err = applySchema(outJSON, outputResolved) + if err != nil { + return nil, fmt.Errorf("validating tool output: %w", err) + } + res.Complete.StructuredContent = outJSON // avoid a second marshal over the wire + + // If the Content field isn't being used, return the serialized JSON in a + // TextContent block, as the spec suggests: + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content. + if res.Complete.Content == nil { + res.Complete.Content = []Content{&TextContent{ + Text: string(outJSON), + }} + } } } return res, nil @@ -498,11 +513,27 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved, cache *SchemaCa // tools to conform to the MCP spec. See [ToolHandlerFor] for a detailed // description of this automatic behavior. func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { - tt, hh, err := toolForErr(t, h, s.opts.SchemaCache) + adapted := func(ctx context.Context, req *CallToolRequest, in In) (*RoundTripCallToolResult, Out, error) { + res, out, err := h(ctx, req, in) + if err != nil { + var zero Out + return nil, zero, err + } + if res == nil { + res = &CallToolResult{} + } + return &RoundTripCallToolResult{Complete: res}, out, nil + } + AddRoundTripTool(s, t, adapted) +} + +// AddRoundTripTool adds a tool and typed MRTR tool handler to the server. +func AddRoundTripTool[In, Out any](s *Server, t *Tool, h RoundTripToolHandlerFor[In, Out]) { + tt, hh, err := roundTripToolForErr(t, h, s.opts.SchemaCache) if err != nil { - panic(fmt.Sprintf("AddTool: tool %q: %v", t.Name, err)) + panic(fmt.Sprintf("AddRoundTripTool: tool %q: %v", t.Name, err)) } - s.AddTool(tt, hh) + s.AddRoundTripTool(tt, hh) } // RemoveTools removes the tools with the given names. @@ -729,7 +760,7 @@ func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListTools }) } -func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { +func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*RoundTripCallToolResult, error) { s.mu.Lock() st, ok := s.tools.get(req.Params.Name) s.mu.Unlock() @@ -740,10 +771,8 @@ func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolR } } res, err := st.handler(ctx, req) - if err == nil && res != nil && res.Content == nil { - res2 := *res - res2.Content = []Content{} // avoid "null" - res = &res2 + if err == nil && res != nil && res.Complete != nil && res.Complete.Content == nil { + res.Complete.Content = []Content{} // avoid "null" } return res, err } diff --git a/mcp/server_test.go b/mcp/server_test.go index e57af1e2..58457e4f 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -656,10 +656,10 @@ type schema = jsonschema.Schema func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut any, wantErrContaining string) { t.Helper() - th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) { + th := func(context.Context, *CallToolRequest, In) (*RoundTripCallToolResult, Out, error) { return nil, out, nil } - gott, goth, err := toolForErr(tool, th, nil) + gott, goth, err := roundTripToolForErr(tool, th, nil) if err != nil { t.Fatal(err) } @@ -687,10 +687,10 @@ func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out t.Errorf("got error %v, want no error", err) } - if gott.OutputSchema != nil && err == nil && !result.IsError { + if gott.OutputSchema != nil && err == nil && result.Complete != nil && !result.Complete.IsError { // Check that structured content matches exactly. - unstructured := result.Content[0].(*TextContent).Text - structured := string(result.StructuredContent.(json.RawMessage)) + unstructured := result.Complete.Content[0].(*TextContent).Text + structured := string(result.Complete.StructuredContent.(json.RawMessage)) if diff := cmp.Diff(unstructured, structured); diff != "" { t.Errorf("Unstructured content does not match structured content exactly (-unstructured +structured):\n%s", diff) } diff --git a/mcp/tool.go b/mcp/tool.go index 3ecb59d3..8c42fe2d 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -56,10 +56,18 @@ type ToolHandler func(context.Context, *CallToolRequest) (*CallToolResult, error // or error. The effective result will be populated as described above. type ToolHandlerFor[In, Out any] func(_ context.Context, request *CallToolRequest, input In) (result *CallToolResult, output Out, _ error) +// RoundTripToolHandler handles tools/call requests taking advantage of Multi Round-Trip Requests. +// It returns a *RoundTripCallToolResult which encapsulates either a complete or incomplete result. +type RoundTripToolHandler func(context.Context, *CallToolRequest) (*RoundTripCallToolResult, error) + +// RoundTripToolHandlerFor is the generic, typed version of RoundTripToolHandler. +// Like ToolHandlerFor, it handles input unmarshaling and validation. +type RoundTripToolHandlerFor[In, Out any] func(_ context.Context, request *CallToolRequest, input In) (result *RoundTripCallToolResult, output Out, _ error) + // A serverTool is a tool definition that is bound to a tool handler. type serverTool struct { tool *Tool - handler ToolHandler + handler RoundTripToolHandler } // applySchema validates whether data is valid JSON according to the provided