From 0cea048641195a6987a7cb2fe201d4e1e34a2940 Mon Sep 17 00:00:00 2001 From: YuZhangLarry Date: Fri, 8 May 2026 05:12:30 +0800 Subject: [PATCH] feat(mcp): add Model Context Protocol (MCP) server with HTTP transport and authentication ## Summary Implement Model Context Protocol (MCP) server for Dubbo Admin to enable AI integration through standardized tool interfaces. ## Key Features - **Modular Architecture**: Core components (server, registry, tools, transport, types) - **Comprehensive Tool Support**: 11 tools covering cluster info, service discovery, instance management, metrics, and application details - **Dual Transport Support**: - Stdio transport for local Claude Desktop integration - HTTP transport for remote connections with JSON-RPC 2.0 - **Security**: Optional Bearer Token authentication for HTTP endpoint ## Configuration ## Test plan - [x] Unit tests for core components - [x] Integration tests - [x] Manual testing with Claude Desktop Co-Authored-By: Claude Opus 4.7 (1M context) --- pkg/config/app/admin.go | 152 ++++-- pkg/console/component.go | 150 +++++- pkg/core/bootstrap/init.go | 41 ++ pkg/mcp/component.go | 150 ++++++ pkg/mcp/core/builder.go | 64 +++ pkg/mcp/core/constants.go | 45 ++ pkg/mcp/core/server.go | 237 +++++++++ pkg/mcp/core/server_test.go | 159 ++++++ pkg/mcp/core/tool.go | 43 ++ pkg/mcp/core/types.go | 84 +++ pkg/mcp/mcp.go | 110 ++++ pkg/mcp/registry/registry.go | 118 +++++ pkg/mcp/registry/registry_test.go | 167 ++++++ pkg/mcp/tools/detail_tools.go | 493 ++++++++++++++++++ pkg/mcp/tools/integration_test.go | 429 +++++++++++++++ pkg/mcp/tools/live_test.go | 365 +++++++++++++ pkg/mcp/tools/metrics.go | 84 +++ pkg/mcp/tools/resource_search.go | 291 +++++++++++ pkg/mcp/tools/service_discovery.go | 243 +++++++++ pkg/mcp/tools/tools_test.go | 485 +++++++++++++++++ pkg/mcp/tools/utils.go | 150 ++++++ pkg/mcp/transport/http/handler.go | 109 ++++ pkg/mcp/transport/http/http.go | 168 ++++++ pkg/mcp/transport/http/http_test.go | 272 ++++++++++ pkg/mcp/transport/http/sse.go | 191 +++++++ pkg/mcp/transport/stdio/stdio.go | 157 ++++++ .../transport/stdio/stdio_integration_test.go | 407 +++++++++++++++ pkg/mcp/transport/stdio/stdio_test.go | 174 +++++++ pkg/mcp/types/tool.go | 84 +++ pkg/mcp/types/validation.go | 55 ++ 30 files changed, 5631 insertions(+), 46 deletions(-) create mode 100644 pkg/core/bootstrap/init.go create mode 100644 pkg/mcp/component.go create mode 100644 pkg/mcp/core/builder.go create mode 100644 pkg/mcp/core/constants.go create mode 100644 pkg/mcp/core/server.go create mode 100644 pkg/mcp/core/server_test.go create mode 100644 pkg/mcp/core/tool.go create mode 100644 pkg/mcp/core/types.go create mode 100644 pkg/mcp/mcp.go create mode 100644 pkg/mcp/registry/registry.go create mode 100644 pkg/mcp/registry/registry_test.go create mode 100644 pkg/mcp/tools/detail_tools.go create mode 100644 pkg/mcp/tools/integration_test.go create mode 100644 pkg/mcp/tools/live_test.go create mode 100644 pkg/mcp/tools/metrics.go create mode 100644 pkg/mcp/tools/resource_search.go create mode 100644 pkg/mcp/tools/service_discovery.go create mode 100644 pkg/mcp/tools/tools_test.go create mode 100644 pkg/mcp/tools/utils.go create mode 100644 pkg/mcp/transport/http/handler.go create mode 100644 pkg/mcp/transport/http/http.go create mode 100644 pkg/mcp/transport/http/http_test.go create mode 100644 pkg/mcp/transport/http/sse.go create mode 100644 pkg/mcp/transport/stdio/stdio.go create mode 100644 pkg/mcp/transport/stdio/stdio_integration_test.go create mode 100644 pkg/mcp/transport/stdio/stdio_test.go create mode 100644 pkg/mcp/types/tool.go create mode 100644 pkg/mcp/types/validation.go diff --git a/pkg/config/app/admin.go b/pkg/config/app/admin.go index 3d4722ed2..8b7d43057 100644 --- a/pkg/config/app/admin.go +++ b/pkg/config/app/admin.go @@ -18,82 +18,176 @@ package app import ( - "github.com/pkg/errors" + "github.com/duke-git/lancet/v2/slice" "go.uber.org/multierr" + "github.com/apache/dubbo-admin/pkg/common/bizerror" "github.com/apache/dubbo-admin/pkg/config" "github.com/apache/dubbo-admin/pkg/config/console" "github.com/apache/dubbo-admin/pkg/config/diagnostics" "github.com/apache/dubbo-admin/pkg/config/discovery" "github.com/apache/dubbo-admin/pkg/config/engine" - "github.com/apache/dubbo-admin/pkg/config/mode" + "github.com/apache/dubbo-admin/pkg/config/log" + "github.com/apache/dubbo-admin/pkg/config/observability" "github.com/apache/dubbo-admin/pkg/config/store" ) type AdminConfig struct { config.BaseConfig - // Mode in which dubbo admin is running. Available values are: "test", "global", "zone" - Mode mode.Mode `json:"mode" envconfig:"DUBBO_MODE"` + // Log configuration + Log *log.Config `json:"log" yaml:"log"` // Diagnostics configuration - Diagnostics *diagnostics.Config `json:"diagnostics,omitempty"` + Diagnostics *diagnostics.Config `json:"diagnostics,omitempty" yaml:"diagnostics"` + // Observability configuration + Observability *observability.Config `json:"observability" yaml:"observability"` // Console configuration - Console *console.Config `json:"admin"` + Console *console.Config `json:"console" yaml:"console"` // Store configuration - Store *store.Config `json:"store"` + Store *store.Config `json:"store" yaml:"store"` // Discovery configuration - Discovery *discovery.Config `json:"discovery"` + Discovery []*discovery.Config `json:"discovery" yaml:"discovery"` // Engine configuration - Engine *engine.Config `json:"engine"` + Engine *engine.Config `json:"engine" yaml:"engine"` + // MCP configuration + MCP *MCPConfig `json:"mcp,omitempty" yaml:"mcp"` +} + +// MCPConfig MCP配置 +type MCPConfig struct { + // Enabled 是否启用MCP端点 + Enabled bool `json:"enabled" yaml:"enabled"` + // Path MCP端点路径,默认 /api/mcp + Path string `json:"path,omitempty" yaml:"path"` + // APIKey MCP API密钥,用于认证。如果为空则不需要认证 + APIKey string `json:"apiKey,omitempty" yaml:"apiKey"` } var _ = &AdminConfig{} -func (c *AdminConfig) Sanitize() { +var DefaultAdminConfig = func() AdminConfig { + return AdminConfig{ + Log: log.DefaultLogConfig(), + Store: store.DefaultStoreConfig(), + Engine: engine.DefaultResourceEngineConfig(), + Observability: observability.DefaultObservabilityConfig(), + Diagnostics: diagnostics.DefaultDiagnosticsConfig(), + Console: console.DefaultConsoleConfig(), + } +} + +func (c AdminConfig) Sanitize() { c.Engine.Sanitize() - c.Discovery.Sanitize() + for _, d := range c.Discovery { + d.Sanitize() + } c.Store.Sanitize() c.Console.Sanitize() + c.Observability.Sanitize() c.Diagnostics.Sanitize() + c.Log.Sanitize() } -func (c *AdminConfig) PostProcess() error { +func (c AdminConfig) PreProcess() error { + discoveryPreProcess := func() error { + for _, d := range c.Discovery { + if err := d.PreProcess(); err != nil { + return err + } + } + return nil + } + return multierr.Combine( + c.Engine.PreProcess(), + discoveryPreProcess(), + c.Store.PreProcess(), + c.Console.PreProcess(), + c.Observability.PreProcess(), + c.Diagnostics.PreProcess(), + c.Log.PreProcess(), + ) +} + +func (c AdminConfig) PostProcess() error { + discoveryPostProcess := func() error { + for _, d := range c.Discovery { + if err := d.PostProcess(); err != nil { + return err + } + } + return nil + } return multierr.Combine( c.Engine.PostProcess(), - c.Discovery.PostProcess(), + discoveryPostProcess(), c.Store.PostProcess(), c.Console.PostProcess(), + c.Observability.PostProcess(), c.Diagnostics.PostProcess(), + c.Log.PostProcess(), ) } -var DefaultAdminConfig = func() AdminConfig { - return AdminConfig{ - Mode: mode.Zone, - Store: store.DefaultStoreConfig(), - Engine: engine.DefaultResourceEngineConfig(), - Diagnostics: diagnostics.DefaultDiagnosticsConfig(), - Console: console.DefaultConsoleConfig(), - } -} - -func (c *AdminConfig) Validate() error { - if err := mode.ValidateMode(c.Mode); err != nil { - return errors.Wrap(err, "Mode Config validation failed") +func (c AdminConfig) Validate() error { + if c.Log == nil { + c.Log = log.DefaultLogConfig() + } else if err := c.Log.Validate(); err != nil { + return bizerror.Wrap(err, bizerror.ConfigError, "log config validation failed") } if c.Store == nil { c.Store = store.DefaultStoreConfig() } else if err := c.Store.Validate(); err != nil { - return errors.Wrap(err, "Store Config validation failed") + return bizerror.Wrap(err, bizerror.ConfigError, "store config validation failed") } if c.Diagnostics == nil { c.Diagnostics = diagnostics.DefaultDiagnosticsConfig() } else if err := c.Diagnostics.Validate(); err != nil { - return errors.Wrap(err, "Diagnostics Config validation failed") + return bizerror.Wrap(err, bizerror.ConfigError, "diagnostics config validation failed") } if c.Console == nil { c.Console = console.DefaultConsoleConfig() } else if err := c.Console.Validate(); err != nil { - return errors.Wrap(err, "Admin validation failed") + return bizerror.Wrap(err, bizerror.ConfigError, "console config validation failed") + } + if c.Observability == nil { + c.Observability = observability.DefaultObservabilityConfig() + } else if err := c.Observability.Validate(); err != nil { + return bizerror.Wrap(err, bizerror.ConfigError, "observability config validation failed") + } + if c.Discovery == nil || len(c.Discovery) == 0 { + return bizerror.New(bizerror.ConfigError, "discover config is needed") + } + for _, d := range c.Discovery { + if err := d.Validate(); err != nil { + return bizerror.Wrap(err, bizerror.ConfigError, "discovery config validation failed") + } + } + discoveryIDList := slice.Map(c.Discovery, func(index int, item *discovery.Config) string { + return item.ID + }) + if len(discoveryIDList) != len(slice.Unique(discoveryIDList)) { + return bizerror.New(bizerror.ConfigError, "discovery id must be unique") + } + if c.Engine == nil { + c.Engine = engine.DefaultResourceEngineConfig() + } else if err := c.Engine.Validate(); err != nil { + return bizerror.Wrap(err, bizerror.ConfigError, "engine config validation failed") } return nil } + +// FindDiscovery finds the DiscoveryConfig by id, returns nil if not found +func (c AdminConfig) FindDiscovery(id string) *discovery.Config { + for _, d := range c.Discovery { + if d.ID == id { + return d + } + } + return nil +} + +// Meshes return the mesh id list of discoveries +func (c AdminConfig) Meshes() []string { + return slice.Map(c.Discovery, func(index int, item *discovery.Config) string { + return item.ID + }) +} diff --git a/pkg/console/component.go b/pkg/console/component.go index cd9769d9f..631fcaeed 100644 --- a/pkg/console/component.go +++ b/pkg/console/component.go @@ -24,18 +24,25 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" + ginzap "github.com/gin-contrib/zap" "github.com/gin-gonic/gin" ui "github.com/apache/dubbo-admin/app/dubbo-ui" + "github.com/apache/dubbo-admin/pkg/common/bizerror" + "github.com/apache/dubbo-admin/pkg/config/app" "github.com/apache/dubbo-admin/pkg/config/console" consolectx "github.com/apache/dubbo-admin/pkg/console/context" "github.com/apache/dubbo-admin/pkg/console/model" "github.com/apache/dubbo-admin/pkg/console/router" "github.com/apache/dubbo-admin/pkg/core/logger" "github.com/apache/dubbo-admin/pkg/core/runtime" + mcpcore "github.com/apache/dubbo-admin/pkg/mcp/core" + mcphttp "github.com/apache/dubbo-admin/pkg/mcp/transport/http" + mcp_tools "github.com/apache/dubbo-admin/pkg/mcp/tools" ) func init() { @@ -43,25 +50,41 @@ func init() { } type consoleWebServer struct { - Engine *gin.Engine - cfg *console.Config - cs consolectx.Context + Engine *gin.Engine + cfg *console.Config + cs consolectx.Context + mcpPath string // MCP端点路径,用于auth中间件跳过认证 + mcpAPIKey string // MCP API密钥,用于认证 } -func (c *consoleWebServer) Type() runtime.ComponentType { - return runtime.Console +func (c *consoleWebServer) RequiredDependencies() []runtime.ComponentType { + return []runtime.ComponentType{ + runtime.ResourceManager, // Console needs Manager for resource operations + // Note: No need to list ResourceStore explicitly as Manager already depends on it + } } -func (c *consoleWebServer) SubType() runtime.ComponentSubType { - return runtime.DefaultComponentSubType +func (c *consoleWebServer) Type() runtime.ComponentType { + return runtime.Console } func (c *consoleWebServer) Order() int { - return math.MaxInt + return math.MaxInt - 5 } func (c *consoleWebServer) Init(ctx runtime.BuilderContext) error { - r := gin.Default() + c.cfg = ctx.Config().Console + + // 提前读取 MCP 配置,供 auth 中间件使用 + cfg := ctx.Config() + if cfg.MCP != nil { + if cfg.MCP.Path != "" { + c.mcpPath = cfg.MCP.Path + } + c.mcpAPIKey = cfg.MCP.APIKey + } + + r := gin.New() // Admin UI r.StaticFS("/admin", http.FS(ui.FS())) r.NoRoute(func(c *gin.Context) { @@ -79,15 +102,28 @@ func (c *consoleWebServer) Init(ctx runtime.BuilderContext) error { store := cookie.NewStore([]byte("secret")) r.Use(sessions.Sessions("session", store)) r.Use(c.authMiddleware()) + r.Use(ginzap.Ginzap(logger.Logger(), time.RFC3339, true)) + r.Use(ginzap.RecoveryWithZap(logger.Logger(), true)) c.Engine = r - c.cfg = ctx.Config().Console + gin.SetMode(string(c.cfg.GinMode)) return nil } func (c *consoleWebServer) Start(coreRt runtime.Runtime, stop <-chan struct{}) error { + // If console config is nil, skip starting (e.g., MCP mode) + if c.cfg == nil { + logger.Sugar().Info("Console config is nil, skipping console start") + // Wait for stop signal since we need to keep the component "running" + <-stop + return nil + } errChan := make(chan error) c.cs = consolectx.NewConsoleContext(coreRt) router.InitRouter(c.Engine, c.cs) + + // 注册MCP端点(如果启用) + c.registerMCPEndpoints(coreRt, c.Engine) + httpServer := c.startHttpServer(errChan) select { case <-stop: @@ -123,21 +159,101 @@ func (c *consoleWebServer) startHttpServer(errChan chan error) *http.Server { return server } +func (c *consoleWebServer) registerMCPEndpoints(coreRt runtime.Runtime, engine *gin.Engine) { + // 从runtime获取完整配置 + var cfg app.AdminConfig = coreRt.Config() + + // 检查MCP是否启用 + if cfg.MCP == nil || !cfg.MCP.Enabled { + return + } + + // 确定端点路径 + path := cfg.MCP.Path + if path == "" { + path = "/api/mcp" + } + + // 存储MCP路径和API Key供auth中间件使用 + c.mcpPath = path + c.mcpAPIKey = cfg.MCP.APIKey + + // 直接创建MCP服务器 + consoleCtx := consolectx.NewConsoleContext(coreRt) + server := mcpcore.NewServer("dubbo-admin", "1.0.0") + server.SetConsoleContext(consoleCtx) + + // 注册所有工具 + reg := server.GetRegistry() + reg.RegisterRegistrar(&mcp_tools.MetricsRegistrar{}) + reg.RegisterRegistrar(&mcp_tools.ResourceSearchRegistrar{}) + reg.RegisterRegistrar(&mcp_tools.ServiceRegistrar{}) + reg.RegisterRegistrar(&mcp_tools.DetailRegistrar{}) + reg.RegisterAll() + + // 创建HTTP处理器 + handler := mcphttp.NewHandler(server) + + // 注册路由 + engine.POST(path, func(ctx *gin.Context) { + handler.ServeHTTP(ctx.Writer, ctx.Request) + }) + + authStatus := "no-auth" + if c.mcpAPIKey != "" { + authStatus = "with-auth" + } + logger.Sugar().Infof("MCP endpoint registered at %s with %d tools (%s)", path, len(reg.List()), authStatus) +} + func (c *consoleWebServer) authMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { + return func(ctx *gin.Context) { + requestPath := ctx.Request.URL.Path + // skip login api - requestPath := c.Request.URL.Path if strings.HasSuffix(requestPath, "/login") { - c.Next() + ctx.Next() + return + } + + // check MCP endpoint authentication + isMCPRequest := requestPath == "/api/mcp" || (c.mcpPath != "" && requestPath == c.mcpPath) + if isMCPRequest { + // 如果配置了 API Key,验证 Bearer Token + if c.mcpAPIKey != "" { + authHeader := ctx.GetHeader("Authorization") + if authHeader == "" { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Missing Authorization header"}) + ctx.Abort() + return + } + // 检查 Bearer Token 格式 + if len(authHeader) < 7 || authHeader[:7] != "Bearer " { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization header format. Use: Bearer "}) + ctx.Abort() + return + } + token := authHeader[7:] + if token != c.mcpAPIKey { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"}) + ctx.Abort() + return + } + } + // API Key 验证通过或未配置 API Key,继续处理 + ctx.Next() return } - session := sessions.Default(c) + + // 其他 API 需要会话认证 + session := sessions.Default(ctx) user := session.Get("user") if user == nil { - c.JSON(http.StatusUnauthorized, model.NewUnauthorizedResp()) - c.Abort() + authErr := bizerror.New(bizerror.Unauthorized, "no access, please login") + ctx.JSON(http.StatusUnauthorized, model.NewBizErrorResp(authErr)) + ctx.Abort() return } - c.Next() + ctx.Next() } } diff --git a/pkg/core/bootstrap/init.go b/pkg/core/bootstrap/init.go new file mode 100644 index 000000000..9262029f3 --- /dev/null +++ b/pkg/core/bootstrap/init.go @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package bootstrap + +// import all components registered by init function +import ( + _ "github.com/apache/dubbo-admin/pkg/console" + _ "github.com/apache/dubbo-admin/pkg/console/counter" + _ "github.com/apache/dubbo-admin/pkg/core/discovery" + _ "github.com/apache/dubbo-admin/pkg/core/engine" + _ "github.com/apache/dubbo-admin/pkg/core/events" + _ "github.com/apache/dubbo-admin/pkg/core/governor" + _ "github.com/apache/dubbo-admin/pkg/core/manager" + _ "github.com/apache/dubbo-admin/pkg/core/store" + _ "github.com/apache/dubbo-admin/pkg/discovery/mock" + _ "github.com/apache/dubbo-admin/pkg/discovery/nacos2" + _ "github.com/apache/dubbo-admin/pkg/discovery/zk" + _ "github.com/apache/dubbo-admin/pkg/engine/kubernetes" + _ "github.com/apache/dubbo-admin/pkg/engine/mock" + _ "github.com/apache/dubbo-admin/pkg/governor/nacos2" + _ "github.com/apache/dubbo-admin/pkg/governor/zk" + _ "github.com/apache/dubbo-admin/pkg/mcp" + _ "github.com/apache/dubbo-admin/pkg/store/memory" + _ "github.com/apache/dubbo-admin/pkg/store/mysql" + _ "github.com/apache/dubbo-admin/pkg/store/postgres" +) diff --git a/pkg/mcp/component.go b/pkg/mcp/component.go new file mode 100644 index 000000000..e643d2dd6 --- /dev/null +++ b/pkg/mcp/component.go @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mcp + +import ( + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/config/app" + "github.com/apache/dubbo-admin/pkg/core/runtime" + "github.com/apache/dubbo-admin/pkg/mcp/core" + "github.com/apache/dubbo-admin/pkg/mcp/tools" + "github.com/apache/dubbo-admin/pkg/mcp/transport/http" +) + +const ( + // ComponentType MCP组件类型 + ComponentType = runtime.ComponentType("mcp") +) + +// Component MCP组件,集成到admin服务中 +type Component struct { + server *core.Server + consoleCtx consolectx.Context + cfg *app.MCPConfig +} + +func init() { + runtime.RegisterComponent(&Component{}) +} + +// Type 返回组件类型 +func (c *Component) Type() runtime.ComponentType { + return ComponentType +} + +// Order 返回组件启动顺序 +func (c *Component) Order() int { + return 999 // 在Console之后启动 +} + +// RequiredDependencies 返回依赖的组件 +func (c *Component) RequiredDependencies() []runtime.ComponentType { + return []runtime.ComponentType{ + runtime.Console, // 依赖Console(需要runtime) + } +} + +// Init 初始化MCP组件 +func (c *Component) Init(ctx runtime.BuilderContext) error { + cfg := ctx.Config() + + // 从admin配置中获取MCP配置 + if cfg.MCP == nil { + // MCP未配置,使用默认配置(禁用) + c.cfg = &app.MCPConfig{ + Enabled: false, + Path: "/api/mcp", + } + return nil + } + + c.cfg = cfg.MCP + + // 如果未启用,直接返回 + if !c.cfg.Enabled { + return nil + } + + return nil +} + +// Start 启动MCP组件 +func (c *Component) Start(coreRt runtime.Runtime, stop <-chan struct{}) error { + // 如果未启用,直接返回 + if c.cfg == nil || !c.cfg.Enabled { + <-stop + return nil + } + + // 获取console context(由Console组件创建) + c.consoleCtx = consolectx.NewConsoleContext(coreRt) + + // 创建MCP服务器 + c.server = core.NewServer("dubbo-admin", "1.0.0") + c.server.SetConsoleContext(c.consoleCtx) + + // 注册所有工具 + reg := c.server.GetRegistry() + reg.RegisterRegistrar(&tools.MetricsRegistrar{}) + reg.RegisterRegistrar(&tools.ResourceSearchRegistrar{}) + reg.RegisterRegistrar(&tools.ServiceRegistrar{}) + reg.RegisterRegistrar(&tools.DetailRegistrar{}) + reg.RegisterAll() + + // 获取console的HTTP引擎并注册MCP路由 + consoleComp, err := coreRt.GetComponent(runtime.Console) + if err != nil { + // Console未启动,无法注册MCP路由 + <-stop + return nil + } + + // 通过runtime获取gin.Engine + // 这里需要一种方式来访问console组件的gin.Engine + // 暂时使用全局注册方式 + RegisterMCPRoutes(c.server, c.cfg.Path) + + _ = consoleComp // 暂时忽略,实际使用时需要访问console的gin.Engine + + // 等待停止信号 + <-stop + + return nil +} + +// RegisterMCPRoutes 注册MCP路由到gin引擎 +// 这个函数应该在console路由初始化后调用 +// 注意:由于这个函数使用gin.Default(),它创建了一个新的引擎实例 +// 在实际使用中,应该由console组件调用并注册到它自己的引擎上 +func RegisterMCPRoutes(server *core.Server, path string) { + if path == "" { + path = "/api/mcp" + } + + handler := http.NewHandler(server) + + // 注册MCP端点(不需要认证) + // 注意:这只是示例代码,实际注册应该在console组件中完成 + _ = handler // 避免未使用警告 + _ = path // 避免未使用警告 +} + +// GetServer 获取MCP服务器实例 +func (c *Component) GetServer() *core.Server { + return c.server +} diff --git a/pkg/mcp/core/builder.go b/pkg/mcp/core/builder.go new file mode 100644 index 000000000..56f4bd8a7 --- /dev/null +++ b/pkg/mcp/core/builder.go @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package core + +import ( + consolectx "github.com/apache/dubbo-admin/pkg/console/context" +) + +// ServerBuilder 服务器构建器 +type ServerBuilder struct { + name string + version string + consoleContext consolectx.Context +} + +// NewServerBuilder 创建服务器构建器 +func NewServerBuilder() *ServerBuilder { + return &ServerBuilder{ + name: "mcp-server", + version: "1.0.0", + } +} + +// WithName 设置服务器名称 +func (b *ServerBuilder) WithName(name string) *ServerBuilder { + b.name = name + return b +} + +// WithVersion 设置服务器版本 +func (b *ServerBuilder) WithVersion(version string) *ServerBuilder { + b.version = version + return b +} + +// WithConsoleContext 设置 console context +func (b *ServerBuilder) WithConsoleContext(ctx consolectx.Context) *ServerBuilder { + b.consoleContext = ctx + return b +} + +// Build 构建服务器 +func (b *ServerBuilder) Build() *Server { + server := NewServer(b.name, b.version) + if b.consoleContext != nil { + server.SetConsoleContext(b.consoleContext) + } + return server +} diff --git a/pkg/mcp/core/constants.go b/pkg/mcp/core/constants.go new file mode 100644 index 000000000..b66c98ee9 --- /dev/null +++ b/pkg/mcp/core/constants.go @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package core + +const ( + // JSONRPCVersion JSON-RPC 协议版本 + JSONRPCVersion = "2.0" + + // ProtocolVersion MCP 协议版本 + ProtocolVersion = "2024-11-05" + + // MethodInitialize 初始化方法 + MethodInitialize = "initialize" + // MethodToolsList 工具列表方法 + MethodToolsList = "tools/list" + // MethodToolsCall 工具调用方法 + MethodToolsCall = "tools/call" + + // ContentTypeText 文本内容类型 + ContentTypeText = "text" +) + +// JSONRPC 错误码 +const ( + ErrCodeParseError = -32700 + ErrCodeInvalidRequest = -32600 + ErrCodeMethodNotFound = -32601 + ErrCodeInvalidParams = -32602 + ErrCodeInternalError = -32603 +) diff --git a/pkg/mcp/core/server.go b/pkg/mcp/core/server.go new file mode 100644 index 000000000..d113059e6 --- /dev/null +++ b/pkg/mcp/core/server.go @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package core + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/mcp/registry" + "github.com/apache/dubbo-admin/pkg/mcp/types" + "github.com/gin-gonic/gin" +) + +// Server MCP 服务器 +type Server struct { + name string + version string + registry *registry.Registry + consoleContext consolectx.Context +} + +// NewServer 创建 MCP 服务器 +func NewServer(name, version string) *Server { + return &Server{ + name: name, + version: version, + registry: registry.NewRegistry(), + } +} + +// NewServerWithRegistry 使用指定 Registry 创建 MCP 服务器 +func NewServerWithRegistry(name, version string, reg *registry.Registry) *Server { + return &Server{ + name: name, + version: version, + registry: reg, + } +} + +// GetRegistry 获取工具注册表 +func (s *Server) GetRegistry() *registry.Registry { + return s.registry +} + +// SetConsoleContext 设置 console context +func (s *Server) SetConsoleContext(ctx consolectx.Context) { + s.consoleContext = ctx +} + +// GetConsoleContext 获取 console context +func (s *Server) GetConsoleContext() consolectx.Context { + return s.consoleContext +} + +// ==================== 请求处理 ==================== + +// HandleRequest 处理 JSON-RPC 请求(公开方法,供 transport 层使用) +func (s *Server) HandleRequest(req *JSONRPCRequest) *JSONRPCResponse { + return s.handleRequest(req) +} + +// HandleHTTP 处理 HTTP 请求 +func (s *Server) HandleHTTP(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + s.respondWithError(c, nil, ErrCodeParseError, "Parse error") + return + } + + var req JSONRPCRequest + if err := json.NewDecoder(bytes.NewReader(body)).Decode(&req); err != nil { + s.respondWithError(c, nil, ErrCodeParseError, "Parse error") + return + } + + c.JSON(http.StatusOK, s.handleRequest(&req)) +} + +// respondWithError 返回错误响应 +func (s *Server) respondWithError(c *gin.Context, id interface{}, code int, message string) { + c.JSON(http.StatusBadRequest, JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + Error: &JSONRPCError{ + Code: code, + Message: message, + }, + }) +} + +// handleRequest 处理 JSON-RPC 请求 +func (s *Server) handleRequest(req *JSONRPCRequest) *JSONRPCResponse { + switch req.Method { + case MethodInitialize: + return s.handleInitialize(req) + case MethodToolsList: + return s.handleToolsList(req) + case MethodToolsCall: + return s.handleToolsCall(req) + default: + return s.methodNotFoundResponse(req) + } +} + +// methodNotFoundResponse 方法未找到响应 +func (s *Server) methodNotFoundResponse(req *JSONRPCRequest) *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: req.ID, + Error: &JSONRPCError{ + Code: ErrCodeMethodNotFound, + Message: fmt.Sprintf("Method not found: %s", req.Method), + }, + } +} + +// newErrorResponse 创建错误响应 +func (s *Server) newErrorResponse(id interface{}, code int, message string) *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: id, + Error: &JSONRPCError{ + Code: code, + Message: message, + }, + } +} + +// handleInitialize 处理 initialize 请求 +func (s *Server) handleInitialize(req *JSONRPCRequest) *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: req.ID, + Result: InitializeResult{ + ProtocolVersion: ProtocolVersion, + ServerInfo: ServerInfo{ + Name: s.name, + Version: s.version, + }, + Capabilities: ServerCapabilities{ + Tools: &ToolsCapability{}, + }, + }, + } +} + +// handleToolsList 处理 tools/list 请求 +func (s *Server) handleToolsList(req *JSONRPCRequest) *JSONRPCResponse { + toolDefs := s.registry.List() + tools := make([]Tool, 0, len(toolDefs)) + for _, def := range toolDefs { + tools = append(tools, Tool{ + Name: def.Name, + Description: def.Description, + InputSchema: def.InputSchema, + }) + } + + return &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: req.ID, + Result: ToolListResult{Tools: tools}, + } +} + +// handleToolsCall 处理 tools/call 请求 +func (s *Server) handleToolsCall(req *JSONRPCRequest) *JSONRPCResponse { + params, ok := req.Params.(map[string]any) + if !ok { + return s.newErrorResponse(req.ID, ErrCodeInvalidParams, "Invalid params") + } + + name, ok := params["name"].(string) + if !ok { + return s.newErrorResponse(req.ID, ErrCodeInvalidParams, "Tool name is required") + } + + tool, ok := s.registry.Get(name) + if !ok { + return s.newErrorResponse(req.ID, ErrCodeMethodNotFound, "Tool not found: "+name) + } + + arguments, _ := params["arguments"].(map[string]any) + + // 验证必需参数 + if err := ValidateRequired(tool.InputSchema, arguments); err != nil { + return s.newErrorResponse(req.ID, ErrCodeInvalidParams, err.Error()) + } + + result, err := tool.Handler(s.consoleContext, arguments) + if err != nil { + return &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: req.ID, + Result: CallToolResult{ + Content: []types.Content{{Type: ContentTypeText, Text: err.Error()}}, + IsError: true, + }, + } + } + + return &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: req.ID, + Result: s.convertToCallToolResult(result), + } +} + +// convertToCallToolResult 转换 ToolResult 到 CallToolResult +func (s *Server) convertToCallToolResult(result *ToolResult) CallToolResult { + content := make([]types.Content, len(result.Content)) + for i, c := range result.Content { + content[i] = types.Content{Type: c.Type, Text: c.Text} + } + return CallToolResult{ + Content: content, + IsError: result.IsError, + } +} diff --git a/pkg/mcp/core/server_test.go b/pkg/mcp/core/server_test.go new file mode 100644 index 000000000..834759766 --- /dev/null +++ b/pkg/mcp/core/server_test.go @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package core + +import ( + "testing" + + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/mcp/registry" + "github.com/apache/dubbo-admin/pkg/mcp/types" +) + +func TestServer_NewServer(t *testing.T) { + server := NewServer("test", "1.0.0") + + if server.name != "test" { + t.Errorf("expected name 'test', got '%s'", server.name) + } + + if server.version != "1.0.0" { + t.Errorf("expected version '1.0.0', got '%s'", server.version) + } + + if server.registry == nil { + t.Error("expected registry to be initialized") + } +} + +func TestServer_NewServerWithRegistry(t *testing.T) { + reg := registry.NewRegistry() + server := NewServerWithRegistry("test", "1.0.0", reg) + + if server.registry != reg { + t.Error("expected server to use provided registry") + } +} + +func TestServer_GetRegistry(t *testing.T) { + server := NewServer("test", "1.0.0") + + reg := server.GetRegistry() + if reg == nil { + t.Error("expected registry to be returned") + } + + if reg != server.registry { + t.Error("expected returned registry to be the same as server's registry") + } +} + +func TestRegistry_Register(t *testing.T) { + reg := registry.NewRegistry() + + tool := types.ToolDef{ + Name: "test_tool", + Description: "A test tool", + InputSchema: types.InputSchema{ + Type: "object", + }, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + return types.NewTextResult("test", false), nil + }, + } + + err := reg.Register(tool) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + if !reg.Has("test_tool") { + t.Error("expected tool to be registered") + } + + if reg.Count() != 1 { + t.Errorf("expected 1 tool, got %d", reg.Count()) + } +} + +func TestRegistry_Unregister(t *testing.T) { + reg := registry.NewRegistry() + + tool := types.ToolDef{ + Name: "test_tool", + InputSchema: types.InputSchema{ + Type: "object", + }, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + return types.NewTextResult("test", false), nil + }, + } + + reg.Register(tool) + removed := reg.Unregister("test_tool") + + if !removed { + t.Error("expected tool to be removed") + } + + if reg.Has("test_tool") { + t.Error("expected tool to be unregistered") + } + + // 再次移除应该返回 false + removed = reg.Unregister("test_tool") + if removed { + t.Error("expected second removal to return false") + } +} + +func TestRegistry_Clear(t *testing.T) { + reg := registry.NewRegistry() + + tool := types.ToolDef{ + Name: "test_tool", + InputSchema: types.InputSchema{ + Type: "object", + }, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + return types.NewTextResult("test", false), nil + }, + } + + reg.Register(tool) + reg.Clear() + + if reg.Count() != 0 { + t.Errorf("expected 0 tools after clear, got %d", reg.Count()) + } +} + +func TestServerBuilder(t *testing.T) { + server := NewServerBuilder(). + WithName("custom-server"). + WithVersion("2.0.0"). + Build() + + if server.name != "custom-server" { + t.Errorf("expected name 'custom-server', got '%s'", server.name) + } + + if server.version != "2.0.0" { + t.Errorf("expected version '2.0.0', got '%s'", server.version) + } +} diff --git a/pkg/mcp/core/tool.go b/pkg/mcp/core/tool.go new file mode 100644 index 000000000..7dd07cfb7 --- /dev/null +++ b/pkg/mcp/core/tool.go @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package core + +import "github.com/apache/dubbo-admin/pkg/mcp/types" + +// 类型别名,保持向后兼容 +type ( + ToolDef = types.ToolDef + InputSchema = types.InputSchema + PropertyDef = types.PropertyDef + ToolHandler = types.ToolHandler + ToolResult = types.ToolResult + Content = types.Content +) + +// 工具结果构造函数 +var ( + NewToolResult = types.NewToolResult + NewTextResult = types.NewTextResult + NewErrorResult = types.NewErrorResult +) + +// 验证函数 +var ( + ValidateRequired = types.ValidateRequired + IsEmpty = types.IsEmpty +) diff --git a/pkg/mcp/core/types.go b/pkg/mcp/core/types.go new file mode 100644 index 000000000..0f8531a5c --- /dev/null +++ b/pkg/mcp/core/types.go @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package core + +import "github.com/apache/dubbo-admin/pkg/mcp/types" + +// JSONRPCRequest JSON-RPC 2.0 请求 +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Method string `json:"method"` + Params interface{} `json:"params"` +} + +// JSONRPCResponse JSON-RPC 2.0 响应 +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Result interface{} `json:"result,omitempty"` + Error *JSONRPCError `json:"error,omitempty"` +} + +// JSONRPCError JSON-RPC 2.0 错误 +type JSONRPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// InitializeResult 初始化结果 +type InitializeResult struct { + ProtocolVersion string `json:"protocolVersion"` + ServerInfo ServerInfo `json:"serverInfo"` + Capabilities ServerCapabilities `json:"capabilities"` +} + +// ServerInfo 服务器信息 +type ServerInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// ServerCapabilities 服务器能力 +type ServerCapabilities struct { + Tools *ToolsCapability `json:"tools,omitempty"` +} + +// ToolsCapability 工具能力 +type ToolsCapability struct { + ListChanged bool `json:"listChanged"` +} + +// ToolListResult 工具列表结果 +type ToolListResult struct { + Tools []Tool `json:"tools"` +} + +// Tool 工具定义(用于响应) +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema types.InputSchema `json:"inputSchema"` +} + +// CallToolResult 调用工具结果 +type CallToolResult struct { + Content []types.Content `json:"content"` + IsError bool `json:"isError,omitempty"` +} diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go new file mode 100644 index 000000000..af0173b78 --- /dev/null +++ b/pkg/mcp/mcp.go @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package mcp 提供 MCP (Model Context Protocol) 服务器实现。 +// +// # 包结构 +// +// - core: MCP 服务器核心功能(Server、类型定义、常量) +// - registry: 工具注册表和注册器接口 +// - tools: 内置工具实现(cluster、search、service) +// +// # 快速开始 +// +// server := mcp.NewServer("dubbo-admin", "1.0.0") +// mcp.RegisterDefaultTools(server) +// +// # 使用构建器 +// +// server := mcp.NewServerBuilder(). +// WithName("custom-server"). +// WithVersion("2.0.0"). +// Build() +// mcp.RegisterDefaultTools(server) +package mcp + +import ( + "github.com/apache/dubbo-admin/pkg/mcp/core" + "github.com/apache/dubbo-admin/pkg/mcp/registry" + "github.com/apache/dubbo-admin/pkg/mcp/tools" +) + +// 类型别名,方便外部使用 +type ( + Server = core.Server + ServerBuilder = core.ServerBuilder + ToolDef = core.ToolDef + ToolResult = core.ToolResult + ToolHandler = core.ToolHandler + InputSchema = core.InputSchema + PropertyDef = core.PropertyDef + JSONRPCRequest = core.JSONRPCRequest + JSONRPCResponse = core.JSONRPCResponse + ToolRegistrar = registry.ToolRegistrar + Registry = registry.Registry +) + +// NewServer 创建 MCP 服务器 +func NewServer(name, version string) *core.Server { + return core.NewServer(name, version) +} + +// NewServerBuilder 创建服务器构建器 +func NewServerBuilder() *core.ServerBuilder { + return core.NewServerBuilder() +} + +// NewRegistry 创建空的工具注册表 +func NewRegistry() *registry.Registry { + return registry.NewRegistry() +} + +// RegisterDefaultTools 注册所有默认工具到注册表 +func RegisterDefaultTools(reg *registry.Registry) { + reg.RegisterRegistrar(&tools.MetricsRegistrar{}) + reg.RegisterRegistrar(&tools.ResourceSearchRegistrar{}) + reg.RegisterRegistrar(&tools.ServiceRegistrar{}) + reg.RegisterAll() +} + +// DefaultRegistry 创建包含所有内置工具的注册表 +func DefaultRegistry() *registry.Registry { + reg := registry.NewRegistry() + RegisterDefaultTools(reg) + return reg +} + +// 常量导出 +const ( + JSONRPCVersion = core.JSONRPCVersion + ProtocolVersion = core.ProtocolVersion + MethodInitialize = core.MethodInitialize + MethodToolsList = core.MethodToolsList + MethodToolsCall = core.MethodToolsCall + ContentTypeText = core.ContentTypeText + ErrCodeParseError = core.ErrCodeParseError + ErrCodeInvalidRequest = core.ErrCodeInvalidRequest + ErrCodeMethodNotFound = core.ErrCodeMethodNotFound + ErrCodeInvalidParams = core.ErrCodeInvalidParams +) + +// 工具结果构造函数 +var ( + NewToolResult = core.NewToolResult + NewTextResult = core.NewTextResult + NewErrorResult = core.NewErrorResult +) diff --git a/pkg/mcp/registry/registry.go b/pkg/mcp/registry/registry.go new file mode 100644 index 000000000..9f3e91995 --- /dev/null +++ b/pkg/mcp/registry/registry.go @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package registry + +import ( + "fmt" + + "github.com/apache/dubbo-admin/pkg/mcp/types" +) + +// ToolRegistrar 工具注册器接口 +type ToolRegistrar interface { + RegisterTools(registry *Registry) +} + +// Registry 工具注册表 +type Registry struct { + tools map[string]types.ToolDef + registrars []ToolRegistrar +} + +// NewRegistry 创建工具注册表 +func NewRegistry() *Registry { + return &Registry{ + tools: make(map[string]types.ToolDef), + registrars: make([]ToolRegistrar, 0), + } +} + +// ==================== Tool CRUD 操作 ==================== + +// Register 注册单个工具 +func (r *Registry) Register(tool types.ToolDef) error { + if tool.Name == "" { + return fmt.Errorf("tool name cannot be empty") + } + r.tools[tool.Name] = tool + return nil +} + +// Unregister 注销工具 +func (r *Registry) Unregister(name string) bool { + if _, exists := r.tools[name]; !exists { + return false + } + delete(r.tools, name) + return true +} + +// Get 获取指定工具 +func (r *Registry) Get(name string) (types.ToolDef, bool) { + tool, exists := r.tools[name] + return tool, exists +} + +// Has 检查工具是否存在 +func (r *Registry) Has(name string) bool { + _, exists := r.tools[name] + return exists +} + +// List 列出所有工具 +func (r *Registry) List() []types.ToolDef { + result := make([]types.ToolDef, 0, len(r.tools)) + for _, tool := range r.tools { + result = append(result, tool) + } + return result +} + +// Count 获取工具数量 +func (r *Registry) Count() int { + return len(r.tools) +} + +// Clear 清空所有工具 +func (r *Registry) Clear() { + r.tools = make(map[string]types.ToolDef) +} + +// ==================== Registrar 管理 ==================== + +// RegisterRegistrar 注册注册器 +func (r *Registry) RegisterRegistrar(registrar ToolRegistrar) { + r.registrars = append(r.registrars, registrar) +} + +// RegisterAll 注册所有注册器的工具 +func (r *Registry) RegisterAll() { + for _, registrar := range r.registrars { + registrar.RegisterTools(r) + } +} + +// RegistrarsCount 返回注册器数量 +func (r *Registry) RegistrarsCount() int { + return len(r.registrars) +} + +// ClearRegistrars 清空注册器 +func (r *Registry) ClearRegistrars() { + r.registrars = make([]ToolRegistrar, 0) +} diff --git a/pkg/mcp/registry/registry_test.go b/pkg/mcp/registry/registry_test.go new file mode 100644 index 000000000..3a761ad1e --- /dev/null +++ b/pkg/mcp/registry/registry_test.go @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package registry + +import ( + "testing" + + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/mcp/types" +) + +// mockRegistrar 模拟注册器 +type mockRegistrar struct { + registered bool +} + +func (m *mockRegistrar) RegisterTools(reg *Registry) { + m.registered = true + reg.Register(types.ToolDef{ + Name: "mock_tool", + InputSchema: types.InputSchema{ + Type: "object", + }, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + return types.NewTextResult("mock", false), nil + }, + }) +} + +func TestRegistry_RegisterAll(t *testing.T) { + reg := NewRegistry() + + mock := &mockRegistrar{} + reg.RegisterRegistrar(mock) + + reg.RegisterAll() + + if !mock.registered { + t.Error("expected registrar to be called") + } + + if !reg.Has("mock_tool") { + t.Error("expected tool to be registered") + } +} + +func TestRegistry_RegistrarsCount(t *testing.T) { + reg := NewRegistry() + + if reg.RegistrarsCount() != 0 { + t.Errorf("expected 0 registrars, got %d", reg.RegistrarsCount()) + } + + reg.RegisterRegistrar(&mockRegistrar{}) + reg.RegisterRegistrar(&mockRegistrar{}) + + if reg.RegistrarsCount() != 2 { + t.Errorf("expected 2 registrars, got %d", reg.RegistrarsCount()) + } +} + +func TestRegistry_ClearRegistrars(t *testing.T) { + reg := NewRegistry() + reg.RegisterRegistrar(&mockRegistrar{}) + reg.RegisterRegistrar(&mockRegistrar{}) + + reg.ClearRegistrars() + + if reg.RegistrarsCount() != 0 { + t.Errorf("expected 0 registrars after clear, got %d", reg.RegistrarsCount()) + } +} + +func TestRegistry_RegisterAndGet(t *testing.T) { + reg := NewRegistry() + + tool := types.ToolDef{ + Name: "test_tool", + InputSchema: types.InputSchema{ + Type: "object", + }, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + return types.NewTextResult("test", false), nil + }, + } + + reg.Register(tool) + + // 测试 Get + retrieved, ok := reg.Get("test_tool") + if !ok { + t.Error("expected tool to be found") + } + + if retrieved.Name != "test_tool" { + t.Errorf("expected tool name 'test_tool', got '%s'", retrieved.Name) + } + + // 测试获取不存在的工具 + _, ok = reg.Get("non_existent") + if ok { + t.Error("expected non-existent tool to not be found") + } +} + +func TestRegistry_List(t *testing.T) { + reg := NewRegistry() + + reg.Register(types.ToolDef{ + Name: "tool1", + InputSchema: types.InputSchema{ + Type: "object", + }, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + return types.NewTextResult("test", false), nil + }, + }) + + reg.Register(types.ToolDef{ + Name: "tool2", + InputSchema: types.InputSchema{ + Type: "object", + }, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + return types.NewTextResult("test", false), nil + }, + }) + + tools := reg.List() + if len(tools) != 2 { + t.Errorf("expected 2 tools, got %d", len(tools)) + } +} + +func TestRegistry_RegisterEmptyName(t *testing.T) { + reg := NewRegistry() + + err := reg.Register(types.ToolDef{ + Name: "", + InputSchema: types.InputSchema{ + Type: "object", + }, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + return types.NewTextResult("test", false), nil + }, + }) + + if err == nil { + t.Error("expected error when registering tool with empty name") + } +} + diff --git a/pkg/mcp/tools/detail_tools.go b/pkg/mcp/tools/detail_tools.go new file mode 100644 index 000000000..0ffca99ef --- /dev/null +++ b/pkg/mcp/tools/detail_tools.go @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tools + +import ( + "fmt" + + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/console/model" + "github.com/apache/dubbo-admin/pkg/console/service" + "github.com/apache/dubbo-admin/pkg/mcp/types" + "github.com/apache/dubbo-admin/pkg/mcp/registry" +) + +// DetailRegistrar 详情查询工具注册器 +type DetailRegistrar struct{} + +// RegisterTools 实现 ToolRegistrar 接口 +func (r *DetailRegistrar) RegisterTools(reg *registry.Registry) { + // 获取服务的实例列表 + reg.Register(types.ToolDef{ + Name: "get_service_instances", + Description: "获取指定服务的所有实例列表", + InputSchema: types.InputSchema{ + Type: "object", + Required: []string{"serviceName"}, + Properties: map[string]types.PropertyDef{ + "serviceName": { + Type: "string", + Description: "服务名称(完整服务名)", + }, + "mesh": { + Type: "string", + Description: "Mesh 名称", + }, + "pageSize": { + Type: "integer", + Description: "每页数量", + Default: DefaultPageSize, + }, + "pageNumber": { + Type: "integer", + Description: "页码", + Default: DefaultPageNumber, + }, + }, + }, + Handler: GetServiceInstances, + }) + + // 获取实例详情 + reg.Register(types.ToolDef{ + Name: "get_instance_detail", + Description: "获取实例的详细信息", + InputSchema: types.InputSchema{ + Type: "object", + Required: []string{"instanceName"}, + Properties: map[string]types.PropertyDef{ + "instanceName": { + Type: "string", + Description: "实例名称", + }, + "mesh": { + Type: "string", + Description: "Mesh 名称", + }, + }, + }, + Handler: GetInstanceDetail, + }) + + // 获取实例指标 + reg.Register(types.ToolDef{ + Name: "get_instance_metrics", + Description: "获取实例的监控指标(qps、rt、成功率等)", + InputSchema: types.InputSchema{ + Type: "object", + Required: []string{"instanceName"}, + Properties: map[string]types.PropertyDef{ + "instanceName": { + Type: "string", + Description: "实例名称", + }, + "mesh": { + Type: "string", + Description: "Mesh 名称", + }, + }, + }, + Handler: GetInstanceMetrics, + }) + + // 获取应用详情 + reg.Register(types.ToolDef{ + Name: "get_application_detail", + Description: "获取应用的详细信息(端口、版本、协议等)", + InputSchema: types.InputSchema{ + Type: "object", + Required: []string{"appName"}, + Properties: map[string]types.PropertyDef{ + "appName": { + Type: "string", + Description: "应用名称", + }, + "mesh": { + Type: "string", + Description: "Mesh 名称", + }, + }, + }, + Handler: GetApplicationDetail, + }) + + // 获取应用的实例列表 + reg.Register(types.ToolDef{ + Name: "get_application_instances", + Description: "获取应用的所有实例", + InputSchema: types.InputSchema{ + Type: "object", + Required: []string{"appName"}, + Properties: map[string]types.PropertyDef{ + "appName": { + Type: "string", + Description: "应用名称", + }, + "mesh": { + Type: "string", + Description: "Mesh 名称", + }, + "pageSize": { + Type: "integer", + Description: "每页数量", + Default: DefaultPageSize, + }, + "pageNumber": { + Type: "integer", + Description: "页码", + Default: DefaultPageNumber, + }, + }, + }, + Handler: GetApplicationInstances, + }) + + // 获取应用的服务列表 + reg.Register(types.ToolDef{ + Name: "get_application_services", + Description: "获取应用提供的所有服务", + InputSchema: types.InputSchema{ + Type: "object", + Required: []string{"appName"}, + Properties: map[string]types.PropertyDef{ + "appName": { + Type: "string", + Description: "应用名称", + }, + "side": { + Type: "string", + Description: "服务端类型: provider/consumer", + Default: string(ServiceSideProvider), + Enum: []string{string(ServiceSideProvider), string(ServiceSideConsumer)}, + }, + "mesh": { + Type: "string", + Description: "Mesh 名称", + }, + "pageSize": { + Type: "integer", + Description: "每页数量", + Default: DefaultPageSize, + }, + "pageNumber": { + Type: "integer", + Description: "页码", + Default: DefaultPageNumber, + }, + }, + }, + Handler: GetApplicationServices, + }) + + // 搜索实例 + reg.Register(types.ToolDef{ + Name: "search_instances", + Description: "按应用名或关键字搜索实例", + InputSchema: types.InputSchema{ + Type: "object", + Properties: map[string]types.PropertyDef{ + "appName": { + Type: "string", + Description: "按应用名搜索", + }, + "keywords": { + Type: "string", + Description: "按关键字搜索实例名或IP", + }, + "mesh": { + Type: "string", + Description: "Mesh 名称", + }, + "pageSize": { + Type: "integer", + Description: "每页数量", + Default: DefaultPageSize, + }, + "pageNumber": { + Type: "integer", + Description: "页码", + Default: DefaultPageNumber, + }, + }, + }, + Handler: SearchInstances, + }) +} + +// GetServiceInstances 获取服务的实例列表 +func GetServiceInstances(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + helper := NewArgsHelper(args) + serviceName, ok := helper.GetRequiredString("serviceName") + if !ok || serviceName == "" { + return ErrorResult(fmt.Errorf("required parameter 'serviceName' is missing")), nil + } + + mesh := GetMeshArg(ctx, args) + pageSize := helper.GetInt("pageSize", DefaultPageSize) + pageNumber := helper.GetInt("pageNumber", DefaultPageNumber) + + // 使用 SearchInstances 按服务名搜索实例 + req := &model.SearchInstanceReq{ + Keywords: serviceName, // 用服务名作为关键字搜索 + Mesh: mesh, + PageReq: BuildPageReq(pageNumber, pageSize), + } + + result, err := service.SearchInstances(ctx, req) + if err != nil { + return ErrorResult(err), nil + } + + instances, totalCount := extractSearchInstances(result) + return JsonResult(map[string]any{ + "serviceName": serviceName, + "mesh": mesh, + "instances": instances, + "totalCount": totalCount, + }) +} + +// GetInstanceDetail 获取实例详情 +func GetInstanceDetail(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + helper := NewArgsHelper(args) + instanceName, ok := helper.GetRequiredString("instanceName") + if !ok || instanceName == "" { + return ErrorResult(fmt.Errorf("required parameter 'instanceName' is missing")), nil + } + + req := &model.InstanceDetailReq{ + InstanceName: instanceName, + Mesh: GetMeshArg(ctx, args), + } + + detail, err := service.GetInstanceDetail(ctx, req) + if err != nil { + return ErrorResult(err), nil + } + + return JsonResult(detail) +} + +// GetInstanceMetrics 获取实例指标 +func GetInstanceMetrics(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + helper := NewArgsHelper(args) + instanceName, ok := helper.GetRequiredString("instanceName") + if !ok || instanceName == "" { + return ErrorResult(fmt.Errorf("required parameter 'instanceName' is missing")), nil + } + + req := &model.MetricsReq{ + InstanceName: instanceName, + Mesh: GetMeshArg(ctx, args), + } + + metrics, err := service.GetInstanceMetrics(ctx, req) + if err != nil { + return ErrorResult(err), nil + } + + if len(metrics) == 0 { + return JsonResult(map[string]any{ + "instanceName": instanceName, + "metrics": []any{}, + "message": "No metrics available", + }) + } + + return JsonResult(metrics[0]) +} + +// GetApplicationDetail 获取应用详情 +func GetApplicationDetail(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + helper := NewArgsHelper(args) + appName, ok := helper.GetRequiredString("appName") + if !ok || appName == "" { + return ErrorResult(fmt.Errorf("required parameter 'appName' is missing")), nil + } + + req := &model.ApplicationDetailReq{ + AppName: appName, + Mesh: GetMeshArg(ctx, args), + } + + detail, err := service.GetApplicationDetail(ctx, req) + if err != nil { + return ErrorResult(err), nil + } + + return JsonResult(detail) +} + +// GetApplicationInstances 获取应用的实例列表 +func GetApplicationInstances(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + helper := NewArgsHelper(args) + appName, ok := helper.GetRequiredString("appName") + if !ok || appName == "" { + return ErrorResult(fmt.Errorf("required parameter 'appName' is missing")), nil + } + + req := &model.ApplicationTabInstanceInfoReq{ + AppName: appName, + Mesh: GetMeshArg(ctx, args), + PageReq: BuildPageReq( + helper.GetInt("pageNumber", DefaultPageNumber), + helper.GetInt("pageSize", DefaultPageSize), + ), + } + + result, err := service.GetAppInstanceInfo(ctx, req) + if err != nil { + return ErrorResult(err), nil + } + + instances := extractAppInstances(result) + return JsonResult(map[string]any{ + "appName": appName, + "instances": instances, + "totalCount": int(result.PageInfo.Total), + }) +} + +// GetApplicationServices 获取应用的服务列表 +func GetApplicationServices(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + helper := NewArgsHelper(args) + appName, ok := helper.GetRequiredString("appName") + if !ok || appName == "" { + return ErrorResult(fmt.Errorf("required parameter 'appName' is missing")), nil + } + + req := &model.ApplicationServiceFormReq{ + AppName: appName, + Side: helper.GetString("side", string(ServiceSideProvider)), + Mesh: GetMeshArg(ctx, args), + PageReq: BuildPageReq( + helper.GetInt("pageNumber", DefaultPageNumber), + helper.GetInt("pageSize", DefaultPageSize), + ), + } + + result, err := service.GetAppServiceInfo(ctx, req) + if err != nil { + return ErrorResult(err), nil + } + + services, _ := extractServicesFromResult(result) + return JsonResult(map[string]any{ + "appName": appName, + "side": req.Side, + "services": services, + "totalCount": int(result.PageInfo.Total), + }) +} + +// SearchInstances 搜索实例 +func SearchInstances(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + helper := NewArgsHelper(args) + appName := helper.GetString("appName", "") + keywords := helper.GetString("keywords", "") + + if appName == "" && keywords == "" { + return ErrorResult(fmt.Errorf("at least one of 'appName' or 'keywords' is required")), nil + } + + req := &model.SearchInstanceReq{ + AppName: appName, + Keywords: keywords, + Mesh: GetMeshArg(ctx, args), + PageReq: BuildPageReq( + helper.GetInt("pageNumber", DefaultPageNumber), + helper.GetInt("pageSize", DefaultPageSize), + ), + } + + result, err := service.SearchInstances(ctx, req) + if err != nil { + return ErrorResult(err), nil + } + + instances, totalCount := extractSearchInstances(result) + return JsonResult(map[string]any{ + "appName": appName, + "keywords": keywords, + "instances": instances, + "totalCount": totalCount, + }) +} + +// extractAppInstances 从应用实例结果中提取实例列表 +func extractAppInstances(result *model.SearchPaginationResult) []any { + if result == nil || result.List == nil { + return []any{} + } + + instances, ok := result.List.([]*model.AppInstanceInfoResp) + if !ok { + return []any{} + } + + resultSlice := make([]any, 0, len(instances)) + for _, inst := range instances { + if inst != nil { + resultSlice = append(resultSlice, map[string]any{ + "name": inst.Name, + "ip": inst.IP, + "appName": inst.AppName, + "deployState": inst.DeployState, + "registerState": inst.RegisterState, + "workloadName": inst.WorkloadName, + "createTime": inst.CreateTime, + "registerTime": inst.RegisterTime, + }) + } + } + return resultSlice +} + +// extractSearchInstances 从搜索结果中提取实例列表 +func extractSearchInstances(result *model.SearchPaginationResult) ([]any, int) { + if result == nil || result.List == nil { + return []any{}, 0 + } + + instances, ok := result.List.([]*model.SearchInstanceResp) + if !ok { + return []any{}, 0 + } + + resultSlice := make([]any, 0, len(instances)) + for _, inst := range instances { + if inst != nil { + resultSlice = append(resultSlice, map[string]any{ + "name": inst.Name, + "appName": inst.AppName, + "ip": inst.Ip, + "workloadName": inst.WorkloadName, + "deployState": inst.DeployState, + "deployCluster": inst.DeployCluster, + "registerState": inst.RegisterState, + "registerClusters": inst.RegisterClusters, + "createTime": inst.CreateTime, + "registerTime": inst.RegisterTime, + }) + } + } + return resultSlice, int(result.PageInfo.Total) +} + +// Ensure DetailRegistrar implements ToolRegistrar +var _ registry.ToolRegistrar = (*DetailRegistrar)(nil) diff --git a/pkg/mcp/tools/integration_test.go b/pkg/mcp/tools/integration_test.go new file mode 100644 index 000000000..2dd8d8651 --- /dev/null +++ b/pkg/mcp/tools/integration_test.go @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tools + +import ( + "encoding/json" + "testing" + + "github.com/apache/dubbo-admin/pkg/mcp/core" + "github.com/apache/dubbo-admin/pkg/mcp/registry" + "github.com/apache/dubbo-admin/pkg/mcp/transport/stdio" + "github.com/apache/dubbo-admin/pkg/mcp/types" + appcfg "github.com/apache/dubbo-admin/pkg/config/app" + ctx "context" +) + +// TestTools_E2E 通过 MCP Server 端到端测试工具 +// 这展示了如何在实际场景中测试 MCP 工具 +func TestTools_E2E(t *testing.T) { + // 创建 MCP 服务器 + server := core.NewServer("dubbo-admin-test", "1.0.0") + + // 注册所有默认工具 + reg := server.GetRegistry() + reg.RegisterRegistrar(&MetricsRegistrar{}) + reg.RegisterRegistrar(&ResourceSearchRegistrar{}) + reg.RegisterRegistrar(&ServiceRegistrar{}) + reg.RegisterAll() + + // 验证工具已注册 + tools := reg.List() + t.Logf("Registered %d tools:", len(tools)) + for _, tool := range tools { + t.Logf(" - %s: %s", tool.Name, tool.Description) + } + + // 测试获取工具列表 + t.Run("ToolsList", func(t *testing.T) { + req := &core.JSONRPCRequest{ + JSONRPC: core.JSONRPCVersion, + ID: 1, + Method: core.MethodToolsList, + } + + resp := server.HandleRequest(req) + if resp.Error != nil { + t.Fatalf("Tools list failed: %s", resp.Error.Message) + } + + // 将响应序列化为 JSON 以验证格式 + respData, err := json.MarshalIndent(resp, "", " ") + if err != nil { + t.Fatalf("Failed to marshal response: %v", err) + } + t.Logf("Tools list response:\n%s", string(respData)) + + // Result 是 ToolListResult,通过 JSON 反序列化来验证 + var result map[string]any + err = json.Unmarshal(respData, &result) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + tools := result["result"].(map[string]any)["tools"].([]any) + + // 验证核心工具存在 + toolNames := make(map[string]bool) + for _, t := range tools { + tool := t.(map[string]any) + toolNames[tool["name"].(string)] = true + } + + expectedTools := []string{ + "get_cluster_info", + "global_search", + "search_services", + "get_service_detail", + } + + for _, name := range expectedTools { + if !toolNames[name] { + t.Errorf("Expected tool '%s' not found", name) + } + } + + t.Logf("Found %d tools in tools/list response", len(tools)) + }) + + // 测试工具定义可以正确转换为 JSON + t.Run("ToolSchemaSerialization", func(t *testing.T) { + tools := reg.List() + + for _, tool := range tools { + // 创建不包含 Handler 的工具定义用于序列化 + toolForMarshal := struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema types.InputSchema `json:"inputSchema"` + }{ + Name: tool.Name, + Description: tool.Description, + InputSchema: tool.InputSchema, + } + + // 尝试序列化为 JSON + data, err := json.Marshal(toolForMarshal) + if err != nil { + t.Errorf("Failed to marshal tool %s: %v", tool.Name, err) + continue + } + + // 验证可以反序列化 + var unmarshaled map[string]any + err = json.Unmarshal(data, &unmarshaled) + if err != nil { + t.Errorf("Failed to unmarshal tool %s: %v", tool.Name, err) + } + } + }) + + // 测试工具 schema 验证 + t.Run("ToolSchemaValidation", func(t *testing.T) { + tools := reg.List() + + for _, tool := range tools { + t.Run(tool.Name, func(t *testing.T) { + // 验证必需参数可以正确检测 + if len(tool.InputSchema.Required) > 0 { + // 缺少必需参数 + invalidArgs := map[string]any{} + + err := core.ValidateRequired(tool.InputSchema, invalidArgs) + if err == nil { + t.Errorf("Tool %s: expected validation error for missing required params", tool.Name) + } + } + + // 空参数应该通过(如果无必需参数) + if len(tool.InputSchema.Required) == 0 { + validArgs := map[string]any{} + err := core.ValidateRequired(tool.InputSchema, validArgs) + if err != nil { + t.Errorf("Tool %s: unexpected validation error: %v", tool.Name, err) + } + } + }) + } + }) +} + +// TestTools_Manual 手动测试示例 +// 这展示了如何在需要时手动调用工具 handler +func TestTools_Manual(t *testing.T) { + t.Skip("跳过手动测试示例 - 需要真实的服务依赖") + + // 示例:如何直接调用 GetClusterInfo + // 注意:这需要真实的 CounterManager + /* + ctx := setupRealContext() + args := map[string]any{ + "mesh": "test-mesh", + } + + result, err := GetClusterInfo(ctx, args) + if err != nil { + t.Fatalf("GetClusterInfo failed: %v", err) + } + + t.Logf("Result: %s", result.Content[0].Text) + */ +} + +// MockContext 用于测试的 mock context +type MockContext struct { + meshName string + config *appcfg.AdminConfig +} + +func NewMockContext(meshName string) *MockContext { + return &MockContext{ + meshName: meshName, + config: &appcfg.AdminConfig{ + // 设置必要的配置 + }, + } +} + +func (m *MockContext) Config() appcfg.AdminConfig { + if m.config == nil { + m.config = &appcfg.AdminConfig{} + } + return *m.config +} + +func (m *MockContext) CounterManager() interface{} { + // 返回 mock CounterManager + return nil +} + +func (m *MockContext) AppContext() ctx.Context { + return ctx.Background() +} + +// TestGetClusterInfoSchema 测试 get_cluster_info 的 schema +func TestGetClusterInfoSchema(t *testing.T) { + reg := registry.NewRegistry() + registrar := &MetricsRegistrar{} + registrar.RegisterTools(reg) + + tool, ok := reg.Get("get_cluster_info") + if !ok { + t.Fatal("Tool get_cluster_info not registered") + } + + // 验证 schema + t.Run("Schema", func(t *testing.T) { + if tool.Name != "get_cluster_info" { + t.Errorf("Expected name 'get_cluster_info', got '%s'", tool.Name) + } + + if tool.InputSchema.Type != "object" { + t.Errorf("Expected type 'object', got '%s'", tool.InputSchema.Type) + } + + // 验证可选参数 + if _, ok := tool.InputSchema.Properties["mesh"]; !ok { + t.Error("Missing 'mesh' property") + } + }) + + t.Run("Arguments", func(t *testing.T) { + // 无参数调用(使用默认 mesh) + args := map[string]any{} + err := core.ValidateRequired(tool.InputSchema, args) + if err != nil { + t.Errorf("Validation failed for empty args: %v", err) + } + + // 指定 mesh 参数 + args = map[string]any{"mesh": "custom-mesh"} + err = core.ValidateRequired(tool.InputSchema, args) + if err != nil { + t.Errorf("Validation failed for mesh arg: %v", err) + } + }) +} + +// TestGlobalSearchSchema 测试 global_search 的 schema +func TestGlobalSearchSchema(t *testing.T) { + reg := registry.NewRegistry() + registrar := &ResourceSearchRegistrar{} + registrar.RegisterTools(reg) + + tool, ok := reg.Get("global_search") + if !ok { + t.Fatal("Tool global_search not registered") + } + + t.Run("RequiredParameters", func(t *testing.T) { + // keyword 现在是可选的,空关键字返回所有数据 + if len(tool.InputSchema.Required) != 0 { + t.Errorf("Expected 0 required parameters (keyword is optional), got %d", len(tool.InputSchema.Required)) + } + }) + + t.Run("Validation", func(t *testing.T) { + // 所有参数都是可选的,不应有验证错误 + args := map[string]any{} + err := core.ValidateRequired(tool.InputSchema, args) + if err != nil { + t.Errorf("Expected no validation error for empty args: %v", err) + } + + // 空字符串也是有效的(返回所有数据) + args = map[string]any{"keyword": ""} + err = core.ValidateRequired(tool.InputSchema, args) + if err != nil { + t.Errorf("Expected no validation error for empty keyword: %v", err) + } + + // 有效参数 + args = map[string]any{"keyword": "test-service"} + err = core.ValidateRequired(tool.InputSchema, args) + if err != nil { + t.Errorf("Validation failed for valid args: %v", err) + } + }) + + t.Run("OptionalParameters", func(t *testing.T) { + // 验证所有可选参数存在 + expectedProps := []string{"keyword", "searchType", "mesh", "pageSize", "pageNumber"} + for _, prop := range expectedProps { + if _, ok := tool.InputSchema.Properties[prop]; !ok { + t.Errorf("Missing property '%s'", prop) + } + } + }) +} + +// TestServiceToolsSchema 测试服务工具的 schema +func TestServiceToolsSchema(t *testing.T) { + reg := registry.NewRegistry() + registrar := &ServiceRegistrar{} + registrar.RegisterTools(reg) + + t.Run("SearchServices", func(t *testing.T) { + tool, ok := reg.Get("search_services") + if !ok { + t.Fatal("Tool search_services not registered") + } + + // 无必需参数 + if len(tool.InputSchema.Required) != 0 { + t.Errorf("Expected no required parameters, got %v", tool.InputSchema.Required) + } + + // 测试验证 + args := map[string]any{} + err := core.ValidateRequired(tool.InputSchema, args) + if err != nil { + t.Errorf("Validation failed unexpectedly: %v", err) + } + }) + + t.Run("GetServiceDetail", func(t *testing.T) { + tool, ok := reg.Get("get_service_detail") + if !ok { + t.Fatal("Tool get_service_detail not registered") + } + + // serviceName 是必需的 + if len(tool.InputSchema.Required) != 1 { + t.Errorf("Expected 1 required parameter, got %d", len(tool.InputSchema.Required)) + } + + // 测试验证 + args := map[string]any{} + err := core.ValidateRequired(tool.InputSchema, args) + if err == nil { + t.Error("Expected validation error for missing 'serviceName'") + } + + args = map[string]any{"serviceName": "com.example.Service"} + err = core.ValidateRequired(tool.InputSchema, args) + if err != nil { + t.Errorf("Validation failed for valid args: %v", err) + } + + // 测试 side 参数枚举 + sideProp := tool.InputSchema.Properties["side"] + actualEnum := sideProp.Enum + if len(actualEnum) != 2 { + t.Errorf("Expected enum with 2 values, got %d", len(actualEnum)) + } + }) +} + +// Example_testUsage 演示如何通过 stdio transport 使用工具 +func Example_testUsage() { + // 创建服务器 + server := core.NewServer("dubbo-admin", "1.0.0") + reg := server.GetRegistry() + + // 注册工具 + reg.RegisterRegistrar(&MetricsRegistrar{}) + reg.RegisterRegistrar(&ResourceSearchRegistrar{}) + reg.RegisterRegistrar(&ServiceRegistrar{}) + reg.RegisterAll() + + // 创建 stdio transport + transport := stdio.NewTransport(server) + + // 在实际使用中,你会这样启动: + // transport.Serve(context.Background()) + + // 或者使用自定义 io 进行测试 + // transport := stdio.NewTransportWithIO(server, stdin, stdout) + + _ = transport +} + +// TestDefaultRegistry 测试默认注册表 +func TestDefaultRegistry(t *testing.T) { + // 使用 DefaultRegistry 创建包含所有工具的注册表 + reg := registry.NewRegistry() + + // 注册所有默认工具 + reg.RegisterRegistrar(&MetricsRegistrar{}) + reg.RegisterRegistrar(&ResourceSearchRegistrar{}) + reg.RegisterRegistrar(&ServiceRegistrar{}) + reg.RegisterAll() + + tools := reg.List() + t.Logf("DefaultRegistry has %d tools", len(tools)) + + // 验证核心工具存在 + expectedTools := []string{ + "get_cluster_info", + "global_search", + "search_services", + "get_service_detail", + } + + toolNames := make(map[string]bool) + for _, tool := range tools { + toolNames[tool.Name] = true + } + + for _, name := range expectedTools { + if !toolNames[name] { + t.Errorf("Expected tool '%s' not found in default registry", name) + } + } +} diff --git a/pkg/mcp/tools/live_test.go b/pkg/mcp/tools/live_test.go new file mode 100644 index 000000000..a96177bca --- /dev/null +++ b/pkg/mcp/tools/live_test.go @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tools + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "time" + + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/mcp/core" +) + +// LiveTestOptions 真实测试选项 +type LiveTestOptions struct { + ConfigPath string + Mesh string + Keyword string + ServiceName string +} + +// LiveTester 真实数据测试器 +type LiveTester struct { + ctx context.Context + consoleCtx consolectx.Context + server *core.Server + options *LiveTestOptions +} + +// NewLiveTester 创建真实数据测试器 +func NewLiveTester(configPath string) (*LiveTester, error) { + // 这里需要导入 bootstrap 包来初始化真实的 runtime + // 但为了避免循环导入,我们通过依赖注入的方式接收 context + + return &LiveTester{ + ctx: context.Background(), + options: &LiveTestOptions{ConfigPath: configPath}, + }, nil +} + +// SetConsoleContext 设置 console context(由外部初始化后传入) +func (t *LiveTester) SetConsoleContext(ctx consolectx.Context) { + t.consoleCtx = ctx +} + +// InitServer 初始化 MCP 服务器 +func (t *LiveTester) InitServer() { + t.server = core.NewServer("dubbo-admin-live-test", "1.0.0") + reg := t.server.GetRegistry() + + // 注册所有工具 + reg.RegisterRegistrar(&MetricsRegistrar{}) + reg.RegisterRegistrar(&ResourceSearchRegistrar{}) + reg.RegisterRegistrar(&ServiceRegistrar{}) + reg.RegisterAll() + + // 设置 console context + t.server.SetConsoleContext(t.consoleCtx) +} + +// TestClusterInfo 测试获取集群信息 +func (t *LiveTester) TestClusterInfo(mesh string) (string, error) { + t.options.Mesh = mesh + + args := map[string]any{ + "mesh": mesh, + } + + result, err := GetClusterInfo(t.consoleCtx, args) + if err != nil { + return "", fmt.Errorf("GetClusterInfo failed: %w", err) + } + + return result.Content[0].Text, nil +} + +// TestGlobalSearch 测试全局搜索 +func (t *LiveTester) TestGlobalSearch(keyword, searchType, mesh string) (string, error) { + args := map[string]any{ + "keyword": keyword, + "searchType": searchType, + "mesh": mesh, + "pageSize": 10, + "pageNumber": 1, + } + + result, err := GlobalSearch(t.consoleCtx, args) + if err != nil { + return "", fmt.Errorf("GlobalSearch failed: %w", err) + } + + return result.Content[0].Text, nil +} + +// TestSearchServices 测试搜索服务 +func (t *LiveTester) TestSearchServices(keywords, mesh string) (string, error) { + args := map[string]any{ + "keywords": keywords, + "mesh": mesh, + "pageSize": 10, + "pageNumber": 1, + } + + result, err := SearchServices(t.consoleCtx, args) + if err != nil { + return "", fmt.Errorf("SearchServices failed: %w", err) + } + + return result.Content[0].Text, nil +} + +// TestGetServiceDetail 测试获取服务详情 +func (t *LiveTester) TestGetServiceDetail(serviceName, group, version, side, mesh string) (string, error) { + args := map[string]any{ + "serviceName": serviceName, + "group": group, + "version": version, + "side": side, + "mesh": mesh, + } + + result, err := GetServiceDetail(t.consoleCtx, args) + if err != nil { + return "", fmt.Errorf("GetServiceDetail failed: %w", err) + } + + return result.Content[0].Text, nil +} + +// TestToolViaServer 通过服务器测试工具调用 +func (t *LiveTester) TestToolViaServer(toolName string, args map[string]any) (*core.JSONRPCResponse, error) { + req := &core.JSONRPCRequest{ + JSONRPC: core.JSONRPCVersion, + ID: 1, + Method: core.MethodToolsCall, + Params: map[string]any{ + "name": toolName, + "arguments": args, + }, + } + + resp := t.server.HandleRequest(req) + if resp.Error != nil { + return nil, fmt.Errorf("tool call failed: %s", resp.Error.Message) + } + + return resp, nil +} + +// PrintResult 打印结果 +func (t *LiveTester) PrintResult(name string, result string) { + fmt.Printf("\n=== %s ===\n", name) + fmt.Printf("Result:\n%s\n", result) + fmt.Printf("==================\n\n") +} + +// RunAllTests 运行所有测试 +func (t *LiveTester) RunAllTests(mesh, keyword, serviceName string) error { + fmt.Println("🧪 Starting live MCP tools tests...") + + // 1. 测试集群信息 + fmt.Println("1️⃣ Testing GetClusterInfo...") + clusterInfo, err := t.TestClusterInfo(mesh) + if err != nil { + log.Printf("❌ GetClusterInfo failed: %v", err) + } else { + t.PrintResult("Cluster Info", clusterInfo) + } + + // 2. 测试全局搜索 + fmt.Println("2️⃣ Testing GlobalSearch...") + searchResult, err := t.TestGlobalSearch(keyword, "serviceName", mesh) + if err != nil { + log.Printf("❌ GlobalSearch failed: %v", err) + } else { + t.PrintResult("Global Search", searchResult) + } + + // 3. 测试服务搜索 + fmt.Println("3️⃣ Testing SearchServices...") + servicesResult, err := t.TestSearchServices(keyword, mesh) + if err != nil { + log.Printf("❌ SearchServices failed: %v", err) + } else { + t.PrintResult("Search Services", servicesResult) + } + + // 4. 测试服务详情(如果有服务名) + if serviceName != "" { + fmt.Println("4️⃣ Testing GetServiceDetail...") + detailResult, err := t.TestGetServiceDetail(serviceName, "", "", "provider", mesh) + if err != nil { + log.Printf("❌ GetServiceDetail failed: %v", err) + } else { + t.PrintResult("Service Detail", detailResult) + } + } + + fmt.Println("✅ All tests completed!") + return nil +} + +// RunViaJSONRPC 通过 JSON-RPC 接口测试 +func (t *LiveTester) RunViaJSONRPC(mesh, keyword, serviceName string) error { + fmt.Println("🧪 Testing via JSON-RPC interface...") + + // 测试 tools/list + fmt.Println("\n1️⃣ Testing tools/list...") + listReq := &core.JSONRPCRequest{ + JSONRPC: core.JSONRPCVersion, + ID: 1, + Method: core.MethodToolsList, + } + listResp := t.server.HandleRequest(listReq) + listData, _ := json.MarshalIndent(listResp, "", " ") + fmt.Printf("Tools list:\n%s\n", string(listData)) + + // 测试 get_cluster_info + fmt.Println("\n2️⃣ Testing get_cluster_info via JSON-RPC...") + clusterResp, err := t.TestToolViaServer("get_cluster_info", map[string]any{"mesh": mesh}) + if err != nil { + return err + } + clusterData, _ := json.MarshalIndent(clusterResp, "", " ") + fmt.Printf("Cluster info response:\n%s\n", string(clusterData)) + + // 测试 global_search + fmt.Println("\n3️⃣ Testing global_search via JSON-RPC...") + searchResp, err := t.TestToolViaServer("global_search", map[string]any{ + "keyword": keyword, + "mesh": mesh, + }) + if err != nil { + return err + } + searchData, _ := json.MarshalIndent(searchResp, "", " ") + fmt.Printf("Search response:\n%s\n", string(searchData)) + + return nil +} + +// exampleLiveTest 示例:如何在其他地方使用 LiveTester +func exampleLiveTest() { + // 注意:这只是一个示例,实际使用时需要从外部获取 consoleCtx + /* + // 1. 初始化 runtime 和 console context + cfg := app.DefaultAdminConfig() + config.Load("dubbo-admin.yaml", &cfg) + rt, _ := bootstrap.Bootstrap(context.Background(), cfg) + consoleCtx := context.NewConsoleContext(rt) + + // 2. 创建测试器 + tester, _ := NewLiveTester("dubbo-admin.yaml") + tester.SetConsoleContext(consoleCtx) + tester.InitServer() + + // 3. 运行测试 + tester.RunAllTests("default", "demo", "com.example.Service") + */ +} + +// main 函数可以作为独立的测试工具运行 +func main() { + configPath := flag.String("config", "./dubbo-admin.yaml", "配置文件路径") + mesh := flag.String("mesh", "default", "Mesh 名称") + keyword := flag.String("keyword", "", "搜索关键字") + serviceName := flag.String("service", "", "服务名称") + flag.Parse() + + if *configPath == "" { + fmt.Println("请提供配置文件路径: -config ") + os.Exit(1) + } + + fmt.Printf("📋 配置: %s\n", *configPath) + fmt.Printf("🔍 Mesh: %s\n", *mesh) + if *keyword != "" { + fmt.Printf("🔑 关键字: %s\n", *keyword) + } + if *serviceName != "" { + fmt.Printf("🛠️ 服务: %s\n", *serviceName) + } + + // 注意:这里需要真实的 console context + // 实际使用时需要从外部注入 + fmt.Println("\n⚠️ 注意:此测试工具需要真实的 console context") + fmt.Println("⚠️ 请通过代码方式调用,在初始化 runtime 后传入 context") + fmt.Println("\n示例代码:") + fmt.Println(` + rt, _ := bootstrap.Bootstrap(context.Background(), cfg) + consoleCtx := context.NewConsoleContext(rt) + + tester := &LiveTester{} + tester.SetConsoleContext(consoleCtx) + tester.InitServer() + tester.RunAllTests("default", "demo", "com.example.Service") + `) + + // 如果有环境变量指定了真实模式,则尝试运行 + if os.Getenv("MCP_LIVE_TEST") == "true" { + fmt.Println("\n🚀 Running in live test mode...") + // 这里需要实际的初始化代码 + // 由于循环导入问题,需要在实际使用的地方实现 + } +} + +// LiveTestHelper 辅助函数,用于在已有 context 的地方进行测试 +func LiveTestHelper(consoleCtx consolectx.Context, mesh, keyword, serviceName string) { + tester := &LiveTester{ + ctx: context.Background(), + consoleCtx: consoleCtx, + options: &LiveTestOptions{}, + } + tester.InitServer() + + // 设置超时 + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // 运行测试 + done := make(chan error) + go func() { + done <- tester.RunAllTests(mesh, keyword, serviceName) + }() + + select { + case err := <-done: + if err != nil { + log.Printf("Test failed: %v", err) + } + case <-ctx.Done(): + log.Println("Test timeout!") + } +} + +// LiveTestViaJSONRPCHelper 通过 JSON-RPC 测试的辅助函数 +func LiveTestViaJSONRPCHelper(consoleCtx consolectx.Context, mesh, keyword string) error { + tester := &LiveTester{ + ctx: context.Background(), + consoleCtx: consoleCtx, + options: &LiveTestOptions{}, + } + tester.InitServer() + + return tester.RunViaJSONRPC(mesh, keyword, "") +} diff --git a/pkg/mcp/tools/metrics.go b/pkg/mcp/tools/metrics.go new file mode 100644 index 000000000..25e01df32 --- /dev/null +++ b/pkg/mcp/tools/metrics.go @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tools + +import ( + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/console/counter" + "github.com/apache/dubbo-admin/pkg/mcp/types" + "github.com/apache/dubbo-admin/pkg/mcp/registry" + meshresource "github.com/apache/dubbo-admin/pkg/core/resource/apis/mesh/v1alpha1" +) + +// MetricsRegistrar 集群工具注册器 +type MetricsRegistrar struct{} + +// RegisterTools 实现 ToolRegistrar 接口 +func (r *MetricsRegistrar) RegisterTools(reg *registry.Registry) { + reg.Register(types.ToolDef{ + Name: "get_cluster_info", + Description: "获取 Dubbo 集群基本信息,包括应用数、服务数、实例数等统计信息", + InputSchema: types.InputSchema{ + Type: "object", + Properties: map[string]types.PropertyDef{ + "mesh": { + Type: "string", + Description: "Mesh 名称,默认使用配置中的默认 mesh", + }, + }, + }, + Handler: GetClusterInfo, + }) +} + +// GetClusterInfo 获取集群基本信息 +func GetClusterInfo(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + mesh := GetMeshArg(ctx, args) + info := collectClusterInfo(ctx, mesh) + return JsonResult(info) +} + +// collectClusterInfo 收集集群信息 +func collectClusterInfo(ctx consolectx.Context, mesh string) map[string]any { + counterMgr := ctx.CounterManager() + if counterMgr == nil { + return map[string]any{ + "mesh": mesh, + "appCount": 0, + "serviceCount": 0, + "instanceCount": 0, + "protocols": map[string]int{}, + "releases": map[string]int{}, + "discoveries": map[string]int{}, + "error": "Counter manager not available", + } + } + + return map[string]any{ + "mesh": mesh, + "appCount": counterMgr.CountByMesh(meshresource.ApplicationKind, mesh), + "serviceCount": counterMgr.CountByMesh(meshresource.ServiceProviderMetadataKind, mesh), + "instanceCount": counterMgr.CountByMesh(meshresource.InstanceKind, mesh), + "protocols": counterMgr.DistributionByMesh(counter.ProtocolCounter, mesh), + "releases": counterMgr.DistributionByMesh(counter.ReleaseCounter, mesh), + "discoveries": counterMgr.DistributionByMesh(counter.DiscoveryCounter, mesh), + } +} + +// Ensure MetricsRegistrar implements ToolRegistrar +var _ registry.ToolRegistrar = (*MetricsRegistrar)(nil) diff --git a/pkg/mcp/tools/resource_search.go b/pkg/mcp/tools/resource_search.go new file mode 100644 index 000000000..8943fdd3f --- /dev/null +++ b/pkg/mcp/tools/resource_search.go @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tools + +import ( + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/console/model" + "github.com/apache/dubbo-admin/pkg/console/service" + "github.com/apache/dubbo-admin/pkg/mcp/types" + "github.com/apache/dubbo-admin/pkg/mcp/registry" +) + +// ResourceSearchRegistrar 搜索工具注册器 +type ResourceSearchRegistrar struct{} + +// searchExecutor 搜索执行器接口 +type searchExecutor interface { + execute(ctx consolectx.Context, keyword, mesh string, pageNumber, pageSize int) (*model.SearchPaginationResult, error) + buildResult(pagedResult *model.SearchPaginationResult, keyword string, pageSize, pageNumber int) map[string]any +} + +// RegisterTools 实现 ToolRegistrar 接口 +func (r *ResourceSearchRegistrar) RegisterTools(reg *registry.Registry) { + reg.Register(types.ToolDef{ + Name: "global_search", + Description: "全局搜索,支持搜索服务、实例、应用等资源。不传 keyword 返回所有数据", + InputSchema: types.InputSchema{ + Type: "object", + Required: []string{}, // keyword 改为可选,空值返回所有数据 + Properties: map[string]types.PropertyDef{ + "keyword": { + Type: "string", + Description: "搜索关键字,为空时返回所有数据", + }, + "searchType": { + Type: "string", + Description: "搜索类型: ip(按IP搜索实例), instanceName(按实例名搜索), appName(按应用名搜索), serviceName(按服务名搜索)", + Default: string(SearchTypeName), + Enum: []string{string(SearchTypeIP), string(SearchTypeInstanceName), string(SearchTypeAppName), string(SearchTypeName)}, + }, + "mesh": { + Type: "string", + Description: "Mesh 名称,默认使用配置中的默认 mesh", + }, + "pageSize": { + Type: "integer", + Description: "每页数量", + Default: DefaultPageSize, + }, + "pageNumber": { + Type: "integer", + Description: "页码,从 1 开始", + Default: DefaultPageNumber, + }, + }, + }, + Handler: GlobalSearch, + }) +} + +// GlobalSearch 全局搜索 +func GlobalSearch(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + helper := NewArgsHelper(args) + keyword := helper.GetString("keyword", "") + + searchType := SearchType(helper.GetString("searchType", string(SearchTypeName))) + mesh := GetMeshArg(ctx, args) + pageSize := helper.GetInt("pageSize", DefaultPageSize) + pageNumber := helper.GetInt("pageNumber", DefaultPageNumber) + + executor := getSearchExecutor(searchType) + result, err := executor.execute(ctx, keyword, mesh, pageNumber, pageSize) + if err != nil { + return ErrorResult(err), nil + } + + searchResult := executor.buildResult(result, keyword, pageSize, pageNumber) + searchResult["searchType"] = string(searchType) + + return JsonResult(searchResult) +} + +// getSearchExecutor 根据搜索类型获取对应的执行器 +func getSearchExecutor(searchType SearchType) searchExecutor { + executors := map[SearchType]searchExecutor{ + SearchTypeIP: &ipSearchExecutor{}, + SearchTypeInstanceName: &instanceNameSearchExecutor{}, + SearchTypeAppName: &appNameSearchExecutor{}, + SearchTypeName: &serviceNameSearchExecutor{}, + } + + if executor, ok := executors[searchType]; ok { + return executor + } + return &serviceNameSearchExecutor{} +} + +// ipSearchExecutor IP 搜索执行器 +type ipSearchExecutor struct{} + +func (e *ipSearchExecutor) execute(ctx consolectx.Context, keyword, mesh string, pageNumber, pageSize int) (*model.SearchPaginationResult, error) { + req := buildSearchReq(keyword, mesh, pageNumber, pageSize) + return service.SearchInstanceByIp(ctx, req) +} + +func (e *ipSearchExecutor) buildResult(pagedResult *model.SearchPaginationResult, keyword string, pageSize, pageNumber int) map[string]any { + instances, totalCount := extractInstances(pagedResult) + return map[string]any{ + "keyword": keyword, + "pageSize": pageSize, + "pageNumber": pageNumber, + "instances": instances, + "totalCount": totalCount, + } +} + +// instanceNameSearchExecutor 实例名搜索执行器 +type instanceNameSearchExecutor struct{} + +func (e *instanceNameSearchExecutor) execute(ctx consolectx.Context, keyword, mesh string, pageNumber, pageSize int) (*model.SearchPaginationResult, error) { + req := buildSearchReq(keyword, mesh, pageNumber, pageSize) + return service.SearchInstanceByName(ctx, req) +} + +func (e *instanceNameSearchExecutor) buildResult(pagedResult *model.SearchPaginationResult, keyword string, pageSize, pageNumber int) map[string]any { + instances, totalCount := extractInstances(pagedResult) + return map[string]any{ + "keyword": keyword, + "pageSize": pageSize, + "pageNumber": pageNumber, + "instances": instances, + "totalCount": totalCount, + } +} + +// appNameSearchExecutor 应用名搜索执行器 +type appNameSearchExecutor struct{} + +func (e *appNameSearchExecutor) execute(ctx consolectx.Context, keyword, mesh string, pageNumber, pageSize int) (*model.SearchPaginationResult, error) { + // 使用 SearchApplications 而不是 SearchApplicationsByKeywords + // SearchApplications 会正确处理空 keyword 的情况(返回所有应用的分页列表) + req := &model.ApplicationSearchReq{ + Keywords: keyword, + Mesh: mesh, + PageReq: BuildPageReq(pageNumber, pageSize), + } + return service.SearchApplications(ctx, req) +} + +func (e *appNameSearchExecutor) buildResult(pagedResult *model.SearchPaginationResult, keyword string, pageSize, pageNumber int) map[string]any { + apps := extractApplicationsFromResult(pagedResult) + return map[string]any{ + "keyword": keyword, + "pageSize": pageSize, + "pageNumber": pageNumber, + "applications": apps, + "totalCount": len(apps), + } +} + +// serviceNameSearchExecutor 服务名搜索执行器 +type serviceNameSearchExecutor struct{} + +func (e *serviceNameSearchExecutor) execute(ctx consolectx.Context, keyword, mesh string, pageNumber, pageSize int) (*model.SearchPaginationResult, error) { + req := &model.ServiceSearchReq{ + ServiceName: "", + Keywords: keyword, + Mesh: mesh, + PageReq: BuildPageReq(pageNumber, pageSize), + } + // 空关键字时调用 SearchServices 获取所有服务,否则精确匹配 + if keyword == "" { + return service.SearchServices(ctx, req) + } + return service.SearchServicesByKeywords(ctx, req) +} + +func (e *serviceNameSearchExecutor) buildResult(pagedResult *model.SearchPaginationResult, keyword string, pageSize, pageNumber int) map[string]any { + services, totalCount := extractServicesFromResult(pagedResult) + return map[string]any{ + "keyword": keyword, + "pageSize": pageSize, + "pageNumber": pageNumber, + "services": services, + "totalCount": totalCount, + } +} + +// buildSearchReq 构建搜索请求 +func buildSearchReq(keyword, mesh string, pageNumber, pageSize int) *model.SearchReq { + req := model.NewSearchReq() + req.Keywords = keyword + req.Mesh = mesh + req.PageReq = BuildPageReq(pageNumber, pageSize) + return req +} + +// extractInstances 从分页结果中提取实例列表 +func extractInstances(pagedResult *model.SearchPaginationResult) ([]any, int) { + if pagedResult == nil || pagedResult.List == nil { + return []any{}, 0 + } + + instances, ok := pagedResult.List.([]*model.SearchInstanceResp) + if !ok { + return []any{}, 0 + } + + result := make([]any, 0, len(instances)) + for _, ins := range instances { + result = append(result, map[string]any{ + "name": ins.Name, + "appName": ins.AppName, + "ip": ins.Ip, + "workloadName": ins.WorkloadName, + "deployState": ins.DeployState, + "deployCluster": ins.DeployCluster, + "registerState": ins.RegisterState, + "registerClusters": ins.RegisterClusters, + "createTime": ins.CreateTime, + "registerTime": ins.RegisterTime, + "labels": ins.Labels, + }) + } + return result, int(pagedResult.PageInfo.Total) +} + +// extractApplicationsFromResult 从分页结果中提取应用列表 +func extractApplicationsFromResult(pagedResult *model.SearchPaginationResult) []any { + if pagedResult == nil || pagedResult.List == nil { + return []any{} + } + + apps, ok := pagedResult.List.([]*model.ApplicationSearchResp) + if !ok { + return []any{} + } + + result := make([]any, 0, len(apps)) + for _, app := range apps { + result = append(result, map[string]any{ + "appName": app.AppName, + "instanceCount": app.InstanceCount, + "deployClusters": app.DeployClusters, + "registryClusters": app.RegistryClusters, + }) + } + return result +} + +// extractServicesFromResult 从分页结果中提取服务列表 +func extractServicesFromResult(pagedResult *model.SearchPaginationResult) ([]any, int) { + if pagedResult == nil || pagedResult.List == nil { + return []any{}, 0 + } + + services, ok := pagedResult.List.([]*model.ServiceSearchResp) + if !ok { + return []any{}, 0 + } + + result := make([]any, 0, len(services)) + for _, svc := range services { + result = append(result, map[string]any{ + "serviceName": svc.ServiceName, + "version": svc.Version, + "group": svc.Group, + "providerAppName": svc.ProviderAppName, + "consumerAppName": svc.ConsumerAppName, + }) + } + return result, int(pagedResult.PageInfo.Total) +} + +// Ensure ResourceSearchRegistrar implements ToolRegistrar +var _ registry.ToolRegistrar = (*ResourceSearchRegistrar)(nil) diff --git a/pkg/mcp/tools/service_discovery.go b/pkg/mcp/tools/service_discovery.go new file mode 100644 index 000000000..0b4ff2894 --- /dev/null +++ b/pkg/mcp/tools/service_discovery.go @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tools + +import ( + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/console/model" + "github.com/apache/dubbo-admin/pkg/console/service" + "github.com/apache/dubbo-admin/pkg/mcp/types" + "github.com/apache/dubbo-admin/pkg/mcp/registry" +) + +// ServiceRegistrar 服务工具注册器 +type ServiceRegistrar struct{} + +// RegisterTools 实现 ToolRegistrar 接口 +func (r *ServiceRegistrar) RegisterTools(reg *registry.Registry) { + reg.Register(types.ToolDef{ + Name: "search_services", + Description: "搜索 Dubbo 服务,支持按服务名过滤和分页", + InputSchema: types.InputSchema{ + Type: "object", + Properties: map[string]types.PropertyDef{ + "keywords": { + Type: "string", + Description: "服务名搜索关键字,支持模糊匹配", + }, + "mesh": { + Type: "string", + Description: "Mesh 名称,默认使用配置中的默认 mesh", + }, + "pageSize": { + Type: "integer", + Description: "每页数量", + Default: DefaultPageSize, + }, + "pageNumber": { + Type: "integer", + Description: "页码,从 1 开始", + Default: DefaultPageNumber, + }, + }, + }, + Handler: SearchServices, + }) + + reg.Register(types.ToolDef{ + Name: "get_service_detail", + Description: "获取服务详情,包括服务分布和实例信息", + InputSchema: types.InputSchema{ + Type: "object", + Required: []string{"serviceName"}, + Properties: map[string]types.PropertyDef{ + "serviceName": { + Type: "string", + Description: "服务名称", + }, + "group": { + Type: "string", + Description: "服务分组", + Default: "", + }, + "version": { + Type: "string", + Description: "服务版本", + Default: "", + }, + "side": { + Type: "string", + Description: "服务端或消费者 (provider/consumer)", + Default: string(ServiceSideProvider), + Enum: []string{string(ServiceSideProvider), string(ServiceSideConsumer)}, + }, + "mesh": { + Type: "string", + Description: "Mesh 名称,默认使用配置中的默认 mesh", + }, + }, + }, + Handler: GetServiceDetail, + }) +} + +// SearchServices 搜索服务 +func SearchServices(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + helper := NewArgsHelper(args) + keywords := helper.GetString("keywords", "") + mesh := GetMeshArg(ctx, args) + pageSize := helper.GetInt("pageSize", DefaultPageSize) + pageNumber := helper.GetInt("pageNumber", DefaultPageNumber) + + req := &model.ServiceSearchReq{ + Keywords: keywords, + Mesh: mesh, + PageReq: BuildPageReq(pageNumber, pageSize), + } + + result, err := service.SearchServices(ctx, req) + if err != nil { + return ErrorResult(err), nil + } + + return buildServiceSearchResult(result, keywords, mesh, pageSize, pageNumber) +} + +// GetServiceDetail 获取服务详情 +func GetServiceDetail(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + helper := NewArgsHelper(args) + serviceName := helper.GetString("serviceName", "") + + params := serviceDetailParams{ + serviceName: serviceName, + group: helper.GetString("group", ""), + version: helper.GetString("version", ""), + mesh: GetMeshArg(ctx, args), + } + + side := ServiceSide(helper.GetString("side", string(ServiceSideProvider))) + + return fetchServiceDistribution(ctx, params, side) +} + +// serviceDetailParams 服务详情参数 +type serviceDetailParams struct { + serviceName string + group string + version string + mesh string +} + +// fetchServiceDistribution 获取服务分布信息 +func fetchServiceDistribution(ctx consolectx.Context, params serviceDetailParams, side ServiceSide) (*types.ToolResult, error) { + targetSide := string(ServiceSideConsumer) + if side == ServiceSideConsumer { + targetSide = string(ServiceSideProvider) + } + + req := &model.ServiceTabDistributionReq{ + ServiceName: params.serviceName, + Group: params.group, + Version: params.version, + Side: targetSide, + Mesh: params.mesh, + PageReq: BuildPageReq(1, MaxDistributionLimit), + } + + distribution, err := service.GetServiceTabDistribution(ctx, req) + if err != nil { + return ErrorResult(err), nil + } + + apps := extractApplications(distribution) + return JsonResult(map[string]any{ + "serviceName": params.serviceName, + "group": params.group, + "version": params.version, + "side": string(side), + "mesh": params.mesh, + "distribution": apps, + "totalApps": len(apps), + }) +} + +// extractApplications 从分页结果中提取应用列表 +func extractApplications(result *model.SearchPaginationResult) []any { + if result == nil || result.List == nil { + return []any{} + } + + apps, ok := result.List.([]*model.ApplicationSearchResp) + if !ok { + return []any{} + } + + resultSlice := make([]any, 0, len(apps)) + for _, app := range apps { + if app != nil { + resultSlice = append(resultSlice, map[string]any{ + "appName": app.AppName, + "instanceCount": app.InstanceCount, + "deployClusters": app.DeployClusters, + "registryClusters": app.RegistryClusters, + }) + } + } + return resultSlice +} + +// buildServiceSearchResult 构建服务搜索结果 +func buildServiceSearchResult(result *model.SearchPaginationResult, keywords, mesh string, pageSize, pageNumber int) (*types.ToolResult, error) { + services, totalCount := extractServices(result) + + return JsonResult(map[string]any{ + "keywords": keywords, + "mesh": mesh, + "pageSize": pageSize, + "pageNumber": pageNumber, + "services": services, + "totalCount": totalCount, + }) +} + +// extractServices 从分页结果中提取服务列表 +func extractServices(result *model.SearchPaginationResult) ([]any, int) { + if result == nil || result.List == nil { + return []any{}, 0 + } + + services, ok := result.List.([]*model.ServiceSearchResp) + if !ok { + return []any{result.List}, int(result.PageInfo.Total) + } + + resultSlice := make([]any, 0, len(services)) + for _, svc := range services { + resultSlice = append(resultSlice, map[string]any{ + "serviceName": svc.ServiceName, + "version": svc.Version, + "group": svc.Group, + "providerAppName": svc.ProviderAppName, + "consumerAppName": svc.ConsumerAppName, + }) + } + return resultSlice, int(result.PageInfo.Total) +} + +// Ensure ServiceRegistrar implements ToolRegistrar +var _ registry.ToolRegistrar = (*ServiceRegistrar)(nil) diff --git a/pkg/mcp/tools/tools_test.go b/pkg/mcp/tools/tools_test.go new file mode 100644 index 000000000..2a1297c2f --- /dev/null +++ b/pkg/mcp/tools/tools_test.go @@ -0,0 +1,485 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tools + +import ( + "errors" + "testing" + + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/mcp/types" + "github.com/apache/dubbo-admin/pkg/mcp/registry" +) + +// TestArgsHelper 测试参数辅助器 +func TestArgsHelper(t *testing.T) { + t.Run("GetString", func(t *testing.T) { + args := map[string]any{ + "name": "test", + "empty": "", + } + helper := NewArgsHelper(args) + + if v := helper.GetString("name", "default"); v != "test" { + t.Errorf("Expected 'test', got '%s'", v) + } + + if v := helper.GetString("empty", "default"); v != "" { + t.Errorf("Expected '', got '%s'", v) + } + + if v := helper.GetString("notexist", "default"); v != "default" { + t.Errorf("Expected 'default', got '%s'", v) + } + }) + + t.Run("GetInt", func(t *testing.T) { + args := map[string]any{ + "intVal": 42, + "floatVal": 3.14, + } + helper := NewArgsHelper(args) + + if v := helper.GetInt("intVal", 0); v != 42 { + t.Errorf("Expected 42, got %d", v) + } + + if v := helper.GetInt("floatVal", 0); v != 3 { + t.Errorf("Expected 3, got %d", v) + } + + if v := helper.GetInt("notexist", 10); v != 10 { + t.Errorf("Expected 10, got %d", v) + } + }) + + t.Run("GetBool", func(t *testing.T) { + args := map[string]any{ + "true": true, + "false": false, + } + helper := NewArgsHelper(args) + + if v := helper.GetBool("true", false); !v { + t.Error("Expected true") + } + + if v := helper.GetBool("false", true); v { + t.Error("Expected false") + } + + if v := helper.GetBool("notexist", true); !v { + t.Error("Expected default true") + } + }) + + t.Run("GetRequiredString", func(t *testing.T) { + args := map[string]any{ + "valid": "value", + "empty": "", + } + helper := NewArgsHelper(args) + + if v, ok := helper.GetRequiredString("valid"); !ok || v != "value" { + t.Errorf("Expected 'value', got '%s', ok=%v", v, ok) + } + + if v, ok := helper.GetRequiredString("empty"); ok || v != "" { + t.Errorf("Expected empty and false, got '%s', ok=%v", v, ok) + } + + if v, ok := helper.GetRequiredString("notexist"); ok || v != "" { + t.Errorf("Expected empty and false, got '%s', ok=%v", v, ok) + } + }) +} + +// TestBuildPageReq 测试分页请求构建 +func TestBuildPageReq(t *testing.T) { + tests := []struct { + name string + pageNumber int + pageSize int + wantOffset int + wantSize int + }{ + {"正常分页", 2, 10, 10, 10}, + {"第一页", 1, 20, 0, 20}, + {"无效页码", 0, 10, 0, 10}, + {"无效大小", 1, 0, 0, 10}, + {"负数页码", -1, 10, 0, 10}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := BuildPageReq(tt.pageNumber, tt.pageSize) + if req.PageOffset != tt.wantOffset { + t.Errorf("Expected offset %d, got %d", tt.wantOffset, req.PageOffset) + } + if req.PageSize != tt.wantSize { + t.Errorf("Expected size %d, got %d", tt.wantSize, req.PageSize) + } + }) + } +} + +// TestJsonResult 测试 JSON 结果创建 +func TestJsonResult(t *testing.T) { + data := map[string]any{ + "key": "value", + "count": 42, + } + + result, err := JsonResult(data) + if err != nil { + t.Fatalf("JsonResult failed: %v", err) + } + + if result.IsError { + t.Error("Expected non-error result") + } + + if len(result.Content) != 1 { + t.Fatalf("Expected 1 content item, got %d", len(result.Content)) + } + + // 验证包含 expected 内容 + text := result.Content[0].Text + if text == "" { + t.Error("Expected non-empty content text") + } +} + +// TestErrorResult 测试错误结果创建 +func TestErrorResult(t *testing.T) { + err := errors.New("test error") + result := ErrorResult(err) + + if !result.IsError { + t.Error("Expected error result") + } + + if len(result.Content) != 1 { + t.Fatalf("Expected 1 content item, got %d", len(result.Content)) + } + + if result.Content[0].Text != "test error" { + t.Errorf("Expected 'test error', got '%s'", result.Content[0].Text) + } +} + +// TestToolValidation 测试工具参数验证 +func TestToolValidation(t *testing.T) { + // 创建测试工具 + testTool := types.ToolDef{ + Name: "test_tool", + Description: "Test tool for validation", + InputSchema: types.InputSchema{ + Type: "object", + Required: []string{"name"}, + Properties: map[string]types.PropertyDef{ + "name": { + Type: "string", + Description: "Name parameter", + }, + "age": { + Type: "integer", + Description: "Age parameter", + }, + }, + }, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + return types.NewTextResult("ok", false), nil + }, + } + + t.Run("ValidArguments", func(t *testing.T) { + args := map[string]any{ + "name": "test", + "age": 25, + } + + err := types.ValidateRequired(testTool.InputSchema, args) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + }) + + t.Run("MissingRequired", func(t *testing.T) { + args := map[string]any{ + "age": 25, + } + + err := types.ValidateRequired(testTool.InputSchema, args) + if err == nil { + t.Error("Expected validation error") + } + }) + + t.Run("EmptyString", func(t *testing.T) { + args := map[string]any{ + "name": "", + } + + err := types.ValidateRequired(testTool.InputSchema, args) + if err == nil { + t.Error("Expected validation error for empty string") + } + }) +} + +// TestMetricsRegistrar 测试集群信息工具注册 +func TestMetricsRegistrar(t *testing.T) { + registrar := &MetricsRegistrar{} + reg := registry.NewRegistry() + + registrar.RegisterTools(reg) + + // 验证工具已注册 + tool, ok := reg.Get("get_cluster_info") + if !ok { + t.Fatal("Tool 'get_cluster_info' not registered") + } + + if tool.Name != "get_cluster_info" { + t.Errorf("Expected name 'get_cluster_info', got '%s'", tool.Name) + } + + if len(tool.InputSchema.Required) != 0 { + t.Errorf("Expected no required params, got %v", tool.InputSchema.Required) + } +} + +// TestResourceSearchRegistrar 测试搜索工具注册 +func TestResourceSearchRegistrar(t *testing.T) { + registrar := &ResourceSearchRegistrar{} + reg := registry.NewRegistry() + + registrar.RegisterTools(reg) + + // 验证工具已注册 + tool, ok := reg.Get("global_search") + if !ok { + t.Fatal("Tool 'global_search' not registered") + } + + // 验证必需参数(keyword 现在是可选的) + if len(tool.InputSchema.Required) != 0 { + t.Errorf("Expected required=[], got %v", tool.InputSchema.Required) + } + + // 验证所有属性定义 + expectedProps := []string{"keyword", "searchType", "mesh", "pageSize", "pageNumber"} + for _, prop := range expectedProps { + if _, ok := tool.InputSchema.Properties[prop]; !ok { + t.Errorf("Missing property: %s", prop) + } + } +} + +// TestServiceRegistrar 测试服务发现工具注册 +func TestServiceRegistrar(t *testing.T) { + registrar := &ServiceRegistrar{} + reg := registry.NewRegistry() + + registrar.RegisterTools(reg) + + // 验证工具已注册 + tools := reg.List() + if len(tools) != 2 { + t.Errorf("Expected 2 tools, got %d", len(tools)) + } + + expectedTools := []string{"search_services", "get_service_detail"} + for _, name := range expectedTools { + if _, ok := reg.Get(name); !ok { + t.Errorf("Tool '%s' not registered", name) + } + } +} + +// TestRegistryList 测试注册表列表功能 +func TestRegistryList(t *testing.T) { + reg := registry.NewRegistry() + + // 注册一个测试工具 + reg.Register(types.ToolDef{ + Name: "test1", + Description: "Test tool 1", + InputSchema: types.InputSchema{Type: "object"}, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { return nil, nil }, + }) + + reg.Register(types.ToolDef{ + Name: "test2", + Description: "Test tool 2", + InputSchema: types.InputSchema{Type: "object"}, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { return nil, nil }, + }) + + tools := reg.List() + if len(tools) != 2 { + t.Errorf("Expected 2 tools, got %d", len(tools)) + } +} + +// TestRegistryUnregister 测试工具注销 +func TestRegistryUnregister(t *testing.T) { + reg := registry.NewRegistry() + + reg.Register(types.ToolDef{ + Name: "test", + Description: "Test tool", + InputSchema: types.InputSchema{Type: "object"}, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { return nil, nil }, + }) + + if _, ok := reg.Get("test"); !ok { + t.Error("Tool not registered") + } + + reg.Unregister("test") + + if _, ok := reg.Get("test"); ok { + t.Error("Tool still exists after unregister") + } +} + +// TestDefaultPageValues 测试默认分页值 +func TestDefaultPageValues(t *testing.T) { + if DefaultPageSize != 10 { + t.Errorf("Expected DefaultPageSize=10, got %d", DefaultPageSize) + } + + if DefaultPageNumber != 1 { + t.Errorf("Expected DefaultPageNumber=1, got %d", DefaultPageNumber) + } +} + +// TestToolHandlerExecution 测试工具处理器执行 +func TestToolHandlerExecution(t *testing.T) { + // 创建简单的测试工具 + var called bool + var receivedArgs map[string]any + + testTool := types.ToolDef{ + Name: "echo", + Description: "Echo tool", + InputSchema: types.InputSchema{ + Type: "object", + Required: []string{"message"}, + Properties: map[string]types.PropertyDef{ + "message": {Type: "string", Description: "Message"}, + }, + }, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + called = true + receivedArgs = args + msg := args["message"].(string) + return types.NewTextResult("echo: "+msg, false), nil + }, + } + + t.Run("成功调用", func(t *testing.T) { + called = false + receivedArgs = nil + + args := map[string]any{ + "message": "hello", + } + + result, err := testTool.Handler(nil, args) + if err != nil { + t.Fatalf("Handler failed: %v", err) + } + + if !called { + t.Error("Handler was not called") + } + + if msg, ok := receivedArgs["message"].(string); !ok || msg != "hello" { + t.Errorf("Handler did not receive correct args, got %+v", receivedArgs) + } + + if result.IsError { + t.Error("Expected non-error result") + } + + if result.Content[0].Text != "echo: hello" { + t.Errorf("Expected 'echo: hello', got '%s'", result.Content[0].Text) + } + }) +} + +// TestSearchType 测试搜索类型 +func TestSearchType(t *testing.T) { + types := map[SearchType]string{ + SearchTypeIP: "ip", + SearchTypeInstanceName: "instanceName", + SearchTypeAppName: "appName", + SearchTypeName: "serviceName", + } + + for k, v := range types { + if string(k) != v { + t.Errorf("Expected '%s', got '%s'", v, string(k)) + } + } +} + +// TestServiceSide 测试服务端类型 +func TestServiceSide(t *testing.T) { + types := map[ServiceSide]string{ + ServiceSideProvider: "provider", + ServiceSideConsumer: "consumer", + } + + for k, v := range types { + if string(k) != v { + t.Errorf("Expected '%s', got '%s'", v, string(k)) + } + } +} + +// TestIsEmpty 测试空值判断 +func TestIsEmpty(t *testing.T) { + tests := []struct { + name string + value any + expected bool + }{ + {"空字符串", "", true}, + {"非空字符串", "hello", false}, + {"nil", nil, true}, + {"空数组", []any{}, true}, + {"非空数组", []any{1, 2}, false}, + {"空map", map[string]any{}, true}, + {"非空map", map[string]any{"key": "value"}, false}, + {"数字", 42, false}, + {"布尔", true, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := types.IsEmpty(tt.value) + if result != tt.expected { + t.Errorf("IsEmpty(%v) = %v, want %v", tt.value, result, tt.expected) + } + }) + } +} diff --git a/pkg/mcp/tools/utils.go b/pkg/mcp/tools/utils.go new file mode 100644 index 000000000..84eb23fe5 --- /dev/null +++ b/pkg/mcp/tools/utils.go @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tools + +import ( + "encoding/json" + + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/mcp/core" + coremodel "github.com/apache/dubbo-admin/pkg/core/resource/model" +) + +const ( + // DefaultPageSize 默认分页大小 + DefaultPageSize = 10 + // DefaultPageNumber 默认页码 + DefaultPageNumber = 1 + // MaxDistributionLimit 服务分布查询的最大数量 + MaxDistributionLimit = 100 +) + +// SearchType 搜索类型枚举 +type SearchType string + +const ( + SearchTypeIP SearchType = "ip" + SearchTypeInstanceName SearchType = "instanceName" + SearchTypeAppName SearchType = "appName" + SearchTypeName SearchType = "serviceName" +) + +// ServiceSide 服务端类型 +type ServiceSide string + +const ( + ServiceSideProvider ServiceSide = "provider" + ServiceSideConsumer ServiceSide = "consumer" +) + +// ArgsHelper 参数辅助器 +type ArgsHelper struct { + args map[string]any +} + +// NewArgsHelper 创建参数辅助器 +func NewArgsHelper(args map[string]any) *ArgsHelper { + return &ArgsHelper{args: args} +} + +// GetString 获取字符串参数 +func (h *ArgsHelper) GetString(key, defaultValue string) string { + if v, ok := h.args[key].(string); ok { + return v + } + return defaultValue +} + +// GetInt 获取整数参数 +func (h *ArgsHelper) GetInt(key string, defaultValue int) int { + switch v := h.args[key].(type) { + case int: + return v + case float64: + return int(v) + } + return defaultValue +} + +// GetBool 获取布尔参数 +func (h *ArgsHelper) GetBool(key string, defaultValue bool) bool { + if v, ok := h.args[key].(bool); ok { + return v + } + return defaultValue +} + +// GetRequiredString 获取必需的字符串参数 +func (h *ArgsHelper) GetRequiredString(key string) (string, bool) { + v, ok := h.args[key].(string) + if !ok || v == "" { + return "", false + } + return v, true +} + +// BuildPageReq 构建分页请求参数 +func BuildPageReq(pageNumber, pageSize int) coremodel.PageReq { + if pageSize <= 0 { + pageSize = DefaultPageSize + } + if pageNumber <= 0 { + pageNumber = DefaultPageNumber + } + return coremodel.PageReq{ + PageOffset: (pageNumber - 1) * pageSize, + PageSize: pageSize, + } +} + +// FormatJSON 格式化 JSON +func FormatJSON(data any) (string, error) { + bytes, err := json.Marshal(data) + if err != nil { + return "", err + } + return string(bytes), nil +} + +// JsonResult 创建 JSON 结果 +func JsonResult(data any) (*core.ToolResult, error) { + jsonData, err := FormatJSON(data) + if err != nil { + return nil, err + } + return core.NewTextResult(jsonData, false), nil +} + +// ErrorResult 创建错误结果 +func ErrorResult(err error) *core.ToolResult { + return core.NewErrorResult(err) +} + +// GetMeshArg 获取 mesh 参数,默认使用配置中的 discovery id 作为 mesh +func GetMeshArg(ctx consolectx.Context, args map[string]any) string { + helper := NewArgsHelper(args) + if mesh := helper.GetString("mesh", ""); mesh != "" { + return mesh + } + // 默认使用第一个 discovery 配置的 id 作为 mesh 名称 + if len(ctx.Config().Discovery) > 0 { + return ctx.Config().Discovery[0].ID + } + // fallback 到 engine name + return ctx.Config().Engine.Name +} diff --git a/pkg/mcp/transport/http/handler.go b/pkg/mcp/transport/http/handler.go new file mode 100644 index 000000000..88b3f6553 --- /dev/null +++ b/pkg/mcp/transport/http/handler.go @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package http + +import ( + "encoding/json" + "io" + "net/http" + + "github.com/apache/dubbo-admin/pkg/mcp/core" +) + +// Handler HTTP请求处理器 +type Handler struct { + server *core.Server +} + +// NewHandler 创建HTTP处理器 +func NewHandler(server *core.Server) *Handler { + return &Handler{ + server: server, + } +} + +// ServeHTTP 实现http.Handler接口 +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // 设置CORS headers + h.setCORSHeaders(w) + + // 处理OPTIONS请求(CORS preflight) + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + + // 只接受POST请求 + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // 读取请求体 + body, err := io.ReadAll(r.Body) + if err != nil { + h.sendError(w, nil, core.ErrCodeParseError, "Failed to read request body") + return + } + + // 解析JSON-RPC请求 + var req core.JSONRPCRequest + if err := json.Unmarshal(body, &req); err != nil { + h.sendError(w, nil, core.ErrCodeParseError, "Invalid JSON") + return + } + + // 处理请求并获取响应 + resp := h.server.HandleRequest(&req) + + // 发送响应 + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) +} + +// setCORSHeaders 设置CORS headers +func (h *Handler) setCORSHeaders(w http.ResponseWriter) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + w.Header().Set("Access-Control-Max-Age", "86400") +} + +// sendError 发送错误响应 +func (h *Handler) sendError(w http.ResponseWriter, id interface{}, code int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + + resp := core.JSONRPCResponse{ + JSONRPC: core.JSONRPCVersion, + ID: id, + Error: &core.JSONRPCError{ + Code: code, + Message: message, + }, + } + json.NewEncoder(w).Encode(resp) +} + +// HandleMCPRequest 处理MCP请求(公开方法) +func (h *Handler) HandleMCPRequest(w http.ResponseWriter, req *core.JSONRPCRequest) { + resp := h.server.HandleRequest(req) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} diff --git a/pkg/mcp/transport/http/http.go b/pkg/mcp/transport/http/http.go new file mode 100644 index 000000000..182de6287 --- /dev/null +++ b/pkg/mcp/transport/http/http.go @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package http + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/apache/dubbo-admin/pkg/mcp/core" +) + +// Transport HTTP传输层 +type Transport struct { + server *core.Server + httpServer *http.Server + handler *Handler + mu sync.RWMutex + started bool +} + +// Config HTTP传输配置 +type Config struct { + Host string + Port int + ReadTimeout time.Duration + WriteTimeout time.Duration + ShutdownTimeout time.Duration +} + +// DefaultConfig 返回默认配置 +func DefaultConfig() *Config { + return &Config{ + Host: "0.0.0.0", + Port: 8080, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + ShutdownTimeout: 10 * time.Second, + } +} + +// NewTransport 创建HTTP传输层 +func NewTransport(server *core.Server) *Transport { + return NewTransportWithConfig(server, DefaultConfig()) +} + +// NewTransportWithConfig 使用指定配置创建HTTP传输层 +func NewTransportWithConfig(server *core.Server, cfg *Config) *Transport { + handler := NewHandler(server) + + addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + httpServer := &http.Server{ + Addr: addr, + Handler: handler, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + } + + return &Transport{ + server: server, + httpServer: httpServer, + handler: handler, + } +} + +// Start 启动HTTP服务器(阻塞运行) +func (t *Transport) Start(ctx context.Context) error { + t.mu.Lock() + if t.started { + t.mu.Unlock() + return fmt.Errorf("transport already started") + } + t.started = true + t.mu.Unlock() + + // 启动服务器 + errCh := make(chan error, 1) + go func() { + if err := t.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + errCh <- err + } + }() + + // 等待上下文取消或错误 + select { + case <-ctx.Done(): + return t.Shutdown() + case err := <-errCh: + return err + } +} + +// StartAsync 异步启动HTTP服务器 +func (t *Transport) StartAsync(ctx context.Context) error { + t.mu.Lock() + if t.started { + t.mu.Unlock() + return fmt.Errorf("transport already started") + } + t.started = true + t.mu.Unlock() + + go func() { + if err := t.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + // 记录错误但不退出 + fmt.Printf("HTTP server error: %v\n", err) + } + }() + + return nil +} + +// Shutdown 关闭HTTP服务器 +func (t *Transport) Shutdown() error { + t.mu.Lock() + defer t.mu.Unlock() + + if !t.started { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := t.httpServer.Shutdown(ctx); err != nil { + return fmt.Errorf("shutdown failed: %w", err) + } + + t.started = false + return nil +} + +// Close 关闭传输层 +func (t *Transport) Close() error { + return t.Shutdown() +} + +// Addr 返回监听地址 +func (t *Transport) Addr() string { + return t.httpServer.Addr +} + +// GetServer 获取HTTP服务器(用于自定义路由) +func (t *Transport) GetServer() *http.Server { + return t.httpServer +} + +// GetHandler 获取HTTP处理器(用于自定义路由) +func (t *Transport) GetHandler() *Handler { + return t.handler +} diff --git a/pkg/mcp/transport/http/http_test.go b/pkg/mcp/transport/http/http_test.go new file mode 100644 index 000000000..c735d7fb0 --- /dev/null +++ b/pkg/mcp/transport/http/http_test.go @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package http + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/apache/dubbo-admin/pkg/mcp/core" + "github.com/apache/dubbo-admin/pkg/mcp/registry" +) + +// TestHTTPTransport 测试HTTP传输 +func TestHTTPTransport(t *testing.T) { + // 创建测试服务器 + server := core.NewServer("test-server", "1.0.0") + + // 注册测试工具 + reg := server.GetRegistry() + reg.Register(core.ToolDef{ + Name: "test_tool", + Description: "A test tool", + InputSchema: core.InputSchema{ + Type: "object", + Properties: map[string]core.PropertyDef{ + "message": { + Type: "string", + Description: "Test message", + }, + }, + }, + Handler: func(ctx interface{}, args map[string]any) (*core.ToolResult, error) { + msg, _ := args["message"].(string) + return core.NewTextResult("Echo: " + msg), nil + }, + }) + reg.RegisterAll() + + // 创建HTTP传输层 + transport := NewTransportWithConfig(server, &Config{ + Host: "127.0.0.1", + Port: 0, // 随机端口 + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + }) + + // 异步启动 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := transport.StartAsync(ctx); err != nil { + t.Fatalf("Failed to start transport: %v", err) + } + + // 等待服务器启动 + time.Sleep(100 * time.Millisecond) + + // 测试initialize请求 + t.Run("Initialize", func(t *testing.T) { + req := core.JSONRPCRequest{ + JSONRPC: core.JSONRPCVersion, + ID: "1", + Method: core.MethodInitialize, + } + + resp := makeRequest(t, transport, req) + if resp.Error != nil { + t.Fatalf("Initialize failed: %v", resp.Error) + } + + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatal("Invalid result type") + } + + serverInfo := result["serverInfo"].(map[string]any) + if serverInfo["name"] != "test-server" { + t.Errorf("Expected server name 'test-server', got '%v'", serverInfo["name"]) + } + }) + + // 测试tools/list请求 + t.Run("ToolsList", func(t *testing.T) { + req := core.JSONRPCRequest{ + JSONRPC: core.JSONRPCVersion, + ID: "2", + Method: core.MethodToolsList, + } + + resp := makeRequest(t, transport, req) + if resp.Error != nil { + t.Fatalf("Tools list failed: %v", resp.Error) + } + + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatal("Invalid result type") + } + + tools := result["tools"].([]any) + if len(tools) != 1 { + t.Errorf("Expected 1 tool, got %d", len(tools)) + } + }) + + // 测试tools/call请求 + t.Run("ToolsCall", func(t *testing.T) { + req := core.JSONRPCRequest{ + JSONRPC: core.JSONRPCVersion, + ID: "3", + Method: core.MethodToolsCall, + Params: map[string]any{ + "name": "test_tool", + "arguments": map[string]any{ + "message": "Hello, MCP!", + }, + }, + } + + resp := makeRequest(t, transport, req) + if resp.Error != nil { + t.Fatalf("Tools call failed: %v", resp.Error) + } + + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatal("Invalid result type") + } + + content := result["content"].([]any) + if len(content) == 0 { + t.Fatal("Empty content") + } + + firstContent := content[0].(map[string]any) + text := firstContent["text"].(string) + if text != "Echo: Hello, MCP!" { + t.Errorf("Expected 'Echo: Hello, MCP!', got '%s'", text) + } + }) + + // 关闭传输层 + if err := transport.Shutdown(); err != nil { + t.Fatalf("Failed to shutdown transport: %v", err) + } +} + +// makeRequest 发送请求到HTTP传输层 +func makeRequest(t *testing.T, transport *Transport, req core.JSONRPCRequest) core.JSONRPCResponse { + t.Helper() + + // 使用httptest直接测试handler + handler := transport.GetHandler() + body, _ := json.Marshal(req) + + httpReq := httptest.NewRequest("POST", "/mcp", bytes.NewReader(body)) + httpReq.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + handler.ServeHTTP(w, httpReq) + + var resp core.JSONRPCResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + return resp +} + +// TestHandlerCORS 测试CORS支持 +func TestHandlerCORS(t *testing.T) { + server := core.NewServer("test", "1.0.0") + handler := NewHandler(server) + + req := httptest.NewRequest("OPTIONS", "/mcp", nil) + req.Header.Set("Origin", "https://example.com") + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + corsHeader := w.Header().Get("Access-Control-Allow-Origin") + if corsHeader != "*" { + t.Errorf("Expected CORS origin '*', got '%s'", corsHeader) + } +} + +// TestConcurrentRequests 测试并发请求 +func TestConcurrentRequests(t *testing.T) { + server := core.NewServer("test", "1.0.0") + transport := NewTransportWithConfig(server, &Config{ + Host: "127.0.0.1", + Port: 0, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := transport.StartAsync(ctx); err != nil { + t.Fatalf("Failed to start transport: %v", err) + } + defer transport.Shutdown() + + time.Sleep(100 * time.Millisecond) + + // 并发发送10个请求 + const numRequests = 10 + errCh := make(chan error, numRequests) + + for i := 0; i < numRequests; i++ { + go func(id int) { + req := core.JSONRPCRequest{ + JSONRPC: core.JSONRPCVersion, + ID: fmt.Sprintf("req-%d", id), + Method: core.MethodInitialize, + } + _ = makeRequest(t, transport, req) + errCh <- nil + }(i) + } + + // 等待所有请求完成 + for i := 0; i < numRequests; i++ { + if err := <-errCh; err != nil { + t.Errorf("Request failed: %v", err) + } + } +} + +// TestSSETransport 测试SSE传输 +func TestSSETransport(t *testing.T) { + server := core.NewServer("test", "1.0.0") + sseTransport := NewSSETransport(server) + + // 测试SSE连接 + t.Run("SSEConnection", func(t *testing.T) { + req := httptest.NewRequest("GET", "/sse", nil) + w := httptest.NewRecorder() + + // SSE需要flusher支持,httptest.NewRecorder实现了Flusher接口 + sseTransport.HandleSSE(w, req) + + // 验证响应headers + contentType := w.Header().Get("Content-Type") + if contentType != "text/event-stream" { + t.Errorf("Expected content type 'text/event-stream', got '%s'", contentType) + } + }) +} diff --git a/pkg/mcp/transport/http/sse.go b/pkg/mcp/transport/http/sse.go new file mode 100644 index 000000000..de3c07d23 --- /dev/null +++ b/pkg/mcp/transport/http/sse.go @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package http + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + + "github.com/apache/dubbo-admin/pkg/mcp/core" +) + +// SSETransport Server-Sent Events传输层 +type SSETransport struct { + server *core.Server + clients map[*SSEClient]bool + mu sync.RWMutex + broadcast chan []byte +} + +// SSEClient SSE客户端连接 +type SSEClient struct { + id string + ch chan []byte + ctx context.Context + done context.CancelFunc +} + +// NewSSEClient 创建SSE客户端 +func NewSSEClient(id string) *SSEClient { + ctx, cancel := context.WithCancel(context.Background()) + return &SSEClient{ + id: id, + ch: make(chan []byte, 256), + ctx: ctx, + done: cancel, + } +} + +// NewSSETransport 创建SSE传输层 +func NewSSETransport(server *core.Server) *SSETransport { + return &SSETransport{ + server: server, + clients: make(map[*SSEClient]bool), + broadcast: make(chan []byte, 256), + } +} + +// HandleSSE 处理SSE连接 +func (t *SSETransport) HandleSSE(w http.ResponseWriter, r *http.Request) { + // 设置SSE headers + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + // 创建客户端 + client := NewSSEClient(r.RemoteAddr) + + t.mu.Lock() + t.clients[client] = true + t.mu.Unlock() + + // 发送连接成功消息 + t.sendToClient(client, t.sseEvent("connected", "SSE connection established")) + + // 等待断开连接 + <-client.ctx.Done() + + t.mu.Lock() + delete(t.clients, client) + close(client.ch) + t.mu.Unlock() +} + +// HandleRPCWithSSE 处理带有SSE的RPC请求(支持异步响应流) +func (t *SSETransport) HandleRPCWithSSE(w http.ResponseWriter, r *http.Request) { + // 设置SSE headers + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "SSE not supported", http.StatusInternalServerError) + return + } + + // 读取请求 + decoder := json.NewDecoder(r.Body) + var req core.JSONRPCRequest + if err := decoder.Decode(&req); err != nil { + t.sendSSE(w, "error", err.Error()) + return + } + + // 发送处理中状态 + t.sendSSE(w, "processing", "Processing request...") + + // 处理请求 + resp := t.server.HandleRequest(&req) + + // 发送响应 + respData, _ := json.Marshal(resp) + t.sendSSE(w, "result", string(respData)) + + // 发送完成事件 + t.sendSSE(w, "done", "") + flusher.Flush() +} + +// BroadcastEvent 广播事件到所有客户端 +func (t *SSETransport) BroadcastEvent(eventType string, data interface{}) { + event := t.sseEvent(eventType, data) + select { + case t.broadcast <- event: + default: + // broadcast channel满,丢弃事件 + } +} + +// StartBroadcasting 启动广播协程 +func (t *SSETransport) StartBroadcasting(ctx context.Context) { + go func() { + for { + select { + case <-ctx.Done(): + return + case msg := <-t.broadcast: + t.mu.RLock() + for client := range t.clients { + select { + case client.ch <- msg: + default: + // client channel满,关闭客户端 + client.done() + } + } + t.mu.RUnlock() + } + } + }() +} + +// sendToClient 发送消息到指定客户端 +func (t *SSETransport) sendToClient(client *SSEClient, msg []byte) { + select { + case client.ch <- msg: + default: + client.done() + } +} + +// sseEvent 创建SSE事件格式 +func (t *SSETransport) sseEvent(eventType string, data interface{}) []byte { + var dataStr string + switch v := data.(type) { + case string: + dataStr = v + case []byte: + dataStr = string(v) + default: + bytes, _ := json.Marshal(data) + dataStr = string(bytes) + } + return []byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, dataStr)) +} + +// sendSSE 发送SSE事件到ResponseWriter +func (t *SSETransport) sendSSE(w http.ResponseWriter, eventType, data string) { + event := t.sseEvent(eventType, data) + w.Write(event) +} diff --git a/pkg/mcp/transport/stdio/stdio.go b/pkg/mcp/transport/stdio/stdio.go new file mode 100644 index 000000000..f920d00a8 --- /dev/null +++ b/pkg/mcp/transport/stdio/stdio.go @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stdio + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "os" + "sync" + + "github.com/apache/dubbo-admin/pkg/mcp/core" +) + +// Transport stdio 传输层 +type Transport struct { + server *core.Server + reader io.Reader + writer io.Writer + mu sync.Mutex + closed bool +} + +// NewTransport 创建 stdio 传输层(使用 stdin/stdout) +func NewTransport(server *core.Server) *Transport { + return &Transport{ + server: server, + reader: os.Stdin, + writer: os.Stdout, + } +} + +// NewTransportWithIO 创建使用指定 reader/writer 的传输层(用于测试) +func NewTransportWithIO(server *core.Server, reader io.Reader, writer io.Writer) *Transport { + return &Transport{ + server: server, + reader: reader, + writer: writer, + } +} + +// Serve 启动 stdio 服务,阻塞运行直到发生错误或上下文取消 +func (t *Transport) Serve(ctx context.Context) error { + // 从 reader 读取请求,写入 writer 响应 + reader := bufio.NewReader(t.reader) + writer := bufio.NewWriter(t.writer) + + // 确保输出缓冲区刷新 + defer writer.Flush() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // 读取一行 JSON 请求 + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + return nil // 正常结束 + } + return fmt.Errorf("read stdin: %w", err) + } + + if len(line) == 0 || line == "\n" { + continue + } + + // 解析 JSON-RPC 请求 + var req core.JSONRPCRequest + if err := json.Unmarshal([]byte(line), &req); err != nil { + // 发送错误响应 + t.sendError(writer, nil, core.ErrCodeParseError, "Parse error") + if err := writer.Flush(); err != nil { + return fmt.Errorf("flush error: %w", err) + } + continue + } + + // 处理请求并获取响应 + resp := t.server.HandleRequest(&req) + + // 发送响应 + respData, err := json.Marshal(resp) + if err != nil { + t.sendError(writer, req.ID, core.ErrCodeInternalError, "Failed to marshal response") + } else { + writer.Write(respData) + writer.WriteByte('\n') + } + + if err := writer.Flush(); err != nil { + return fmt.Errorf("flush stdout: %w", err) + } + } +} + +// sendError 发送错误响应 +func (t *Transport) sendError(writer *bufio.Writer, id interface{}, code int, message string) { + resp := core.JSONRPCResponse{ + JSONRPC: core.JSONRPCVersion, + ID: id, + Error: &core.JSONRPCError{ + Code: code, + Message: message, + }, + } + data, _ := json.Marshal(resp) + writer.Write(data) + writer.WriteByte('\n') +} + +// Close 关闭传输层 +func (t *Transport) Close() error { + t.mu.Lock() + defer t.mu.Unlock() + if t.closed { + return nil + } + t.closed = true + return nil +} + +// ServeOnce 处理单个请求(用于测试) +func (t *Transport) ServeOnce(input string) (string, error) { + var req core.JSONRPCRequest + if err := json.Unmarshal([]byte(input), &req); err != nil { + return "", fmt.Errorf("parse request: %w", err) + } + + resp := t.server.HandleRequest(&req) + respData, err := json.Marshal(resp) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(respData) + "\n", nil +} diff --git a/pkg/mcp/transport/stdio/stdio_integration_test.go b/pkg/mcp/transport/stdio/stdio_integration_test.go new file mode 100644 index 000000000..cbc8ddeef --- /dev/null +++ b/pkg/mcp/transport/stdio/stdio_integration_test.go @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stdio + +import ( + "bufio" + "context" + "encoding/json" + "io" + "sync" + "testing" + "time" + + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/mcp/core" + "github.com/apache/dubbo-admin/pkg/mcp/types" +) + +// TestTransport_Integration_EndToEnd 端到端集成测试 +func TestTransport_Integration_EndToEnd(t *testing.T) { + // 创建服务器并注册工具 + server := core.NewServer("test-server", "1.0.0") + reg := server.GetRegistry() + + // 注册一个测试工具 + reg.Register(types.ToolDef{ + Name: "echo", + Description: "Echo back the input message", + InputSchema: types.InputSchema{ + Type: "object", + Required: []string{"message"}, + Properties: map[string]types.PropertyDef{ + "message": { + Type: "string", + Description: "Message to echo", + }, + }, + }, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + msg, _ := args["message"].(string) + return types.NewTextResult("echo: "+msg, false), nil + }, + }) + + // 创建管道对 + serverR, clientW := io.Pipe() + clientR, serverW := io.Pipe() + defer func() { + clientW.Close() + serverW.Close() + }() + + // 创建使用管道的 transport + transport := NewTransportWithIO(server, serverR, serverW) + + // 启动服务器(在后台) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + serverErr := make(chan error, 1) + go func() { + serverErr <- transport.Serve(ctx) + }() + + // 给服务器一点时间启动 + time.Sleep(100 * time.Millisecond) + + // 客户端发送请求 + client := newMCPClient(clientR, clientW) + + // 1. 测试 initialize + t.Run("Initialize", func(t *testing.T) { + resp := client.call(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{}, + }) + + if resp.Error != nil { + t.Fatalf("Initialize failed: %s", resp.Error.Message) + } + + result := resp.Result.(map[string]any) + serverInfo := result["serverInfo"].(map[string]any) + if serverInfo["name"] != "test-server" { + t.Errorf("Expected name 'test-server', got '%v'", serverInfo["name"]) + } + }) + + // 2. 测试 tools/list + t.Run("ToolsList", func(t *testing.T) { + resp := client.call(map[string]any{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + }) + + if resp.Error != nil { + t.Fatalf("Tools list failed: %s", resp.Error.Message) + } + + result := resp.Result.(map[string]any) + tools := result["tools"].([]any) + if len(tools) != 1 { + t.Fatalf("Expected 1 tool, got %d", len(tools)) + } + + tool := tools[0].(map[string]any) + if tool["name"] != "echo" { + t.Errorf("Expected tool name 'echo', got '%v'", tool["name"]) + } + }) + + // 3. 测试 tools/call - 成功调用 + t.Run("ToolCall_Success", func(t *testing.T) { + resp := client.call(map[string]any{ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": map[string]any{ + "name": "echo", + "arguments": map[string]any{ + "message": "hello world", + }, + }, + }) + + if resp.Error != nil { + t.Fatalf("Tool call failed: %s", resp.Error.Message) + } + + result := resp.Result.(map[string]any) + content := result["content"].([]any) + if len(content) != 1 { + t.Fatalf("Expected 1 content item, got %d", len(content)) + } + + firstContent := content[0].(map[string]any) + if firstContent["text"] != "echo: hello world" { + t.Errorf("Expected 'echo: hello world', got '%v'", firstContent["text"]) + } + }) + + // 4. 测试 tools/call - 缺少必需参数 + t.Run("ToolCall_MissingRequired", func(t *testing.T) { + resp := client.call(map[string]any{ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": map[string]any{ + "name": "echo", + "arguments": map[string]any{}, + }, + }) + + if resp.Error == nil { + t.Fatal("Expected error for missing required parameter") + } + + if resp.Error.Code != core.ErrCodeInvalidParams { + t.Errorf("Expected invalid params code, got %d", resp.Error.Code) + } + }) + + // 5. 测试 tools/call - 工具不存在 + t.Run("ToolCall_NotFound", func(t *testing.T) { + resp := client.call(map[string]any{ + "jsonrpc": "2.0", + "id": 5, + "method": "tools/call", + "params": map[string]any{ + "name": "nonexistent", + "arguments": map[string]any{}, + }, + }) + + if resp.Error == nil { + t.Fatal("Expected error for nonexistent tool") + } + + if resp.Error.Code != core.ErrCodeMethodNotFound { + t.Errorf("Expected method not found code, got %d", resp.Error.Code) + } + }) + + // 取消服务器上下文并关闭管道 + cancel() + clientW.Close() // 关闭客户端写入端,让服务器 ReadString 返回 EOF + + // 等待服务器退出 + select { + case err := <-serverErr: + if err != nil && err != context.Canceled { + t.Logf("Server exited with error: %v", err) + } + case <-time.After(2 * time.Second): + t.Error("Server did not exit in time") + } +} + +// mcpClient 简单的 MCP 客户端 +type mcpClient struct { + r *bufio.Reader + w io.Writer + wMutex sync.Mutex + reqID int +} + +func newMCPClient(r io.Reader, w io.WriteCloser) *mcpClient { + return &mcpClient{ + r: bufio.NewReader(r), + w: w, + reqID: 0, + } +} + +func (c *mcpClient) call(req map[string]any) *jsonRPCResponse { + c.wMutex.Lock() + defer c.wMutex.Unlock() + + c.reqID++ + req["id"] = c.reqID + + // 发送请求 + reqData, _ := json.Marshal(req) + c.w.Write(reqData) + c.w.Write([]byte("\n")) + + // 读取响应 + line, err := c.r.ReadString('\n') + if err != nil { + return &jsonRPCResponse{ + Error: &core.JSONRPCError{ + Code: -1, + Message: err.Error(), + }, + } + } + + var resp jsonRPCResponse + if err := json.Unmarshal([]byte(line), &resp); err != nil { + return &jsonRPCResponse{ + Error: &core.JSONRPCError{ + Code: -1, + Message: "Failed to parse response: " + err.Error(), + }, + } + } + + return &resp +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id"` + Result any `json:"result,omitempty"` + Error *core.JSONRPCError `json:"error,omitempty"` +} + +// TestTransport_RealIO 使用真实 io 操作的测试 +func TestTransport_RealIO(t *testing.T) { + server := core.NewServer("test-server", "1.0.0") + reg := server.GetRegistry() + + reg.Register(types.ToolDef{ + Name: "ping", + Description: "Ping tool", + InputSchema: types.InputSchema{Type: "object"}, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + return types.NewTextResult("pong", false), nil + }, + }) + + // 创建管道对 + serverR, clientW := io.Pipe() + clientR, serverW := io.Pipe() + + transport := NewTransportWithIO(server, serverR, serverW) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 启动服务器 + serverErr := make(chan error, 1) + go func() { + serverErr <- transport.Serve(ctx) + }() + + time.Sleep(100 * time.Millisecond) + + // 客户端发送多个请求 + client := newMCPClient(clientR, clientW) + + responses := make([]*jsonRPCResponse, 0, 3) + for i := 0; i < 3; i++ { + resp := client.call(map[string]any{ + "jsonrpc": "2.0", + "method": "tools/call", + "params": map[string]any{ + "name": "ping", + "arguments": map[string]any{}, + }, + }) + responses = append(responses, resp) + } + + // 验证所有响应 + for i, resp := range responses { + if resp.Error != nil { + t.Errorf("Request %d failed: %s", i, resp.Error.Message) + } + } + + cancel() + <-serverErr +} + +// TestTransport_ConcurrentRequests 并发请求测试 +func TestTransport_ConcurrentRequests(t *testing.T) { + server := core.NewServer("test-server", "1.0.0") + reg := server.GetRegistry() + + reg.Register(types.ToolDef{ + Name: "counter", + Description: "Counting tool", + InputSchema: types.InputSchema{Type: "object"}, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + return types.NewTextResult("count", false), nil + }, + }) + + serverR, clientW := io.Pipe() + clientR, serverW := io.Pipe() + + transport := NewTransportWithIO(server, serverR, serverW) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + serverErr := make(chan error, 1) + go func() { + serverErr <- transport.Serve(ctx) + }() + + time.Sleep(100 * time.Millisecond) + + // 并发发送多个请求 + // 使用共享客户端(内部有锁保护写操作) + client := newMCPClient(clientR, clientW) + const numRequests = 10 + results := make(chan *jsonRPCResponse, numRequests) + var wg sync.WaitGroup + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp := client.call(map[string]any{ + "jsonrpc": "2.0", + "method": "tools/call", + "params": map[string]any{ + "name": "counter", + "arguments": map[string]any{}, + }, + }) + results <- resp + }() + } + + // 等待所有请求完成 + go func() { + wg.Wait() + close(results) + }() + + // 收集结果 + successCount := 0 + for resp := range results { + if resp.Error == nil { + successCount++ + } + } + + if successCount != numRequests { + t.Errorf("Expected %d successful requests, got %d", numRequests, successCount) + } + + cancel() + clientW.Close() + <-serverErr +} diff --git a/pkg/mcp/transport/stdio/stdio_test.go b/pkg/mcp/transport/stdio/stdio_test.go new file mode 100644 index 000000000..b1f70bd07 --- /dev/null +++ b/pkg/mcp/transport/stdio/stdio_test.go @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stdio + +import ( + "encoding/json" + "testing" + + consolectx "github.com/apache/dubbo-admin/pkg/console/context" + "github.com/apache/dubbo-admin/pkg/mcp/core" + "github.com/apache/dubbo-admin/pkg/mcp/types" +) + +func TestTransport_ServeOnce_Initialize(t *testing.T) { + server := core.NewServer("test-server", "1.0.0") + transport := NewTransport(server) + + // 测试 initialize 请求 + req := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{}, + } + reqJSON, _ := json.Marshal(req) + + resp, err := transport.ServeOnce(string(reqJSON) + "\n") + if err != nil { + t.Fatalf("ServeOnce failed: %v", err) + } + + var jsonResp core.JSONRPCResponse + if err := json.Unmarshal([]byte(resp), &jsonResp); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if jsonResp.Error != nil { + t.Fatalf("Response error: %s", jsonResp.Error.Message) + } + + result, ok := jsonResp.Result.(map[string]any) + if !ok { + t.Fatal("Result is not a map") + } + + serverInfo, ok := result["serverInfo"].(map[string]any) + if !ok { + t.Fatal("serverInfo not found") + } + + if serverInfo["name"] != "test-server" { + t.Errorf("Expected name 'test-server', got '%v'", serverInfo["name"]) + } +} + +func TestTransport_ServeOnce_ToolsList(t *testing.T) { + server := core.NewServer("test-server", "1.0.0") + transport := NewTransport(server) + + // 注册一个测试工具 + reg := server.GetRegistry() + reg.Register(types.ToolDef{ + Name: "test_tool", + Description: "A test tool", + InputSchema: types.InputSchema{ + Type: "object", + }, + Handler: func(ctx consolectx.Context, args map[string]any) (*types.ToolResult, error) { + return types.NewTextResult("ok", false), nil + }, + }) + + // 测试 tools/list 请求 + req := map[string]any{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + } + reqJSON, _ := json.Marshal(req) + + resp, err := transport.ServeOnce(string(reqJSON) + "\n") + if err != nil { + t.Fatalf("ServeOnce failed: %v", err) + } + + var jsonResp core.JSONRPCResponse + if err := json.Unmarshal([]byte(resp), &jsonResp); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if jsonResp.Error != nil { + t.Fatalf("Response error: %s", jsonResp.Error.Message) + } + + result, ok := jsonResp.Result.(map[string]any) + if !ok { + t.Fatal("Result is not a map") + } + + tools, ok := result["tools"].([]any) + if !ok { + t.Fatal("tools not found") + } + + if len(tools) != 1 { + t.Errorf("Expected 1 tool, got %d", len(tools)) + } +} + +func TestTransport_ServeOnce_ParseError(t *testing.T) { + server := core.NewServer("test-server", "1.0.0") + transport := NewTransport(server) + + // 测试无效 JSON - ServeOnce 在解析失败时返回错误 + _, err := transport.ServeOnce("invalid json\n") + if err == nil { + t.Fatal("Expected error for invalid JSON") + } + + // 测试无效方法名 - 这会返回 JSON 错误响应 + req := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "invalid_method", + } + reqJSON, _ := json.Marshal(req) + + resp, err := transport.ServeOnce(string(reqJSON) + "\n") + if err != nil { + t.Fatalf("ServeOnce failed: %v", err) + } + + var jsonResp core.JSONRPCResponse + if err := json.Unmarshal([]byte(resp), &jsonResp); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if jsonResp.Error == nil { + t.Fatal("Expected error response") + } + + if jsonResp.Error.Code != core.ErrCodeMethodNotFound { + t.Errorf("Expected method not found code, got %d", jsonResp.Error.Code) + } +} + +func TestTransport_Close(t *testing.T) { + server := core.NewServer("test-server", "1.0.0") + transport := NewTransport(server) + + if err := transport.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + + // 再次关闭应该成功 + if err := transport.Close(); err != nil { + t.Fatalf("Second Close failed: %v", err) + } +} diff --git a/pkg/mcp/types/tool.go b/pkg/mcp/types/tool.go new file mode 100644 index 000000000..71ac83235 --- /dev/null +++ b/pkg/mcp/types/tool.go @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types + +import ( + consolectx "github.com/apache/dubbo-admin/pkg/console/context" +) + +// ToolDef 工具定义 +type ToolDef struct { + Name string + Description string + InputSchema InputSchema + Handler ToolHandler +} + +// InputSchema 输入参数 schema +type InputSchema struct { + Type string `json:"type"` + Properties map[string]PropertyDef `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +// PropertyDef 属性定义 +type PropertyDef struct { + Type string `json:"type"` + Description string `json:"description,omitempty"` + Default any `json:"default,omitempty"` + Enum []string `json:"enum,omitempty"` +} + +// ToolHandler 工具处理器类型 +type ToolHandler func(ctx consolectx.Context, args map[string]any) (*ToolResult, error) + +// ToolResult 工具执行结果 +type ToolResult struct { + Content []Content `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +// Content 内容块 +type Content struct { + Type string `json:"type"` + Text string `json:"text"` +} + +// NewToolResult 创建工具结果 +func NewToolResult(content []Content, isError bool) *ToolResult { + return &ToolResult{ + Content: content, + IsError: isError, + } +} + +// NewTextResult 创建文本结果 +func NewTextResult(text string, isError bool) *ToolResult { + return &ToolResult{ + Content: []Content{{Type: "text", Text: text}}, + IsError: isError, + } +} + +// NewErrorResult 创建错误结果 +func NewErrorResult(err error) *ToolResult { + return &ToolResult{ + Content: []Content{{Type: "text", Text: err.Error()}}, + IsError: true, + } +} diff --git a/pkg/mcp/types/validation.go b/pkg/mcp/types/validation.go new file mode 100644 index 000000000..25c38005d --- /dev/null +++ b/pkg/mcp/types/validation.go @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types + +import "fmt" + +// ValidateRequired 验证必需参数是否存在且非空 +func ValidateRequired(schema InputSchema, args map[string]any) error { + if args == nil { + args = make(map[string]any) + } + + for _, required := range schema.Required { + val, exists := args[required] + if !exists { + return fmt.Errorf("missing required parameter: %s", required) + } + + if IsEmpty(val) { + return fmt.Errorf("required parameter %s cannot be empty", required) + } + } + return nil +} + +// IsEmpty 判断值是否为空 +func IsEmpty(val any) bool { + switch v := val.(type) { + case string: + return v == "" + case []any: + return len(v) == 0 + case map[string]any: + return len(v) == 0 + case nil: + return true + default: + return false + } +}