From dca2351eebb5f14f6e1091adb380d8ce37670014 Mon Sep 17 00:00:00 2001 From: David Ahmann Date: Wed, 25 Feb 2026 15:00:33 -0500 Subject: [PATCH] fix: add collision-resistant request and operation ids (#2099) --- pkg/context/request_ids.go | 44 ++++++++++++++++++ pkg/github/server.go | 23 ++++++++++ pkg/github/server_operation_id_test.go | 53 ++++++++++++++++++++++ pkg/http/headers/headers.go | 2 + pkg/http/middleware/request_config.go | 12 +++++ pkg/http/middleware/request_config_test.go | 52 +++++++++++++++++++++ 6 files changed, 186 insertions(+) create mode 100644 pkg/context/request_ids.go create mode 100644 pkg/github/server_operation_id_test.go create mode 100644 pkg/http/middleware/request_config_test.go diff --git a/pkg/context/request_ids.go b/pkg/context/request_ids.go new file mode 100644 index 000000000..3ccbcbf36 --- /dev/null +++ b/pkg/context/request_ids.go @@ -0,0 +1,44 @@ +package context + +import ( + "context" + "crypto/rand" + "fmt" +) + +type requestIDCtxKey struct{} +type operationIDCtxKey struct{} + +func WithRequestID(ctx context.Context, requestID string) context.Context { + return context.WithValue(ctx, requestIDCtxKey{}, requestID) +} + +func RequestID(ctx context.Context) (string, bool) { + requestID, ok := ctx.Value(requestIDCtxKey{}).(string) + return requestID, ok +} + +func WithOperationID(ctx context.Context, operationID string) context.Context { + return context.WithValue(ctx, operationIDCtxKey{}, operationID) +} + +func OperationID(ctx context.Context) (string, bool) { + operationID, ok := ctx.Value(operationIDCtxKey{}).(string) + return operationID, ok +} + +func GenerateRequestID() (string, error) { + return generateID("req") +} + +func GenerateOperationID() (string, error) { + return generateID("op") +} + +func generateID(prefix string) (string, error) { + buf := make([]byte, 16) + if _, err := rand.Read(buf); err != nil { + return "", fmt.Errorf("generate %s id: %w", prefix, err) + } + return fmt.Sprintf("%s_%x", prefix, buf), nil +} diff --git a/pkg/github/server.go b/pkg/github/server.go index 06c12575d..036efac9e 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -8,6 +8,7 @@ import ( "strings" "time" + ghcontext "github.com/github/github-mcp-server/pkg/context" gherrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/octicons" @@ -107,6 +108,7 @@ func NewMCPServer(ctx context.Context, cfg *MCPServerConfig, deps ToolDependenci // and any middleware that needs to read or modify the context should be before it. ghServer.AddReceivingMiddleware(middleware...) ghServer.AddReceivingMiddleware(InjectDepsMiddleware(deps)) + ghServer.AddReceivingMiddleware(withOperationID) ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) if unrecognized := inv.UnrecognizedToolsets(); len(unrecognized) > 0 { @@ -176,6 +178,27 @@ func addGitHubAPIErrorToContext(next mcp.MethodHandler) mcp.MethodHandler { } } +func withOperationID(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) { + requestID, ok := ghcontext.RequestID(ctx) + if !ok || requestID == "" { + requestID, err = ghcontext.GenerateRequestID() + if err != nil { + return nil, err + } + ctx = ghcontext.WithRequestID(ctx, requestID) + } + + operationID, err := ghcontext.GenerateOperationID() + if err != nil { + return nil, err + } + ctx = ghcontext.WithOperationID(ctx, operationID) + + return next(ctx, method, req) + } +} + // NewServer creates a new GitHub MCP server with the specified GH client and logger. func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { if opts == nil { diff --git a/pkg/github/server_operation_id_test.go b/pkg/github/server_operation_id_test.go new file mode 100644 index 000000000..d72bb8d7a --- /dev/null +++ b/pkg/github/server_operation_id_test.go @@ -0,0 +1,53 @@ +package github + +import ( + "context" + "testing" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWithOperationID_PreservesRequestIDAndAddsOperationID(t *testing.T) { + t.Parallel() + + var capturedRequestID string + var capturedOperationID string + handler := withOperationID(func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + var ok bool + capturedRequestID, ok = ghcontext.RequestID(ctx) + require.True(t, ok) + + capturedOperationID, ok = ghcontext.OperationID(ctx) + require.True(t, ok) + return nil, nil + }) + + _, err := handler(ghcontext.WithRequestID(context.Background(), "req_client"), "tools/call", nil) + require.NoError(t, err) + + assert.Equal(t, "req_client", capturedRequestID) + assert.Regexp(t, `^op_[0-9a-f]+$`, capturedOperationID) +} + +func TestWithOperationID_GeneratesUniqueOperationIDs(t *testing.T) { + t.Parallel() + + var operationIDs []string + handler := withOperationID(func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + operationID, ok := ghcontext.OperationID(ctx) + require.True(t, ok) + operationIDs = append(operationIDs, operationID) + return nil, nil + }) + + _, err := handler(context.Background(), "tools/call", nil) + require.NoError(t, err) + _, err = handler(context.Background(), "tools/call", nil) + require.NoError(t, err) + + require.Len(t, operationIDs, 2) + assert.NotEqual(t, operationIDs[0], operationIDs[1]) +} diff --git a/pkg/http/headers/headers.go b/pkg/http/headers/headers.go index e032a0ce9..a52bf101f 100644 --- a/pkg/http/headers/headers.go +++ b/pkg/http/headers/headers.go @@ -25,6 +25,8 @@ const ( ForwardedHostHeader = "X-Forwarded-Host" // ForwardedProtoHeader is a standard HTTP Header for preserving the original protocol when proxying. ForwardedProtoHeader = "X-Forwarded-Proto" + // RequestIDHeader is a standard request-correlation header. + RequestIDHeader = "X-Request-ID" // RequestHmacHeader is used to authenticate requests to the Raw API. RequestHmacHeader = "Request-Hmac" diff --git a/pkg/http/middleware/request_config.go b/pkg/http/middleware/request_config.go index a7311334d..c898843bd 100644 --- a/pkg/http/middleware/request_config.go +++ b/pkg/http/middleware/request_config.go @@ -15,6 +15,18 @@ func WithRequestConfig(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + requestID := strings.TrimSpace(r.Header.Get(headers.RequestIDHeader)) + if requestID == "" { + generatedRequestID, err := ghcontext.GenerateRequestID() + if err != nil { + http.Error(w, "failed to generate request id", http.StatusInternalServerError) + return + } + requestID = generatedRequestID + } + ctx = ghcontext.WithRequestID(ctx, requestID) + w.Header().Set(headers.RequestIDHeader, requestID) + // Readonly mode if relaxedParseBool(r.Header.Get(headers.MCPReadOnlyHeader)) { ctx = ghcontext.WithReadonly(ctx, true) diff --git a/pkg/http/middleware/request_config_test.go b/pkg/http/middleware/request_config_test.go new file mode 100644 index 000000000..6398e8de6 --- /dev/null +++ b/pkg/http/middleware/request_config_test.go @@ -0,0 +1,52 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWithRequestConfig_PreservesProvidedRequestID(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/", nil) + request.Header.Set(headers.RequestIDHeader, "client-request-id") + + var requestID string + handler := WithRequestConfig(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + var ok bool + requestID, ok = ghcontext.RequestID(r.Context()) + require.True(t, ok) + })) + + handler.ServeHTTP(recorder, request) + + assert.Equal(t, "client-request-id", requestID) + assert.Equal(t, "client-request-id", recorder.Header().Get(headers.RequestIDHeader)) +} + +func TestWithRequestConfig_GeneratesRequestIDWhenMissing(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/", nil) + + var requestID string + handler := WithRequestConfig(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + var ok bool + requestID, ok = ghcontext.RequestID(r.Context()) + require.True(t, ok) + })) + + handler.ServeHTTP(recorder, request) + + assert.NotEmpty(t, requestID) + assert.Equal(t, requestID, recorder.Header().Get(headers.RequestIDHeader)) + assert.Regexp(t, `^req_[0-9a-f]+$`, requestID) +}