Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -915,8 +915,26 @@ var clientMethodInfos = map[string]methodInfo{
notificationElicitationComplete: newClientMethodInfo(clientMethod((*Client).callElicitationCompleteHandler), notification|missingParamsOK),
}

var clientSendingMethodInfos = map[string]methodInfo{}

func init() {
// clientSendingMethodInfos is identical to serverMethodInfos, but tools/call
// produces a *CallToolResult, honoring the client's v1 signature, rather than
// the server's internal v2 *RoundTripCallToolResult.
// This is a temporary workaround to make the tests pass, given that we're not updating the client.
for k, v := range serverMethodInfos {
if k == methodCallTool {
clone := v
clone.newResult = func() Result { return &CallToolResult{} }
clientSendingMethodInfos[k] = clone
} else {
clientSendingMethodInfos[k] = v
}
}
}

func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo {
return serverMethodInfos
return clientSendingMethodInfos
}

func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo {
Expand Down
2 changes: 2 additions & 0 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2200,6 +2200,8 @@ func TestToolErrorMiddleware(t *testing.T) {
if err == nil {
if ctr, ok := res.(*CallToolResult); ok {
middleErr = ctr.GetError()
} else if rctr, ok := res.(*RoundTripCallToolResult); ok && rctr.Complete != nil {
middleErr = rctr.Complete.GetError()
}
}
return res, err
Expand Down
59 changes: 59 additions & 0 deletions mcp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ type CallToolParams struct {
// Arguments holds the tool arguments. It can hold any value that can be
// marshaled to JSON.
Arguments any `json:"arguments,omitempty"`

// MRTR: InputResponses map from server-issued ID to client-fulfilled states.
InputResponses map[string]any `json:"inputResponses,omitempty"`
// MRTR: RequestState is an opaque token from an earlier IncompleteResult.
RequestState string `json:"requestState,omitempty"`
}

// CallToolParamsRaw is passed to tool handlers on the server. Its arguments
Expand All @@ -60,6 +65,11 @@ type CallToolParamsRaw struct {
// is the responsibility of the tool handler to unmarshal and validate the
// Arguments (see [AddTool]).
Arguments json.RawMessage `json:"arguments,omitempty"`

// MRTR: InputResponses map from server-issued ID to client-fulfilled states.
InputResponses map[string]any `json:"inputResponses,omitempty"`
// MRTR: RequestState is an opaque token from an earlier IncompleteResult.
RequestState string `json:"requestState,omitempty"`
}

// A CallToolResult is the server's response to a tool call.
Expand Down Expand Up @@ -148,6 +158,55 @@ func (x *CallToolResult) UnmarshalJSON(data []byte) error {
return nil
}

// IncompleteResult corresponds to a CallToolResult structurally
// when ResultType == "incomplete".
type IncompleteResult struct {
Meta `json:"_meta,omitempty"`
ResultType string `json:"result_type"` // Expected to be "incomplete"
InputRequests map[string]any `json:"inputRequests,omitempty"`
RequestState string `json:"requestState,omitempty"`
}

// RoundTripCallToolResult wraps a standard CallToolResult or an IncompleteResult.
// Handlers populate exactly one of the fields.
// This is not the final choice to represent the union type, but some type will be needed for that.
type RoundTripCallToolResult struct {
Complete *CallToolResult
Incomplete *IncompleteResult
}

func (r *RoundTripCallToolResult) isResult() {}

func (r *RoundTripCallToolResult) GetMeta() map[string]any {
if r.Incomplete != nil {
return r.Incomplete.GetMeta()
}
if r.Complete != nil {
return r.Complete.GetMeta()
}
return nil
}

func (r *RoundTripCallToolResult) SetMeta(m map[string]any) {
if r.Incomplete != nil {
r.Incomplete.SetMeta(m)
} else if r.Complete != nil {
r.Complete.SetMeta(m)
}
}

// MarshalJSON safely marshals whichever result branch is taken.
func (r *RoundTripCallToolResult) MarshalJSON() ([]byte, error) {
if r.Incomplete != nil {
r.Incomplete.ResultType = "incomplete"
return json.Marshal(r.Incomplete)
}
if r.Complete == nil {
return json.Marshal(&CallToolResult{})
}
return json.Marshal(r.Complete)
}

func (x *CallToolParams) isParams() {}
func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) }
func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) }
Expand Down
117 changes: 73 additions & 44 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func (s *Server) RemovePrompts(names ...string) {
//
// Most users should use the top-level function [AddTool], which handles all these
// responsibilities.
func (s *Server) AddTool(t *Tool, h ToolHandler) {
func (s *Server) AddRoundTripTool(t *Tool, h RoundTripToolHandler) {
if err := validateToolName(t.Name); err != nil {
s.opts.Logger.Error(fmt.Sprintf("AddTool: invalid tool name %q: %v", t.Name, err))
}
Expand Down Expand Up @@ -282,7 +282,20 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
s.changeAndNotify(notificationToolListChanged, func() bool { s.tools.add(st); return true })
}

func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], cache *SchemaCache) (*Tool, ToolHandler, error) {
// AddTool adds a [Tool] to the server using the v1 ToolHandler API.
// It wraps the provided ToolHandler into a RoundTripToolHandler and calls AddRoundTripTool.
func (s *Server) AddTool(t *Tool, h ToolHandler) {
adapted := func(ctx context.Context, req *CallToolRequest) (*RoundTripCallToolResult, error) {
res, err := h(ctx, req)
if err != nil {
return nil, err
}
return &RoundTripCallToolResult{Complete: res}, nil
}
s.AddRoundTripTool(t, adapted)
}

func roundTripToolForErr[In, Out any](t *Tool, h RoundTripToolHandlerFor[In, Out], cache *SchemaCache) (*Tool, RoundTripToolHandler, error) {
tt := *t

// Special handling for an "any" input: treat as an empty object.
Expand Down Expand Up @@ -312,7 +325,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], cache *SchemaCa
}
}

