diff --git a/crates/bashkit-cli/src/main.rs b/crates/bashkit-cli/src/main.rs index 6b1ec0a4..a23853aa 100644 --- a/crates/bashkit-cli/src/main.rs +++ b/crates/bashkit-cli/src/main.rs @@ -185,7 +185,7 @@ fn main() -> Result<()> { let args = Args::parse(); match cli_mode(&args) { - CliMode::Mcp => run_mcp(), + CliMode::Mcp => run_mcp(args), CliMode::Command | CliMode::Script => { let output = run_oneshot(args)?; print!("{}", output.stdout); @@ -202,12 +202,12 @@ fn main() -> Result<()> { } } -fn run_mcp() -> Result<()> { +fn run_mcp(args: Args) -> Result<()> { Builder::new_multi_thread() .enable_all() .build() .context("Failed to build MCP runtime")? - .block_on(mcp::run()) + .block_on(mcp::run(move || build_bash(&args))) } fn run_oneshot(args: Args) -> Result { diff --git a/crates/bashkit-cli/src/mcp.rs b/crates/bashkit-cli/src/mcp.rs index df9a2a92..468f306b 100644 --- a/crates/bashkit-cli/src/mcp.rs +++ b/crates/bashkit-cli/src/mcp.rs @@ -120,15 +120,23 @@ struct ContentItem { } /// MCP server with optional ScriptedTool registrations. +/// +/// Accepts a factory function that produces configured `Bash` instances, +/// ensuring CLI execution limits (max_commands, etc.) are applied to every +/// MCP `tools/call` invocation. pub struct McpServer { + bash_factory: Box bashkit::Bash + Send>, #[cfg(feature = "scripted_tool")] scripted_tools: Vec, } impl McpServer { /// Create a new MCP server with only the default `bash` tool. - pub fn new() -> Self { + /// Each `tools/call` will create a `Bash` via the provided factory, + /// inheriting whatever limits/configuration the caller sets up. + pub fn new(bash_factory: impl Fn() -> bashkit::Bash + Send + 'static) -> Self { Self { + bash_factory: Box::new(bash_factory), #[cfg(feature = "scripted_tool")] scripted_tools: Vec::new(), } @@ -271,7 +279,7 @@ impl McpServer { } }; - let mut bash = bashkit::Bash::new(); + let mut bash = (self.bash_factory)(); let result = match bash.exec(&args.script).await { Ok(r) => r, Err(e) => { @@ -358,9 +366,9 @@ impl McpServer { } } -/// Run the MCP server (backward-compatible entry point). -pub async fn run() -> Result<()> { - let mut server = McpServer::new(); +/// Run the MCP server with a factory that produces configured `Bash` instances. +pub async fn run(bash_factory: impl Fn() -> bashkit::Bash + Send + 'static) -> Result<()> { + let mut server = McpServer::new(bash_factory); server.run().await } @@ -370,7 +378,7 @@ mod tests { #[tokio::test] async fn test_initialize() { - let mut server = McpServer::new(); + let mut server = McpServer::new(bashkit::Bash::new); let req = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: serde_json::json!(1), @@ -385,7 +393,7 @@ mod tests { #[tokio::test] async fn test_tools_list_default() { - let mut server = McpServer::new(); + let mut server = McpServer::new(bashkit::Bash::new); let req = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: serde_json::json!(1), @@ -400,7 +408,7 @@ mod tests { #[tokio::test] async fn test_tools_call_bash() { - let mut server = McpServer::new(); + let mut server = McpServer::new(bashkit::Bash::new); let req = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: serde_json::json!(1), @@ -418,7 +426,7 @@ mod tests { #[tokio::test] async fn test_tools_call_unknown() { - let mut server = McpServer::new(); + let mut server = McpServer::new(bashkit::Bash::new); let req = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: serde_json::json!(1), @@ -434,7 +442,7 @@ mod tests { #[tokio::test] async fn test_method_not_found() { - let mut server = McpServer::new(); + let mut server = McpServer::new(bashkit::Bash::new); let req = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: serde_json::json!(1), @@ -446,6 +454,35 @@ mod tests { assert_eq!(resp.error.expect("error").code, -32601); } + #[tokio::test] + async fn test_tools_call_respects_max_commands() { + // Factory that creates a Bash with max_commands=2 + let mut server = McpServer::new(|| { + bashkit::Bash::builder() + .limits(bashkit::ExecutionLimits::new().max_commands(2)) + .build() + }); + + // Script with 3 commands should hit the limit + let req = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "tools/call".to_string(), + params: serde_json::json!({ + "name": "bash", + "arguments": { "script": "echo a; echo b; echo c" } + }), + }; + let resp = server.handle_request(req).await; + let result = resp.result.expect("should have result"); + let text = result["content"][0]["text"].as_str().expect("text"); + // Should report the limit was exceeded + assert!( + text.contains("limit") || text.contains("exceeded") || result["isError"] == true, + "expected execution limit error, got: {text}" + ); + } + #[cfg(feature = "scripted_tool")] mod scripted_tool_tests { use super::*; @@ -463,7 +500,7 @@ mod tests { #[tokio::test] async fn test_tools_list_includes_scripted_tool() { - let mut server = McpServer::new(); + let mut server = McpServer::new(bashkit::Bash::new); server.register_scripted_tool(make_test_tool()); let req = JsonRpcRequest { @@ -481,7 +518,7 @@ mod tests { #[tokio::test] async fn test_tools_call_scripted_tool() { - let mut server = McpServer::new(); + let mut server = McpServer::new(bashkit::Bash::new); server.register_scripted_tool(make_test_tool()); let req = JsonRpcRequest { @@ -501,7 +538,7 @@ mod tests { #[tokio::test] async fn test_tools_call_scripted_tool_error() { - let mut server = McpServer::new(); + let mut server = McpServer::new(bashkit::Bash::new); let tool = ScriptedTool::builder("err_api") .short_description("Error API") .tool(ToolDef::new("fail", "Always fails"), |_args: &ToolArgs| { @@ -526,7 +563,7 @@ mod tests { #[tokio::test] async fn test_full_jsonrpc_roundtrip() { - let mut server = McpServer::new(); + let mut server = McpServer::new(bashkit::Bash::new); server.register_scripted_tool(make_test_tool()); // Step 1: initialize