Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/bashkit-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ fn main() -> Result<()> {
let args = Args::parse();

match cli_mode(&args) {
CliMode::Mcp => run_mcp(),
CliMode::Mcp => run_mcp(args),
CliMode::Command | CliMode::Script => {
let output = run_oneshot(args)?;
print!("{}", output.stdout);
Expand All @@ -202,12 +202,12 @@ fn main() -> Result<()> {
}
}

fn run_mcp() -> Result<()> {
fn run_mcp(args: Args) -> Result<()> {
Builder::new_multi_thread()
.enable_all()
.build()
.context("Failed to build MCP runtime")?
.block_on(mcp::run())
.block_on(mcp::run(move || build_bash(&args)))
}

fn run_oneshot(args: Args) -> Result<RunOutput> {
Expand Down
65 changes: 51 additions & 14 deletions crates/bashkit-cli/src/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,23 @@ struct ContentItem {
}

/// MCP server with optional ScriptedTool registrations.
///
/// Accepts a factory function that produces configured `Bash` instances,
/// ensuring CLI execution limits (max_commands, etc.) are applied to every
/// MCP `tools/call` invocation.
pub struct McpServer {
bash_factory: Box<dyn Fn() -> bashkit::Bash + Send>,
#[cfg(feature = "scripted_tool")]
scripted_tools: Vec<bashkit::ScriptedTool>,
}

impl McpServer {
/// Create a new MCP server with only the default `bash` tool.
pub fn new() -> Self {
/// Each `tools/call` will create a `Bash` via the provided factory,
/// inheriting whatever limits/configuration the caller sets up.
pub fn new(bash_factory: impl Fn() -> bashkit::Bash + Send + 'static) -> Self {
Self {
bash_factory: Box::new(bash_factory),
#[cfg(feature = "scripted_tool")]
scripted_tools: Vec::new(),
}
Expand Down Expand Up @@ -271,7 +279,7 @@ impl McpServer {
}
};

let mut bash = bashkit::Bash::new();
let mut bash = (self.bash_factory)();
let result = match bash.exec(&args.script).await {
Ok(r) => r,
Err(e) => {
Expand Down Expand Up @@ -358,9 +366,9 @@ impl McpServer {
}
}

/// Run the MCP server (backward-compatible entry point).
pub async fn run() -> Result<()> {
let mut server = McpServer::new();
/// Run the MCP server with a factory that produces configured `Bash` instances.
pub async fn run(bash_factory: impl Fn() -> bashkit::Bash + Send + 'static) -> Result<()> {
let mut server = McpServer::new(bash_factory);
server.run().await
}

Expand All @@ -370,7 +378,7 @@ mod tests {

#[tokio::test]
async fn test_initialize() {
let mut server = McpServer::new();
let mut server = McpServer::new(bashkit::Bash::new);
let req = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: serde_json::json!(1),
Expand All @@ -385,7 +393,7 @@ mod tests {

#[tokio::test]
async fn test_tools_list_default() {
let mut server = McpServer::new();
let mut server = McpServer::new(bashkit::Bash::new);
let req = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: serde_json::json!(1),
Expand All @@ -400,7 +408,7 @@ mod tests {

#[tokio::test]
async fn test_tools_call_bash() {
let mut server = McpServer::new();
let mut server = McpServer::new(bashkit::Bash::new);
let req = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: serde_json::json!(1),
Expand All @@ -418,7 +426,7 @@ mod tests {

#[tokio::test]
async fn test_tools_call_unknown() {
let mut server = McpServer::new();
let mut server = McpServer::new(bashkit::Bash::new);
let req = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: serde_json::json!(1),
Expand All @@ -434,7 +442,7 @@ mod tests {

#[tokio::test]
async fn test_method_not_found() {
let mut server = McpServer::new();
let mut server = McpServer::new(bashkit::Bash::new);
let req = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: serde_json::json!(1),
Expand All @@ -446,6 +454,35 @@ mod tests {
assert_eq!(resp.error.expect("error").code, -32601);
}

#[tokio::test]
async fn test_tools_call_respects_max_commands() {
// Factory that creates a Bash with max_commands=2
let mut server = McpServer::new(|| {
bashkit::Bash::builder()
.limits(bashkit::ExecutionLimits::new().max_commands(2))
.build()
});

// Script with 3 commands should hit the limit
let req = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: serde_json::json!(1),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "bash",
"arguments": { "script": "echo a; echo b; echo c" }
}),
};
let resp = server.handle_request(req).await;
let result = resp.result.expect("should have result");
let text = result["content"][0]["text"].as_str().expect("text");
// Should report the limit was exceeded
assert!(
text.contains("limit") || text.contains("exceeded") || result["isError"] == true,
"expected execution limit error, got: {text}"
);
}

#[cfg(feature = "scripted_tool")]
mod scripted_tool_tests {
use super::*;
Expand All @@ -463,7 +500,7 @@ mod tests {

#[tokio::test]
async fn test_tools_list_includes_scripted_tool() {
let mut server = McpServer::new();
let mut server = McpServer::new(bashkit::Bash::new);
server.register_scripted_tool(make_test_tool());

let req = JsonRpcRequest {
Expand All @@ -481,7 +518,7 @@ mod tests {

#[tokio::test]
async fn test_tools_call_scripted_tool() {
let mut server = McpServer::new();
let mut server = McpServer::new(bashkit::Bash::new);
server.register_scripted_tool(make_test_tool());

let req = JsonRpcRequest {
Expand All @@ -501,7 +538,7 @@ mod tests {

#[tokio::test]
async fn test_tools_call_scripted_tool_error() {
let mut server = McpServer::new();
let mut server = McpServer::new(bashkit::Bash::new);
let tool = ScriptedTool::builder("err_api")
.short_description("Error API")
.tool(ToolDef::new("fail", "Always fails"), |_args: &ToolArgs| {
Expand All @@ -526,7 +563,7 @@ mod tests {

#[tokio::test]
async fn test_full_jsonrpc_roundtrip() {
let mut server = McpServer::new();
let mut server = McpServer::new(bashkit::Bash::new);
server.register_scripted_tool(make_test_tool());

// Step 1: initialize
Expand Down
Loading