th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) {
th := func(ctx context.Context, req *CallToolRequest) (*RoundTripCallToolResult, error) {
var input json.RawMessage
if req.Params.Arguments != nil {
input = req.Params.Arguments
Expand Down Expand Up @@ -347,46 +360,48 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], cache *SchemaCa
// For regular errors, embed them in the tool result as per MCP spec
var errRes CallToolResult
errRes.SetError(err)
return &errRes, nil
return &RoundTripCallToolResult{Complete: &errRes}, nil
}

if res == nil {
res = &CallToolResult{}
res = &RoundTripCallToolResult{Complete: &CallToolResult{}}
}

// Marshal the output and put the RawMessage in the StructuredContent field.
var outval any = out
if elemZero != nil {
// Avoid typed nil, which will serialize as JSON null.
// Instead, use the zero value of the unpointered type.
var z Out
if any(out) == any(z) { // zero is only non-nil if Out is a pointer type
outval = elemZero
if res.Complete != nil {
var outval any = out
if elemZero != nil {
// Avoid typed nil, which will serialize as JSON null.
// Instead, use the zero value of the unpointered type.
var z Out
if any(out) == any(z) { // zero is only non-nil if Out is a pointer type
outval = elemZero
}
}
}
if outval != nil {
outbytes, err := json.Marshal(outval)
if err != nil {
return nil, fmt.Errorf("marshaling output: %w", err)
}
outJSON := json.RawMessage(outbytes)
// Validate the output JSON, and apply defaults.
//
// We validate against the JSON, rather than the output value, as
// some types may have custom JSON marshalling (issue #447).
outJSON, err = applySchema(outJSON, outputResolved)
if err != nil {
return nil, fmt.Errorf("validating tool output: %w", err)
}
res.StructuredContent = outJSON // avoid a second marshal over the wire

// If the Content field isn't being used, return the serialized JSON in a
// TextContent block, as the spec suggests:
// https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content.
if res.Content == nil {
res.Content = []Content{&TextContent{
Text: string(outJSON),
}}
if outval != nil {
outbytes, err := json.Marshal(outval)
if err != nil {
return nil, fmt.Errorf("marshaling output: %w", err)
}
outJSON := json.RawMessage(outbytes)
// Validate the output JSON, and apply defaults.
//
// We validate against the JSON, rather than the output value, as
// some types may have custom JSON marshalling (issue #447).
outJSON, err = applySchema(outJSON, outputResolved)
if err != nil {
return nil, fmt.Errorf("validating tool output: %w", err)
}
res.Complete.StructuredContent = outJSON // avoid a second marshal over the wire

// If the Content field isn't being used, return the serialized JSON in a
// TextContent block, as the spec suggests:
// https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content.
if res.Complete.Content == nil {
res.Complete.Content = []Content{&TextContent{
Text: string(outJSON),
}}
}
}
}
return res, nil
Expand Down Expand Up @@ -498,11 +513,27 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved, cache *SchemaCa
// tools to conform to the MCP spec. See [ToolHandlerFor] for a detailed
// description of this automatic behavior.
func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) {
tt, hh, err := toolForErr(t, h, s.opts.SchemaCache)
adapted := func(ctx context.Context, req *CallToolRequest, in In) (*RoundTripCallToolResult, Out, error) {
res, out, err := h(ctx, req, in)
if err != nil {
var zero Out
return nil, zero, err
}
if res == nil {
res = &CallToolResult{}
}
return &RoundTripCallToolResult{Complete: res}, out, nil
}
AddRoundTripTool(s, t, adapted)
}

// AddRoundTripTool adds a tool and typed MRTR tool handler to the server.
func AddRoundTripTool[In, Out any](s *Server, t *Tool, h RoundTripToolHandlerFor[In, Out]) {
tt, hh, err := roundTripToolForErr(t, h, s.opts.SchemaCache)
if err != nil {
panic(fmt.Sprintf("AddTool: tool %q: %v", t.Name, err))
panic(fmt.Sprintf("AddRoundTripTool: tool %q: %v", t.Name, err))
}
s.AddTool(tt, hh)
s.AddRoundTripTool(tt, hh)
}

// RemoveTools removes the tools with the given names.
Expand Down Expand Up @@ -729,7 +760,7 @@ func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListTools
})
}

