From 4bc52c76fa4a00fe450d5d38ad1e3af1b0dd2044 Mon Sep 17 00:00:00 2001 From: kpom Date: Thu, 21 May 2026 13:31:03 -0500 Subject: [PATCH 1/2] feature(mcp-server): Build an MCP server for dawgrun --- README.md | 5 + go.mod | 7 + go.sum | 12 ++ tools/dawgrun/README.md | 47 ++++++ tools/dawgrun/cmd/dawgrun-mcp/main.go | 28 ++++ tools/dawgrun/pkg/commands/cypher.go | 140 ++++++++--------- tools/dawgrun/pkg/commands/db.go | 31 +++- tools/dawgrun/pkg/commands/helpers.go | 7 +- tools/dawgrun/pkg/mcpserver/cypher.go | 133 ++++++++++++++++ tools/dawgrun/pkg/mcpserver/kinds.go | 80 ++++++++++ tools/dawgrun/pkg/mcpserver/opengraph.go | 158 +++++++++++++++++++ tools/dawgrun/pkg/mcpserver/server.go | 167 +++++++++++++++++++++ tools/dawgrun/pkg/mcpserver/server_test.go | 92 ++++++++++++ tools/dawgrun/pkg/texttools/cypher.go | 8 + 14 files changed, 838 insertions(+), 77 deletions(-) create mode 100644 tools/dawgrun/cmd/dawgrun-mcp/main.go create mode 100644 tools/dawgrun/pkg/mcpserver/cypher.go create mode 100644 tools/dawgrun/pkg/mcpserver/kinds.go create mode 100644 tools/dawgrun/pkg/mcpserver/opengraph.go create mode 100644 tools/dawgrun/pkg/mcpserver/server.go create mode 100644 tools/dawgrun/pkg/mcpserver/server_test.go diff --git a/README.md b/README.md index e1aad7bb..8399bb5f 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,11 @@ export CONNECTION_STRING="neo4j://neo4j:weneedbetterpasswords@localhost:7687" Use `make test` for unit tests only and `make test_integration` for integration tests only. +### Dawgrun MCP Server + +`go tool dawgrun-mcp` runs a stdio MCP server for agent access to dawgrun-backed DAWGS inspection tools. See +[`tools/dawgrun/README.md`](tools/dawgrun/README.md) for client configuration and the current tool surface. + ### Test Metrics `make test` writes unit test coverage artifacts under `.coverage/`: diff --git a/go.mod b/go.mod index 58534cd7..5d3d1cfd 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/jedib0t/go-pretty/v6 v6.7.8 github.com/kanmu/go-sqlfmt v0.0.2-0.20200215095417-d1e63e2ee5eb github.com/mitchellh/go-wordwrap v1.0.1 + github.com/modelcontextprotocol/go-sdk v1.6.0 github.com/specterops/go-repl v1.0.1 golang.org/x/term v0.39.0 ) @@ -129,6 +130,7 @@ require ( github.com/golangci/swaggoswag v0.0.0-20250504205917-77f2aca3143e // indirect github.com/golangci/unconvert v0.0.0-20250410112200-a129a6e6413e // indirect github.com/google/go-cmp v0.7.0 // indirect + github.com/google/jsonschema-go v0.4.3 // indirect github.com/gordonklaus/ineffassign v0.2.0 // indirect github.com/gostaticanalysis/analysisutil v0.7.1 // indirect github.com/gostaticanalysis/comment v1.5.0 // indirect @@ -207,6 +209,8 @@ require ( github.com/sashamelentyev/interfacebloat v1.1.0 // indirect github.com/sashamelentyev/usestdlibvars v1.29.0 // indirect github.com/securego/gosec/v2 v2.24.8-0.20260309165252-619ce2117e08 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.4 // indirect github.com/sirupsen/logrus v1.9.4 // indirect github.com/sivchari/containedctx v1.0.3 // indirect github.com/sonatard/noctx v0.5.1 // indirect @@ -235,6 +239,7 @@ require ( github.com/yagipy/maintidx v1.0.0 // indirect github.com/yeya24/promlinter v0.3.0 // indirect github.com/ykadowak/zerologlint v0.1.5 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect gitlab.com/bosi/decorder v0.4.2 // indirect go-simpler.org/musttag v0.14.0 // indirect go-simpler.org/sloglint v0.11.1 // indirect @@ -246,6 +251,7 @@ require ( golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect golang.org/x/exp/typeparams v0.0.0-20260209203927-2842357ff358 // indirect golang.org/x/mod v0.34.0 // indirect + golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect @@ -262,5 +268,6 @@ tool ( github.com/fzipp/gocyclo/cmd/gocyclo github.com/golangci/golangci-lint/v2/cmd/golangci-lint github.com/specterops/dawgs/tools/dawgrun/cmd/dawgrun + github.com/specterops/dawgs/tools/dawgrun/cmd/dawgrun-mcp github.com/specterops/dawgs/tools/metrics/cmd/dawgs-metrics ) diff --git a/go.sum b/go.sum index ff24078a..7c24d4a0 100644 --- a/go.sum +++ b/go.sum @@ -284,6 +284,8 @@ github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -351,6 +353,8 @@ github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0= +github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -563,6 +567,8 @@ github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQ github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modelcontextprotocol/go-sdk v1.6.0 h1:PPLS3kn7WtOEnR+Af4X5H96SG0qSab8R/ZQT/HkhPkY= +github.com/modelcontextprotocol/go-sdk v1.6.0/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -672,6 +678,10 @@ github.com/sashamelentyev/usestdlibvars v1.29.0/go.mod h1:8PpnjHMk5VdeWlVb4wCdrB github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/securego/gosec/v2 v2.24.8-0.20260309165252-619ce2117e08 h1:AoLtJX4WUtZkhhUUMFy3GgecAALp/Mb4S1iyQOA2s0U= github.com/securego/gosec/v2 v2.24.8-0.20260309165252-619ce2117e08/go.mod h1:+XLCJiRE95ga77XInNELh2M6zQP+PdqiT9Zpm0D9Wpk= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0= +github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= @@ -761,6 +771,8 @@ github.com/yeya24/promlinter v0.3.0 h1:JVDbMp08lVCP7Y6NP3qHroGAO6z2yGKQtS5Jsjqto github.com/yeya24/promlinter v0.3.0/go.mod h1:cDfJQQYv9uYciW60QT0eeHlFodotkYZlL+YcPQN+mW4= github.com/ykadowak/zerologlint v0.1.5 h1:Gy/fMz1dFQN9JZTPjv1hxEk+sRWm05row04Yoolgdiw= github.com/ykadowak/zerologlint v0.1.5/go.mod h1:KaUskqF3e/v59oPmdq1U1DnKcuHokl2/K1U4pmIELKg= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/tools/dawgrun/README.md b/tools/dawgrun/README.md index c92676ab..f0e1b25c 100644 --- a/tools/dawgrun/README.md +++ b/tools/dawgrun/README.md @@ -66,6 +66,53 @@ connection names without an interactive `open` first. At any time, run `help` to list commands or `help ` for detailed usage, flag defaults, and description. +## MCP server + +`dawgrun-mcp` is a stdio MCP server that exposes a small dawgrun tool +surface to agent clients. It uses the official +`github.com/modelcontextprotocol/go-sdk` and keeps named backend +connections in memory for the lifetime of the MCP process. +It does not read or write dawgrun local config; use `open_connection` +to create MCP-session-local connections explicitly. + +Run it from a `DAWGS` checkout with: + + go tool dawgrun-mcp + +An MCP client can launch it with a config like: + +```json +{ + "mcpServers": { + "dawgrun": { + "command": "go", + "args": ["tool", "dawgrun-mcp"] + } + } +} +``` + +For OpenCode specifically, use the array-form local command: + +```json +{ + "$schema": "https://opencode.ai/config.json", + "mcp": { + "dawgrun": { + "type": "local", + "command": ["go", "tool", "dawgrun-mcp"], + "enabled": true + } + } +} +``` + +The tool set includes `list_connections`, `open_connection`, +`parse_cypher`, `translate_cypher_to_pgsql`, `query_cypher`, +`explain_psql`, `save_opengraph`, `load_db_kinds`, `lookup_kind`, and +`lookup_kind_id`. Write-capable tools, currently `load_opengraph` and +`copy_opengraph`, require `--allow-writes`. + ## Commands The REPL supports command-name completion with `Tab`; ambiguous matches render a transient popover list near the prompt and can be dismissed with `Esc`. diff --git a/tools/dawgrun/cmd/dawgrun-mcp/main.go b/tools/dawgrun/cmd/dawgrun-mcp/main.go new file mode 100644 index 00000000..a00b7d14 --- /dev/null +++ b/tools/dawgrun/cmd/dawgrun-mcp/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + + "github.com/davecgh/go-spew/spew" + + "github.com/specterops/dawgs/tools/dawgrun/pkg/commands" + "github.com/specterops/dawgs/tools/dawgrun/pkg/mcpserver" +) + +func main() { + spew.Config.DisablePointerAddresses = true + + var allowWrites bool + flag.BoolVar(&allowWrites, "allow-writes", false, "enable write-capable DAWGS tools") + flag.Parse() + + scope := commands.NewScope(commands.RunModeCLI) + server := mcpserver.NewWithScope(scope, mcpserver.Options{AllowWrites: allowWrites}) + if err := server.Run(context.Background()); err != nil { + fmt.Fprintf(os.Stderr, "dawgrun-mcp failed: %v\n", err) + os.Exit(1) + } +} diff --git a/tools/dawgrun/pkg/commands/cypher.go b/tools/dawgrun/pkg/commands/cypher.go index 2d2e05ca..22938d18 100644 --- a/tools/dawgrun/pkg/commands/cypher.go +++ b/tools/dawgrun/pkg/commands/cypher.go @@ -21,6 +21,61 @@ const ( queryCypherOutputFormatJSON = "json" ) +type TranslateCypherOptions struct { + Connection string + DumpPGAst bool + SQLFormatDistance int +} + +type TranslateCypherOutput struct { + SQL string + PGAst string +} + +// TranslateCypherToPsql parses and translates a Cypher query to PostgreSQL SQL. +func TranslateCypherToPsql(ctx *CommandContext, queryText string, options TranslateCypherOptions) (TranslateCypherOutput, error) { + query, err := ParseQueryText(queryText) + if err != nil { + return TranslateCypherOutput{}, fmt.Errorf("error trying to parse query: %w", err) + } + + kindMapper := stubs.EmptyMapper() + if strings.TrimSpace(options.Connection) != "" { + kindMap, err := loadKindMap(ctx, options.Connection) + if err != nil { + return TranslateCypherOutput{}, fmt.Errorf("could not load kind map for translation: %w", err) + } + kindMapper = stubs.MapperFromKindMap(kindMap) + } + + result, err := translate.Translate(ctx, query, kindMapper, nil, defaultGraphID(ctx, options.Connection)) + if err != nil { + return TranslateCypherOutput{}, fmt.Errorf("could not translate cypher query to pgsql: %w", err) + } + + queryBuilder := format.NewOutputBuilder() + if result.Parameters != nil { + queryBuilder.WithMaterializedParameters(result.Parameters) + } + + sqlQuery, err := format.Statement(result.Statement, queryBuilder) + if err != nil { + return TranslateCypherOutput{}, fmt.Errorf("could not format translated statement into a string query: %w", err) + } + + formattedQuery, err := sqlfmt.Format(sqlQuery, &sqlfmt.Options{Distance: options.SQLFormatDistance}) + if err != nil { + formattedQuery = sqlQuery + } + + output := TranslateCypherOutput{SQL: formattedQuery} + if options.DumpPGAst { + output.PGAst = spew.Sdump(result.Statement) + } + + return output, nil +} + func parseCmd() CommandDesc { return CommandDesc{ args: []string{"<...query>"}, @@ -65,52 +120,21 @@ func translateToPsqlCmd() CommandDesc { } fields = flagSet.Args() - query, err := parseQueryArray(fields) + output, err := TranslateCypherToPsql(ctx, strings.Join(fields, " "), TranslateCypherOptions{ + Connection: kindMapperConnRef, + DumpPGAst: dumpTranslatedAst, + }) if err != nil { - return fmt.Errorf("error trying to parse query '%s': %w", fields, err) - } - - kindMapper := stubs.EmptyMapper() - if kindMapperConnRef != "" { - // Fetch kinds regardless of if it's already loaded. - kindMap, err := loadKindMap(ctx, kindMapperConnRef) - if err != nil { - return fmt.Errorf("could not load kind map for explain: %w", err) - } - kindMapper = stubs.MapperFromKindMap(kindMap) + return err } - result, err := translate.Translate(ctx, query, kindMapper, nil, defaultGraphID(ctx, kindMapperConnRef)) - if err != nil { - return fmt.Errorf("could not translate cypher query to pgsql: %w", err) - } - if dumpTranslatedAst { + if output.PGAst != "" { fmt.Fprintf(ctx.output, "TRANSLATOR AST\n\n") - ctx.output.WriteHighlighted(spew.Sdump(result.Statement), "golang") + ctx.output.WriteHighlighted(output.PGAst, "golang") fmt.Fprintf(ctx.output, "\n") } - // Certain queries will materialize parameters into the output when translated, so we need to build - // an OutputBuilder so we can carry forward those params. - queryBuilder := format.NewOutputBuilder() - if result.Parameters != nil { - queryBuilder.WithMaterializedParameters(result.Parameters) - } - - sqlQuery, err := format.Statement(result.Statement, queryBuilder) - if err != nil { - return fmt.Errorf("could not format translated statement into a string query: %w", err) - } - - formattedQuery, err := sqlfmt.Format(sqlQuery, &sqlfmt.Options{ - Distance: 0, - }) - if err != nil { - ctx.output.Warnf("could not format query: %s", err.Error()) - formattedQuery = sqlQuery - } - - ctx.output.WriteHighlighted(formattedQuery, "postgres") + ctx.output.WriteHighlighted(output.SQL, "postgres") return nil }, } @@ -133,44 +157,12 @@ func explainAsPsqlCmd() CommandDesc { return err } - // Fetch kinds regardless of if it's already loaded. - kindMap, err := loadKindMap(ctx, connName) - if err != nil { - return fmt.Errorf("could not load kind map for explain: %w", err) - } - - query, err := parseQueryArray(fields[1:]) - if err != nil { - return fmt.Errorf("could not parse query: %w", err) - } - - // Populate a DumbKindMapper from the database's kinds table - kindMapper := stubs.MapperFromKindMap(kindMap) - result, err := translate.Translate(ctx, query, kindMapper, nil, defaultGraphID(ctx, connName)) + translation, err := TranslateCypherToPsql(ctx, strings.Join(fields[1:], " "), TranslateCypherOptions{Connection: connName, SQLFormatDistance: 2}) if err != nil { - return fmt.Errorf("could not translate cypher query to pgsql: %w", err) - } - - // Certain queries will materialize parameters into the output when translated, so we need to build - // an OutputBuilder so we can carry forward those params. - queryBuilder := format.NewOutputBuilder() - if result.Parameters != nil { - queryBuilder.WithMaterializedParameters(result.Parameters) - } - - sqlQuery, err := format.Statement(result.Statement, queryBuilder) - if err != nil { - return fmt.Errorf("could not format translated statement into a string query: %w", err) + return err } - formattedQuery, err := sqlfmt.Format(sqlQuery, &sqlfmt.Options{ - Distance: 2, - }) - if err != nil { - ctx.output.Warnf("could not format query: %s", err.Error()) - formattedQuery = sqlQuery - } - explainSQLQuery := fmt.Sprintf("EXPLAIN %s", formattedQuery) + explainSQLQuery := fmt.Sprintf("EXPLAIN %s", translation.SQL) ctx.output.WriteHighlighted(explainSQLQuery, "postgres") fmt.Fprint(ctx.output, "\n\n") diff --git a/tools/dawgrun/pkg/commands/db.go b/tools/dawgrun/pkg/commands/db.go index 895da3b4..33deeeb0 100644 --- a/tools/dawgrun/pkg/commands/db.go +++ b/tools/dawgrun/pkg/commands/db.go @@ -19,6 +19,30 @@ import ( "github.com/specterops/dawgs/tools/dawgrun/pkg/types" ) +type OpenConnectionOptions struct { + Name string + ConnectionString string + Driver string + DefaultGraph string + InitGraph bool +} + +// OpenConnection opens a DAWGS backend connection and stores it in the command scope. +func OpenConnection(ctx *CommandContext, options OpenConnectionOptions) error { + name := strings.TrimSpace(options.Name) + connStr := strings.TrimSpace(options.ConnectionString) + if name == "" || connStr == "" { + return fmt.Errorf("connection name and connection string are required") + } + + _, err := ctx.OpenConnection(name, connStr, openConnectionOptions{ + driverName: strings.ToLower(strings.TrimSpace(options.Driver)), + defaultGraphName: options.DefaultGraph, + initGraphOnFail: options.InitGraph, + }) + return err +} + func listConnectionsCmd() CommandDesc { return CommandDesc{ args: []string{}, @@ -292,6 +316,11 @@ func loadKindMap(ctx *CommandContext, connName string) (stubs.KindMap, error) { return kindMap, nil } +// LoadKindMap loads and caches the database kind map for an open connection. +func LoadKindMap(ctx *CommandContext, connName string) (stubs.KindMap, error) { + return loadKindMap(ctx, connName) +} + func lookupKindCmd() CommandDesc { return CommandDesc{ args: []string{"", ""}, @@ -307,7 +336,6 @@ func lookupKindCmd() CommandDesc { kindMap, ok := ctx.scope.connKindMaps[connName] if !ok { - // Try to fetch the kind map if the connection is open var err error kindMap, err = loadKindMap(ctx, connName) if err != nil { @@ -347,7 +375,6 @@ func lookupKindIDCmd() CommandDesc { kindMap, ok := ctx.scope.connKindMaps[connName] if !ok { - // Try to fetch the kind map if the connection is open var err error kindMap, err = loadKindMap(ctx, connName) if err != nil { diff --git a/tools/dawgrun/pkg/commands/helpers.go b/tools/dawgrun/pkg/commands/helpers.go index cbc4222f..e6df5505 100644 --- a/tools/dawgrun/pkg/commands/helpers.go +++ b/tools/dawgrun/pkg/commands/helpers.go @@ -8,6 +8,11 @@ import ( ) func parseQueryArray(fields []string) (*cypherModels.RegularQuery, error) { + return ParseQueryText(strings.Join(fields, " ")) +} + +// ParseQueryText parses a Cypher query string using dawgrun's default parser settings. +func ParseQueryText(query string) (*cypherModels.RegularQuery, error) { cypherCtx := cypherFrontend.DefaultCypherContext() - return cypherFrontend.ParseCypher(cypherCtx, strings.Join(fields, " ")) + return cypherFrontend.ParseCypher(cypherCtx, query) } diff --git a/tools/dawgrun/pkg/mcpserver/cypher.go b/tools/dawgrun/pkg/mcpserver/cypher.go new file mode 100644 index 00000000..dc725e4e --- /dev/null +++ b/tools/dawgrun/pkg/mcpserver/cypher.go @@ -0,0 +1,133 @@ +package mcpserver + +import ( + "context" + "fmt" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/specterops/dawgs/graph" + + "github.com/specterops/dawgs/tools/dawgrun/pkg/commands" + "github.com/specterops/dawgs/tools/dawgrun/pkg/texttools" +) + +const ( + defaultQueryLimit = 100 + maxQueryLimit = 1000 +) + +type QueryCypherInput struct { + Connection string `json:"connection" jsonschema:"open connection name"` + Query string `json:"query" jsonschema:"CySQL/Cypher query to execute"` + Limit int `json:"limit,omitempty" jsonschema:"maximum rows to return; defaults to 100 and caps at 1000"` +} + +type QueryCypherOutput struct { + Columns []string `json:"columns"` + Rows []map[string]any `json:"rows"` + RowCount int `json:"row_count"` + Truncated bool `json:"truncated"` +} + +type ExplainPsqlInput struct { + Connection string `json:"connection" jsonschema:"open PostgreSQL connection name"` + Query string `json:"query" jsonschema:"CySQL/Cypher query to explain"` +} + +type ExplainPsqlOutput struct { + SQL string `json:"sql"` + Plan []string `json:"plan"` +} + +func (s *Server) queryCypher(ctx context.Context, _ *mcp.CallToolRequest, input QueryCypherInput) (*mcp.CallToolResult, QueryCypherOutput, error) { + cmdCtx := s.commandContext(ctx) + conn, err := cmdCtx.EnsureConnection(input.Connection) + if err != nil { + return nil, QueryCypherOutput{}, err + } + + limit := normalizeQueryLimit(input.Limit) + output := QueryCypherOutput{Rows: []map[string]any{}} + err = conn.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Query(input.Query, nil) + if err := result.Error(); err != nil { + return fmt.Errorf("error running cypher query: %w", err) + } + defer result.Close() + + for result.Next() { + values := result.Values() + if output.Columns == nil { + output.Columns = texttools.BuildCypherResultColumns(result.Keys(), len(values)) + } + + if len(output.Rows) >= limit { + output.Truncated = true + break + } + + output.Rows = append(output.Rows, texttools.BuildCypherResultJSONRow(output.Columns, values)) + } + + return result.Error() + }) + if err != nil { + return nil, QueryCypherOutput{}, err + } + + if output.Columns == nil { + output.Columns = []string{} + } + output.RowCount = len(output.Rows) + + return nil, output, nil +} + +func (s *Server) explainPsql(ctx context.Context, _ *mcp.CallToolRequest, input ExplainPsqlInput) (*mcp.CallToolResult, ExplainPsqlOutput, error) { + cmdCtx := s.commandContext(ctx) + conn, err := cmdCtx.EnsureConnection(input.Connection) + if err != nil { + return nil, ExplainPsqlOutput{}, err + } + + translation, err := commands.TranslateCypherToPsql(cmdCtx, input.Query, commands.TranslateCypherOptions{Connection: input.Connection, SQLFormatDistance: 2}) + if err != nil { + return nil, ExplainPsqlOutput{}, err + } + + explainSQLQuery := fmt.Sprintf("EXPLAIN %s", translation.SQL) + output := ExplainPsqlOutput{SQL: explainSQLQuery} + err = conn.ReadTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Raw(explainSQLQuery, nil) + if err := result.Error(); err != nil { + return fmt.Errorf("error running raw query: %w", err) + } + defer result.Close() + + for result.Next() { + var value string + if err := graph.ScanNextResult(result, &value); err != nil { + return fmt.Errorf("could not scan EXPLAIN row: %w", err) + } + output.Plan = append(output.Plan, value) + } + + return result.Error() + }) + if err != nil { + return nil, ExplainPsqlOutput{}, err + } + + return nil, output, nil +} + +func normalizeQueryLimit(limit int) int { + if limit <= 0 { + return defaultQueryLimit + } + if limit > maxQueryLimit { + return maxQueryLimit + } + + return limit +} diff --git a/tools/dawgrun/pkg/mcpserver/kinds.go b/tools/dawgrun/pkg/mcpserver/kinds.go new file mode 100644 index 00000000..7f87fa36 --- /dev/null +++ b/tools/dawgrun/pkg/mcpserver/kinds.go @@ -0,0 +1,80 @@ +package mcpserver + +import ( + "context" + "strconv" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/specterops/dawgs/graph" + + "github.com/specterops/dawgs/tools/dawgrun/pkg/commands" + "github.com/specterops/dawgs/tools/dawgrun/pkg/stubs" +) + +type LoadDBKindsInput struct { + Connection string `json:"connection" jsonschema:"open PostgreSQL connection name"` +} + +type LoadDBKindsOutput struct { + Kinds map[string]string `json:"kinds"` +} + +type LookupKindInput struct { + Connection string `json:"connection" jsonschema:"open PostgreSQL connection name"` + Kind string `json:"kind" jsonschema:"kind name"` +} + +type LookupKindOutput struct { + Kind string `json:"kind"` + KindID int16 `json:"kind_id"` +} + +type LookupKindIDInput struct { + Connection string `json:"connection" jsonschema:"open PostgreSQL connection name"` + KindID int16 `json:"kind_id" jsonschema:"kind ID"` +} + +type LookupKindIDOutput struct { + KindID int16 `json:"kind_id"` + Kind string `json:"kind"` +} + +func (s *Server) loadDBKinds(ctx context.Context, _ *mcp.CallToolRequest, input LoadDBKindsInput) (*mcp.CallToolResult, LoadDBKindsOutput, error) { + kindMap, err := commands.LoadKindMap(s.commandContext(ctx), input.Connection) + if err != nil { + return nil, LoadDBKindsOutput{}, err + } + + output := LoadDBKindsOutput{Kinds: make(map[string]string, len(kindMap))} + for kindID, kind := range kindMap { + output.Kinds[strconv.FormatInt(int64(kindID), 10)] = kind.String() + } + + return nil, output, nil +} + +func (s *Server) lookupKind(ctx context.Context, _ *mcp.CallToolRequest, input LookupKindInput) (*mcp.CallToolResult, LookupKindOutput, error) { + kindMap, err := commands.LoadKindMap(s.commandContext(ctx), input.Connection) + if err != nil { + return nil, LookupKindOutput{}, err + } + kindID, err := stubs.MapperFromKindMap(kindMap).GetIDByKind(graph.StringKind(input.Kind)) + if err != nil { + return nil, LookupKindOutput{}, err + } + + return nil, LookupKindOutput{Kind: input.Kind, KindID: kindID}, nil +} + +func (s *Server) lookupKindID(ctx context.Context, _ *mcp.CallToolRequest, input LookupKindIDInput) (*mcp.CallToolResult, LookupKindIDOutput, error) { + kindMap, err := commands.LoadKindMap(s.commandContext(ctx), input.Connection) + if err != nil { + return nil, LookupKindIDOutput{}, err + } + kind, err := stubs.MapperFromKindMap(kindMap).GetKindByID(input.KindID) + if err != nil { + return nil, LookupKindIDOutput{}, err + } + + return nil, LookupKindIDOutput{KindID: input.KindID, Kind: kind.String()}, nil +} diff --git a/tools/dawgrun/pkg/mcpserver/opengraph.go b/tools/dawgrun/pkg/mcpserver/opengraph.go new file mode 100644 index 00000000..8c7c24b5 --- /dev/null +++ b/tools/dawgrun/pkg/mcpserver/opengraph.go @@ -0,0 +1,158 @@ +package mcpserver + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/specterops/dawgs/opengraph" +) + +const maxInlineOpenGraphBytes = 1024 * 1024 + +type SaveOpenGraphInput struct { + Connection string `json:"connection" jsonschema:"open connection name"` + OutputPath string `json:"output_path,omitempty" jsonschema:"optional file path for OpenGraph JSON output"` +} + +type SaveOpenGraphOutput struct { + OutputPath string `json:"output_path,omitempty"` + Content string `json:"content,omitempty"` + Bytes int `json:"bytes"` +} + +type LoadOpenGraphInput struct { + Connection string `json:"connection" jsonschema:"open connection name"` + InputPath string `json:"input_path" jsonschema:"OpenGraph JSON file path to load"` +} + +type LoadOpenGraphOutput struct { + Nodes int `json:"nodes"` + Edges int `json:"edges"` +} + +type CopyOpenGraphInput struct { + FromConnection string `json:"from_connection" jsonschema:"source open connection name"` + ToConnection string `json:"to_connection" jsonschema:"destination open connection name"` +} + +type CopyOpenGraphOutput struct { + Nodes int `json:"nodes"` + Edges int `json:"edges"` +} + +func (s *Server) saveOpenGraph(ctx context.Context, _ *mcp.CallToolRequest, input SaveOpenGraphInput) (*mcp.CallToolResult, SaveOpenGraphOutput, error) { + conn, err := s.commandContext(ctx).EnsureConnection(input.Connection) + if err != nil { + return nil, SaveOpenGraphOutput{}, err + } + + outputPath := strings.TrimSpace(input.OutputPath) + if outputPath != "" { + outFile, err := os.Create(outputPath) + if err != nil { + return nil, SaveOpenGraphOutput{}, fmt.Errorf("could not create output file %s: %w", outputPath, err) + } + defer outFile.Close() + + if err := opengraph.Export(ctx, conn, outFile); err != nil { + return nil, SaveOpenGraphOutput{}, fmt.Errorf("could not export opengraph data: %w", err) + } + + if fileInfo, err := outFile.Stat(); err == nil { + return nil, SaveOpenGraphOutput{OutputPath: outputPath, Bytes: int(fileInfo.Size())}, nil + } + + return nil, SaveOpenGraphOutput{OutputPath: outputPath}, nil + } + + buffer := new(bytes.Buffer) + if err := opengraph.Export(ctx, conn, buffer); err != nil { + return nil, SaveOpenGraphOutput{}, fmt.Errorf("could not export opengraph data: %w", err) + } + if buffer.Len() > maxInlineOpenGraphBytes { + return nil, SaveOpenGraphOutput{}, fmt.Errorf("OpenGraph export is %d bytes, exceeding inline limit %d; provide output_path", buffer.Len(), maxInlineOpenGraphBytes) + } + + return nil, SaveOpenGraphOutput{Content: buffer.String(), Bytes: buffer.Len()}, nil +} + +func (s *Server) loadOpenGraph(ctx context.Context, _ *mcp.CallToolRequest, input LoadOpenGraphInput) (*mcp.CallToolResult, LoadOpenGraphOutput, error) { + if !s.allowWrites { + return nil, LoadOpenGraphOutput{}, fmt.Errorf("load_opengraph requires --allow-writes") + } + + inputFile, err := os.Open(input.InputPath) + if err != nil { + return nil, LoadOpenGraphOutput{}, fmt.Errorf("could not open opengraph input file %s: %w", input.InputPath, err) + } + defer inputFile.Close() + + conn, err := s.commandContext(ctx).EnsureConnection(input.Connection) + if err != nil { + return nil, LoadOpenGraphOutput{}, err + } + doc, err := opengraph.ParseDocument(inputFile) + if err != nil { + return nil, LoadOpenGraphOutput{}, fmt.Errorf("could not parse opengraph input file %s: %w", input.InputPath, err) + } + if _, err := opengraph.WriteGraph(ctx, conn, &doc.Graph); err != nil { + return nil, LoadOpenGraphOutput{}, fmt.Errorf("could not write opengraph data into connection %s: %w", input.Connection, err) + } + + return nil, LoadOpenGraphOutput{Nodes: len(doc.Graph.Nodes), Edges: len(doc.Graph.Edges)}, nil +} + +func (s *Server) copyOpenGraph(ctx context.Context, _ *mcp.CallToolRequest, input CopyOpenGraphInput) (*mcp.CallToolResult, CopyOpenGraphOutput, error) { + if !s.allowWrites { + return nil, CopyOpenGraphOutput{}, fmt.Errorf("copy_opengraph requires --allow-writes") + } + + if input.FromConnection == input.ToConnection { + return nil, CopyOpenGraphOutput{}, fmt.Errorf("source and destination connections must differ") + } + + cmdCtx := s.commandContext(ctx) + fromConn, err := cmdCtx.EnsureConnection(input.FromConnection) + if err != nil { + return nil, CopyOpenGraphOutput{}, err + } + toConn, err := cmdCtx.EnsureConnection(input.ToConnection) + if err != nil { + return nil, CopyOpenGraphOutput{}, err + } + + pipeReader, pipeWriter := io.Pipe() + exportErrCh := make(chan error, 1) + go func() { + if err := opengraph.Export(ctx, fromConn, pipeWriter); err != nil { + _ = pipeWriter.CloseWithError(err) + exportErrCh <- err + return + } + + exportErrCh <- pipeWriter.Close() + }() + + doc, err := opengraph.ParseDocument(pipeReader) + if err != nil { + _ = pipeReader.CloseWithError(err) + } + exportErr := <-exportErrCh + if exportErr != nil { + return nil, CopyOpenGraphOutput{}, fmt.Errorf("could not export opengraph data from connection %s: %w", input.FromConnection, exportErr) + } + if err != nil { + return nil, CopyOpenGraphOutput{}, fmt.Errorf("could not parse streamed opengraph data from connection %s: %w", input.FromConnection, err) + } + + if _, err := opengraph.WriteGraph(ctx, toConn, &doc.Graph); err != nil { + return nil, CopyOpenGraphOutput{}, fmt.Errorf("could not copy opengraph data into connection %s: %w", input.ToConnection, err) + } + + return nil, CopyOpenGraphOutput{Nodes: len(doc.Graph.Nodes), Edges: len(doc.Graph.Edges)}, nil +} diff --git a/tools/dawgrun/pkg/mcpserver/server.go b/tools/dawgrun/pkg/mcpserver/server.go new file mode 100644 index 00000000..877ede94 --- /dev/null +++ b/tools/dawgrun/pkg/mcpserver/server.go @@ -0,0 +1,167 @@ +package mcpserver + +import ( + "context" + "strings" + + "github.com/davecgh/go-spew/spew" + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/specterops/dawgs/tools/dawgrun/pkg/commands" +) + +type Options struct { + AllowWrites bool +} + +type Server struct { + server *mcp.Server + scope *commands.Scope + allowWrites bool +} + +type ListConnectionsInput struct{} + +type ListConnectionsOutput struct { + Open []string `json:"open"` +} + +type OpenConnectionInput struct { + Name string `json:"name" jsonschema:"session-local connection name"` + ConnectionString string `json:"connection_string" jsonschema:"PostgreSQL or Neo4j connection string; treated as sensitive"` + Driver string `json:"driver,omitempty" jsonschema:"optional driver override: pg or neo4j"` + DefaultGraph string `json:"default_graph,omitempty" jsonschema:"default graph name; defaults to default"` + InitGraph bool `json:"init_graph,omitempty" jsonschema:"create default graph if selecting it fails"` +} + +type OpenConnectionOutput struct { + Name string `json:"name"` +} + +type ParseCypherInput struct { + Query string `json:"query" jsonschema:"CySQL/Cypher query to parse"` +} + +type ParseCypherOutput struct { + AST string `json:"ast"` +} + +type TranslateCypherInput struct { + Query string `json:"query" jsonschema:"CySQL/Cypher query to translate"` + Connection string `json:"connection,omitempty" jsonschema:"optional open connection name for live kind mapping"` + DumpPGAst bool `json:"dump_pg_ast,omitempty" jsonschema:"include PostgreSQL AST dump"` +} + +type TranslateCypherOutput struct { + SQL string `json:"sql"` + PGAst string `json:"pg_ast,omitempty"` +} + +func New(options Options) *Server { + return NewWithScope(commands.NewScope(commands.RunModeCLI), options) +} + +func NewWithScope(scope *commands.Scope, options Options) *Server { + s := &Server{ + server: mcp.NewServer(&mcp.Implementation{ + Name: "dawgrun-mcp", + Version: "0.1.0", + }, nil), + scope: scope, + allowWrites: options.AllowWrites, + } + s.registerTools() + + return s +} + +func (s *Server) Run(ctx context.Context) error { + defer s.scope.CloseConnections(context.Background()) + return s.server.Run(ctx, &mcp.StdioTransport{}) +} + +func (s *Server) MCPServer() *mcp.Server { + return s.server +} + +func (s *Server) commandContext(ctx context.Context) *commands.CommandContext { + cmdCtx := commands.NewCommandContext(ctx, nil, s.scope, "") + cmdCtx.SetStyledOutputEnabled(false) + return cmdCtx +} + +func (s *Server) registerTools() { + mcp.AddTool(s.server, readOnlyTool("list_connections", "List Connections", "List open dawgrun connection names without exposing connection strings."), s.listConnections) + mcp.AddTool(s.server, readOnlyTool("open_connection", "Open Connection", "Open a named DAWGS backend connection. Connection strings are treated as sensitive and are not echoed."), s.openConnection) + mcp.AddTool(s.server, readOnlyTool("parse_cypher", "Parse Cypher", "Parse a CySQL/Cypher query and return an AST dump."), s.parseCypher) + mcp.AddTool(s.server, readOnlyTool("translate_cypher_to_pgsql", "Translate Cypher To PGSQL", "Translate a CySQL/Cypher query to PostgreSQL SQL."), s.translateCypher) + mcp.AddTool(s.server, readOnlyTool("query_cypher", "Query Cypher", "Execute a CySQL/Cypher query against a named DAWGS connection and return bounded rows."), s.queryCypher) + mcp.AddTool(s.server, readOnlyTool("explain_psql", "Explain PGSQL", "Translate a CySQL/Cypher query and return the PostgreSQL EXPLAIN plan."), s.explainPsql) + mcp.AddTool(s.server, readOnlyTool("save_opengraph", "Save OpenGraph", "Export OpenGraph JSON from a named connection."), s.saveOpenGraph) + mcp.AddTool(s.server, writeTool("load_opengraph", "Load OpenGraph", "Load OpenGraph JSON into a named connection. Requires --allow-writes."), s.loadOpenGraph) + mcp.AddTool(s.server, writeTool("copy_opengraph", "Copy OpenGraph", "Copy OpenGraph data between named connections. Requires --allow-writes."), s.copyOpenGraph) + mcp.AddTool(s.server, readOnlyTool("load_db_kinds", "Load DB Kinds", "Load and return the kind mapping from a named PostgreSQL connection."), s.loadDBKinds) + mcp.AddTool(s.server, readOnlyTool("lookup_kind", "Lookup Kind", "Resolve a kind name to its PostgreSQL kind ID."), s.lookupKind) + mcp.AddTool(s.server, readOnlyTool("lookup_kind_id", "Lookup Kind ID", "Resolve a PostgreSQL kind ID to its kind name."), s.lookupKindID) +} + +func readOnlyTool(name, title, description string) *mcp.Tool { + closedWorld := false + return &mcp.Tool{ + Name: name, + Title: title, + Description: description, + Annotations: &mcp.ToolAnnotations{ReadOnlyHint: true, OpenWorldHint: &closedWorld}, + } +} + +func writeTool(name, title, description string) *mcp.Tool { + destructive := true + closedWorld := false + return &mcp.Tool{ + Name: name, + Title: title, + Description: description, + Annotations: &mcp.ToolAnnotations{DestructiveHint: &destructive, OpenWorldHint: &closedWorld}, + } +} + +func (s *Server) listConnections(context.Context, *mcp.CallToolRequest, ListConnectionsInput) (*mcp.CallToolResult, ListConnectionsOutput, error) { + return nil, ListConnectionsOutput{Open: s.scope.GetConnectionNames()}, nil +} + +func (s *Server) openConnection(ctx context.Context, _ *mcp.CallToolRequest, input OpenConnectionInput) (*mcp.CallToolResult, OpenConnectionOutput, error) { + name := strings.TrimSpace(input.Name) + if err := commands.OpenConnection(s.commandContext(ctx), commands.OpenConnectionOptions{ + Name: name, + ConnectionString: input.ConnectionString, + Driver: input.Driver, + DefaultGraph: input.DefaultGraph, + InitGraph: input.InitGraph, + }); err != nil { + return nil, OpenConnectionOutput{}, err + } + + return nil, OpenConnectionOutput{Name: name}, nil +} + +func (s *Server) parseCypher(_ context.Context, _ *mcp.CallToolRequest, input ParseCypherInput) (*mcp.CallToolResult, ParseCypherOutput, error) { + query, err := commands.ParseQueryText(input.Query) + if err != nil { + return nil, ParseCypherOutput{}, err + } + + return nil, ParseCypherOutput{AST: spew.Sdump(query)}, nil +} + +func (s *Server) translateCypher(ctx context.Context, _ *mcp.CallToolRequest, input TranslateCypherInput) (*mcp.CallToolResult, TranslateCypherOutput, error) { + output, err := commands.TranslateCypherToPsql(s.commandContext(ctx), input.Query, commands.TranslateCypherOptions{ + Connection: input.Connection, + DumpPGAst: input.DumpPGAst, + }) + if err != nil { + return nil, TranslateCypherOutput{}, err + } + + return nil, TranslateCypherOutput{SQL: output.SQL, PGAst: output.PGAst}, nil +} diff --git a/tools/dawgrun/pkg/mcpserver/server_test.go b/tools/dawgrun/pkg/mcpserver/server_test.go new file mode 100644 index 00000000..99deed31 --- /dev/null +++ b/tools/dawgrun/pkg/mcpserver/server_test.go @@ -0,0 +1,92 @@ +package mcpserver + +import ( + "context" + "encoding/json" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/require" +) + +func TestServerAdvertisesExpectedTools(t *testing.T) { + ctx := context.Background() + session := connectTestServer(t, ctx, New(Options{})) + + tools := make(map[string]bool) + for tool, err := range session.Tools(ctx, nil) { + require.NoError(t, err) + tools[tool.Name] = true + } + + require.Equal(t, map[string]bool{ + "list_connections": true, + "open_connection": true, + "parse_cypher": true, + "translate_cypher_to_pgsql": true, + "query_cypher": true, + "explain_psql": true, + "save_opengraph": true, + "load_opengraph": true, + "copy_opengraph": true, + "load_db_kinds": true, + "lookup_kind": true, + "lookup_kind_id": true, + }, tools) +} + +func TestListConnectionsReturnsOpenNamesOnly(t *testing.T) { + ctx := context.Background() + session := connectTestServer(t, ctx, New(Options{})) + + result, err := session.CallTool(ctx, &mcp.CallToolParams{Name: "list_connections"}) + require.NoError(t, err) + require.False(t, result.IsError) + + output := decodeToolOutput[ListConnectionsOutput](t, result) + require.Empty(t, output.Open) +} + +func TestParseCypherReturnsStructuredOutput(t *testing.T) { + ctx := context.Background() + session := connectTestServer(t, ctx, New(Options{})) + + result, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "parse_cypher", + Arguments: map[string]any{ + "query": "MATCH (n) RETURN n", + }, + }) + require.NoError(t, err) + require.False(t, result.IsError) + + output := decodeToolOutput[ParseCypherOutput](t, result) + require.Contains(t, output.AST, "RegularQuery") +} + +func connectTestServer(t *testing.T, ctx context.Context, server *Server) *mcp.ClientSession { + t.Helper() + + serverTransport, clientTransport := mcp.NewInMemoryTransports() + serverSession, err := server.MCPServer().Connect(ctx, serverTransport, nil) + require.NoError(t, err) + t.Cleanup(func() { serverSession.Close() }) + + client := mcp.NewClient(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil) + clientSession, err := client.Connect(ctx, clientTransport, nil) + require.NoError(t, err) + t.Cleanup(func() { clientSession.Close() }) + + return clientSession +} + +func decodeToolOutput[T any](t *testing.T, result *mcp.CallToolResult) T { + t.Helper() + + var output T + data, err := json.Marshal(result.StructuredContent) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(data, &output)) + + return output +} diff --git a/tools/dawgrun/pkg/texttools/cypher.go b/tools/dawgrun/pkg/texttools/cypher.go index 1bb54b82..93cf3eae 100644 --- a/tools/dawgrun/pkg/texttools/cypher.go +++ b/tools/dawgrun/pkg/texttools/cypher.go @@ -140,6 +140,10 @@ func buildCypherResultColumns(keys []string, numValues int) []string { return columns } +func BuildCypherResultColumns(keys []string, numValues int) []string { + return buildCypherResultColumns(keys, numValues) +} + func buildCypherResultHeader(columns []string) table.Row { row := make(table.Row, len(columns)) for idx, key := range columns { @@ -162,6 +166,10 @@ func buildCypherResultJSONRow(columns []string, values []any) map[string]any { return row } +func BuildCypherResultJSONRow(columns []string, values []any) map[string]any { + return buildCypherResultJSONRow(columns, values) +} + func formatCypherResultJSONValue(value any) any { switch typed := value.(type) { case nil: From f8fe2216271b4f088a4df69c2e7de7475eadf41f Mon Sep 17 00:00:00 2001 From: kpom Date: Fri, 22 May 2026 10:33:33 -0500 Subject: [PATCH 2/2] general clean up of MCP stuff --- drivers/pg/batch.go | 4 +- tools/dawgrun/pkg/commands/db.go | 29 ------ tools/dawgrun/pkg/mcpserver/cypher.go | 102 ++------------------- tools/dawgrun/pkg/mcpserver/kinds.go | 37 ++------ tools/dawgrun/pkg/mcpserver/opengraph.go | 108 +++-------------------- tools/dawgrun/pkg/mcpserver/server.go | 43 +++++++-- tools/dawgrun/pkg/texttools/cypher.go | 8 -- 7 files changed, 67 insertions(+), 264 deletions(-) diff --git a/drivers/pg/batch.go b/drivers/pg/batch.go index ebea05cd..59fa4946 100644 --- a/drivers/pg/batch.go +++ b/drivers/pg/batch.go @@ -679,10 +679,10 @@ func (s *relationshipCreateBatchBuilder) Add(ctx context.Context, kindMapper Kin edgeProperties = edge.Properties.Clone() ) - if edgeKindID, err := kindMapper.MapKind(ctx, edge.Kind); err != nil { + if edgeKindIDs, err := kindMapper.AssertKinds(ctx, graph.Kinds{edge.Kind}); err != nil { return err } else { - s.relationshipUpdateBatch.Add(startID, endID, edgeKindID) + s.relationshipUpdateBatch.Add(startID, endID, edgeKindIDs[0]) } s.keyToEdgeID[key] = edgeID diff --git a/tools/dawgrun/pkg/commands/db.go b/tools/dawgrun/pkg/commands/db.go index 33deeeb0..71fcee36 100644 --- a/tools/dawgrun/pkg/commands/db.go +++ b/tools/dawgrun/pkg/commands/db.go @@ -19,30 +19,6 @@ import ( "github.com/specterops/dawgs/tools/dawgrun/pkg/types" ) -type OpenConnectionOptions struct { - Name string - ConnectionString string - Driver string - DefaultGraph string - InitGraph bool -} - -// OpenConnection opens a DAWGS backend connection and stores it in the command scope. -func OpenConnection(ctx *CommandContext, options OpenConnectionOptions) error { - name := strings.TrimSpace(options.Name) - connStr := strings.TrimSpace(options.ConnectionString) - if name == "" || connStr == "" { - return fmt.Errorf("connection name and connection string are required") - } - - _, err := ctx.OpenConnection(name, connStr, openConnectionOptions{ - driverName: strings.ToLower(strings.TrimSpace(options.Driver)), - defaultGraphName: options.DefaultGraph, - initGraphOnFail: options.InitGraph, - }) - return err -} - func listConnectionsCmd() CommandDesc { return CommandDesc{ args: []string{}, @@ -316,11 +292,6 @@ func loadKindMap(ctx *CommandContext, connName string) (stubs.KindMap, error) { return kindMap, nil } -// LoadKindMap loads and caches the database kind map for an open connection. -func LoadKindMap(ctx *CommandContext, connName string) (stubs.KindMap, error) { - return loadKindMap(ctx, connName) -} - func lookupKindCmd() CommandDesc { return CommandDesc{ args: []string{"", ""}, diff --git a/tools/dawgrun/pkg/mcpserver/cypher.go b/tools/dawgrun/pkg/mcpserver/cypher.go index dc725e4e..8406c4d4 100644 --- a/tools/dawgrun/pkg/mcpserver/cypher.go +++ b/tools/dawgrun/pkg/mcpserver/cypher.go @@ -2,31 +2,17 @@ package mcpserver import ( "context" - "fmt" "github.com/modelcontextprotocol/go-sdk/mcp" - "github.com/specterops/dawgs/graph" - - "github.com/specterops/dawgs/tools/dawgrun/pkg/commands" - "github.com/specterops/dawgs/tools/dawgrun/pkg/texttools" -) - -const ( - defaultQueryLimit = 100 - maxQueryLimit = 1000 ) type QueryCypherInput struct { Connection string `json:"connection" jsonschema:"open connection name"` Query string `json:"query" jsonschema:"CySQL/Cypher query to execute"` - Limit int `json:"limit,omitempty" jsonschema:"maximum rows to return; defaults to 100 and caps at 1000"` } type QueryCypherOutput struct { - Columns []string `json:"columns"` - Rows []map[string]any `json:"rows"` - RowCount int `json:"row_count"` - Truncated bool `json:"truncated"` + Output string `json:"output"` } type ExplainPsqlInput struct { @@ -35,99 +21,23 @@ type ExplainPsqlInput struct { } type ExplainPsqlOutput struct { - SQL string `json:"sql"` - Plan []string `json:"plan"` + Output string `json:"output"` } func (s *Server) queryCypher(ctx context.Context, _ *mcp.CallToolRequest, input QueryCypherInput) (*mcp.CallToolResult, QueryCypherOutput, error) { - cmdCtx := s.commandContext(ctx) - conn, err := cmdCtx.EnsureConnection(input.Connection) - if err != nil { - return nil, QueryCypherOutput{}, err - } - - limit := normalizeQueryLimit(input.Limit) - output := QueryCypherOutput{Rows: []map[string]any{}} - err = conn.ReadTransaction(ctx, func(tx graph.Transaction) error { - result := tx.Query(input.Query, nil) - if err := result.Error(); err != nil { - return fmt.Errorf("error running cypher query: %w", err) - } - defer result.Close() - - for result.Next() { - values := result.Values() - if output.Columns == nil { - output.Columns = texttools.BuildCypherResultColumns(result.Keys(), len(values)) - } - - if len(output.Rows) >= limit { - output.Truncated = true - break - } - - output.Rows = append(output.Rows, texttools.BuildCypherResultJSONRow(output.Columns, values)) - } - - return result.Error() - }) + output, err := s.runCommand(ctx, "query-cypher", []string{"-format", "json", input.Connection, input.Query}) if err != nil { return nil, QueryCypherOutput{}, err } - if output.Columns == nil { - output.Columns = []string{} - } - output.RowCount = len(output.Rows) - - return nil, output, nil + return nil, QueryCypherOutput{Output: output}, nil } func (s *Server) explainPsql(ctx context.Context, _ *mcp.CallToolRequest, input ExplainPsqlInput) (*mcp.CallToolResult, ExplainPsqlOutput, error) { - cmdCtx := s.commandContext(ctx) - conn, err := cmdCtx.EnsureConnection(input.Connection) - if err != nil { - return nil, ExplainPsqlOutput{}, err - } - - translation, err := commands.TranslateCypherToPsql(cmdCtx, input.Query, commands.TranslateCypherOptions{Connection: input.Connection, SQLFormatDistance: 2}) - if err != nil { - return nil, ExplainPsqlOutput{}, err - } - - explainSQLQuery := fmt.Sprintf("EXPLAIN %s", translation.SQL) - output := ExplainPsqlOutput{SQL: explainSQLQuery} - err = conn.ReadTransaction(ctx, func(tx graph.Transaction) error { - result := tx.Raw(explainSQLQuery, nil) - if err := result.Error(); err != nil { - return fmt.Errorf("error running raw query: %w", err) - } - defer result.Close() - - for result.Next() { - var value string - if err := graph.ScanNextResult(result, &value); err != nil { - return fmt.Errorf("could not scan EXPLAIN row: %w", err) - } - output.Plan = append(output.Plan, value) - } - - return result.Error() - }) + output, err := s.runCommand(ctx, "explain-psql", []string{input.Connection, input.Query}) if err != nil { return nil, ExplainPsqlOutput{}, err } - return nil, output, nil -} - -func normalizeQueryLimit(limit int) int { - if limit <= 0 { - return defaultQueryLimit - } - if limit > maxQueryLimit { - return maxQueryLimit - } - - return limit + return nil, ExplainPsqlOutput{Output: output}, nil } diff --git a/tools/dawgrun/pkg/mcpserver/kinds.go b/tools/dawgrun/pkg/mcpserver/kinds.go index 7f87fa36..ece58977 100644 --- a/tools/dawgrun/pkg/mcpserver/kinds.go +++ b/tools/dawgrun/pkg/mcpserver/kinds.go @@ -5,10 +5,6 @@ import ( "strconv" "github.com/modelcontextprotocol/go-sdk/mcp" - "github.com/specterops/dawgs/graph" - - "github.com/specterops/dawgs/tools/dawgrun/pkg/commands" - "github.com/specterops/dawgs/tools/dawgrun/pkg/stubs" ) type LoadDBKindsInput struct { @@ -16,7 +12,7 @@ type LoadDBKindsInput struct { } type LoadDBKindsOutput struct { - Kinds map[string]string `json:"kinds"` + Output string `json:"output"` } type LookupKindInput struct { @@ -25,8 +21,7 @@ type LookupKindInput struct { } type LookupKindOutput struct { - Kind string `json:"kind"` - KindID int16 `json:"kind_id"` + Output string `json:"output"` } type LookupKindIDInput struct { @@ -35,46 +30,32 @@ type LookupKindIDInput struct { } type LookupKindIDOutput struct { - KindID int16 `json:"kind_id"` - Kind string `json:"kind"` + Output string `json:"output"` } func (s *Server) loadDBKinds(ctx context.Context, _ *mcp.CallToolRequest, input LoadDBKindsInput) (*mcp.CallToolResult, LoadDBKindsOutput, error) { - kindMap, err := commands.LoadKindMap(s.commandContext(ctx), input.Connection) + output, err := s.runCommand(ctx, "load-db-kinds", []string{input.Connection}) if err != nil { return nil, LoadDBKindsOutput{}, err } - output := LoadDBKindsOutput{Kinds: make(map[string]string, len(kindMap))} - for kindID, kind := range kindMap { - output.Kinds[strconv.FormatInt(int64(kindID), 10)] = kind.String() - } - - return nil, output, nil + return nil, LoadDBKindsOutput{Output: output}, nil } func (s *Server) lookupKind(ctx context.Context, _ *mcp.CallToolRequest, input LookupKindInput) (*mcp.CallToolResult, LookupKindOutput, error) { - kindMap, err := commands.LoadKindMap(s.commandContext(ctx), input.Connection) - if err != nil { - return nil, LookupKindOutput{}, err - } - kindID, err := stubs.MapperFromKindMap(kindMap).GetIDByKind(graph.StringKind(input.Kind)) + output, err := s.runCommand(ctx, "lookup-kind", []string{input.Connection, input.Kind}) if err != nil { return nil, LookupKindOutput{}, err } - return nil, LookupKindOutput{Kind: input.Kind, KindID: kindID}, nil + return nil, LookupKindOutput{Output: output}, nil } func (s *Server) lookupKindID(ctx context.Context, _ *mcp.CallToolRequest, input LookupKindIDInput) (*mcp.CallToolResult, LookupKindIDOutput, error) { - kindMap, err := commands.LoadKindMap(s.commandContext(ctx), input.Connection) - if err != nil { - return nil, LookupKindIDOutput{}, err - } - kind, err := stubs.MapperFromKindMap(kindMap).GetKindByID(input.KindID) + output, err := s.runCommand(ctx, "lookup-kind-id", []string{input.Connection, strconv.FormatInt(int64(input.KindID), 10)}) if err != nil { return nil, LookupKindIDOutput{}, err } - return nil, LookupKindIDOutput{KindID: input.KindID, Kind: kind.String()}, nil + return nil, LookupKindIDOutput{Output: output}, nil } diff --git a/tools/dawgrun/pkg/mcpserver/opengraph.go b/tools/dawgrun/pkg/mcpserver/opengraph.go index 8c7c24b5..70582874 100644 --- a/tools/dawgrun/pkg/mcpserver/opengraph.go +++ b/tools/dawgrun/pkg/mcpserver/opengraph.go @@ -1,28 +1,20 @@ package mcpserver import ( - "bytes" "context" "fmt" - "io" - "os" "strings" "github.com/modelcontextprotocol/go-sdk/mcp" - "github.com/specterops/dawgs/opengraph" ) -const maxInlineOpenGraphBytes = 1024 * 1024 - type SaveOpenGraphInput struct { Connection string `json:"connection" jsonschema:"open connection name"` OutputPath string `json:"output_path,omitempty" jsonschema:"optional file path for OpenGraph JSON output"` } type SaveOpenGraphOutput struct { - OutputPath string `json:"output_path,omitempty"` - Content string `json:"content,omitempty"` - Bytes int `json:"bytes"` + Output string `json:"output"` } type LoadOpenGraphInput struct { @@ -31,8 +23,7 @@ type LoadOpenGraphInput struct { } type LoadOpenGraphOutput struct { - Nodes int `json:"nodes"` - Edges int `json:"edges"` + Output string `json:"output"` } type CopyOpenGraphInput struct { @@ -41,44 +32,23 @@ type CopyOpenGraphInput struct { } type CopyOpenGraphOutput struct { - Nodes int `json:"nodes"` - Edges int `json:"edges"` + Output string `json:"output"` } func (s *Server) saveOpenGraph(ctx context.Context, _ *mcp.CallToolRequest, input SaveOpenGraphInput) (*mcp.CallToolResult, SaveOpenGraphOutput, error) { - conn, err := s.commandContext(ctx).EnsureConnection(input.Connection) - if err != nil { - return nil, SaveOpenGraphOutput{}, err - } - + args := []string{} outputPath := strings.TrimSpace(input.OutputPath) if outputPath != "" { - outFile, err := os.Create(outputPath) - if err != nil { - return nil, SaveOpenGraphOutput{}, fmt.Errorf("could not create output file %s: %w", outputPath, err) - } - defer outFile.Close() - - if err := opengraph.Export(ctx, conn, outFile); err != nil { - return nil, SaveOpenGraphOutput{}, fmt.Errorf("could not export opengraph data: %w", err) - } - - if fileInfo, err := outFile.Stat(); err == nil { - return nil, SaveOpenGraphOutput{OutputPath: outputPath, Bytes: int(fileInfo.Size())}, nil - } - - return nil, SaveOpenGraphOutput{OutputPath: outputPath}, nil + args = append(args, "-out", outputPath) } + args = append(args, input.Connection) - buffer := new(bytes.Buffer) - if err := opengraph.Export(ctx, conn, buffer); err != nil { - return nil, SaveOpenGraphOutput{}, fmt.Errorf("could not export opengraph data: %w", err) - } - if buffer.Len() > maxInlineOpenGraphBytes { - return nil, SaveOpenGraphOutput{}, fmt.Errorf("OpenGraph export is %d bytes, exceeding inline limit %d; provide output_path", buffer.Len(), maxInlineOpenGraphBytes) + output, err := s.runCommand(ctx, "save-opengraph", args) + if err != nil { + return nil, SaveOpenGraphOutput{}, err } - return nil, SaveOpenGraphOutput{Content: buffer.String(), Bytes: buffer.Len()}, nil + return nil, SaveOpenGraphOutput{Output: output}, nil } func (s *Server) loadOpenGraph(ctx context.Context, _ *mcp.CallToolRequest, input LoadOpenGraphInput) (*mcp.CallToolResult, LoadOpenGraphOutput, error) { @@ -86,25 +56,12 @@ func (s *Server) loadOpenGraph(ctx context.Context, _ *mcp.CallToolRequest, inpu return nil, LoadOpenGraphOutput{}, fmt.Errorf("load_opengraph requires --allow-writes") } - inputFile, err := os.Open(input.InputPath) - if err != nil { - return nil, LoadOpenGraphOutput{}, fmt.Errorf("could not open opengraph input file %s: %w", input.InputPath, err) - } - defer inputFile.Close() - - conn, err := s.commandContext(ctx).EnsureConnection(input.Connection) + output, err := s.runCommand(ctx, "load-opengraph", []string{input.Connection, input.InputPath}) if err != nil { return nil, LoadOpenGraphOutput{}, err } - doc, err := opengraph.ParseDocument(inputFile) - if err != nil { - return nil, LoadOpenGraphOutput{}, fmt.Errorf("could not parse opengraph input file %s: %w", input.InputPath, err) - } - if _, err := opengraph.WriteGraph(ctx, conn, &doc.Graph); err != nil { - return nil, LoadOpenGraphOutput{}, fmt.Errorf("could not write opengraph data into connection %s: %w", input.Connection, err) - } - return nil, LoadOpenGraphOutput{Nodes: len(doc.Graph.Nodes), Edges: len(doc.Graph.Edges)}, nil + return nil, LoadOpenGraphOutput{Output: output}, nil } func (s *Server) copyOpenGraph(ctx context.Context, _ *mcp.CallToolRequest, input CopyOpenGraphInput) (*mcp.CallToolResult, CopyOpenGraphOutput, error) { @@ -112,47 +69,10 @@ func (s *Server) copyOpenGraph(ctx context.Context, _ *mcp.CallToolRequest, inpu return nil, CopyOpenGraphOutput{}, fmt.Errorf("copy_opengraph requires --allow-writes") } - if input.FromConnection == input.ToConnection { - return nil, CopyOpenGraphOutput{}, fmt.Errorf("source and destination connections must differ") - } - - cmdCtx := s.commandContext(ctx) - fromConn, err := cmdCtx.EnsureConnection(input.FromConnection) - if err != nil { - return nil, CopyOpenGraphOutput{}, err - } - toConn, err := cmdCtx.EnsureConnection(input.ToConnection) + output, err := s.runCommand(ctx, "copy-opengraph", []string{input.FromConnection, input.ToConnection}) if err != nil { return nil, CopyOpenGraphOutput{}, err } - pipeReader, pipeWriter := io.Pipe() - exportErrCh := make(chan error, 1) - go func() { - if err := opengraph.Export(ctx, fromConn, pipeWriter); err != nil { - _ = pipeWriter.CloseWithError(err) - exportErrCh <- err - return - } - - exportErrCh <- pipeWriter.Close() - }() - - doc, err := opengraph.ParseDocument(pipeReader) - if err != nil { - _ = pipeReader.CloseWithError(err) - } - exportErr := <-exportErrCh - if exportErr != nil { - return nil, CopyOpenGraphOutput{}, fmt.Errorf("could not export opengraph data from connection %s: %w", input.FromConnection, exportErr) - } - if err != nil { - return nil, CopyOpenGraphOutput{}, fmt.Errorf("could not parse streamed opengraph data from connection %s: %w", input.FromConnection, err) - } - - if _, err := opengraph.WriteGraph(ctx, toConn, &doc.Graph); err != nil { - return nil, CopyOpenGraphOutput{}, fmt.Errorf("could not copy opengraph data into connection %s: %w", input.ToConnection, err) - } - - return nil, CopyOpenGraphOutput{Nodes: len(doc.Graph.Nodes), Edges: len(doc.Graph.Edges)}, nil + return nil, CopyOpenGraphOutput{Output: output}, nil } diff --git a/tools/dawgrun/pkg/mcpserver/server.go b/tools/dawgrun/pkg/mcpserver/server.go index 877ede94..3081eb2c 100644 --- a/tools/dawgrun/pkg/mcpserver/server.go +++ b/tools/dawgrun/pkg/mcpserver/server.go @@ -2,7 +2,9 @@ package mcpserver import ( "context" + "fmt" "strings" + "sync" "github.com/davecgh/go-spew/spew" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -17,6 +19,7 @@ type Options struct { type Server struct { server *mcp.Server scope *commands.Scope + commandMu sync.Mutex allowWrites bool } @@ -90,6 +93,26 @@ func (s *Server) commandContext(ctx context.Context) *commands.CommandContext { return cmdCtx } +func (s *Server) runCommand(ctx context.Context, name string, args []string) (string, error) { + s.commandMu.Lock() + defer s.commandMu.Unlock() + + cmd, ok := commands.Registry()[name] + if !ok { + return "", fmt.Errorf("unknown command %s", name) + } + + cmdCtx := s.commandContext(ctx) + if cmd.ClearFlagsFn != nil { + defer cmd.ClearFlagsFn() + } + if err := cmd.Fn(cmdCtx, args); err != nil { + return "", err + } + + return cmdCtx.OutputString(), nil +} + func (s *Server) registerTools() { mcp.AddTool(s.server, readOnlyTool("list_connections", "List Connections", "List open dawgrun connection names without exposing connection strings."), s.listConnections) mcp.AddTool(s.server, readOnlyTool("open_connection", "Open Connection", "Open a named DAWGS backend connection. Connection strings are treated as sensitive and are not echoed."), s.openConnection) @@ -132,13 +155,19 @@ func (s *Server) listConnections(context.Context, *mcp.CallToolRequest, ListConn func (s *Server) openConnection(ctx context.Context, _ *mcp.CallToolRequest, input OpenConnectionInput) (*mcp.CallToolResult, OpenConnectionOutput, error) { name := strings.TrimSpace(input.Name) - if err := commands.OpenConnection(s.commandContext(ctx), commands.OpenConnectionOptions{ - Name: name, - ConnectionString: input.ConnectionString, - Driver: input.Driver, - DefaultGraph: input.DefaultGraph, - InitGraph: input.InitGraph, - }); err != nil { + args := []string{} + if driver := strings.TrimSpace(input.Driver); driver != "" { + args = append(args, "-driver", driver) + } + if defaultGraph := strings.TrimSpace(input.DefaultGraph); defaultGraph != "" { + args = append(args, "-default-graph", defaultGraph) + } + if input.InitGraph { + args = append(args, "-init-graph") + } + args = append(args, name, input.ConnectionString) + + if _, err := s.runCommand(ctx, "open", args); err != nil { return nil, OpenConnectionOutput{}, err } diff --git a/tools/dawgrun/pkg/texttools/cypher.go b/tools/dawgrun/pkg/texttools/cypher.go index 93cf3eae..1bb54b82 100644 --- a/tools/dawgrun/pkg/texttools/cypher.go +++ b/tools/dawgrun/pkg/texttools/cypher.go @@ -140,10 +140,6 @@ func buildCypherResultColumns(keys []string, numValues int) []string { return columns } -func BuildCypherResultColumns(keys []string, numValues int) []string { - return buildCypherResultColumns(keys, numValues) -} - func buildCypherResultHeader(columns []string) table.Row { row := make(table.Row, len(columns)) for idx, key := range columns { @@ -166,10 +162,6 @@ func buildCypherResultJSONRow(columns []string, values []any) map[string]any { return row } -func BuildCypherResultJSONRow(columns []string, values []any) map[string]any { - return buildCypherResultJSONRow(columns, values) -} - func formatCypherResultJSONValue(value any) any { switch typed := value.(type) { case nil: