Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions pkg/gateway/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"fmt"
"os"
"slices"
"sort"
"strings"
"time"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/docker/mcp-gateway/pkg/docker"
"github.com/docker/mcp-gateway/pkg/log"
"github.com/docker/mcp-gateway/pkg/oci"
"github.com/docker/mcp-gateway/pkg/policy"
)

type Configurator interface {
Expand Down Expand Up @@ -90,6 +92,74 @@ func (c *Configuration) Find(serverName string) (*catalog.ServerConfig, *map[str
return nil, &byName, true
}

// FilterByPolicy removes servers and tools that are denied by the policy client.
func (c *Configuration) FilterByPolicy(ctx context.Context, pc policy.Client) error {
if pc == nil {
return nil
}

filteredServers := make(map[string]catalog.Server)
filteredServerNames := make([]string, 0, len(c.serverNames))
filteredConfig := make(map[string]map[string]any)
filteredTools := config.ToolsConfig{
ServerTools: make(map[string][]string),
}

for _, name := range c.serverNames {
decision, err := pc.Evaluate(ctx, policy.Request{
Server: name,
Action: policy.ActionLoad,
})
if err != nil {
log.Logf("policy check failed for server %s: %v (denying)", name, err)
}
if decision.Allowed && err == nil {
server := c.servers[name]

// Filter tools for this server if any.
if tools, ok := c.tools.ServerTools[name]; ok {
for _, t := range tools {
toolDecision, tErr := pc.Evaluate(ctx, policy.Request{
Server: name,
Tool: t,
Action: policy.ActionLoad,
})
if tErr != nil {
log.Logf("policy check failed for tool %s/%s: %v (denying)", name, t, tErr)
}
if toolDecision.Allowed && tErr == nil {
filteredTools.ServerTools[name] = append(filteredTools.ServerTools[name], t)
}
}
// Also trim catalog.Tools slice if present.
if len(server.Tools) > 0 {
var kept []catalog.Tool
for _, tool := range server.Tools {
if slices.Contains(filteredTools.ServerTools[name], tool.Name) {
kept = append(kept, tool)
}
}
server.Tools = kept
}
}

filteredServers[name] = server
filteredServerNames = append(filteredServerNames, name)
canon := oci.CanonicalizeServerName(name)
if cfg, ok := c.config[canon]; ok {
filteredConfig[canon] = cfg
}
}
}

c.serverNames = filteredServerNames
c.servers = filteredServers
c.config = filteredConfig
c.tools = filteredTools
// c.secrets unchanged
return nil
}

type FileBasedConfiguration struct {
CatalogPath []string
ServerNames []string // Takes precedence over the RegistryPath
Expand Down
30 changes: 30 additions & 0 deletions pkg/gateway/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"go.opentelemetry.io/otel/metric"

"github.com/docker/mcp-gateway/pkg/catalog"
"github.com/docker/mcp-gateway/pkg/policy"
"github.com/docker/mcp-gateway/pkg/telemetry"
)

Expand Down Expand Up @@ -65,6 +66,21 @@ func (g *Gateway) mcpServerToolHandler(serverName string, server *mcp.Server, an
return nil, fmt.Errorf("server %q not found in configuration", serverName)
}

if g.policyClient != nil {
decision, err := g.policyClient.Evaluate(ctx, policy.Request{
Server: serverConfig.Name,
Tool: originalToolName,
Action: policy.ActionInvoke,
})
if err != nil {
telemetry.RecordToolError(ctx, nil, serverConfig.Name, inferServerType(serverConfig), req.Params.Name)
return nil, fmt.Errorf("policy check failed for %s/%s: %w", serverConfig.Name, originalToolName, err)
}
if !decision.Allowed {
return nil, fmt.Errorf("policy denied tool %s on server %s: %s", originalToolName, serverConfig.Name, decision.Reason)
}
}

// Debug logging to stderr
if os.Getenv("DOCKER_MCP_TELEMETRY_DEBUG") != "" {
fmt.Fprintf(os.Stderr, "[MCP-HANDLER] Tool call received: %s from server: %s\n", req.Params.Name, serverConfig.Name)
Expand Down Expand Up @@ -166,6 +182,20 @@ func (g *Gateway) mcpServerPromptHandler(serverName string, server *mcp.Server)
return nil, fmt.Errorf("server %q not found in configuration", serverName)
}

if g.policyClient != nil {
decision, err := g.policyClient.Evaluate(ctx, policy.Request{
Server: serverConfig.Name,
Tool: req.Params.Name,
Action: policy.ActionPrompt,
})
if err != nil {
return nil, fmt.Errorf("policy check failed for prompt %s on server %s: %w", req.Params.Name, serverConfig.Name, err)
}
if !decision.Allowed {
return nil, fmt.Errorf("policy denied prompt %s on server %s: %s", req.Params.Name, serverConfig.Name, decision.Reason)
}
}

// Debug logging to stderr
if os.Getenv("DOCKER_MCP_TELEMETRY_DEBUG") != "" {
fmt.Fprintf(os.Stderr, "[MCP-HANDLER] Prompt get received: %s from server: %s\n", req.Params.Name, serverConfig.Name)
Expand Down
24 changes: 24 additions & 0 deletions pkg/gateway/mcpadd.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/docker/mcp-gateway/pkg/log"
"github.com/docker/mcp-gateway/pkg/oauth"
"github.com/docker/mcp-gateway/pkg/oci"
"github.com/docker/mcp-gateway/pkg/policy"
)

func addServerHandler(g *Gateway, clientConfig *clientConfig) mcp.ToolHandler {
Expand Down Expand Up @@ -58,6 +59,29 @@ func addServerHandler(g *Gateway, clientConfig *clientConfig) mcp.ToolHandler {
}, nil
}

// Check if server is allowed by policy before adding
if g.policyClient != nil {
decision, err := g.policyClient.Evaluate(ctx, policy.Request{
Server: serverName,
Action: policy.ActionLoad,
})
if err != nil {
log.Logf("policy check failed for server %s: %v (denying)", serverName, err)
return &mcp.CallToolResult{
Content: []mcp.Content{&mcp.TextContent{
Text: fmt.Sprintf("Error: Server '%s' blocked by policy check error: %v", serverName, err),
}},
}, nil
}
if !decision.Allowed {
return &mcp.CallToolResult{
Content: []mcp.Content{&mcp.TextContent{
Text: fmt.Sprintf("Error: Server '%s' is blocked by policy: %s", serverName, decision.Reason),
}},
}, nil
}
}

// Append the new server to the current serverNames if not already present
found = slices.Contains(g.configuration.serverNames, serverName)
if !found {
Expand Down
25 changes: 25 additions & 0 deletions pkg/gateway/mcpexec.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"

"github.com/docker/mcp-gateway/pkg/log"
"github.com/docker/mcp-gateway/pkg/policy"
)

func addMcpExecHandler(g *Gateway) mcp.ToolHandler {
Expand Down Expand Up @@ -78,6 +79,30 @@ func addMcpExecHandler(g *Gateway) mcp.ToolHandler {
Extra: req.Extra,
}

// Policy check before executing the tool
if g.policyClient != nil {
decision, err := g.policyClient.Evaluate(ctx, policy.Request{
Server: toolReg.ServerName,
Tool: toolName,
Action: policy.ActionInvoke,
})
if err != nil {
log.Logf("policy check failed for mcp-exec %s: %v (denying)", toolName, err)
return &mcp.CallToolResult{
Content: []mcp.Content{&mcp.TextContent{
Text: fmt.Sprintf("Error: Tool '%s' blocked due to policy check error: %v", toolName, err),
}},
}, nil
}
if !decision.Allowed {
return &mcp.CallToolResult{
Content: []mcp.Content{&mcp.TextContent{
Text: fmt.Sprintf("Error: Tool '%s' blocked by policy: %s", toolName, decision.Reason),
}},
}, nil
}
}

// Execute the tool using its registered handler
result, err := toolReg.Handler(ctx, toolCallRequest)
if err != nil {
Expand Down
16 changes: 16 additions & 0 deletions pkg/gateway/policy_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package gateway

import (
"os"

"github.com/docker/mcp-gateway/pkg/desktop"
"github.com/docker/mcp-gateway/pkg/policy"
)

func newPolicyClient() policy.Client {
paths := desktop.Paths()
if _, err := os.Stat(paths.BackendSocket); err != nil {
return policy.NoopClient{}
}
return policy.NewDesktopClient()
}
Loading