func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) {
func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*RoundTripCallToolResult, error) {
s.mu.Lock()
st, ok := s.tools.get(req.Params.Name)
s.mu.Unlock()
Expand All @@ -740,10 +771,8 @@ func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolR
}
}
res, err := st.handler(ctx, req)
if err == nil && res != nil && res.Content == nil {
res2 := *res
res2.Content = []Content{} // avoid "null"
res = &res2
if err == nil && res != nil && res.Complete != nil && res.Complete.Content == nil {
res.Complete.Content = []Content{} // avoid "null"
}
return res, err
}
Expand Down
10 changes: 5 additions & 5 deletions mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -656,10 +656,10 @@ type schema = jsonschema.Schema

func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut any, wantErrContaining string) {
t.Helper()
th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) {
th := func(context.Context, *CallToolRequest, In) (*RoundTripCallToolResult, Out, error) {
return nil, out, nil
}
gott, goth, err := toolForErr(tool, th, nil)
gott, goth, err := roundTripToolForErr(tool, th, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -687,10 +687,10 @@ func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out
t.Errorf("got error %v, want no error", err)
}

if gott.OutputSchema != nil && err == nil && !result.IsError {
if gott.OutputSchema != nil && err == nil && result.Complete != nil && !result.Complete.IsError {
// Check that structured content matches exactly.
unstructured := result.Content[0].(*TextContent).Text
structured := string(result.StructuredContent.(json.RawMessage))
unstructured := result.Complete.Content[0].(*TextContent).Text
structured := string(result.Complete.StructuredContent.(json.RawMessage))
if diff := cmp.Diff(unstructured, structured); diff != "" {
t.Errorf("Unstructured content does not match structured content exactly (-unstructured +structured):\n%s", diff)
}
Expand Down
10 changes: 9 additions & 1 deletion mcp/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,18 @@ type ToolHandler func(context.Context, *CallToolRequest) (*CallToolResult, error
// or error. The effective result will be populated as described above.
type ToolHandlerFor[In, Out any] func(_ context.Context, request *CallToolRequest, input In) (result *CallToolResult, output Out, _ error)

// RoundTripToolHandler handles tools/call requests taking advantage of Multi Round-Trip Requests.
// It returns a *RoundTripCallToolResult which encapsulates either a complete or incomplete result.
type RoundTripToolHandler func(context.Context, *CallToolRequest) (*RoundTripCallToolResult, error)

// RoundTripToolHandlerFor is the generic, typed version of RoundTripToolHandler.
// Like ToolHandlerFor, it handles input unmarshaling and validation.
type RoundTripToolHandlerFor[In, Out any] func(_ context.Context, request *CallToolRequest, input In) (result *RoundTripCallToolResult, output Out, _ error)

// A serverTool is a tool definition that is bound to a tool handler.
type serverTool struct {
tool *Tool
handler ToolHandler
handler RoundTripToolHandler
}

// applySchema validates whether data is valid JSON according to the provided
Expand Down
Loading