diff --git a/README.md b/README.md index e1aad7b..8399bb5 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,11 @@ export CONNECTION_STRING="neo4j://neo4j:weneedbetterpasswords@localhost:7687" Use `make test` for unit tests only and `make test_integration` for integration tests only. +### Dawgrun MCP Server + +`go tool dawgrun-mcp` runs a stdio MCP server for agent access to dawgrun-backed DAWGS inspection tools. See +[`tools/dawgrun/README.md`](tools/dawgrun/README.md) for client configuration and the current tool surface. + ### Test Metrics `make test` writes unit test coverage artifacts under `.coverage/`: diff --git a/drivers/pg/batch.go b/drivers/pg/batch.go index ebea05c..59fa494 100644 --- a/drivers/pg/batch.go +++ b/drivers/pg/batch.go @@ -679,10 +679,10 @@ func (s *relationshipCreateBatchBuilder) Add(ctx context.Context, kindMapper Kin edgeProperties = edge.Properties.Clone() ) - if edgeKindID, err := kindMapper.MapKind(ctx, edge.Kind); err != nil { + if edgeKindIDs, err := kindMapper.AssertKinds(ctx, graph.Kinds{edge.Kind}); err != nil { return err } else { - s.relationshipUpdateBatch.Add(startID, endID, edgeKindID) + s.relationshipUpdateBatch.Add(startID, endID, edgeKindIDs[0]) } s.keyToEdgeID[key] = edgeID diff --git a/go.mod b/go.mod index 58534cd..5d3d1cf 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/jedib0t/go-pretty/v6 v6.7.8 github.com/kanmu/go-sqlfmt v0.0.2-0.20200215095417-d1e63e2ee5eb github.com/mitchellh/go-wordwrap v1.0.1 + github.com/modelcontextprotocol/go-sdk v1.6.0 github.com/specterops/go-repl v1.0.1 golang.org/x/term v0.39.0 ) @@ -129,6 +130,7 @@ require ( github.com/golangci/swaggoswag v0.0.0-20250504205917-77f2aca3143e // indirect github.com/golangci/unconvert v0.0.0-20250410112200-a129a6e6413e // indirect github.com/google/go-cmp v0.7.0 // indirect + github.com/google/jsonschema-go v0.4.3 // indirect github.com/gordonklaus/ineffassign v0.2.0 // indirect github.com/gostaticanalysis/analysisutil v0.7.1 // indirect github.com/gostaticanalysis/comment v1.5.0 // indirect @@ -207,6 +209,8 @@ require ( github.com/sashamelentyev/interfacebloat v1.1.0 // indirect github.com/sashamelentyev/usestdlibvars v1.29.0 // indirect github.com/securego/gosec/v2 v2.24.8-0.20260309165252-619ce2117e08 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.4 // indirect github.com/sirupsen/logrus v1.9.4 // indirect github.com/sivchari/containedctx v1.0.3 // indirect github.com/sonatard/noctx v0.5.1 // indirect @@ -235,6 +239,7 @@ require ( github.com/yagipy/maintidx v1.0.0 // indirect github.com/yeya24/promlinter v0.3.0 // indirect github.com/ykadowak/zerologlint v0.1.5 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect gitlab.com/bosi/decorder v0.4.2 // indirect go-simpler.org/musttag v0.14.0 // indirect go-simpler.org/sloglint v0.11.1 // indirect @@ -246,6 +251,7 @@ require ( golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect golang.org/x/exp/typeparams v0.0.0-20260209203927-2842357ff358 // indirect golang.org/x/mod v0.34.0 // indirect + golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect @@ -262,5 +268,6 @@ tool ( github.com/fzipp/gocyclo/cmd/gocyclo github.com/golangci/golangci-lint/v2/cmd/golangci-lint github.com/specterops/dawgs/tools/dawgrun/cmd/dawgrun + github.com/specterops/dawgs/tools/dawgrun/cmd/dawgrun-mcp github.com/specterops/dawgs/tools/metrics/cmd/dawgs-metrics ) diff --git a/go.sum b/go.sum index ff24078..7c24d4a 100644 --- a/go.sum +++ b/go.sum @@ -284,6 +284,8 @@ github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -351,6 +353,8 @@ github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0= +github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -563,6 +567,8 @@ github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQ github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modelcontextprotocol/go-sdk v1.6.0 h1:PPLS3kn7WtOEnR+Af4X5H96SG0qSab8R/ZQT/HkhPkY= +github.com/modelcontextprotocol/go-sdk v1.6.0/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -672,6 +678,10 @@ github.com/sashamelentyev/usestdlibvars v1.29.0/go.mod h1:8PpnjHMk5VdeWlVb4wCdrB github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/securego/gosec/v2 v2.24.8-0.20260309165252-619ce2117e08 h1:AoLtJX4WUtZkhhUUMFy3GgecAALp/Mb4S1iyQOA2s0U= github.com/securego/gosec/v2 v2.24.8-0.20260309165252-619ce2117e08/go.mod h1:+XLCJiRE95ga77XInNELh2M6zQP+PdqiT9Zpm0D9Wpk= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0= +github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= @@ -761,6 +771,8 @@ github.com/yeya24/promlinter v0.3.0 h1:JVDbMp08lVCP7Y6NP3qHroGAO6z2yGKQtS5Jsjqto github.com/yeya24/promlinter v0.3.0/go.mod h1:cDfJQQYv9uYciW60QT0eeHlFodotkYZlL+YcPQN+mW4= github.com/ykadowak/zerologlint v0.1.5 h1:Gy/fMz1dFQN9JZTPjv1hxEk+sRWm05row04Yoolgdiw= github.com/ykadowak/zerologlint v0.1.5/go.mod h1:KaUskqF3e/v59oPmdq1U1DnKcuHokl2/K1U4pmIELKg= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/tools/dawgrun/README.md b/tools/dawgrun/README.md index c92676a..f0e1b25 100644 --- a/tools/dawgrun/README.md +++ b/tools/dawgrun/README.md @@ -66,6 +66,53 @@ connection names without an interactive `open` first. At any time, run `help` to list commands or `help ` for detailed usage, flag defaults, and description. +## MCP server + +`dawgrun-mcp` is a stdio MCP server that exposes a small dawgrun tool +surface to agent clients. It uses the official +`github.com/modelcontextprotocol/go-sdk` and keeps named backend +connections in memory for the lifetime of the MCP process. +It does not read or write dawgrun local config; use `open_connection` +to create MCP-session-local connections explicitly. + +Run it from a `DAWGS` checkout with: + + go tool dawgrun-mcp + +An MCP client can launch it with a config like: + +```json +{ + "mcpServers": { + "dawgrun": { + "command": "go", + "args": ["tool", "dawgrun-mcp"] + } + } +} +``` + +For OpenCode specifically, use the array-form local command: + +```json +{ + "$schema": "https://opencode.ai/config.json", + "mcp": { + "dawgrun": { + "type": "local", + "command": ["go", "tool", "dawgrun-mcp"], + "enabled": true + } + } +} +``` + +The tool set includes `list_connections`, `open_connection`, +`parse_cypher`, `translate_cypher_to_pgsql`, `query_cypher`, +`explain_psql`, `save_opengraph`, `load_db_kinds`, `lookup_kind`, and +`lookup_kind_id`. Write-capable tools, currently `load_opengraph` and +`copy_opengraph`, require `--allow-writes`. + ## Commands The REPL supports command-name completion with `Tab`; ambiguous matches render a transient popover list near the prompt and can be dismissed with `Esc`. diff --git a/tools/dawgrun/cmd/dawgrun-mcp/main.go b/tools/dawgrun/cmd/dawgrun-mcp/main.go new file mode 100644 index 0000000..a00b7d1 --- /dev/null +++ b/tools/dawgrun/cmd/dawgrun-mcp/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + + "github.com/davecgh/go-spew/spew" + + "github.com/specterops/dawgs/tools/dawgrun/pkg/commands" + "github.com/specterops/dawgs/tools/dawgrun/pkg/mcpserver" +) + +func main() { + spew.Config.DisablePointerAddresses = true + + var allowWrites bool + flag.BoolVar(&allowWrites, "allow-writes", false, "enable write-capable DAWGS tools") + flag.Parse() + + scope := commands.NewScope(commands.RunModeCLI) + server := mcpserver.NewWithScope(scope, mcpserver.Options{AllowWrites: allowWrites}) + if err := server.Run(context.Background()); err != nil { + fmt.Fprintf(os.Stderr, "dawgrun-mcp failed: %v\n", err) + os.Exit(1) + } +} diff --git a/tools/dawgrun/pkg/commands/cypher.go b/tools/dawgrun/pkg/commands/cypher.go index 2d2e05c..22938d1 100644 --- a/tools/dawgrun/pkg/commands/cypher.go +++ b/tools/dawgrun/pkg/commands/cypher.go @@ -21,6 +21,61 @@ const ( queryCypherOutputFormatJSON = "json" ) +type TranslateCypherOptions struct { + Connection string + DumpPGAst bool + SQLFormatDistance int +} + +type TranslateCypherOutput struct { + SQL string + PGAst string +} + +// TranslateCypherToPsql parses and translates a Cypher query to PostgreSQL SQL. +func TranslateCypherToPsql(ctx *CommandContext, queryText string, options TranslateCypherOptions) (TranslateCypherOutput, error) { + query, err := ParseQueryText(queryText) + if err != nil { + return TranslateCypherOutput{}, fmt.Errorf("error trying to parse query: %w", err) + } + + kindMapper := stubs.EmptyMapper() + if strings.TrimSpace(options.Connection) != "" { + kindMap, err := loadKindMap(ctx, options.Connection) + if err != nil { + return TranslateCypherOutput{}, fmt.Errorf("could not load kind map for translation: %w", err) + } + kindMapper = stubs.MapperFromKindMap(kindMap) + } + + result, err := translate.Translate(ctx, query, kindMapper, nil, defaultGraphID(ctx, options.Connection)) + if err != nil { + return TranslateCypherOutput{}, fmt.Errorf("could not translate cypher query to pgsql: %w", err) + } + + queryBuilder := format.NewOutputBuilder() + if result.Parameters != nil { + queryBuilder.WithMaterializedParameters(result.Parameters) + } + + sqlQuery, err := format.Statement(result.Statement, queryBuilder) + if err != nil { + return TranslateCypherOutput{}, fmt.Errorf("could not format translated statement into a string query: %w", err) + } + + formattedQuery, err := sqlfmt.Format(sqlQuery, &sqlfmt.Options{Distance: options.SQLFormatDistance}) + if err != nil { + formattedQuery = sqlQuery + } + + output := TranslateCypherOutput{SQL: formattedQuery} + if options.DumpPGAst { + output.PGAst = spew.Sdump(result.Statement) + } + + return output, nil +} + func parseCmd() CommandDesc { return CommandDesc{ args: []string{"<...query>"}, @@ -65,52 +120,21 @@ func translateToPsqlCmd() CommandDesc { } fields = flagSet.Args() - query, err := parseQueryArray(fields) + output, err := TranslateCypherToPsql(ctx, strings.Join(fields, " "), TranslateCypherOptions{ + Connection: kindMapperConnRef, + DumpPGAst: dumpTranslatedAst, + }) if err != nil { - return fmt.Errorf("error trying to parse query '%s': %w", fields, err) - } - - kindMapper := stubs.EmptyMapper() - if kindMapperConnRef != "" { - // Fetch kinds regardless of if it's already loaded. - kindMap, err := loadKindMap(ctx, kindMapperConnRef) - if err != nil { - return fmt.Errorf("could not load kind map for explain: %w", err) - } - kindMapper = stubs.MapperFromKindMap(kindMap) + return err } - result, err := translate.Translate(ctx, query, kindMapper, nil, defaultGraphID(ctx, kindMapperConnRef)) - if err != nil { - return fmt.Errorf("could not translate cypher query to pgsql: %w", err) - } - if dumpTranslatedAst { + if output.PGAst != "" { fmt.Fprintf(ctx.output, "TRANSLATOR AST\n\n") - ctx.output.WriteHighlighted(spew.Sdump(result.Statement), "golang") + ctx.output.WriteHighlighted(output.PGAst, "golang") fmt.Fprintf(ctx.output, "\n") } - // Certain queries will materialize parameters into the output when translated, so we need to build - // an OutputBuilder so we can carry forward those params. - queryBuilder := format.NewOutputBuilder() - if result.Parameters != nil { - queryBuilder.WithMaterializedParameters(result.Parameters) - } - - sqlQuery, err := format.Statement(result.Statement, queryBuilder) - if err != nil { - return fmt.Errorf("could not format translated statement into a string query: %w", err) - } - - formattedQuery, err := sqlfmt.Format(sqlQuery, &sqlfmt.Options{ - Distance: 0, - }) - if err != nil { - ctx.output.Warnf("could not format query: %s", err.Error()) - formattedQuery = sqlQuery - } - - ctx.output.WriteHighlighted(formattedQuery, "postgres") + ctx.output.WriteHighlighted(output.SQL, "postgres") return nil }, } @@ -133,44 +157,12 @@ func explainAsPsqlCmd() CommandDesc { return err } - // Fetch kinds regardless of if it's already loaded. - kindMap, err := loadKindMap(ctx, connName) - if err != nil { - return fmt.Errorf("could not load kind map for explain: %w", err) - } - - query, err := parseQueryArray(fields[1:]) - if err != nil { - return fmt.Errorf("could not parse query: %w", err) - } - - // Populate a DumbKindMapper from the database's kinds table - kindMapper := stubs.MapperFromKindMap(kindMap) - result, err := translate.Translate(ctx, query, kindMapper, nil, defaultGraphID(ctx, connName)) + translation, err := TranslateCypherToPsql(ctx, strings.Join(fields[1:], " "), TranslateCypherOptions{Connection: connName, SQLFormatDistance: 2}) if err != nil { - return fmt.Errorf("could not translate cypher query to pgsql: %w", err) - } - - // Certain queries will materialize parameters into the output when translated, so we need to build - // an OutputBuilder so we can carry forward those params. - queryBuilder := format.NewOutputBuilder() - if result.Parameters != nil { - queryBuilder.WithMaterializedParameters(result.Parameters) - } - - sqlQuery, err := format.Statement(result.Statement, queryBuilder) - if err != nil { - return fmt.Errorf("could not format translated statement into a string query: %w", err) + return err } - formattedQuery, err := sqlfmt.Format(sqlQuery, &sqlfmt.Options{ - Distance: 2, - }) - if err != nil { - ctx.output.Warnf("could not format query: %s", err.Error()) - formattedQuery = sqlQuery - } - explainSQLQuery := fmt.Sprintf("EXPLAIN %s", formattedQuery) + explainSQLQuery := fmt.Sprintf("EXPLAIN %s", translation.SQL) ctx.output.WriteHighlighted(explainSQLQuery, "postgres") fmt.Fprint(ctx.output, "\n\n") diff --git a/tools/dawgrun/pkg/commands/db.go b/tools/dawgrun/pkg/commands/db.go index 895da3b..71fcee3 100644 --- a/tools/dawgrun/pkg/commands/db.go +++ b/tools/dawgrun/pkg/commands/db.go @@ -307,7 +307,6 @@ func lookupKindCmd() CommandDesc { kindMap, ok := ctx.scope.connKindMaps[connName] if !ok { - // Try to fetch the kind map if the connection is open var err error kindMap, err = loadKindMap(ctx, connName) if err != nil { @@ -347,7 +346,6 @@ func lookupKindIDCmd() CommandDesc { kindMap, ok := ctx.scope.connKindMaps[connName] if !ok { - // Try to fetch the kind map if the connection is open var err error kindMap, err = loadKindMap(ctx, connName) if err != nil { diff --git a/tools/dawgrun/pkg/commands/helpers.go b/tools/dawgrun/pkg/commands/helpers.go index cbc4222..e6df550 100644 --- a/tools/dawgrun/pkg/commands/helpers.go +++ b/tools/dawgrun/pkg/commands/helpers.go @@ -8,6 +8,11 @@ import ( ) func parseQueryArray(fields []string) (*cypherModels.RegularQuery, error) { + return ParseQueryText(strings.Join(fields, " ")) +} + +// ParseQueryText parses a Cypher query string using dawgrun's default parser settings. +func ParseQueryText(query string) (*cypherModels.RegularQuery, error) { cypherCtx := cypherFrontend.DefaultCypherContext() - return cypherFrontend.ParseCypher(cypherCtx, strings.Join(fields, " ")) + return cypherFrontend.ParseCypher(cypherCtx, query) } diff --git a/tools/dawgrun/pkg/mcpserver/cypher.go b/tools/dawgrun/pkg/mcpserver/cypher.go new file mode 100644 index 0000000..8406c4d --- /dev/null +++ b/tools/dawgrun/pkg/mcpserver/cypher.go @@ -0,0 +1,43 @@ +package mcpserver + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type QueryCypherInput struct { + Connection string `json:"connection" jsonschema:"open connection name"` + Query string `json:"query" jsonschema:"CySQL/Cypher query to execute"` +} + +type QueryCypherOutput struct { + Output string `json:"output"` +} + +type ExplainPsqlInput struct { + Connection string `json:"connection" jsonschema:"open PostgreSQL connection name"` + Query string `json:"query" jsonschema:"CySQL/Cypher query to explain"` +} + +type ExplainPsqlOutput struct { + Output string `json:"output"` +} + +func (s *Server) queryCypher(ctx context.Context, _ *mcp.CallToolRequest, input QueryCypherInput) (*mcp.CallToolResult, QueryCypherOutput, error) { + output, err := s.runCommand(ctx, "query-cypher", []string{"-format", "json", input.Connection, input.Query}) + if err != nil { + return nil, QueryCypherOutput{}, err + } + + return nil, QueryCypherOutput{Output: output}, nil +} + +func (s *Server) explainPsql(ctx context.Context, _ *mcp.CallToolRequest, input ExplainPsqlInput) (*mcp.CallToolResult, ExplainPsqlOutput, error) { + output, err := s.runCommand(ctx, "explain-psql", []string{input.Connection, input.Query}) + if err != nil { + return nil, ExplainPsqlOutput{}, err + } + + return nil, ExplainPsqlOutput{Output: output}, nil +} diff --git a/tools/dawgrun/pkg/mcpserver/kinds.go b/tools/dawgrun/pkg/mcpserver/kinds.go new file mode 100644 index 0000000..ece5897 --- /dev/null +++ b/tools/dawgrun/pkg/mcpserver/kinds.go @@ -0,0 +1,61 @@ +package mcpserver + +import ( + "context" + "strconv" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type LoadDBKindsInput struct { + Connection string `json:"connection" jsonschema:"open PostgreSQL connection name"` +} + +type LoadDBKindsOutput struct { + Output string `json:"output"` +} + +type LookupKindInput struct { + Connection string `json:"connection" jsonschema:"open PostgreSQL connection name"` + Kind string `json:"kind" jsonschema:"kind name"` +} + +type LookupKindOutput struct { + Output string `json:"output"` +} + +type LookupKindIDInput struct { + Connection string `json:"connection" jsonschema:"open PostgreSQL connection name"` + KindID int16 `json:"kind_id" jsonschema:"kind ID"` +} + +type LookupKindIDOutput struct { + Output string `json:"output"` +} + +func (s *Server) loadDBKinds(ctx context.Context, _ *mcp.CallToolRequest, input LoadDBKindsInput) (*mcp.CallToolResult, LoadDBKindsOutput, error) { + output, err := s.runCommand(ctx, "load-db-kinds", []string{input.Connection}) + if err != nil { + return nil, LoadDBKindsOutput{}, err + } + + return nil, LoadDBKindsOutput{Output: output}, nil +} + +func (s *Server) lookupKind(ctx context.Context, _ *mcp.CallToolRequest, input LookupKindInput) (*mcp.CallToolResult, LookupKindOutput, error) { + output, err := s.runCommand(ctx, "lookup-kind", []string{input.Connection, input.Kind}) + if err != nil { + return nil, LookupKindOutput{}, err + } + + return nil, LookupKindOutput{Output: output}, nil +} + +func (s *Server) lookupKindID(ctx context.Context, _ *mcp.CallToolRequest, input LookupKindIDInput) (*mcp.CallToolResult, LookupKindIDOutput, error) { + output, err := s.runCommand(ctx, "lookup-kind-id", []string{input.Connection, strconv.FormatInt(int64(input.KindID), 10)}) + if err != nil { + return nil, LookupKindIDOutput{}, err + } + + return nil, LookupKindIDOutput{Output: output}, nil +} diff --git a/tools/dawgrun/pkg/mcpserver/opengraph.go b/tools/dawgrun/pkg/mcpserver/opengraph.go new file mode 100644 index 0000000..7058287 --- /dev/null +++ b/tools/dawgrun/pkg/mcpserver/opengraph.go @@ -0,0 +1,78 @@ +package mcpserver + +import ( + "context" + "fmt" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type SaveOpenGraphInput struct { + Connection string `json:"connection" jsonschema:"open connection name"` + OutputPath string `json:"output_path,omitempty" jsonschema:"optional file path for OpenGraph JSON output"` +} + +type SaveOpenGraphOutput struct { + Output string `json:"output"` +} + +type LoadOpenGraphInput struct { + Connection string `json:"connection" jsonschema:"open connection name"` + InputPath string `json:"input_path" jsonschema:"OpenGraph JSON file path to load"` +} + +type LoadOpenGraphOutput struct { + Output string `json:"output"` +} + +type CopyOpenGraphInput struct { + FromConnection string `json:"from_connection" jsonschema:"source open connection name"` + ToConnection string `json:"to_connection" jsonschema:"destination open connection name"` +} + +type CopyOpenGraphOutput struct { + Output string `json:"output"` +} + +func (s *Server) saveOpenGraph(ctx context.Context, _ *mcp.CallToolRequest, input SaveOpenGraphInput) (*mcp.CallToolResult, SaveOpenGraphOutput, error) { + args := []string{} + outputPath := strings.TrimSpace(input.OutputPath) + if outputPath != "" { + args = append(args, "-out", outputPath) + } + args = append(args, input.Connection) + + output, err := s.runCommand(ctx, "save-opengraph", args) + if err != nil { + return nil, SaveOpenGraphOutput{}, err + } + + return nil, SaveOpenGraphOutput{Output: output}, nil +} + +func (s *Server) loadOpenGraph(ctx context.Context, _ *mcp.CallToolRequest, input LoadOpenGraphInput) (*mcp.CallToolResult, LoadOpenGraphOutput, error) { + if !s.allowWrites { + return nil, LoadOpenGraphOutput{}, fmt.Errorf("load_opengraph requires --allow-writes") + } + + output, err := s.runCommand(ctx, "load-opengraph", []string{input.Connection, input.InputPath}) + if err != nil { + return nil, LoadOpenGraphOutput{}, err + } + + return nil, LoadOpenGraphOutput{Output: output}, nil +} + +func (s *Server) copyOpenGraph(ctx context.Context, _ *mcp.CallToolRequest, input CopyOpenGraphInput) (*mcp.CallToolResult, CopyOpenGraphOutput, error) { + if !s.allowWrites { + return nil, CopyOpenGraphOutput{}, fmt.Errorf("copy_opengraph requires --allow-writes") + } + + output, err := s.runCommand(ctx, "copy-opengraph", []string{input.FromConnection, input.ToConnection}) + if err != nil { + return nil, CopyOpenGraphOutput{}, err + } + + return nil, CopyOpenGraphOutput{Output: output}, nil +} diff --git a/tools/dawgrun/pkg/mcpserver/server.go b/tools/dawgrun/pkg/mcpserver/server.go new file mode 100644 index 0000000..3081eb2 --- /dev/null +++ b/tools/dawgrun/pkg/mcpserver/server.go @@ -0,0 +1,196 @@ +package mcpserver + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/davecgh/go-spew/spew" + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/specterops/dawgs/tools/dawgrun/pkg/commands" +) + +type Options struct { + AllowWrites bool +} + +type Server struct { + server *mcp.Server + scope *commands.Scope + commandMu sync.Mutex + allowWrites bool +} + +type ListConnectionsInput struct{} + +type ListConnectionsOutput struct { + Open []string `json:"open"` +} + +type OpenConnectionInput struct { + Name string `json:"name" jsonschema:"session-local connection name"` + ConnectionString string `json:"connection_string" jsonschema:"PostgreSQL or Neo4j connection string; treated as sensitive"` + Driver string `json:"driver,omitempty" jsonschema:"optional driver override: pg or neo4j"` + DefaultGraph string `json:"default_graph,omitempty" jsonschema:"default graph name; defaults to default"` + InitGraph bool `json:"init_graph,omitempty" jsonschema:"create default graph if selecting it fails"` +} + +type OpenConnectionOutput struct { + Name string `json:"name"` +} + +type ParseCypherInput struct { + Query string `json:"query" jsonschema:"CySQL/Cypher query to parse"` +} + +type ParseCypherOutput struct { + AST string `json:"ast"` +} + +type TranslateCypherInput struct { + Query string `json:"query" jsonschema:"CySQL/Cypher query to translate"` + Connection string `json:"connection,omitempty" jsonschema:"optional open connection name for live kind mapping"` + DumpPGAst bool `json:"dump_pg_ast,omitempty" jsonschema:"include PostgreSQL AST dump"` +} + +type TranslateCypherOutput struct { + SQL string `json:"sql"` + PGAst string `json:"pg_ast,omitempty"` +} + +func New(options Options) *Server { + return NewWithScope(commands.NewScope(commands.RunModeCLI), options) +} + +func NewWithScope(scope *commands.Scope, options Options) *Server { + s := &Server{ + server: mcp.NewServer(&mcp.Implementation{ + Name: "dawgrun-mcp", + Version: "0.1.0", + }, nil), + scope: scope, + allowWrites: options.AllowWrites, + } + s.registerTools() + + return s +} + +func (s *Server) Run(ctx context.Context) error { + defer s.scope.CloseConnections(context.Background()) + return s.server.Run(ctx, &mcp.StdioTransport{}) +} + +func (s *Server) MCPServer() *mcp.Server { + return s.server +} + +func (s *Server) commandContext(ctx context.Context) *commands.CommandContext { + cmdCtx := commands.NewCommandContext(ctx, nil, s.scope, "") + cmdCtx.SetStyledOutputEnabled(false) + return cmdCtx +} + +func (s *Server) runCommand(ctx context.Context, name string, args []string) (string, error) { + s.commandMu.Lock() + defer s.commandMu.Unlock() + + cmd, ok := commands.Registry()[name] + if !ok { + return "", fmt.Errorf("unknown command %s", name) + } + + cmdCtx := s.commandContext(ctx) + if cmd.ClearFlagsFn != nil { + defer cmd.ClearFlagsFn() + } + if err := cmd.Fn(cmdCtx, args); err != nil { + return "", err + } + + return cmdCtx.OutputString(), nil +} + +func (s *Server) registerTools() { + mcp.AddTool(s.server, readOnlyTool("list_connections", "List Connections", "List open dawgrun connection names without exposing connection strings."), s.listConnections) + mcp.AddTool(s.server, readOnlyTool("open_connection", "Open Connection", "Open a named DAWGS backend connection. Connection strings are treated as sensitive and are not echoed."), s.openConnection) + mcp.AddTool(s.server, readOnlyTool("parse_cypher", "Parse Cypher", "Parse a CySQL/Cypher query and return an AST dump."), s.parseCypher) + mcp.AddTool(s.server, readOnlyTool("translate_cypher_to_pgsql", "Translate Cypher To PGSQL", "Translate a CySQL/Cypher query to PostgreSQL SQL."), s.translateCypher) + mcp.AddTool(s.server, readOnlyTool("query_cypher", "Query Cypher", "Execute a CySQL/Cypher query against a named DAWGS connection and return bounded rows."), s.queryCypher) + mcp.AddTool(s.server, readOnlyTool("explain_psql", "Explain PGSQL", "Translate a CySQL/Cypher query and return the PostgreSQL EXPLAIN plan."), s.explainPsql) + mcp.AddTool(s.server, readOnlyTool("save_opengraph", "Save OpenGraph", "Export OpenGraph JSON from a named connection."), s.saveOpenGraph) + mcp.AddTool(s.server, writeTool("load_opengraph", "Load OpenGraph", "Load OpenGraph JSON into a named connection. Requires --allow-writes."), s.loadOpenGraph) + mcp.AddTool(s.server, writeTool("copy_opengraph", "Copy OpenGraph", "Copy OpenGraph data between named connections. Requires --allow-writes."), s.copyOpenGraph) + mcp.AddTool(s.server, readOnlyTool("load_db_kinds", "Load DB Kinds", "Load and return the kind mapping from a named PostgreSQL connection."), s.loadDBKinds) + mcp.AddTool(s.server, readOnlyTool("lookup_kind", "Lookup Kind", "Resolve a kind name to its PostgreSQL kind ID."), s.lookupKind) + mcp.AddTool(s.server, readOnlyTool("lookup_kind_id", "Lookup Kind ID", "Resolve a PostgreSQL kind ID to its kind name."), s.lookupKindID) +} + +func readOnlyTool(name, title, description string) *mcp.Tool { + closedWorld := false + return &mcp.Tool{ + Name: name, + Title: title, + Description: description, + Annotations: &mcp.ToolAnnotations{ReadOnlyHint: true, OpenWorldHint: &closedWorld}, + } +} + +func writeTool(name, title, description string) *mcp.Tool { + destructive := true + closedWorld := false + return &mcp.Tool{ + Name: name, + Title: title, + Description: description, + Annotations: &mcp.ToolAnnotations{DestructiveHint: &destructive, OpenWorldHint: &closedWorld}, + } +} + +func (s *Server) listConnections(context.Context, *mcp.CallToolRequest, ListConnectionsInput) (*mcp.CallToolResult, ListConnectionsOutput, error) { + return nil, ListConnectionsOutput{Open: s.scope.GetConnectionNames()}, nil +} + +func (s *Server) openConnection(ctx context.Context, _ *mcp.CallToolRequest, input OpenConnectionInput) (*mcp.CallToolResult, OpenConnectionOutput, error) { + name := strings.TrimSpace(input.Name) + args := []string{} + if driver := strings.TrimSpace(input.Driver); driver != "" { + args = append(args, "-driver", driver) + } + if defaultGraph := strings.TrimSpace(input.DefaultGraph); defaultGraph != "" { + args = append(args, "-default-graph", defaultGraph) + } + if input.InitGraph { + args = append(args, "-init-graph") + } + args = append(args, name, input.ConnectionString) + + if _, err := s.runCommand(ctx, "open", args); err != nil { + return nil, OpenConnectionOutput{}, err + } + + return nil, OpenConnectionOutput{Name: name}, nil +} + +func (s *Server) parseCypher(_ context.Context, _ *mcp.CallToolRequest, input ParseCypherInput) (*mcp.CallToolResult, ParseCypherOutput, error) { + query, err := commands.ParseQueryText(input.Query) + if err != nil { + return nil, ParseCypherOutput{}, err + } + + return nil, ParseCypherOutput{AST: spew.Sdump(query)}, nil +} + +func (s *Server) translateCypher(ctx context.Context, _ *mcp.CallToolRequest, input TranslateCypherInput) (*mcp.CallToolResult, TranslateCypherOutput, error) { + output, err := commands.TranslateCypherToPsql(s.commandContext(ctx), input.Query, commands.TranslateCypherOptions{ + Connection: input.Connection, + DumpPGAst: input.DumpPGAst, + }) + if err != nil { + return nil, TranslateCypherOutput{}, err + } + + return nil, TranslateCypherOutput{SQL: output.SQL, PGAst: output.PGAst}, nil +} diff --git a/tools/dawgrun/pkg/mcpserver/server_test.go b/tools/dawgrun/pkg/mcpserver/server_test.go new file mode 100644 index 0000000..99deed3 --- /dev/null +++ b/tools/dawgrun/pkg/mcpserver/server_test.go @@ -0,0 +1,92 @@ +package mcpserver + +import ( + "context" + "encoding/json" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/require" +) + +func TestServerAdvertisesExpectedTools(t *testing.T) { + ctx := context.Background() + session := connectTestServer(t, ctx, New(Options{})) + + tools := make(map[string]bool) + for tool, err := range session.Tools(ctx, nil) { + require.NoError(t, err) + tools[tool.Name] = true + } + + require.Equal(t, map[string]bool{ + "list_connections": true, + "open_connection": true, + "parse_cypher": true, + "translate_cypher_to_pgsql": true, + "query_cypher": true, + "explain_psql": true, + "save_opengraph": true, + "load_opengraph": true, + "copy_opengraph": true, + "load_db_kinds": true, + "lookup_kind": true, + "lookup_kind_id": true, + }, tools) +} + +func TestListConnectionsReturnsOpenNamesOnly(t *testing.T) { + ctx := context.Background() + session := connectTestServer(t, ctx, New(Options{})) + + result, err := session.CallTool(ctx, &mcp.CallToolParams{Name: "list_connections"}) + require.NoError(t, err) + require.False(t, result.IsError) + + output := decodeToolOutput[ListConnectionsOutput](t, result) + require.Empty(t, output.Open) +} + +func TestParseCypherReturnsStructuredOutput(t *testing.T) { + ctx := context.Background() + session := connectTestServer(t, ctx, New(Options{})) + + result, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "parse_cypher", + Arguments: map[string]any{ + "query": "MATCH (n) RETURN n", + }, + }) + require.NoError(t, err) + require.False(t, result.IsError) + + output := decodeToolOutput[ParseCypherOutput](t, result) + require.Contains(t, output.AST, "RegularQuery") +} + +func connectTestServer(t *testing.T, ctx context.Context, server *Server) *mcp.ClientSession { + t.Helper() + + serverTransport, clientTransport := mcp.NewInMemoryTransports() + serverSession, err := server.MCPServer().Connect(ctx, serverTransport, nil) + require.NoError(t, err) + t.Cleanup(func() { serverSession.Close() }) + + client := mcp.NewClient(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil) + clientSession, err := client.Connect(ctx, clientTransport, nil) + require.NoError(t, err) + t.Cleanup(func() { clientSession.Close() }) + + return clientSession +} + +func decodeToolOutput[T any](t *testing.T, result *mcp.CallToolResult) T { + t.Helper() + + var output T + data, err := json.Marshal(result.StructuredContent) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(data, &output)) + + return output +}