diff --git a/CLAUDE.md b/CLAUDE.md index 5ccac7f..b279a4b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,6 +1,6 @@ # Claude Code on Databricks -Welcome! This environment comes pre-configured with 5 AI coding agents, 39 skills, and 2 MCP servers. Hermes Agent is available alongside Claude Code, Codex, Gemini CLI, and OpenCode — launch it with `hermes chat`. +Welcome! This environment comes pre-configured with 5 AI coding agents, 43 skills, and 3 MCP servers. Hermes Agent is available alongside Claude Code, Codex, Gemini CLI, and OpenCode — launch it with `hermes chat`. ## Skills (30 total) @@ -39,6 +39,7 @@ From [obra/superpowers](https://github.com/obra/superpowers): - **DeepWiki** - AI-powered documentation for any GitHub repository - **Exa** - Web search and code context retrieval +- **CoDA** (exposed at `/mcp`) - Delegate coding tasks to AI agents via MCP. Any MCP client (Genie Code, Claude Desktop, Cursor) can call `coda_run`, `coda_inbox`, and `coda_get_result` to submit background tasks, check status, and retrieve results. See `docs/mcp-v2-background-execution.md`. ## Databricks CLI diff --git a/README.md b/README.md index fd492bd..55227c2 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Use this template](https://img.shields.io/badge/Use%20this%20template-2ea44f?logo=github)](https://github.com/datasciencemonkey/coding-agents-databricks-apps/generate) [![Deploy to Databricks](https://img.shields.io/badge/Deploy-Databricks%20Apps-FF3621?logo=databricks&logoColor=white)](docs/deployment.md) [![Agents](https://img.shields.io/badge/Agents-5%20included-green)](#whats-inside) -[![Skills](https://img.shields.io/badge/Skills-39%20built--in-blue)](#-all-39-skills) +[![Skills](https://img.shields.io/badge/Skills-43%20built--in-blue)](#-all-43-skills) > Run Claude Code, Codex, Gemini CLI, Hermes Agent, and OpenCode in your browser — zero setup, wired to your Databricks workspace. @@ -57,7 +57,7 @@ This isn't just a terminal in the cloud. Running coding agents on Databricks giv | ✂️ **Split Panes** | Run two sessions side by side with a draggable divider | | 🌐 **WebSocket I/O** | Real-time terminal output over WebSocket — zero-latency, eliminates polling delay | | 🔁 **HTTP Polling Fallback** | Automatic fallback via Web Worker when WebSocket is unavailable | -| 🚀 **Parallel Setup** | 7 agent setups run in parallel (~5x faster startup) | +| 🚀 **Parallel Setup** | 6 agent setups run in parallel (~5x faster startup) | | 🔍 **Search** | Find anything in your terminal history (Ctrl+Shift+F) | | 🎤 **Voice Input** | Dictate commands with your mic (Option+V) | | 📋 **Image Paste** | Paste or drag-and-drop images into the terminal — saved to `~/uploads/`, path inserted automatically | @@ -71,7 +71,7 @@ This isn't just a terminal in the cloud. Running coding agents on Databricks giv ## MLflow Tracing -Every Claude Code session is **automatically traced** to a Databricks MLflow experiment — zero configuration required. +Claude Code sessions can be **automatically traced** to a Databricks MLflow experiment. Tracing is disabled by default — set `MLFLOW_CLAUDE_TRACING_ENABLED=true` in your app environment to enable it. ### How it works @@ -116,11 +116,11 @@ View them in the Databricks UI: **Workspace > Machine Learning > Experiments**. ### Configuration -Tracing is configured during app startup by `setup_mlflow.py`, which merges the following into `~/.claude/settings.json`: +Tracing is configured during app startup by `setup/setup_mlflow.py`, which merges the following into `~/.claude/settings.json`: | Setting | Value | Purpose | |---------|-------|---------| -| `MLFLOW_CLAUDE_TRACING_ENABLED` | `true` | Enables Claude Code tracing | +| `MLFLOW_CLAUDE_TRACING_ENABLED` | `false` | Claude Code tracing (disabled by default, set to `true` to enable) | | `MLFLOW_TRACKING_URI` | `databricks` | Routes traces to Databricks backend | | `MLFLOW_EXPERIMENT_NAME` | `/Users/{owner}/{app}` | Target experiment path | | `OTEL_EXPORTER_OTLP_ENDPOINT` | `""` | Overrides container OTEL to prevent trace loss | @@ -170,7 +170,7 @@ This template repo opens that vision up for every Databricks user — no IDE set ---
-🧠 All 39 Skills +🧠 All 43 Skills ### Databricks Skills (25) — [ai-dev-kit](https://github.com/databricks-solutions/ai-dev-kit) @@ -195,16 +195,100 @@ This template repo opens that vision up for every Databricks user — no IDE set | Ship | finishing-branch, git-worktrees | | Meta | dispatching-agents, writing-skills, using-superpowers | +### BDD Skills (4) + +| Category | Skills | +|----------|--------| +| Testing | bdd-features, bdd-run, bdd-scaffold, bdd-steps | +
-🔌 2 MCP Servers +🔌 MCP Servers + +### Built-in MCP Clients | Server | What it does | |--------|-------------| | **DeepWiki** | Ask questions about any GitHub repo — gets AI-powered answers from the codebase | | **Exa** | Web search and code context retrieval for up-to-date information | +### CoDA MCP Server (exposed at `/mcp`) + +CoDA itself exposes an **MCP server** that any MCP-compatible client can connect to — delegate coding tasks to AI agents running on Databricks, without needing the terminal UI. + +| Tool | Purpose | +|------|---------| +| `coda_run` | Fire-and-forget: submit a coding task, get back immediately | +| `coda_inbox` | Dashboard: see all running/completed/failed tasks at a glance | +| `coda_get_result` | Pull the full structured result of a completed task | + +**Why this matters:** Any tool that speaks MCP can use your Databricks-hosted coding agents — no custom integration needed. + +#### Example: Databricks Genie Code + +Genie Code connects to CoDA's MCP endpoint and delegates coding work to agents running in the background: + +``` +User → Genie Code: "Build me a sales pipeline using the transactions table" + +Genie Code calls coda_run(prompt="Build a sales pipeline...", email="user@company.com", + context='{"tables": ["sales.transactions"]}') + +→ Returns immediately: {task_id: "task-abc", status: "running"} +→ User keeps chatting with Genie Code while the agent works + +User → Genie Code: "How's my pipeline coming?" + +Genie Code calls coda_inbox() +→ {tasks: [{task_id: "task-abc", status: "completed", summary: "Built pipeline.py..."}]} + +Genie Code calls coda_get_result(task_id="task-abc", session_id="sess-123") +→ {summary: "Created pipeline.py with 3 stages", files_changed: ["pipeline.py"], ...} +``` + +#### Connecting MCP Clients (Claude Code, Claude Desktop, Cursor, etc.) + +Databricks Apps use OAuth — not PATs — for authentication. A static `Authorization: Bearer ` header will get a `302` redirect to the OAuth login page. To connect any MCP client, use the **stdio bridge** (`tools/coda-bridge.py`) which injects fresh OAuth tokens automatically via `databricks auth token`. + +**1. Copy the bridge script:** + +```bash +mkdir -p ~/.claude/mcp-bridges +cp tools/coda-bridge.py ~/.claude/mcp-bridges/ +``` + +**2. Add to your MCP client settings** (e.g. `~/.claude/settings.json`): + +```json +"coda-mcp": { + "type": "stdio", + "command": "python3", + "args": ["/path/to/.claude/mcp-bridges/coda-bridge.py"], + "env": { + "CODA_MCP_URL": "https://your-app.databricksapps.com/mcp", + "DATABRICKS_PROFILE": "your-profile" + } +} +``` + +**3. Restart your MCP client.** + +The bridge reads `CODA_MCP_URL` and `DATABRICKS_PROFILE` from environment — no hardcoded values. If you redeploy the app or switch workspaces, just update the `env` block. + +**Prerequisites:** `databricks` CLI installed and authenticated (`databricks auth login -p `), Python 3.8+, no pip dependencies. + +**Troubleshooting:** Bridge logs go to stderr. If you see `Auth failed (302)`, refresh your CLI session with `databricks auth login -p `. See [full setup guide](docs/mcp-client-setup.md) for details. + +#### Task Chaining + +Chain tasks by passing `previous_session_id` — the new agent reads the prior task's results for context: + +``` +coda_run(prompt="Add monitoring to the pipeline", previous_session_id="sess-123") +``` + +See [MCP v2 Design Doc](docs/mcp-v2-background-execution.md) for the full protocol reference.
@@ -238,7 +322,7 @@ This template repo opens that vision up for every Databricks user — no IDE set 1. Gunicorn starts, calls `initialize_app()` via `post_worker_init` hook 2. App serves the terminal UI with inline setup progress -3. Background thread runs setup: 5 sequential steps (git config, micro editor, GitHub CLI, Databricks CLI upgrade, content-filter proxy), then 6 agent setups (Claude, Codex, OpenCode, Gemini, Databricks CLI config, MLflow) run in parallel via `ThreadPoolExecutor` +3. Background thread runs setup: 5 sequential steps (git config, micro editor, GitHub CLI, Databricks CLI upgrade, content-filter proxy), then 6 agent setups (`setup/setup_claude.py`, `setup/setup_codex.py`, etc.) run in parallel via `ThreadPoolExecutor` 4. `/api/setup-status` endpoint reports progress to the UI 5. Once complete, the terminal becomes interactive @@ -258,6 +342,7 @@ This template repo opens that vision up for every Databricks user — no IDE set | `/api/resize` | POST | Resize terminal dimensions | | `/api/upload` | POST | Upload file (clipboard image paste) | | `/api/session/close` | POST | Close terminal session | +| `/mcp` | POST | MCP JSON-RPC endpoint (CoDA tools) | ### WebSocket Events (Socket.IO) @@ -306,7 +391,7 @@ Production uses `workers=1` (PTY state is process-local), `threads=16` (concurre coding-agents-databricks-apps/ ├── app.py # Flask backend + PTY management + setup orchestration ├── app_state.py # Shared app state (setup progress, session registry) -├── app.yaml.template # Databricks Apps deployment config template +├── app.yaml # Databricks Apps deployment config (gunicorn) ├── cli_auth.py # Interactive PAT setup + CLI credential writer ├── content_filter_proxy.py # Proxy that sanitises empty-content blocks for OpenCode ├── gunicorn.conf.py # Gunicorn production server config @@ -315,18 +400,27 @@ coding-agents-databricks-apps/ ├── requirements.txt # Compiled from pyproject.toml (Dependabot compatibility) ├── requirements.lock # Hash-pinned lockfile (auto-regenerated by CI) ├── Makefile # Deploy, redeploy, status, and cleanup targets -├── setup_claude.py # Claude Code CLI + MCP configuration -├── setup_codex.py # Codex CLI configuration -├── setup_gemini.py # Gemini CLI configuration -├── setup_opencode.py # OpenCode configuration -├── setup_databricks.py # Databricks CLI configuration -├── setup_mlflow.py # MLflow tracing auto-configuration -├── setup_proxy.py # Content-filter proxy startup ├── sync_to_workspace.py # Post-commit hook: sync to Workspace -├── install_micro.sh # Micro editor installer -├── install_gh.sh # GitHub CLI installer (OS/arch-aware) -├── install_databricks_cli.sh # Databricks CLI upgrade script -├── utils.py # Utility functions (ensure_https) +├── utils.py # Utility functions (ensure_https, gateway discovery) +├── coda_mcp/ # MCP server package (CoDA — Coding Agents) +│ ├── __init__.py +│ ├── mcp_server.py # FastMCP tool definitions (coda_run, coda_inbox, coda_get_result) +│ ├── mcp_endpoint.py # Flask Blueprint: JSON-RPC /mcp endpoint +│ ├── mcp_asgi.py # ASGI bridge (optional, for native MCP SDK transport) +│ └── task_manager.py # Disk-based session/task state manager +├── setup/ # Agent setup scripts (run at boot) +│ ├── setup_claude.py # Claude Code CLI + MCP configuration +│ ├── setup_codex.py # Codex CLI configuration +│ ├── setup_gemini.py # Gemini CLI configuration +│ ├── setup_opencode.py # OpenCode configuration +│ ├── setup_hermes.py # Hermes Agent configuration +│ ├── setup_databricks.py # Databricks CLI configuration +│ ├── setup_mlflow.py # MLflow tracing auto-configuration +│ └── setup_proxy.py # Content-filter proxy startup +├── scripts/ # Shell scripts +│ ├── install_micro.sh # Micro editor installer +│ ├── install_gh.sh # GitHub CLI installer (OS/arch-aware) +│ └── install_databricks_cli.sh # Databricks CLI upgrade script ├── static/ │ ├── index.html # Terminal UI (xterm.js + split panes + WebSocket) │ ├── favicon.svg # App favicon @@ -340,8 +434,12 @@ coding-agents-databricks-apps/ │ └── workflows/ │ ├── dependency-audit.yml # Weekly CVE audit + lockfile drift check │ └── update-lockfile.yml # Auto-regenerate requirements.lock on push +├── tools/ +│ └── coda-bridge.py # Stdio-to-HTTP MCP bridge (OAuth token injection) └── docs/ ├── deployment.md # Full Databricks Apps deployment guide + ├── mcp-client-setup.md # MCP client setup guide (bridge config) + ├── mcp-v2-background-execution.md # MCP server design doc ├── prd/ # Product requirement documents └── plans/ # Design documentation ``` diff --git a/app.py b/app.py index 0c63cad..c22646b 100644 --- a/app.py +++ b/app.py @@ -1,3 +1,4 @@ +import asyncio import os import pty import fcntl @@ -57,8 +58,46 @@ app.config['MAX_CONTENT_LENGTH'] = 32 * 1024 * 1024 # 32 MB — aligned with Claude Code's 30 MB file limit # WebSocket support via Flask-SocketIO (simple-websocket transport, threading mode) +# Used for local dev (python app.py). Under uvicorn/ASGI, the AsyncServer in +# mcp_asgi.py intercepts /socket.io/ before WSGIMiddleware, so these handlers +# are only active in WSGI mode. socketio = SocketIO(app, async_mode='threading', cors_allowed_origins=[], logger=False, engineio_logger=False) +# ── ASGI WebSocket support (python-socketio AsyncServer) ───────────── +# Set by mcp_asgi.py at startup. Background threads use _emit_from_thread() +# which routes to the async server (ASGI) or Flask-SocketIO (WSGI) automatically. +_async_sio = None +_event_loop = None + + +def set_async_sio(sio_instance, loop): + """Called by mcp_asgi.py to wire up the ASGI Socket.IO server.""" + global _async_sio, _event_loop + _async_sio = sio_instance + _event_loop = loop + + +def _emit_from_thread(event, data, room=None): + """Thread-safe emit for background threads (PTY reader, cleanup, SIGTERM). + + Routes to AsyncServer (ASGI mode) or Flask-SocketIO (WSGI mode) automatically. + """ + if _async_sio and _event_loop and _event_loop.is_running(): + try: + asyncio.run_coroutine_threadsafe( + _async_sio.emit(event, data, room=room), + _event_loop, + ) + except Exception: + pass + else: + # WSGI mode (local dev) — use Flask-SocketIO directly + try: + socketio.emit(event, data, room=room) + except Exception: + pass + + # Store sessions: {session_id: {"master_fd": fd, "pid": pid, "output_buffer": deque, "lock": Lock, ...}} # sessions_lock guards dict-level ops (add/remove/iterate); each session["lock"] guards per-session state sessions = {} @@ -85,10 +124,7 @@ def handle_sigterm(signum, frame): shutting_down = True logger.info("SIGTERM received — setting shutting_down flag for clients") # Notify WS clients immediately (HTTP poll clients will see shutting_down on next poll) - try: - socketio.emit('shutting_down', {}) - except Exception: - pass + _emit_from_thread('shutting_down', {}) # NOTE: Do not register SIGTERM handler at module level. # It is installed in initialize_app() for gunicorn only. @@ -149,6 +185,11 @@ def _run_step(step_id, command): env.pop("DATABRICKS_CLIENT_ID", None) env.pop("DATABRICKS_CLIENT_SECRET", None) + # Ensure setup scripts can still import from repo root (e.g. `from utils import ...`) + app_dir = os.path.dirname(os.path.abspath(__file__)) + existing_pp = env.get("PYTHONPATH", "") + env["PYTHONPATH"] = f"{app_dir}:{existing_pp}" if existing_pp else app_dir + result = subprocess.run(command, env=env, capture_output=True, text=True, timeout=300) if result.returncode == 0: _update_step(step_id, status="complete", completed_at=time.time()) @@ -323,8 +364,14 @@ def _configure_all_cli_auth(token): # 3. Re-run Codex, OpenCode, Gemini setup scripts with token in env # They are idempotent: detect CLI already installed, just write config files - env = {**os.environ, "DATABRICKS_TOKEN": token} - for script in ["setup_codex.py", "setup_opencode.py", "setup_gemini.py", "setup_hermes.py"]: + app_dir = os.path.dirname(os.path.abspath(__file__)) + existing_pp = os.environ.get("PYTHONPATH", "") + env = { + **os.environ, + "DATABRICKS_TOKEN": token, + "PYTHONPATH": f"{app_dir}:{existing_pp}" if existing_pp else app_dir, + } + for script in ["setup/setup_codex.py", "setup/setup_opencode.py", "setup/setup_gemini.py", "setup/setup_hermes.py"]: try: result = subprocess.run( ["uv", "run", "python", script], @@ -357,26 +404,26 @@ def run_setup(): _update_step("git", status="error", completed_at=time.time(), error=str(e)) _run_step("micro", ["bash", "-c", - "mkdir -p ~/.local/bin && bash install_micro.sh && mv micro ~/.local/bin/ 2>/dev/null || true"]) + "mkdir -p ~/.local/bin && bash scripts/install_micro.sh && mv micro ~/.local/bin/ 2>/dev/null || true"]) - _run_step("gh", ["bash", "install_gh.sh"]) + _run_step("gh", ["bash", "scripts/install_gh.sh"]) # --- Upgrade Databricks CLI (runtime image ships an older version) --- - _run_step("dbcli", ["bash", "install_databricks_cli.sh"]) + _run_step("dbcli", ["bash", "scripts/install_databricks_cli.sh"]) # --- Content-filter proxy (must be running before OpenCode starts) --- # Sanitizes requests/responses between OpenCode and Databricks # (see OpenCode #5028, docs/plans/2026-03-11-litellm-empty-content-blocks-design.md) - _run_step("proxy", ["uv", "run", "python", "setup_proxy.py"]) + _run_step("proxy", ["uv", "run", "python", "setup/setup_proxy.py"]) # --- Parallel agent setup (all independent of each other) --- parallel_steps = [ - ("claude", ["uv", "run", "python", "setup_claude.py"]), - ("codex", ["uv", "run", "python", "setup_codex.py"]), - ("opencode", ["uv", "run", "python", "setup_opencode.py"]), - ("gemini", ["uv", "run", "python", "setup_gemini.py"]), - ("hermes", ["uv", "run", "python", "setup_hermes.py"]), - ("databricks", ["uv", "run", "python", "setup_databricks.py"]), + ("claude", ["uv", "run", "python", "setup/setup_claude.py"]), + ("codex", ["uv", "run", "python", "setup/setup_codex.py"]), + ("opencode", ["uv", "run", "python", "setup/setup_opencode.py"]), + ("gemini", ["uv", "run", "python", "setup/setup_gemini.py"]), + ("hermes", ["uv", "run", "python", "setup/setup_hermes.py"]), + ("databricks", ["uv", "run", "python", "setup/setup_databricks.py"]), ] with ThreadPoolExecutor(max_workers=len(parallel_steps)) as executor: @@ -389,7 +436,7 @@ def run_setup(): # --- MLflow setup runs AFTER claude setup to avoid settings.json race --- # setup_mlflow.py merges env vars into ~/.claude/settings.json which # setup_claude.py also writes; running sequentially prevents clobbering. - _run_step("mlflow", ["uv", "run", "python", "setup_mlflow.py"]) + _run_step("mlflow", ["uv", "run", "python", "setup/setup_mlflow.py"]) # Sync latest token into all CLI configs — covers the race where PAT # rotation happened while a setup script was still installing (the @@ -527,7 +574,132 @@ def _check_ws_authorization(): return True -# ── WebSocket Event Handlers ────────────────────────────────────────────── +def _check_ws_authorization_from_environ(environ): + """Check authorization from WSGI environ dict (for ASGI WebSocket via python-socketio). + + Same logic as _check_ws_authorization() but reads headers from the environ + dict instead of Flask's request context. WSGI environ stores HTTP headers as + HTTP_X_FORWARDED_EMAIL (uppercase, underscores, HTTP_ prefix). + """ + if not app_owner: + if _is_databricks_apps(): + logger.error("SECURITY: app_owner not resolved — denying WebSocket (fail-closed)") + return False + return True # Local dev only + + raw_user = ( + environ.get("HTTP_X_FORWARDED_EMAIL") + or environ.get("HTTP_X_FORWARDED_USER") + or environ.get("HTTP_X_DATABRICKS_USER_EMAIL") + ) + current_user = raw_user.lower() if raw_user else raw_user + + if not current_user: + if _is_databricks_apps(): + logger.warning("No user identity in WebSocket request on Databricks Apps — denying") + return False + return True # Local dev only + + if current_user != app_owner: + logger.warning(f"WebSocket unauthorized: {current_user} (owner: {app_owner})") + return False + return True + + +def register_sio_handlers(sio): + """Register Socket.IO event handlers on an AsyncServer for ASGI mode. + + Called by mcp_asgi.py. The handlers mirror the Flask-SocketIO handlers below + but use python-socketio's async API (explicit sid, enter_room/leave_room, + async def, ConnectionRefusedError for auth denial). + """ + + @sio.on('connect') + async def handle_connect(sid, environ, auth): + # Capture event loop on first connection for _emit_from_thread() + set_async_sio(sio, asyncio.get_running_loop()) + + # Diagnostic: log transport and header presence for debugging proxy behavior + transport = environ.get('QUERY_STRING', '') + has_email = bool(environ.get('HTTP_X_FORWARDED_EMAIL')) + has_user = bool(environ.get('HTTP_X_FORWARDED_USER')) + logger.info(f"WS connect: sid={sid}, qs={transport}, " + f"has_email={has_email}, has_user={has_user}") + + if not _check_ws_authorization_from_environ(environ): + raise ConnectionRefusedError('unauthorized') + logger.info("WebSocket client connected (ASGI)") + + @sio.on('join_session') + async def handle_join_session(sid, data): + session_id = data.get('session_id') + if not session_id: + return {'status': 'error', 'message': 'session_id required'} + sess = _get_session(session_id) + if not sess: + return {'status': 'error', 'message': 'Session not found'} + with sess["lock"]: + sess["last_poll_time"] = time.time() + sess["output_buffer"].clear() + await sio.enter_room(sid, session_id) + logger.info(f"WebSocket client joined session room {session_id}") + return {'status': 'ok'} + + @sio.on('leave_session') + async def handle_leave_session(sid, data): + session_id = data.get('session_id') + if session_id: + await sio.leave_room(sid, session_id) + logger.info(f"WebSocket client left session room {session_id}") + + @sio.on('terminal_input') + async def handle_terminal_input(sid, data): + session_id = data.get('session_id') + input_data = data.get('input', '') + sess = _get_session(session_id) + if not sess: + return + with sess["lock"]: + sess["last_poll_time"] = time.time() + fd = sess["master_fd"] + try: + os.write(fd, input_data.encode()) + except OSError as e: + logger.warning(f"WebSocket input write error for {session_id}: {e}") + + @sio.on('terminal_resize') + async def handle_terminal_resize(sid, data): + session_id = data.get('session_id') + cols = data.get('cols', 80) + rows = data.get('rows', 24) + sess = _get_session(session_id) + if not sess: + return + with sess["lock"]: + sess["last_poll_time"] = time.time() + fd = sess["master_fd"] + try: + winsize = struct.pack("HHHH", rows, cols, 0, 0) + fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize) + except OSError as e: + logger.warning(f"WebSocket resize error for {session_id}: {e}") + + @sio.on('heartbeat') + async def handle_heartbeat(sid, data): + session_ids = data.get('session_ids', []) + now = time.time() + for s_id in session_ids: + sess = _get_session(s_id) + if sess: + with sess["lock"]: + sess["last_poll_time"] = now + + @sio.on('disconnect') + async def handle_disconnect(sid): + logger.info("WebSocket client disconnected (ASGI)") + + +# ── WebSocket Event Handlers (Flask-SocketIO — WSGI/local dev only) ────── @socketio.on('connect') def handle_ws_connect(): @@ -658,12 +830,9 @@ def read_pty_output(session_id, fd): session["output_buffer"].append(decoded) session["last_poll_time"] = time.time() # Keep session alive during WS output # Push via WebSocket to the session room (AC-8) - try: - socketio.emit('terminal_output', + _emit_from_thread('terminal_output', {'session_id': session_id, 'output': decoded}, room=session_id) - except Exception: - pass # No WebSocket clients — HTTP polling handles it else: # select timed out — check if process is still alive try: @@ -678,10 +847,7 @@ def read_pty_output(session_id, fd): break # Process exited or fd closed — notify WebSocket clients (AC-9) - try: - socketio.emit('session_exited', {'session_id': session_id}, room=session_id) - except Exception: - pass + _emit_from_thread('session_exited', {'session_id': session_id}, room=session_id) logger.info(f"Session {session_id} process exited") @@ -695,10 +861,7 @@ def terminate_session(session_id, pid, master_fd): logger.info(f"Terminating stale session {session_id} (pid={pid})") # Notify WebSocket clients that the session is closed - try: - socketio.emit('session_closed', {'session_id': session_id}, room=session_id) - except Exception: - pass + _emit_from_thread('session_closed', {'session_id': session_id}, room=session_id) try: os.kill(pid, signal.SIGHUP) @@ -805,7 +968,7 @@ def cleanup_stale_sessions(): def authorize_request(): """Check authorization before processing any request.""" # Skip auth for health check, setup status, and Socket.IO (has own auth via connect event) - if request.path in ("/health", "/api/setup-status", "/api/pat-status", "/api/configure-pat", "/api/app-state") or request.path.startswith("/socket.io"): + if request.path in ("/health", "/api/setup-status", "/api/pat-status", "/api/configure-pat", "/api/app-state") or request.path.startswith("/socket.io") or request.path.startswith("/mcp"): return None authorized, user = check_authorization() @@ -820,6 +983,10 @@ def authorize_request(): @app.after_request def set_security_headers(response): + # MCP endpoint handles its own CORS/headers — skip security headers + # that might interfere (CSP connect-src, X-Frame-Options, etc.) + if request.path.startswith("/mcp"): + return response response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" @@ -1080,6 +1247,92 @@ def create_session(): return jsonify({"error": str(e)}), 500 +# ── MCP Integration Helpers ────────────────────────────────────────── + + +def mcp_create_pty_session(label: str = "hermes-mcp") -> str: + """Create a PTY session for MCP use. Returns the PTY session_id.""" + with sessions_lock: + if len(sessions) >= MAX_CONCURRENT_SESSIONS: + raise RuntimeError( + f"Maximum {MAX_CONCURRENT_SESSIONS} concurrent sessions reached." + ) + + master_fd, slave_fd = pty.openpty() + + shell_env = os.environ.copy() + shell_env["TERM"] = "xterm-256color" + shell_env.pop("CLAUDECODE", None) + shell_env.pop("CLAUDE_CODE_SESSION", None) + shell_env.pop("DATABRICKS_TOKEN", None) + shell_env.pop("DATABRICKS_HOST", None) + shell_env.pop("GEMINI_API_KEY", None) + if not shell_env.get("HOME") or shell_env["HOME"] == "/": + shell_env["HOME"] = "/app/python/source_code" + local_bin = f"{shell_env['HOME']}/.local/bin" + shell_env["PATH"] = f"{local_bin}:{shell_env.get('PATH', '')}" + + projects_dir = os.path.join(shell_env["HOME"], "projects") + os.makedirs(projects_dir, exist_ok=True) + + pid = subprocess.Popen( + ["/bin/bash"], + stdin=slave_fd, + stdout=slave_fd, + stderr=slave_fd, + preexec_fn=os.setsid, + env=shell_env, + cwd=projects_dir, + ).pid + os.close(slave_fd) + + session_id = str(uuid.uuid4()) + + with sessions_lock: + if len(sessions) >= MAX_CONCURRENT_SESSIONS: + os.close(master_fd) + try: + os.kill(pid, signal.SIGKILL) + except OSError: + pass + raise RuntimeError( + f"Maximum {MAX_CONCURRENT_SESSIONS} concurrent sessions reached." + ) + sessions[session_id] = { + "master_fd": master_fd, + "pid": pid, + "output_buffer": deque(maxlen=1000), + "lock": threading.Lock(), + "last_poll_time": time.time(), + "created_at": time.time(), + "label": label, + } + + thread = threading.Thread( + target=read_pty_output, args=(session_id, master_fd), daemon=True + ) + thread.start() + + return session_id + + +def mcp_send_input(session_id: str, data: str): + """Send input to a PTY session.""" + session = _get_session(session_id) + if not session: + raise RuntimeError(f"Session {session_id} not found") + with session["lock"]: + os.write(session["master_fd"], data.encode()) + + +def mcp_close_pty_session(session_id: str): + """Close a PTY session.""" + session = _get_session(session_id) + if not session: + return + terminate_session(session_id, session["pid"], session["master_fd"]) + + @app.route("/api/input", methods=["POST"]) def send_input(): """Send input to the terminal.""" @@ -1297,6 +1550,20 @@ def initialize_app(local_dev=False): logger.info(f"Started session cleanup thread (timeout={SESSION_TIMEOUT_SECONDS}s, interval={CLEANUP_INTERVAL_SECONDS}s)") +# ── MCP Endpoint ───────────────────────────────────────────────────── +from coda_mcp.mcp_endpoint import mcp_bp +from coda_mcp.mcp_server import set_app_hooks + +app.register_blueprint(mcp_bp) + +# Wire MCP tools to PTY infrastructure +set_app_hooks( + create_session_fn=mcp_create_pty_session, + send_input_fn=mcp_send_input, + close_session_fn=mcp_close_pty_session, +) + + if __name__ == "__main__": # Local dev — no SIGTERM handler (SIG_DFL), no shutting_down flag initialize_app(local_dev=True) diff --git a/app.yaml b/app.yaml index a0f443c..380a434 100644 --- a/app.yaml +++ b/app.yaml @@ -1,6 +1,10 @@ command: - - gunicorn - - app:app + - uvicorn + - coda_mcp.mcp_asgi:app + - --host + - 0.0.0.0 + - --port + - "8000" env: - name: HOME value: /app/python/source_code diff --git a/cli_auth.py b/cli_auth.py index 61c9f25..53c2a25 100644 --- a/cli_auth.py +++ b/cli_auth.py @@ -35,6 +35,7 @@ def _update_claude(token): settings["env"]["ANTHROPIC_AUTH_TOKEN"] = token with open(path, "w") as f: json.dump(settings, f, indent=2) + os.chmod(path, 0o600) except (OSError, json.JSONDecodeError): pass # file doesn't exist yet — initial setup hasn't run @@ -59,6 +60,7 @@ def _update_opencode(token): if changed: with open(path, "w") as f: json.dump(auth, f, indent=2) + os.chmod(path, 0o600) except (OSError, json.JSONDecodeError): pass @@ -84,6 +86,7 @@ def _update_hermes(token): if new_content != content: with open(path, "w") as f: f.write(new_content) + os.chmod(path, 0o600) except OSError: pass @@ -102,5 +105,6 @@ def _replace_dotenv_key(path, key, value): if new_content != content: with open(path, "w") as f: f.write(new_content) + os.chmod(path, 0o600) except OSError: pass diff --git a/coda_mcp/__init__.py b/coda_mcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/coda_mcp/mcp_asgi.py b/coda_mcp/mcp_asgi.py new file mode 100644 index 0000000..c90a939 --- /dev/null +++ b/coda_mcp/mcp_asgi.py @@ -0,0 +1,91 @@ +"""Native MCP ASGI app with WebSocket support for terminal I/O. + +Architecture (all on one port, one uvicorn process): + + socketio.ASGIApp ← /socket.io/ → native ASGI WebSocket (terminal) + └── mcp_starlette ← /mcp → FastMCP Streamable HTTP (Genie Code) + └── WSGI(Flask) ← /* → REST API, static files (HTTP only) + +Usage in app.yaml:: + + command: ["uvicorn", "coda_mcp.mcp_asgi:app", "--host", "0.0.0.0", "--port", "8000"] +""" + +import os +import logging +import warnings + +import socketio as socketio_lib +from starlette.middleware.cors import CORSMiddleware + +with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + from starlette.middleware.wsgi import WSGIMiddleware + +from coda_mcp.mcp_server import mcp as mcp_instance, set_app_hooks +from utils import ensure_https + +logger = logging.getLogger(__name__) + +# ── Build allowed origins ───────────────────────────────────────── +# The browser connects from the app's own URL (e.g. mcp-test-coda-*.databricksapps.com) +# which differs from DATABRICKS_HOST (workspace URL). Databricks proxy handles auth, +# so Socket.IO CORS can safely allow all origins. Starlette CORSMiddleware below +# uses the same list for MCP/Flask routes. +_databricks_host = os.environ.get("DATABRICKS_HOST", "") +ALLOWED_ORIGINS = [] +if _databricks_host: + ALLOWED_ORIGINS.append(ensure_https(_databricks_host).rstrip("/")) + +# ── Import and initialize Flask app ──────────────────────────────── +from app import ( + app as flask_app, + initialize_app, + mcp_create_pty_session, + mcp_send_input, + mcp_close_pty_session, + register_sio_handlers, +) + +initialize_app() + +# Wire MCP tools to PTY infrastructure +set_app_hooks( + create_session_fn=mcp_create_pty_session, + send_input_fn=mcp_send_input, + close_session_fn=mcp_close_pty_session, +) + +# ── Async Socket.IO server (native ASGI WebSocket) ─────────────── +# python-socketio AsyncServer handles /socket.io/ with real WebSocket, +# eliminating the WSGIMiddleware limitation that forced HTTP polling fallback. +sio = socketio_lib.AsyncServer( + async_mode='asgi', + cors_allowed_origins='*', # App URL differs from DATABRICKS_HOST; proxy handles auth + logger=False, + engineio_logger=False, +) + +# Register terminal I/O event handlers (connect, join_session, terminal_input, etc.) +register_sio_handlers(sio) + +# ── Build the ASGI app per Genie Code docs ───────────────────────── +mcp_starlette = mcp_instance.streamable_http_app() + +# Mount Flask as catch-all via WSGI adapter (HTTP routes only) +flask_asgi = WSGIMiddleware(flask_app.wsgi_app) +mcp_starlette.mount("/", app=flask_asgi) + +# CORS for MCP and Flask routes +mcp_starlette.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS or ["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# ── Top-level ASGI app ──────────────────────────────────────────── +# socketio.ASGIApp intercepts /socket.io/ for WebSocket + polling, +# passes everything else to mcp_starlette (MCP at /mcp, Flask at /) +app = socketio_lib.ASGIApp(sio, other_asgi_app=mcp_starlette) diff --git a/coda_mcp/mcp_endpoint.py b/coda_mcp/mcp_endpoint.py new file mode 100644 index 0000000..ce4ab27 --- /dev/null +++ b/coda_mcp/mcp_endpoint.py @@ -0,0 +1,171 @@ +"""Flask-native MCP JSON-RPC endpoint. + +Implements the MCP protocol as a plain Flask route — no ASGI bridge needed. +This keeps gunicorn + Flask-SocketIO working for WebSocket terminal I/O +while serving MCP over standard HTTP. +""" +import asyncio +import json +import logging +from flask import Blueprint, request, jsonify + +logger = logging.getLogger(__name__) + +mcp_bp = Blueprint("mcp", __name__) + +# Import tool functions from mcp_server.py +from coda_mcp.mcp_server import ( + mcp as mcp_instance, + coda_run, + coda_inbox, + coda_get_result, +) + +# Tool function dispatch +_TOOL_DISPATCH = { + "coda_run": coda_run, + "coda_inbox": coda_inbox, + "coda_get_result": coda_get_result, +} + +SERVER_INFO = { + "name": "coda", + "version": "1.0.0", +} + +CAPABILITIES = { + "tools": {"listChanged": False}, +} + + + +def _cors_headers(): + """Build CORS response headers. + + Permissive CORS for /mcp — the Databricks Apps proxy handles auth. + """ + headers = {} + origin = request.headers.get("Origin", "") + if origin: + headers["Access-Control-Allow-Origin"] = origin + headers["Access-Control-Allow-Methods"] = "GET, POST, DELETE, OPTIONS" + # Explicitly list all headers Genie Code might send + # (wildcard * is incompatible with credentials=true per CORS spec) + allowed_headers = ", ".join([ + "Content-Type", "Authorization", "Accept", + "Mcp-Session-Id", "X-Request-Id", "X-Requested-With", + "X-Forwarded-Email", "X-Forwarded-User", "X-Databricks-User-Email", + "Cookie", "Origin", "Referer", + ]) + headers["Access-Control-Allow-Headers"] = allowed_headers + headers["Access-Control-Allow-Credentials"] = "true" + headers["Access-Control-Max-Age"] = "86400" + return headers + + +@mcp_bp.route("/mcp", methods=["POST", "OPTIONS", "GET"]) +def mcp_handler(): + # Handle CORS preflight + if request.method == "OPTIONS": + resp = jsonify({}) + resp.status_code = 204 + for k, v in _cors_headers().items(): + resp.headers[k] = v + return resp + + # Handle GET for SSE (not supported in stateless mode) + if request.method == "GET": + resp = jsonify({"error": "SSE not supported. Use POST."}) + resp.status_code = 405 + return resp + + # Origin validation skipped — Databricks Apps proxy handles auth. + + data = request.get_json(silent=True) or {} + method = data.get("method", "") + req_id = data.get("id") + params = data.get("params", {}) + + # Route by method + if method == "initialize": + result = { + "protocolVersion": params.get("protocolVersion", "2025-03-26"), + "capabilities": CAPABILITIES, + "serverInfo": SERVER_INFO, + "instructions": mcp_instance._instructions if hasattr(mcp_instance, '_instructions') else "", + } + resp = jsonify({"jsonrpc": "2.0", "id": req_id, "result": result}) + + elif method == "notifications/initialized": + # No-op acknowledgment — return empty OK + resp = jsonify({}) + resp.status_code = 200 + + elif method == "tools/list": + tools = _build_tools_list() + resp = jsonify({"jsonrpc": "2.0", "id": req_id, "result": {"tools": tools}}) + + elif method == "tools/call": + tool_name = params.get("name", "") + arguments = params.get("arguments", {}) + tool_fn = _TOOL_DISPATCH.get(tool_name) + if not tool_fn: + resp = jsonify({ + "jsonrpc": "2.0", "id": req_id, + "error": {"code": -32601, "message": f"Unknown tool: {tool_name}"} + }) + else: + try: + # Tool functions are async — run them + result_str = asyncio.run(tool_fn(**arguments)) + result_data = json.loads(result_str) + resp = jsonify({ + "jsonrpc": "2.0", "id": req_id, + "result": { + "content": [{"type": "text", "text": result_str}], + "isError": "error" in result_data, + } + }) + except Exception as e: + resp = jsonify({ + "jsonrpc": "2.0", "id": req_id, + "error": {"code": -32603, "message": str(e)} + }) + + elif method == "ping": + resp = jsonify({"jsonrpc": "2.0", "id": req_id, "result": {}}) + + else: + resp = jsonify({ + "jsonrpc": "2.0", "id": req_id, + "error": {"code": -32601, "message": f"Method not found: {method}"} + }) + + # Add CORS headers + for k, v in _cors_headers().items(): + resp.headers[k] = v + + return resp + + +def _build_tools_list(): + """Extract tool definitions from FastMCP registry.""" + tools = [] + # Access FastMCP's internal tool manager + tool_manager = mcp_instance._tool_manager + for name, tool in tool_manager._tools.items(): + tool_dict = { + "name": tool.name, + "description": tool.description or "", + "inputSchema": tool.parameters if hasattr(tool, 'parameters') else {}, + } + if hasattr(tool, 'annotations') and tool.annotations: + tool_dict["annotations"] = {} + if tool.annotations.readOnlyHint is not None: + tool_dict["annotations"]["readOnlyHint"] = tool.annotations.readOnlyHint + if tool.annotations.destructiveHint is not None: + tool_dict["annotations"]["destructiveHint"] = tool.annotations.destructiveHint + if tool.annotations.idempotentHint is not None: + tool_dict["annotations"]["idempotentHint"] = tool.annotations.idempotentHint + tools.append(tool_dict) + return tools diff --git a/coda_mcp/mcp_server.py b/coda_mcp/mcp_server.py new file mode 100644 index 0000000..c4884e6 --- /dev/null +++ b/coda_mcp/mcp_server.py @@ -0,0 +1,365 @@ +"""MCP server exposing CoDA session/task tools via FastMCP. + +v2: Background execution + inbox pattern. +- ``coda_run`` — fire-and-forget task submission (auto-creates ephemeral session) +- ``coda_inbox`` — dashboard of all background tasks +- ``coda_get_result`` — pull full structured result for a completed task + +Delegates all disk state to ``task_manager.py``. PTY operations are +handled through optional app hooks set via ``set_app_hooks()``. + +Run standalone for testing:: + + python mcp_server.py # stdio transport +""" + +import json +import logging +import os +import threading +import time + +from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp.server import TransportSecuritySettings +from mcp.types import ToolAnnotations + +from coda_mcp import task_manager + +logger = logging.getLogger(__name__) + +# ── FastMCP instance ──────────────────────────────────────────────── + +# Build allowed origins from DATABRICKS_HOST for Genie Code requests +_databricks_host = os.environ.get("DATABRICKS_HOST", "") +_allowed_origins = [] +if _databricks_host: + # Ensure https:// prefix, strip trailing slash + origin = _databricks_host if _databricks_host.startswith("https://") else f"https://{_databricks_host}" + _allowed_origins.append(origin.rstrip("/")) + +mcp = FastMCP( + "coda", + instructions=( + "CoDA MCP server — delegate coding tasks to AI agents on Databricks.\n\n" + "CRITICAL — FIRE AND FORGET:\n" + "coda_run submits work and returns IMMEDIATELY. The task runs autonomously " + "in the background. After calling coda_run, DO NOT call coda_inbox or " + "coda_get_result to check on it. Do NOT loop, poll, or wait. Simply tell " + "the user the task was submitted and MOVE ON to their next request.\n\n" + "WHEN TO CHECK INBOX:\n" + "Call coda_inbox ONLY when the user explicitly asks about background tasks " + "(e.g. 'how's my task going?', 'check on that', 'what's in my inbox'). " + "Never call it proactively, automatically, or in a loop.\n\n" + "WORKFLOW:\n" + "1) coda_run — submit work, get back task_id. Tell user it's running. Stop.\n" + "2) Continue chatting about other topics — the task runs independently.\n" + "3) coda_inbox — ONLY when user asks. Shows all tasks from last 24h.\n" + "4) coda_get_result — for completed tasks, get full structured output.\n\n" + "CHAINING: pass previous_session_id from a completed task's session_id " + "to give the new task context of what was done before." + ), + stateless_http=True, + json_response=True, + transport_security=TransportSecuritySettings( + enable_dns_rebinding_protection=False, + ), +) + +# ── App hooks (PTY integration) ───────────────────────────────────── + +_app_create_session = None +_app_send_input = None +_app_close_session = None + + +def set_app_hooks(create_session_fn, send_input_fn, close_session_fn): + """Wire up Flask app callbacks for PTY operations. + + When hooks are set: + - ``coda_run`` creates a PTY via ``create_session_fn(label=...)`` + - ``coda_run`` sends the hermes command via ``send_input_fn(pty_id, cmd)`` + - Task completion destroys the PTY via ``close_session_fn(pty_id)`` + + When hooks are *not* set (e.g. in tests), only disk state is managed. + """ + global _app_create_session, _app_send_input, _app_close_session + _app_create_session = create_session_fn + _app_send_input = send_input_fn + _app_close_session = close_session_fn + + +# ── Background watcher ────────────────────────────────────────────── + + +def _watch_task(session_id: str, task_id: str, timeout_s: int) -> None: + """Poll for result.json in a daemon thread. + + - Checks every 5 seconds for ``result.json`` in the task directory. + - If found, calls ``task_manager.complete_task()`` (which auto-closes session). + - Tracks last activity from ``status.jsonl`` mtime. + - Timeout: if wall clock exceeds *timeout_s* AND no status update + in the last 5 minutes, writes a timeout result and completes. + - On completion, closes the PTY if hooks are wired. + """ + tdir = task_manager._task_dir(session_id, task_id) + status_path = os.path.join(tdir, "status.jsonl") + start = time.time() + stale_threshold = 300 # 5 minutes + + while True: + time.sleep(5) + + # Check for result.json (may be at root or in results/ subdir) + result_path = task_manager._find_result_json(tdir) + if result_path: + try: + task_manager.complete_task(session_id, task_id) + _close_pty_for_session(session_id) + logger.info("Watcher: task %s completed (result found)", task_id) + except Exception: + logger.exception("Watcher: error completing task %s", task_id) + return + + # Check timeout + elapsed = time.time() - start + if elapsed > timeout_s: + # Check last activity + try: + last_activity = os.path.getmtime(status_path) + except OSError: + last_activity = start + + if (time.time() - last_activity) > stale_threshold: + # Write timeout result and complete + try: + timeout_result_path = os.path.join(tdir, "result.json") + task_manager._write_json(timeout_result_path, { + "status": "timeout", + "summary": "Task timed out", + "files_changed": [], + "artifacts": [], + "errors": [f"Timeout after {timeout_s}s with no activity for 5 min"], + }) + task_manager.complete_task(session_id, task_id) + _close_pty_for_session(session_id) + logger.warning("Watcher: task %s timed out", task_id) + except Exception: + logger.exception("Watcher: error timing out task %s", task_id) + return + + +def _close_pty_for_session(session_id: str) -> None: + """Close the PTY associated with a session, if hooks are wired.""" + if _app_close_session is None: + return + try: + session = task_manager._read_session(session_id) + pty_session_id = session.get("pty_session_id") + if pty_session_id: + _app_close_session(pty_session_id) + except Exception: + logger.debug("Could not close PTY for session %s", session_id, exc_info=True) + + +# ── Tool definitions ──────────────────────────────────────────────── + + +@mcp.tool( + annotations=ToolAnnotations( + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + ), +) +async def coda_run( + prompt: str, + email: str, + context: str = "{}", + previous_session_id: str = "", + permissions: str = "smart", + timeout_s: int = 3600, +) -> str: + """Submit a coding task — FIRE AND FORGET. + + Returns IMMEDIATELY with a task_id. The task runs autonomously in the + background. After receiving the response, tell the user the task was + submitted and move on. Do NOT follow up with coda_inbox or coda_get_result + unless the user explicitly asks to check status later. + + ``context`` is a JSON string with Unity Catalog metadata (tables, schemas). + ``previous_session_id`` chains to a prior task's session for context continuity. + ``permissions`` can be ``"smart"`` (default, safe) or ``"yolo"`` (auto-approve all). + + Returns JSON with ``task_id``, ``session_id``, and ``status: "running"``. + """ + try: + # Check concurrency limit + running = task_manager.count_running_tasks() + if running >= task_manager.MAX_CONCURRENT_TASKS: + return json.dumps({ + "status": "error", + "error": f"Concurrency limit reached ({task_manager.MAX_CONCURRENT_TASKS} " + f"tasks running). Try again when a task completes.", + }) + + # Parse context JSON + try: + ctx = json.loads(context) if context else None + except json.JSONDecodeError: + return json.dumps({ + "status": "error", + "error": f"Invalid JSON in context parameter: {context!r}", + }) + + # Auto-create ephemeral session + session_result = task_manager.create_session(email, "", label="hermes-mcp") + session_id = session_result["session_id"] + + # Create PTY if hooks are wired + if _app_create_session is not None: + pty_session_id = _app_create_session(label="hermes-mcp") + task_manager._update_session_field( + session_id, "pty_session_id", pty_session_id + ) + + # Create task with chaining support + result = task_manager.create_task( + session_id=session_id, + prompt=prompt, + email=email, + context=ctx, + timeout_s=timeout_s, + permissions=permissions, + previous_session_id=previous_session_id or None, + ) + task_id = result["task_id"] + + # Send to PTY if hooks are wired + if _app_send_input is not None: + session = task_manager._read_session(session_id) + pty_session_id = session.get("pty_session_id") + if pty_session_id: + # Build hermes command + tdir = task_manager._task_dir(session_id, task_id) + prompt_path = os.path.join(tdir, "prompt.txt") + cmd = f'hermes -z "{prompt_path}"' + if permissions == "yolo": + cmd += " --yolo" + cmd += "\n" + + _app_send_input(pty_session_id, cmd) + + # Start background watcher + t = threading.Thread( + target=_watch_task, + args=(session_id, task_id, timeout_s), + daemon=True, + ) + t.start() + + return json.dumps({ + "task_id": task_id, + "session_id": session_id, + "status": "running", + }) + + except Exception as exc: + return json.dumps({"status": "error", "error": str(exc)}) + + +@mcp.tool( + annotations=ToolAnnotations( + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + ), +) +async def coda_inbox( + email: str = "", + status: str = "", +) -> str: + """Check status of all background tasks — your inbox. + + Call this instead of polling — it returns ALL tasks at once. + No need to track individual task_ids; the inbox shows everything + from the last 24 hours: running, completed, and failed tasks. + + By default returns all tasks. Filter by ``status`` to narrow: + ``"running"`` for in-progress only, ``"completed"`` for finished, + ``"failed"`` for errors, or ``""`` (default) for everything. + + Each task includes: ``task_id``, ``session_id``, ``status``, + ``elapsed_s``, ``prompt_summary`` (first 100 chars of what was asked), + ``previous_session_id`` (if chained from prior work). + Completed tasks also include ``summary`` (what was done). + Running tasks also include ``progress`` (latest agent step). + + Returns JSON with ``tasks`` (list sorted most recent first) + and ``counts`` (e.g. ``{"running": 1, "completed": 2, "failed": 0}``). + """ + try: + tasks = task_manager.list_all_tasks(email=email, status_filter=status) + + counts = {"running": 0, "completed": 0, "failed": 0} + for t in tasks: + s = t.get("status", "") + if s in counts: + counts[s] += 1 + elif s == "done": + counts["completed"] += 1 + elif s == "timeout": + counts["failed"] += 1 + + return json.dumps({"tasks": tasks, "counts": counts}) + except Exception as exc: + return json.dumps({"status": "error", "error": str(exc)}) + + +@mcp.tool( + annotations=ToolAnnotations( + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + ), +) +async def coda_get_result( + task_id: str, + session_id: str, +) -> str: + """Retrieve the structured result of a completed task. + + Call this AFTER coda_inbox shows a task as "completed" or "failed". + + Returns JSON with ``task_id``, ``session_id``, ``status``, ``summary`` + (what was done), ``files_changed`` (list of modified files), + ``artifacts`` (job IDs, commit hashes, etc.), and ``errors`` (if any). + """ + try: + result = task_manager.get_task_result(task_id, session_id) + if result is None: + # No result yet — return current status + status = task_manager.get_task_status(task_id, session_id) + return json.dumps({ + "task_id": task_id, + "session_id": session_id, + "status": status.get("status", "unknown"), + "message": "Result not yet available — task is still in progress.", + }) + + result["task_id"] = task_id + result["session_id"] = session_id + # Ensure standard fields exist + result.setdefault("status", "done") + result.setdefault("summary", "") + result.setdefault("files_changed", []) + result.setdefault("artifacts", []) + result.setdefault("errors", []) + return json.dumps(result) + except Exception as exc: + return json.dumps({"status": "error", "task_id": task_id, "error": str(exc)}) + + +# ── Standalone entry point ────────────────────────────────────────── + +if __name__ == "__main__": + mcp.run() diff --git a/coda_mcp/task_manager.py b/coda_mcp/task_manager.py new file mode 100644 index 0000000..9718638 --- /dev/null +++ b/coda_mcp/task_manager.py @@ -0,0 +1,551 @@ +"""Disk-based state manager for MCP sessions and tasks. + +Pure Python module — no Flask dependency. Just file I/O. + +Layout on disk +-------------- +~/.coda/sessions/{session-id}/ + session.json – session metadata + tasks/{task-id}/ + prompt.txt – wrapped prompt sent to the agent + meta.json – task metadata (email, timestamps, chaining) + status.jsonl – append-only progress log + result.json – final output (written by the agent) +""" + +import json +import os +import secrets +import time +import logging + +logger = logging.getLogger(__name__) + +# ── Root directory (patched in tests) ──────────────────────────────── + +SESSIONS_DIR = os.path.join( + os.environ.get("HOME", "/app/python/source_code"), ".coda", "sessions" +) + +# ── Concurrency limit ─────────────────────────────────────────────── + +MAX_CONCURRENT_TASKS = int(os.environ.get("CODA_MAX_CONCURRENT", "5")) + +# ── Task TTL (seconds) ────────────────────────────────────────────── + +TASK_TTL_S = int(os.environ.get("CODA_TASK_TTL", str(24 * 3600))) # 24h + +# ── Exceptions ─────────────────────────────────────────────────────── + + +class SessionBusyError(Exception): + """Raised when a task is submitted to a session that already has one running.""" + + +class SessionNotFoundError(Exception): + """Raised when the requested session does not exist or is closed.""" + + +class ConcurrencyLimitError(Exception): + """Raised when MAX_CONCURRENT_TASKS running tasks already exist.""" + + +# ── ID generators ──────────────────────────────────────────────────── + + +def _new_session_id() -> str: + return f"sess-{secrets.token_hex(6)}" + + +def _new_task_id() -> str: + return f"task-{secrets.token_hex(4)}" + + +# ── Low-level I/O ──────────────────────────────────────────────────── + + +def _session_dir(session_id: str) -> str: + return os.path.join(SESSIONS_DIR, session_id) + + +def _session_file(session_id: str) -> str: + return os.path.join(_session_dir(session_id), "session.json") + + +def _task_dir(session_id: str, task_id: str) -> str: + """Return the path to a task's directory.""" + return os.path.join(_session_dir(session_id), "tasks", task_id) + + +def _write_json(path: str, data: dict) -> None: + """Atomic write via tmp-then-rename.""" + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp = path + ".tmp" + with open(tmp, "w") as f: + json.dump(data, f, indent=2) + os.replace(tmp, path) + + +def _read_session(session_id: str) -> dict: + """Read session.json or raise SessionNotFoundError.""" + path = _session_file(session_id) + try: + with open(path) as f: + return json.load(f) + except (OSError, json.JSONDecodeError): + raise SessionNotFoundError(f"Session {session_id} not found or corrupt") + + +def _update_session_field(session_id: str, key: str, value) -> None: + """Update a single field in session.json (read-modify-write).""" + data = _read_session(session_id) + data[key] = value + _write_json(_session_file(session_id), data) + + +# ── Session lifecycle ──────────────────────────────────────────────── + + +def create_session(email: str, user_id: str, label: str = "") -> dict: + """Create a new session directory with session.json. + + Returns ``{"session_id": "sess-…", "status": "ready"}``. + """ + session_id = _new_session_id() + data = { + "session_id": session_id, + "email": email, + "user_id": user_id, + "label": label, + "status": "ready", + "current_task": None, + "completed_tasks": [], + "created_at": time.time(), + } + _write_json(_session_file(session_id), data) + logger.info("Created session %s for %s", session_id, email) + return {"session_id": session_id, "status": "ready"} + + +def close_session(session_id: str) -> None: + """Mark a session as closed. Raises SessionNotFoundError if missing.""" + _read_session(session_id) # existence check + _update_session_field(session_id, "status", "closed") + logger.info("Closed session %s", session_id) + + +# ── Prompt wrapping ────────────────────────────────────────────────── + + +def wrap_prompt( + task_id: str, + session_id: str, + email: str, + prompt: str, + context: dict | None, + results_dir: str, + context_hint: str | None = None, + previous_session_id: str | None = None, +) -> str: + """Build the full prompt string written to ``prompt.txt``. + + Uses the ``---CODA-TASK---`` envelope convention so the agent can + parse metadata from the prompt deterministically. + """ + context_block = "" + if context: + context_block = f"\nCONTEXT:\n{json.dumps(context, indent=2)}\n" + + hint_line = "" + if context_hint: + hint_line = f"context_hint: {context_hint}\n" + + prior_session_block = "" + if previous_session_id: + prior_dir = _session_dir(previous_session_id) + prior_session_block = ( + f"\nPRIOR SESSION: {previous_session_id}\n" + f"Read {prior_dir}/tasks/*/result.json for context on prior work.\n" + ) + + return ( + f"---CODA-TASK---\n" + f"task_id: {task_id}\n" + f"session_id: {session_id}\n" + f"user: {email}\n" + f"{hint_line}" + f"{prior_session_block}" + f"{context_block}\n" + f"TASK:\n" + f"{prompt}\n" + f"\n" + f"INSTRUCTIONS:\n" + f"1. As you work, append progress lines to {results_dir}/status.jsonl\n" + f' Each line must be valid JSON: {{"step": "label", "message": "what you are doing"}}\n' + f"\n" + f"2. When you are COMPLETELY DONE, write a SINGLE FILE at this exact path:\n" + f" {results_dir}/result.json\n" + f" It must contain this JSON structure:\n" + f" {{\n" + f' "status": "completed",\n' + f' "summary": "one paragraph describing what you did",\n' + f' "files_changed": ["list", "of", "file", "paths"],\n' + f' "artifacts": {{}},\n' + f' "errors": []\n' + f" }}\n" + f" If you failed, set status to \"failed\" and describe the error.\n" + f" IMPORTANT: result.json is a FILE not a directory. Write it with:\n" + f" echo '{{...}}' > {results_dir}/result.json\n" + f"\n" + f"3. If you delegate to a sub-agent, update status.jsonl with delegation steps.\n" + f"\n" + f"SAFETY:\n" + f"- Do NOT delete, drop, or truncate tables, schemas, catalogs, or volumes.\n" + f"- Do NOT delete files outside the current project directory.\n" + f"- Do NOT run destructive Databricks CLI commands (e.g. databricks clusters delete, " + f"databricks jobs delete, databricks pipelines delete).\n" + f"- Do NOT modify permissions, grants, or access controls unless explicitly requested.\n" + f"- Prefer CREATE OR REPLACE over DROP+CREATE. Prefer INSERT/MERGE over DELETE+INSERT.\n" + f"- If the task requires a destructive operation, describe what you would do in " + f"result.json with status \"needs_approval\" instead of executing it.\n" + f"---END-CODA-TASK---" + ) + + +# ── Task lifecycle ─────────────────────────────────────────────────── + + +def create_task( + session_id: str, + prompt: str, + email: str, + context: dict | None = None, + context_hint: str | None = None, + timeout_s: int | None = None, + permissions: str | None = None, + previous_session_id: str | None = None, +) -> dict: + """Create a task inside an existing session. + + Raises + ------ + SessionNotFoundError + If the session does not exist or is closed. + SessionBusyError + If the session already has a running task. + + Returns ``{"task_id": "task-…", "status": "running"}``. + """ + session = _read_session(session_id) + + if session.get("status") == "closed": + raise SessionNotFoundError(f"Session {session_id} is closed") + + if session.get("status") == "busy": + raise SessionBusyError( + f"Session {session_id} already has a running task: " + f"{session.get('current_task')}" + ) + + task_id = _new_task_id() + tdir = _task_dir(session_id, task_id) + os.makedirs(tdir, exist_ok=True) + + # Write wrapped prompt + results_dir = os.path.join(tdir, "results") + wrapped = wrap_prompt( + task_id=task_id, + session_id=session_id, + email=email, + prompt=prompt, + context=context, + results_dir=results_dir, + context_hint=context_hint, + previous_session_id=previous_session_id, + ) + with open(os.path.join(tdir, "prompt.txt"), "w") as f: + f.write(wrapped) + + # Write meta.json for inbox scanning + now = time.time() + meta = { + "email": email, + "created_at": now, + "previous_session_id": previous_session_id or "", + "permissions": permissions or "smart", + "timeout_s": timeout_s or 3600, + "prompt_summary": prompt[:100], + } + _write_json(os.path.join(tdir, "meta.json"), meta) + + # Seed status log + with open(os.path.join(tdir, "status.jsonl"), "w") as f: + f.write(json.dumps({"status": "running", "ts": now}) + "\n") + + # Mark session busy + data = _read_session(session_id) + data["status"] = "busy" + data["current_task"] = task_id + _write_json(_session_file(session_id), data) + + logger.info("Created task %s in session %s", task_id, session_id) + return {"task_id": task_id, "status": "running"} + + +# ── Task queries ───────────────────────────────────────────────────── + + +def get_task_status(task_id: str, session_id: str) -> dict: + """Read the last line of status.jsonl for the task. + + Returns ``{"status": "not_found"}`` if the task directory is missing. + """ + status_path = os.path.join(_task_dir(session_id, task_id), "status.jsonl") + try: + last = None + with open(status_path) as f: + for line in f: + line = line.strip() + if line: + last = json.loads(line) + return last or {"status": "not_found"} + except (OSError, json.JSONDecodeError): + return {"status": "not_found"} + + +def _find_result_json(task_dir: str) -> str | None: + """Find result.json — agents may write it at root or in results/ subdir.""" + for candidate in [ + os.path.join(task_dir, "result.json"), + os.path.join(task_dir, "results", "result.json"), + ]: + if os.path.isfile(candidate): + return candidate + return None + + +def get_task_result(task_id: str, session_id: str) -> dict | None: + """Read result.json if it exists; otherwise return None.""" + result_path = _find_result_json(_task_dir(session_id, task_id)) + if not result_path: + return None + try: + with open(result_path) as f: + return json.load(f) + except (OSError, json.JSONDecodeError): + return None + + +# ── Task completion ────────────────────────────────────────────────── + + +def complete_task(session_id: str, task_id: str) -> None: + """Mark a task as done and auto-close the session. + + Appends a ``done`` entry to status.jsonl, adds task_id to + ``completed_tasks``, and closes the session (v2: ephemeral sessions). + """ + session = _read_session(session_id) + + # Append done to status log + status_path = os.path.join(_task_dir(session_id, task_id), "status.jsonl") + with open(status_path, "a") as f: + f.write(json.dumps({"status": "done", "ts": time.time()}) + "\n") + + # Update session — auto-close (v2: sessions are ephemeral) + session["status"] = "closed" + session["current_task"] = None + session["closed_at"] = time.time() + if task_id not in session["completed_tasks"]: + session["completed_tasks"].append(task_id) + _write_json(_session_file(session_id), session) + + logger.info("Completed task %s in session %s (auto-closed)", task_id, session_id) + + +# ── Inbox: list all tasks across sessions ─────────────────────────── + + +def list_all_tasks(email: str = "", status_filter: str = "") -> list[dict]: + """Scan all sessions and return a flat list of tasks for the inbox. + + Returns tasks from the last ``TASK_TTL_S`` seconds, sorted most recent first. + Each entry includes task_id, session_id, status, elapsed_s, prompt_summary, + summary (if completed), progress (if running), previous_session_id, created_at. + """ + now = time.time() + cutoff = now - TASK_TTL_S + tasks = [] + + if not os.path.isdir(SESSIONS_DIR): + return tasks + + for sess_name in os.listdir(SESSIONS_DIR): + sess_dir = os.path.join(SESSIONS_DIR, sess_name) + if not os.path.isdir(sess_dir): + continue + + tasks_dir = os.path.join(sess_dir, "tasks") + if not os.path.isdir(tasks_dir): + continue + + for task_name in os.listdir(tasks_dir): + task_dir = os.path.join(tasks_dir, task_name) + if not os.path.isdir(task_dir): + continue + + # Read meta.json + meta_path = os.path.join(task_dir, "meta.json") + try: + with open(meta_path) as f: + meta = json.load(f) + except (OSError, json.JSONDecodeError): + # Legacy task without meta.json — skip or build minimal entry + meta = {} + + created_at = meta.get("created_at", 0) + if created_at < cutoff: + continue + + # Filter by email + if email and meta.get("email", "") != email: + continue + + # Determine task status from status.jsonl + task_status = _read_last_status(task_dir) + + # Check for result.json to determine completion + result_path = _find_result_json(task_dir) + summary = "" + if result_path: + try: + with open(result_path) as f: + result_data = json.load(f) + task_status = result_data.get("status", "completed") + summary = result_data.get("summary", "") + except (OSError, json.JSONDecodeError): + pass + + # Filter by status + if status_filter and task_status != status_filter: + continue + + # Get progress for running tasks + progress = "" + if task_status == "running": + progress = _read_last_progress(task_dir) + + elapsed_s = round(now - created_at, 1) + + entry = { + "task_id": task_name, + "session_id": sess_name, + "status": task_status, + "elapsed_s": elapsed_s, + "prompt_summary": meta.get("prompt_summary", ""), + "previous_session_id": meta.get("previous_session_id", ""), + "created_at": created_at, + } + if summary: + entry["summary"] = summary + if progress: + entry["progress"] = progress + + tasks.append(entry) + + # Sort most recent first + tasks.sort(key=lambda t: t["created_at"], reverse=True) + return tasks + + +def _read_last_status(task_dir: str) -> str: + """Read the last status from status.jsonl.""" + status_path = os.path.join(task_dir, "status.jsonl") + try: + last = None + with open(status_path) as f: + for line in f: + line = line.strip() + if line: + last = json.loads(line) + return (last or {}).get("status", "unknown") + except (OSError, json.JSONDecodeError): + return "unknown" + + +def _read_last_progress(task_dir: str) -> str: + """Read the last progress message from status.jsonl.""" + status_path = os.path.join(task_dir, "status.jsonl") + try: + last = None + with open(status_path) as f: + for line in f: + line = line.strip() + if line: + last = json.loads(line) + return (last or {}).get("message", "") + except (OSError, json.JSONDecodeError): + return "" + + +# ── Concurrency check ────────────────────────────────────────────── + + +def count_running_tasks() -> int: + """Count tasks currently in 'running' state across all sessions.""" + count = 0 + if not os.path.isdir(SESSIONS_DIR): + return count + + for sess_name in os.listdir(SESSIONS_DIR): + sess_file = os.path.join(SESSIONS_DIR, sess_name, "session.json") + try: + with open(sess_file) as f: + session = json.load(f) + if session.get("status") == "busy": + count += 1 + except (OSError, json.JSONDecodeError): + continue + return count + + +# ── Cleanup expired sessions ──────────────────────────────────────── + + +def cleanup_expired_tasks() -> int: + """Remove session directories older than TASK_TTL_S. Returns count removed.""" + import shutil + + now = time.time() + cutoff = now - TASK_TTL_S + removed = 0 + + if not os.path.isdir(SESSIONS_DIR): + return removed + + for sess_name in os.listdir(SESSIONS_DIR): + sess_dir = os.path.join(SESSIONS_DIR, sess_name) + if not os.path.isdir(sess_dir): + continue + + sess_file = os.path.join(sess_dir, "session.json") + try: + with open(sess_file) as f: + session = json.load(f) + except (OSError, json.JSONDecodeError): + continue + + # Only clean closed sessions past TTL + if session.get("status") != "closed": + continue + + closed_at = session.get("closed_at", session.get("created_at", 0)) + if closed_at < cutoff: + try: + shutil.rmtree(sess_dir) + removed += 1 + logger.info("Cleaned up expired session %s", sess_name) + except OSError: + logger.warning("Failed to clean up session %s", sess_name) + + return removed diff --git a/docs/mcp-client-setup.md b/docs/mcp-client-setup.md new file mode 100644 index 0000000..f8e1bb6 --- /dev/null +++ b/docs/mcp-client-setup.md @@ -0,0 +1,73 @@ +# CoDA MCP Client Setup + +CoDA exposes an MCP endpoint at `/mcp` on the Databricks App. Databricks Apps use OAuth (not PATs) for authentication, so MCP clients need a stdio bridge that injects fresh OAuth tokens. + +## How it works + +`tools/coda-bridge.py` is a zero-dependency Python script that: + +1. Claude Code launches it as a stdio MCP server +2. It reads JSON-RPC messages from stdin +3. Fetches a fresh OAuth token via `databricks auth token` +4. Forwards requests to the App's HTTP endpoint with the token +5. Returns responses on stdout + +Tokens are cached for 30 minutes (they expire after 60). + +## Setup + +### 1. Copy the bridge script + +```bash +mkdir -p ~/.claude/mcp-bridges +cp tools/coda-bridge.py ~/.claude/mcp-bridges/ +``` + +### 2. Add to Claude Code settings + +Add this to `mcpServers` in `~/.claude/settings.json`: + +```json +"coda-mcp": { + "type": "stdio", + "command": "python3", + "args": ["/path/to/.claude/mcp-bridges/coda-bridge.py"], + "env": { + "CODA_MCP_URL": "https://.databricksapps.com/mcp", + "DATABRICKS_PROFILE": "" + } +} +``` + +### 3. Restart Claude Code + +The MCP server will start automatically on next session. + +## Configuration + +| Environment Variable | Description | Example | +|---------------------|-------------|---------| +| `CODA_MCP_URL` | Full URL to the app's `/mcp` endpoint | `https://mcp-test-coda-747...com/mcp` | +| `DATABRICKS_PROFILE` | Databricks CLI profile name | `9cefok` | + +## Prerequisites + +- `databricks` CLI installed and authenticated (`databricks auth login -p `) +- Python 3.8+ +- No pip dependencies required (stdlib only) + +## Troubleshooting + +Bridge logs go to stderr. Check with: + +```bash +CODA_MCP_URL="https://your-app.databricksapps.com/mcp" \ +DATABRICKS_PROFILE="your-profile" \ +echo '{"jsonrpc":"2.0","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}},"id":1}' | python3 tools/coda-bridge.py +``` + +If you see `Auth failed (302)`, your Databricks CLI session may have expired. Run: + +```bash +databricks auth login -p +``` diff --git a/docs/mcp-v2-background-execution.md b/docs/mcp-v2-background-execution.md new file mode 100644 index 0000000..3d7557c --- /dev/null +++ b/docs/mcp-v2-background-execution.md @@ -0,0 +1,171 @@ +# CoDA MCP v2 — Background Execution + Inbox Pattern + +## Overview + +CoDA exposes 3 MCP tools so Databricks GenieCode (or any MCP client) can delegate +coding tasks to AI agents running in the background. GenieCode's chat context stays +free while tasks execute — no polling required. + +## Tools + +| Tool | Purpose | +|------|---------| +| `coda_run` | Fire-and-forget task submission | +| `coda_inbox` | Dashboard of all background tasks | +| `coda_get_result` | Pull full structured result | + +## Flow Diagram + +``` +┌─────────────┐ ┌──────────────┐ ┌─────────────┐ +│ GenieCode │ │ CoDA MCP │ │ Hermes │ +│ (caller) │ │ (3 tools) │ │ (executor) │ +└──────┬──────┘ └──────┬───────┘ └──────┬──────┘ + │ │ │ + │ 1. coda_run(prompt) │ │ + │──────────────────────>│ │ + │ │ auto-create session │ + │ │ + PTY + task dir │ + │ │ write prompt.txt │ + │ │ write meta.json │ + │ │ │ + │ {task_id, sess_id, │ hermes -z prompt.txt │ + │ status: "running"} │───────────────────────>│ + │<──────────────────────│ │ + │ │ _watch_task thread │ + │ ✓ context is FREE │ monitors result.json │ + │ user keeps chatting │ │ + │ │ │ works... + │ ... │ │ delegates + │ │ │ to claude/ + │ │ │ codex/gemini + │ │ │ + │ 2. coda_inbox() │ │ writes + │──────────────────────>│ │ status.jsonl + │ │ scan all sessions │ + │ {tasks: [...], │ read meta + status │ + │ counts: {run:1}} │ │ + │<──────────────────────│ │ + │ │ │ + │ ... │ │ writes + │ │ │ result.json + │ │ │ + │ │ _watch_task detects │ + │ │ result.json exists │ + │ │ → complete_task() │ + │ │ → auto-close session │ + │ │ → free PTY │ + │ │ │ + │ 3. coda_inbox() │ │ + │──────────────────────>│ │ + │ {tasks: [{status: │ │ + │ "completed", │ │ + │ summary: "..."}]} │ │ + │<──────────────────────│ │ + │ │ │ + │ 4. coda_get_result() │ │ + │──────────────────────>│ │ + │ {summary, files, │ read result.json │ + │ artifacts, errors} │ │ + │<──────────────────────│ │ + │ │ │ + ├── CHAINING ───────────┤ │ + │ │ │ + │ 5. coda_run(prompt, │ │ + │ previous_session_id) │ new session + PTY │ + │──────────────────────>│ inject PRIOR SESSION │ + │ │ block in prompt │ + │ {new task_id, │───────────────────────>│ + │ new sess_id} │ │ reads prior + │<──────────────────────│ │ result.json + │ │ │ for context +``` + +## Key Design Decisions + +### Sessions are ephemeral, tasks are persistent +- Session = PTY + Hermes instance. Auto-closes when task completes. +- Task state (prompt, status, result) persists on disk for 24 hours. +- Continuity via `previous_session_id`, not long-lived sessions. + +### No polling from GenieCode +- `coda_inbox` replaces `coda_get_status` — shows ALL tasks at once. +- GenieCode checks when the user asks, not on a timer. +- CoDA's internal `_watch_task` thread polls the filesystem (invisible to caller). + +### Task chaining +- `previous_session_id` points to a prior session's disk state. +- Hermes reads `~/.coda/sessions/{prev_id}/tasks/*/result.json` for context. +- Chain depth: one level. Hermes can walk deeper if needed. + +### Concurrency +- `CODA_MAX_CONCURRENT` env var (default: 5). +- Each task gets its own session — no "session busy" errors. +- Exceeding the limit returns a clear error. + +## Data Model + +``` +~/.coda/sessions/{session-id}/ + session.json # metadata + auto-close timestamp + tasks/{task-id}/ + prompt.txt # wrapped prompt sent to Hermes + meta.json # {email, created_at, previous_session_id, permissions} + status.jsonl # append-only progress log + result.json # final structured output +``` + +## Tool Reference + +### `coda_run` + +```python +coda_run( + prompt: str, # what to do + email: str, # who's asking + context: str = "{}", # UC metadata (tables, schemas) + previous_session_id: str = "", # chain from prior work + permissions: str = "smart", # "smart" or "yolo" + timeout_s: int = 3600, # max 1 hour default +) +# Returns: {"task_id", "session_id", "status": "running"} +``` + +### `coda_inbox` + +```python +coda_inbox( + email: str = "", # filter by user + status: str = "", # "running", "completed", "failed", or "" for all +) +# Returns: {"tasks": [...], "counts": {"running": N, "completed": N, "failed": N}} +``` + +Each task entry: `task_id`, `session_id`, `status`, `elapsed_s`, `prompt_summary`, +`summary` (completed), `progress` (running), `previous_session_id`, `created_at`. + +### `coda_get_result` + +```python +coda_get_result(task_id: str, session_id: str) +# Returns: {"task_id", "session_id", "status", "summary", +# "files_changed", "artifacts", "errors"} +``` + +## Migration from v1 + +| v1 Tool | v2 Equivalent | +|---------|--------------| +| `coda_create_session` | Removed — auto-created by `coda_run` | +| `coda_run_task` | `coda_run` (simplified, auto-session) | +| `coda_get_status` | `coda_inbox` (all tasks at once) | +| `coda_get_result` | `coda_get_result` (unchanged) | +| `coda_close_session` | Removed — auto-closed on completion | + +## Limitations + +- **Ephemeral filesystem**: On Databricks Apps, `~/.coda/` is local disk. App + redeployment wipes task state. Real artifacts (git commits, jobs, workspace files) + are unaffected. +- **No push notifications**: GenieCode must call `coda_inbox` to discover completions. + SSE/streaming is a future consideration if polling proves insufficient. diff --git a/install_databricks_cli.sh b/scripts/install_databricks_cli.sh similarity index 100% rename from install_databricks_cli.sh rename to scripts/install_databricks_cli.sh diff --git a/install_gh.sh b/scripts/install_gh.sh similarity index 100% rename from install_gh.sh rename to scripts/install_gh.sh diff --git a/install_micro.sh b/scripts/install_micro.sh similarity index 100% rename from install_micro.sh rename to scripts/install_micro.sh diff --git a/setup_claude.py b/setup/setup_claude.py similarity index 90% rename from setup_claude.py rename to setup/setup_claude.py index 9815ef5..a28734e 100644 --- a/setup_claude.py +++ b/setup/setup_claude.py @@ -90,13 +90,17 @@ local_bin = home / ".local" / "bin" claude_bin = local_bin / "claude" -print("Installing/upgrading Claude Code CLI...") -result = subprocess.run( - ["bash", "-c", "curl -fsSL https://claude.ai/install.sh | bash"], - env={**os.environ, "HOME": str(home)}, - capture_output=True, - text=True -) +if os.environ.get("SKIP_CLAUDE_INSTALL"): + print("SKIP_CLAUDE_INSTALL set — skipping CLI install") + result = type("R", (), {"returncode": 0, "stderr": ""})() +else: + print("Installing/upgrading Claude Code CLI...") + result = subprocess.run( + ["bash", "-c", "curl -fsSL https://claude.ai/install.sh | bash"], + env={**os.environ, "HOME": str(home)}, + capture_output=True, + text=True + ) if result.returncode == 0: print("Claude Code CLI installed successfully") else: diff --git a/setup_codex.py b/setup/setup_codex.py similarity index 100% rename from setup_codex.py rename to setup/setup_codex.py diff --git a/setup_databricks.py b/setup/setup_databricks.py similarity index 100% rename from setup_databricks.py rename to setup/setup_databricks.py diff --git a/setup_gemini.py b/setup/setup_gemini.py similarity index 100% rename from setup_gemini.py rename to setup/setup_gemini.py diff --git a/setup_hermes.py b/setup/setup_hermes.py similarity index 52% rename from setup_hermes.py rename to setup/setup_hermes.py index 07bb030..4f56aaf 100644 --- a/setup_hermes.py +++ b/setup/setup_hermes.py @@ -216,6 +216,172 @@ def _run(cmd, **kwargs): cli_name="Hermes", ) +# 5b. Append CoDA orchestrator instructions to HERMES.md +CODA_ORCHESTRATOR_INSTRUCTIONS = """ + +## CoDA Constitution (NON-NEGOTIABLE) + +This is the single most important rule. It applies to you AND every sub-agent you delegate to. + +**NO DESTRUCTIVE ACTIONS on pre-existing assets.** Specifically: +- **NEVER delete** files, tables, jobs, notebooks, pipelines, or any resource that was NOT + created during the current session — unless you have EXPLICIT confirmation from the user + or upstream caller. +- **NEVER drop** database tables, schemas, or catalogs that existed before the task started. +- **NEVER overwrite** existing files without confirmation if the content would be lost. +- **NEVER run** destructive CLI commands (`rm -rf`, `databricks jobs delete`, `DROP TABLE`, etc.) + on assets you didn't create. + +**What IS allowed without confirmation:** +- Creating new files, tables, jobs, pipelines, notebooks — building is always OK. +- Modifying files you created during the session. +- Deleting temporary files or artifacts you created during the session. +- Iterating on work in progress — edit, refactor, rebuild freely. +- Overwriting files you created in this session. + +**When in doubt:** Report back to the upstream caller (Genie Code or the user) describing +what you want to delete and why, and ask for confirmation before proceeding. This applies +to you directly AND to any sub-agent you delegate to — pass this rule in every delegation prompt. + +## CoDA Orchestrator Role + +You are Hermes, the primary orchestrator inside **CoDA** (Coding Agents on Databricks Apps). +You are not just a chat assistant — you are the brain that receives tasks and decides how +to execute them, either directly or by delegating to specialized sub-agents. + +### Your Environment + +- You are running inside a Databricks App with full workspace access. +- The Databricks CLI is pre-configured: `databricks` commands work out of the box. +- Unity Catalog, Jobs, Workflows, Notebooks, MLflow — all accessible. +- Projects live at `~/projects/` and sync to `/Workspace/Users/{email}/` on git commit. +- You have 39 Databricks and workflow skills available. + +### Prior Session Context + +When your prompt includes a `PRIOR SESSION:` block, it means this task continues +work from a previous session. The prior session's results are stored on disk: + +``` +~/.coda/sessions/{previous_session_id}/tasks/*/result.json +``` + +**Read those result files** to understand what was done before. Each result.json contains: +- `summary` — what the prior task accomplished +- `files_changed` — which files were created or modified +- `artifacts` — job IDs, commit hashes, dashboard URLs, etc. + +Use this context to continue the work without asking the user to repeat themselves. + +### Sub-Agents Available + +You have three coding agents you can delegate work to. Choose the best one for each subtask: + +**Claude Code** — Deep work, complex implementations, orchestration +```bash +claude -p "your prompt here" --allowedTools "Read,Edit,Bash" --max-turns 50 +``` +- Best for: multi-step implementations, planning, debugging, code review +- Can spawn teams: assign roles, goals, and backstory to parallel workers +- Has access to all 39 skills (Databricks + workflow) +- Use `--max-turns` to bound execution, `--max-budget-usd` for cost control + +**Codex** — Fast edits, refactoring, structured transforms +```bash +codex -q "your prompt here" +``` +- Best for: quick code changes, targeted refactors, code review +- Lightweight and fast — use when the task is well-scoped + +**Gemini** — Research, documentation, large-context analysis +```bash +gemini -p "your prompt here" +``` +- Best for: broad codebase analysis, documentation generation, research tasks +- Large context window — good for understanding big codebases + +### How to Delegate + +1. **Assess the task.** Is it something you can handle directly, or does it need a specialist? +2. **Pick the right agent.** Match the task to the agent's strengths (see above). +3. **Be specific.** Give the sub-agent a clear, self-contained prompt with all context it needs. +4. **Collect results.** Read the sub-agent's output and incorporate it into your response. +5. **Chain when needed.** Plan with Claude, implement with Codex, review with Gemini. + +### For Complex Tasks — Use Claude Code Teams + +When a task is large enough to benefit from parallel work, use Claude Code's team capability: +```bash +claude -p "Create a team of 3 agents to: [task]. Agent 1 handles [X], Agent 2 handles [Y], Agent 3 handles [Z]. Coordinate and merge results." --allowedTools "Read,Edit,Bash" --max-turns 100 +``` + +### Ephemeral Session Model + +Each task runs in its own short-lived session. When the task completes, the session closes +automatically. You will NOT receive follow-up tasks in the same session. + +**What this means for you:** +- **Be self-contained.** Complete the entire task in one go — there is no "next message." +- **Read prior context if provided.** If the prompt has a `PRIOR SESSION:` block, read + those result files to understand what was done before. This is how task chaining works. +- **Write thorough results.** Your `result.json` is the only thing the next task (or the + user) will see. Include a clear summary, all files changed, and any artifacts created. +- **Don't rely on in-memory state.** Anything you want to persist must go to disk — + either in the result files, git commits, or the workspace. + +### Single-User Mode + +You are operating in **single-user mode**. Every task comes from the same person — the app owner. +This means: + +- **Learn their patterns.** Pay attention to how they work, what tools they prefer, what + coding style they use, and what kind of tasks they send. +- **Remember across tasks.** If they always work with certain tables, frameworks, or patterns, + carry that knowledge forward. Use your memory system to persist insights. +- **Be proactive.** If you notice patterns, suggest improvements: + - "I've noticed you frequently create similar pipelines — want me to template this?" + - "Based on your last 3 tasks, you might want to consider..." + - "This task is similar to what you asked last time. Should I reuse that approach?" +- **Adapt your communication style.** Match their level of detail preference, verbosity, + and technical depth. Some users want terse results, others want explanations. +- **Build a profile over time.** Track their preferred tools, common workflows, recurring + patterns, and pain points. The longer you work together, the better you should get. + +### Task Protocol (CODA-TASK Convention) + +When you receive a task wrapped in `---CODA-TASK---` markers, follow this protocol: + +1. **Read the envelope.** Extract task_id, session_id, user, context, and the actual task. +2. **Write progress.** As you work, append lines to `{results_dir}/status.jsonl`: + ```json + {"step": "planning", "message": "Analyzing task requirements"} + {"step": "delegating", "message": "Sending implementation to Claude Code"} + {"step": "complete", "message": "Pipeline created successfully"} + ``` +3. **Write result.** When done, write `{results_dir}/result.json`: + ```json + { + "status": "completed", + "summary": "One paragraph of what was done", + "files_changed": ["path/to/file1.py"], + "artifacts": {"job_id": "123", "commit": "abc123"}, + "errors": [] + } + ``` + IMPORTANT: `result.json` must be a FILE, not a directory. + +4. **If you delegate,** update `status.jsonl` with delegation steps so the caller can track + which sub-agent is doing what. +""" + +if hermes_md.exists(): + existing_content = hermes_md.read_text() + if "CoDA Orchestrator Role" not in existing_content: + hermes_md.write_text(existing_content + CODA_ORCHESTRATOR_INSTRUCTIONS) + print("CoDA orchestrator instructions appended to HERMES.md") + else: + print("CoDA orchestrator instructions already present in HERMES.md") + # 6. Create projects directory (parity with other agents) projects_dir = home / "projects" projects_dir.mkdir(exist_ok=True) diff --git a/setup_mlflow.py b/setup/setup_mlflow.py similarity index 100% rename from setup_mlflow.py rename to setup/setup_mlflow.py diff --git a/setup_opencode.py b/setup/setup_opencode.py similarity index 100% rename from setup_opencode.py rename to setup/setup_opencode.py diff --git a/setup_proxy.py b/setup/setup_proxy.py similarity index 100% rename from setup_proxy.py rename to setup/setup_proxy.py diff --git a/static/index.html b/static/index.html index c1f53fa..c986aa9 100644 --- a/static/index.html +++ b/static/index.html @@ -955,7 +955,10 @@

General

return; } - socket = io({ transports: ['websocket', 'polling'] }); + // Start with polling (HTTP) so Databricks proxy identity headers are present + // for auth, then upgrade to WebSocket transparently. Direct WebSocket-first + // fails because the proxy doesn't inject X-Forwarded-Email on WS upgrade. + socket = io({ transports: ['polling', 'websocket'] }); socket.on('connect', () => { // Check actual transport — Socket.IO reports connected=true even on long-polling diff --git a/tests/test_content_filter_proxy.py b/tests/test_content_filter_proxy.py new file mode 100644 index 0000000..4aad029 --- /dev/null +++ b/tests/test_content_filter_proxy.py @@ -0,0 +1,556 @@ +"""Tests for content_filter_proxy — request/response sanitization for OpenCode.""" + +import json +import time + +import pytest +from unittest import mock + + +# --------------------------------------------------------------------------- +# strip_unsupported_schema_keys +# --------------------------------------------------------------------------- + +class TestStripUnsupportedSchemaKeys: + def test_strips_top_level_keys(self): + from content_filter_proxy import strip_unsupported_schema_keys + obj = {"type": "object", "$schema": "http://...", "additionalProperties": False, "title": "Foo"} + result = strip_unsupported_schema_keys(obj) + assert result == {"type": "object", "title": "Foo"} + + def test_strips_nested_keys(self): + from content_filter_proxy import strip_unsupported_schema_keys + obj = { + "type": "object", + "properties": { + "name": {"type": "string", "$ref": "#/defs/Name", "$comment": "ignore"}, + }, + } + result = strip_unsupported_schema_keys(obj) + assert result == { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + } + + def test_strips_inside_lists(self): + from content_filter_proxy import strip_unsupported_schema_keys + obj = [{"$id": "x", "type": "string"}, {"type": "int"}] + result = strip_unsupported_schema_keys(obj) + assert result == [{"type": "string"}, {"type": "int"}] + + def test_passes_through_primitives(self): + from content_filter_proxy import strip_unsupported_schema_keys + assert strip_unsupported_schema_keys("hello") == "hello" + assert strip_unsupported_schema_keys(42) == 42 + assert strip_unsupported_schema_keys(None) is None + + +# --------------------------------------------------------------------------- +# sanitize_tool_schemas +# --------------------------------------------------------------------------- + +class TestSanitizeToolSchemas: + def test_cleans_tool_parameters(self): + from content_filter_proxy import sanitize_tool_schemas + data = { + "tools": [ + {"function": {"name": "foo", "parameters": {"$schema": "x", "type": "object"}}}, + ], + } + result = sanitize_tool_schemas(data) + assert result["tools"][0]["function"]["parameters"] == {"type": "object"} + + def test_strips_top_level_request_keys(self): + from content_filter_proxy import sanitize_tool_schemas + data = { + "tools": [{"function": {"name": "foo", "parameters": {"type": "object"}}}], + "stream_options": {"include_usage": True}, + "$schema": "x", + } + result = sanitize_tool_schemas(data) + assert "stream_options" not in result + assert "$schema" not in result + + def test_no_tools_is_noop(self): + from content_filter_proxy import sanitize_tool_schemas + data = {"messages": [{"role": "user", "content": "hi"}]} + result = sanitize_tool_schemas(data) + assert result == data + + +# --------------------------------------------------------------------------- +# _extract_tool_ids_from_message +# --------------------------------------------------------------------------- + +class TestExtractToolIds: + def test_anthropic_format(self): + from content_filter_proxy import _extract_tool_ids_from_message + msg = { + "role": "assistant", + "content": [ + {"type": "tool_use", "id": "tu_1", "name": "bash"}, + {"type": "text", "text": "running..."}, + {"type": "tool_use", "id": "tu_2", "name": "read"}, + ], + } + assert _extract_tool_ids_from_message(msg) == {"tu_1", "tu_2"} + + def test_openai_format(self): + from content_filter_proxy import _extract_tool_ids_from_message + msg = { + "role": "assistant", + "tool_calls": [ + {"id": "tc_1", "function": {"name": "bash"}}, + {"id": "tc_2", "function": {"name": "read"}}, + ], + } + assert _extract_tool_ids_from_message(msg) == {"tc_1", "tc_2"} + + def test_no_tools(self): + from content_filter_proxy import _extract_tool_ids_from_message + msg = {"role": "assistant", "content": "hello"} + assert _extract_tool_ids_from_message(msg) == set() + + +# --------------------------------------------------------------------------- +# _extract_tool_refs_from_message +# --------------------------------------------------------------------------- + +class TestExtractToolRefs: + def test_anthropic_tool_result(self): + from content_filter_proxy import _extract_tool_refs_from_message + msg = { + "role": "user", + "content": [ + {"type": "tool_result", "tool_use_id": "tu_1", "content": "ok"}, + ], + } + assert _extract_tool_refs_from_message(msg) == {"tu_1"} + + def test_openai_tool_message(self): + from content_filter_proxy import _extract_tool_refs_from_message + msg = {"role": "tool", "tool_call_id": "tc_1", "content": "result"} + assert _extract_tool_refs_from_message(msg) == {"tc_1"} + + def test_no_refs(self): + from content_filter_proxy import _extract_tool_refs_from_message + msg = {"role": "user", "content": "hi"} + assert _extract_tool_refs_from_message(msg) == set() + + +# --------------------------------------------------------------------------- +# sanitize_messages — the big one +# --------------------------------------------------------------------------- + +class TestSanitizeMessages: + def test_strips_empty_text_blocks(self): + from content_filter_proxy import sanitize_messages + messages = [ + {"role": "user", "content": [ + {"type": "text", "text": "hello"}, + {"type": "text", "text": ""}, + {"type": "text", "text": " "}, + ]}, + ] + result = sanitize_messages(messages) + assert len(result) == 1 + assert len(result[0]["content"]) == 1 + assert result[0]["content"][0]["text"] == "hello" + + def test_strips_orphaned_tool_result_anthropic(self): + """tool_result referencing a tool_use ID that doesn't exist in prev assistant msg.""" + from content_filter_proxy import sanitize_messages + messages = [ + {"role": "assistant", "content": [ + {"type": "tool_use", "id": "tu_1", "name": "bash"}, + ]}, + {"role": "user", "content": [ + {"type": "tool_result", "tool_use_id": "tu_1", "content": "ok"}, + {"type": "tool_result", "tool_use_id": "tu_ORPHAN", "content": "stale"}, + ]}, + ] + result = sanitize_messages(messages) + assert len(result) == 2 + # Only tu_1 should survive + user_blocks = result[1]["content"] + assert len(user_blocks) == 1 + assert user_blocks[0]["tool_use_id"] == "tu_1" + + def test_strips_orphaned_openai_tool_message(self): + from content_filter_proxy import sanitize_messages + messages = [ + {"role": "assistant", "tool_calls": [{"id": "tc_1", "function": {"name": "bash"}}]}, + {"role": "tool", "tool_call_id": "tc_1", "content": "ok"}, + {"role": "tool", "tool_call_id": "tc_ORPHAN", "content": "stale"}, + ] + result = sanitize_messages(messages) + assert len(result) == 2 + assert result[1]["role"] == "tool" + assert result[1]["tool_call_id"] == "tc_1" + + def test_cascading_orphan_removal(self): + """Dropping one message can make the next one orphaned too — multi-pass.""" + from content_filter_proxy import sanitize_messages + messages = [ + # assistant with tool_use tu_A + {"role": "assistant", "content": [{"type": "tool_use", "id": "tu_A", "name": "bash"}]}, + # user responds to tu_A + {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "tu_A", "content": "ok"}]}, + # assistant with tool_use tu_B (referencing something dropped) + {"role": "assistant", "content": [{"type": "tool_use", "id": "tu_B", "name": "read"}]}, + # user responds to tu_B AND orphan tu_C (no matching tool_use) + {"role": "user", "content": [ + {"type": "tool_result", "tool_use_id": "tu_B", "content": "ok"}, + {"type": "tool_result", "tool_use_id": "tu_C", "content": "orphan"}, + ]}, + ] + result = sanitize_messages(messages) + # tu_C should be stripped, tu_A and tu_B should survive + assert len(result) == 4 + last_user_blocks = result[3]["content"] + assert len(last_user_blocks) == 1 + assert last_user_blocks[0]["tool_use_id"] == "tu_B" + + def test_drops_empty_user_message_after_filter(self): + """If all content blocks are stripped, the user message is dropped entirely.""" + from content_filter_proxy import sanitize_messages + messages = [ + {"role": "assistant", "content": [{"type": "tool_use", "id": "tu_1", "name": "bash"}]}, + {"role": "user", "content": [ + {"type": "tool_result", "tool_use_id": "tu_ORPHAN", "content": "stale"}, + ]}, + ] + result = sanitize_messages(messages) + # The user message should be dropped (all blocks were orphaned) + assert len(result) == 1 + assert result[0]["role"] == "assistant" + + def test_keeps_empty_assistant_message(self): + """Empty assistant messages are kept (not dropped) to preserve alternation.""" + from content_filter_proxy import sanitize_messages + messages = [ + {"role": "assistant", "content": [{"type": "text", "text": ""}]}, + ] + result = sanitize_messages(messages) + assert len(result) == 1 + assert result[0]["role"] == "assistant" + + def test_replaces_null_assistant_content(self): + from content_filter_proxy import sanitize_messages + messages = [ + {"role": "assistant", "content": None}, + ] + result = sanitize_messages(messages) + assert result[0]["content"] == "." + + def test_replaces_empty_string_assistant(self): + from content_filter_proxy import sanitize_messages + messages = [ + {"role": "assistant", "content": " "}, + ] + result = sanitize_messages(messages) + assert result[0]["content"] == "." + + def test_strips_empty_string_user(self): + from content_filter_proxy import sanitize_messages + messages = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + {"role": "user", "content": ""}, + ] + result = sanitize_messages(messages) + assert len(result) == 2 # empty user dropped + + def test_passthrough_non_list(self): + from content_filter_proxy import sanitize_messages + assert sanitize_messages("not a list") == "not a list" + assert sanitize_messages(None) is None + + def test_preserves_non_dict_blocks(self): + """Non-dict items in content list are preserved as-is.""" + from content_filter_proxy import sanitize_messages + messages = [ + {"role": "user", "content": ["plain string", {"type": "text", "text": "hi"}]}, + ] + result = sanitize_messages(messages) + assert len(result[0]["content"]) == 2 + + def test_null_assistant_with_tool_calls_not_replaced(self): + """Assistant msg with null content but tool_calls should NOT get placeholder.""" + from content_filter_proxy import sanitize_messages + messages = [ + {"role": "assistant", "content": None, "tool_calls": [{"id": "tc_1"}]}, + ] + result = sanitize_messages(messages) + assert result[0]["content"] is None # preserved because tool_calls exist + + +# --------------------------------------------------------------------------- +# remap_tool_call +# --------------------------------------------------------------------------- + +class TestRemapToolCall: + def test_remaps_databricks_tool_call(self): + from content_filter_proxy import remap_tool_call + tc = { + "id": "tc_1", + "function": { + "name": "databricks-tool-call", + "arguments": json.dumps({"name": "execute_sql", "query": "SELECT 1"}), + }, + } + result = remap_tool_call(tc) + assert result["function"]["name"] == "execute_sql" + args = json.loads(result["function"]["arguments"]) + assert "name" not in args + assert args["query"] == "SELECT 1" + + def test_passthrough_normal_tool(self): + from content_filter_proxy import remap_tool_call + tc = {"id": "tc_1", "function": {"name": "bash", "arguments": '{"cmd": "ls"}'}} + result = remap_tool_call(tc) + assert result["function"]["name"] == "bash" + + def test_handles_invalid_json_args(self): + from content_filter_proxy import remap_tool_call + tc = {"id": "tc_1", "function": {"name": "databricks-tool-call", "arguments": "not json"}} + result = remap_tool_call(tc) + assert result["function"]["name"] == "databricks-tool-call" # unchanged + + +# --------------------------------------------------------------------------- +# fix_response_data +# --------------------------------------------------------------------------- + +class TestFixResponseData: + def test_remaps_tool_calls_in_message(self): + from content_filter_proxy import fix_response_data + data = { + "choices": [{ + "message": { + "tool_calls": [{ + "id": "tc_1", + "function": { + "name": "databricks-tool-call", + "arguments": json.dumps({"name": "run_sql", "q": "SELECT 1"}), + }, + }], + }, + "finish_reason": "stop", + }], + } + result = fix_response_data(data) + assert result["choices"][0]["message"]["tool_calls"][0]["function"]["name"] == "run_sql" + assert result["choices"][0]["finish_reason"] == "tool_calls" + + def test_fixes_streaming_delta(self): + from content_filter_proxy import fix_response_data + data = { + "choices": [{ + "delta": { + "tool_calls": [{ + "id": "tc_1", + "function": { + "name": "databricks-tool-call", + "arguments": json.dumps({"name": "run_sql"}), + }, + }], + }, + "finish_reason": "stop", + }], + } + result = fix_response_data(data) + assert result["choices"][0]["delta"]["tool_calls"][0]["function"]["name"] == "run_sql" + assert result["choices"][0]["finish_reason"] == "tool_calls" + + def test_noop_on_non_dict(self): + from content_filter_proxy import fix_response_data + assert fix_response_data("string") == "string" + assert fix_response_data(None) is None + + def test_no_choices_is_noop(self): + from content_filter_proxy import fix_response_data + data = {"id": "resp_1"} + assert fix_response_data(data) == data + + +# --------------------------------------------------------------------------- +# SSEProcessor +# --------------------------------------------------------------------------- + +class TestSSEProcessor: + def test_passthrough_non_data_lines(self): + from content_filter_proxy import SSEProcessor + proc = SSEProcessor() + assert proc.process_line("event: message") == ["event: message"] + assert proc.process_line(": comment") == [": comment"] + + def test_passthrough_done_signal(self): + from content_filter_proxy import SSEProcessor + proc = SSEProcessor() + result = proc.process_line("data: [DONE]") + assert "data: [DONE]" in result + + def test_passthrough_normal_tool(self): + from content_filter_proxy import SSEProcessor + proc = SSEProcessor() + event = { + "choices": [{ + "delta": {"tool_calls": [{"index": 0, "function": {"name": "bash"}}]}, + "finish_reason": None, + }], + } + result = proc.process_line(f"data: {json.dumps(event)}") + assert len(result) == 1 + assert "bash" in result[0] + + def test_buffers_databricks_tool_call(self): + """First chunk with databricks-tool-call name should be buffered.""" + from content_filter_proxy import SSEProcessor + proc = SSEProcessor() + event = { + "choices": [{ + "delta": { + "tool_calls": [{ + "index": 0, + "function": {"name": "databricks-tool-call", "arguments": ""}, + }], + }, + "finish_reason": None, + }], + } + result = proc.process_line(f"data: {json.dumps(event)}") + assert result == [] # buffered, not sent + + def test_resolves_name_from_args(self): + """Once args JSON is complete, name is resolved and buffered events flushed.""" + from content_filter_proxy import SSEProcessor + proc = SSEProcessor() + # First chunk — name is databricks-tool-call + event1 = { + "choices": [{ + "delta": { + "tool_calls": [{ + "index": 0, + "function": {"name": "databricks-tool-call", "arguments": ""}, + }], + }, + "finish_reason": None, + }], + } + proc.process_line(f"data: {json.dumps(event1)}") + + # Second chunk — args with real name + event2 = { + "choices": [{ + "delta": { + "tool_calls": [{ + "index": 0, + "function": {"arguments": json.dumps({"name": "execute_sql", "query": "SELECT 1"})}, + }], + }, + "finish_reason": None, + }], + } + result = proc.process_line(f"data: {json.dumps(event2)}") + # Should flush buffered events + current event + assert len(result) >= 1 + # The resolved name should appear in flushed output + combined = " ".join(result) + assert "execute_sql" in combined + + def test_flush_remaining(self): + from content_filter_proxy import SSEProcessor + proc = SSEProcessor() + # Buffer a databricks-tool-call but never resolve it + event = { + "choices": [{ + "delta": { + "tool_calls": [{ + "index": 0, + "function": {"name": "databricks-tool-call", "arguments": '{"partial'}, + }], + }, + "finish_reason": None, + }], + } + proc.process_line(f"data: {json.dumps(event)}") + remaining = proc.flush_remaining() + assert len(remaining) >= 1 # buffered lines flushed as-is + + def test_fixes_finish_reason_on_stop(self): + """finish_reason 'stop' with active tool state should become 'tool_calls'.""" + from content_filter_proxy import SSEProcessor + proc = SSEProcessor() + # Seed tool state + proc._tool_state[0] = {"args_buffer": "", "resolved_name": "bash", "buffered_lines": []} + event = { + "choices": [{"delta": {}, "finish_reason": "stop"}], + } + result = proc.process_line(f"data: {json.dumps(event)}") + parsed = json.loads(result[0][6:]) # strip "data: " + assert parsed["choices"][0]["finish_reason"] == "tool_calls" + + def test_invalid_json_passthrough(self): + from content_filter_proxy import SSEProcessor + proc = SSEProcessor() + result = proc.process_line("data: {invalid json}") + assert result == ["data: {invalid json}"] + + +# --------------------------------------------------------------------------- +# _get_fresh_token +# --------------------------------------------------------------------------- + +class TestGetFreshToken: + def setup_method(self): + """Reset token cache before each test.""" + from content_filter_proxy import _TOKEN_CACHE + _TOKEN_CACHE["token"] = None + _TOKEN_CACHE["read_at"] = 0.0 + + def test_reads_from_databrickscfg(self, tmp_path): + from content_filter_proxy import _get_fresh_token, _TOKEN_CACHE + cfg = tmp_path / ".databrickscfg" + cfg.write_text("[DEFAULT]\nhost = https://test.cloud.databricks.com\ntoken = dapi_test123\n") + with mock.patch("content_filter_proxy._DATABRICKSCFG_PATH", str(cfg)): + token = _get_fresh_token() + assert token == "dapi_test123" + assert _TOKEN_CACHE["token"] == "dapi_test123" + + def test_returns_cached_within_ttl(self, tmp_path): + from content_filter_proxy import _get_fresh_token, _TOKEN_CACHE + _TOKEN_CACHE["token"] = "cached_token" + _TOKEN_CACHE["read_at"] = time.time() # just now + # Even with a bad path, should return cached + with mock.patch("content_filter_proxy._DATABRICKSCFG_PATH", "/nonexistent"): + token = _get_fresh_token() + assert token == "cached_token" + + def test_refreshes_after_ttl(self, tmp_path): + from content_filter_proxy import _get_fresh_token, _TOKEN_CACHE + _TOKEN_CACHE["token"] = "old_token" + _TOKEN_CACHE["read_at"] = time.time() - 60 # expired + cfg = tmp_path / ".databrickscfg" + cfg.write_text("[DEFAULT]\nhost = https://test.cloud.databricks.com\ntoken = new_token\n") + with mock.patch("content_filter_proxy._DATABRICKSCFG_PATH", str(cfg)): + token = _get_fresh_token() + assert token == "new_token" + + def test_returns_stale_on_read_error(self, tmp_path): + from content_filter_proxy import _get_fresh_token, _TOKEN_CACHE + _TOKEN_CACHE["token"] = "stale_token" + _TOKEN_CACHE["read_at"] = 0.0 # force re-read + with mock.patch("content_filter_proxy._DATABRICKSCFG_PATH", "/nonexistent"): + token = _get_fresh_token() + assert token == "stale_token" + + def test_returns_none_when_no_cache_and_no_file(self): + from content_filter_proxy import _get_fresh_token, _TOKEN_CACHE + _TOKEN_CACHE["token"] = None + _TOKEN_CACHE["read_at"] = 0.0 + with mock.patch("content_filter_proxy._DATABRICKSCFG_PATH", "/nonexistent"): + token = _get_fresh_token() + assert token is None diff --git a/tests/test_gateway_discovery.py b/tests/test_gateway_discovery.py index 698445a..92ca725 100644 --- a/tests/test_gateway_discovery.py +++ b/tests/test_gateway_discovery.py @@ -132,7 +132,7 @@ def test_workspace_id_whitespace_stripped(self, mock_probe): # Integration tests — verify endpoint URLs constructed by setup scripts # --------------------------------------------------------------------------- -SETUP_DIR = Path(__file__).parent.parent +SETUP_DIR = Path(__file__).parent.parent / "setup" class TestEndpointConstruction: @@ -146,9 +146,11 @@ def _run_setup(self, script_name, tmp_path, env_overrides=None): "DATABRICKS_TOKEN": "dapi_test_token", "DATABRICKS_WORKSPACE_ID": "6280049833385130", "PATH": os.environ.get("PATH", ""), - "PYTHONPATH": str(SETUP_DIR), + "PYTHONPATH": str(SETUP_DIR.parent), # Pre-resolve gateway so subprocess skips the network probe "_GATEWAY_RESOLVED": "", + # Skip CLI install (curl | bash) — tests only verify config files + "SKIP_CLAUDE_INSTALL": "1", } # Ensure DATABRICKS_GATEWAY_HOST is NOT set (test auto-discovery) env.pop("DATABRICKS_GATEWAY_HOST", None) @@ -175,15 +177,15 @@ def test_setup_claude_falls_back_when_gateway_unreachable(self, tmp_path): # Gateway is unreachable from test env, so should fall back import json settings_path = tmp_path / ".claude" / "settings.json" - if settings_path.exists(): - settings = json.loads(settings_path.read_text()) - base_url = settings.get("env", {}).get("ANTHROPIC_BASE_URL", "") - assert base_url.endswith("/anthropic") - # Either gateway or serving-endpoints is valid - assert ( - "ai-gateway.cloud.databricks.com" in base_url - or "serving-endpoints/anthropic" in base_url - ) + assert settings_path.exists(), "settings.json was not written" + settings = json.loads(settings_path.read_text()) + base_url = settings.get("env", {}).get("ANTHROPIC_BASE_URL", "") + assert base_url.endswith("/anthropic") + # Either gateway or serving-endpoints is valid + assert ( + "ai-gateway.cloud.databricks.com" in base_url + or "serving-endpoints/anthropic" in base_url + ) def test_setup_claude_explicit_override(self, tmp_path): """setup_claude.py should prefer explicit DATABRICKS_GATEWAY_HOST.""" @@ -196,10 +198,10 @@ def test_setup_claude_explicit_override(self, tmp_path): import json settings_path = tmp_path / ".claude" / "settings.json" - if settings_path.exists(): - settings = json.loads(settings_path.read_text()) - base_url = settings.get("env", {}).get("ANTHROPIC_BASE_URL", "") - assert "custom.gateway.example.com" in base_url + assert settings_path.exists(), "settings.json was not written" + settings = json.loads(settings_path.read_text()) + base_url = settings.get("env", {}).get("ANTHROPIC_BASE_URL", "") + assert "custom.gateway.example.com" in base_url def test_setup_claude_fallback_no_gateway(self, tmp_path): """setup_claude.py falls back to DATABRICKS_HOST when no gateway available.""" @@ -210,10 +212,10 @@ def test_setup_claude_fallback_no_gateway(self, tmp_path): import json settings_path = tmp_path / ".claude" / "settings.json" - if settings_path.exists(): - settings = json.loads(settings_path.read_text()) - base_url = settings.get("env", {}).get("ANTHROPIC_BASE_URL", "") - assert "test.cloud.databricks.com/serving-endpoints/anthropic" in base_url + assert settings_path.exists(), "settings.json was not written" + settings = json.loads(settings_path.read_text()) + base_url = settings.get("env", {}).get("ANTHROPIC_BASE_URL", "") + assert "test.cloud.databricks.com/serving-endpoints/anthropic" in base_url @mock.patch("utils._probe_gateway", return_value=True) def test_codex_gateway_url_construction(self, mock_probe): diff --git a/tests/test_mcp_integration.py b/tests/test_mcp_integration.py new file mode 100644 index 0000000..2dfbc1a --- /dev/null +++ b/tests/test_mcp_integration.py @@ -0,0 +1,290 @@ +"""End-to-end MCP integration tests — v2 background execution + inbox API. + +Exercises the full flow: coda_run -> coda_inbox -> coda_get_result. +No real PTY — app hooks are mocked. +""" + +import json +import os +import time +from unittest.mock import MagicMock + +import pytest + + +# ── helpers ────────────────────────────────────────────────────────── + + +def _parse(result: str) -> dict: + """Parse JSON string returned by MCP tools.""" + return json.loads(result) + + +# ── fixture ────────────────────────────────────────────────────────── + + +@pytest.fixture(autouse=True) +def isolated_env(tmp_path): + """Redirect state to tmp and mock PTY hooks.""" + from coda_mcp import task_manager as tm + from coda_mcp import mcp_server as ms + + original_dir = tm.SESSIONS_DIR + tm.SESSIONS_DIR = str(tmp_path / "sessions") + + mock_send = MagicMock() + mock_close = MagicMock() + ms.set_app_hooks( + create_session_fn=lambda label: f"pty-mock-{label}", + send_input_fn=mock_send, + close_session_fn=mock_close, + ) + + yield {"tmp": tmp_path, "mock_send": mock_send, "mock_close": mock_close} + + tm.SESSIONS_DIR = original_dir + ms.set_app_hooks(None, None, None) + + +# ── 1. Happy-path: fire-and-forget → inbox → result ───────────────── + + +class TestFullMcpFlow: + @pytest.mark.asyncio + async def test_full_background_flow(self, isolated_env): + """Happy path: run (fire-and-forget) → inbox → result.""" + from coda_mcp import mcp_server as ms + from coda_mcp import task_manager as tm + + # Step 1: submit task (returns immediately) + with MagicMock() as mock_thread: + from coda_mcp import mcp_server + with pytest.MonkeyPatch.context() as mp: + mp.setattr("coda_mcp.mcp_server.threading", mock_thread) + raw = await ms.coda_run( + prompt="create a sales pipeline", + email="alice@test.com", + context='{"tables": ["sales.transactions"]}', + ) + + task = _parse(raw) + assert task["status"] == "running" + task_id = task["task_id"] + session_id = task["session_id"] + assert task_id.startswith("task-") + assert session_id.startswith("sess-") + + # Step 2: inbox shows running task + raw = await ms.coda_inbox() + inbox = _parse(raw) + assert len(inbox["tasks"]) == 1 + assert inbox["tasks"][0]["task_id"] == task_id + assert inbox["tasks"][0]["status"] == "running" + assert inbox["counts"]["running"] == 1 + + # Step 3: simulate agent writing result.json + tdir = tm._task_dir(session_id, task_id) + result_path = os.path.join(tdir, "result.json") + with open(result_path, "w") as f: + json.dump({ + "status": "completed", + "summary": "Created sales pipeline with 3 stages", + "files_changed": ["pipeline.py", "config.yaml"], + "artifacts": ["/workspace/pipeline.py"], + "errors": [], + }, f) + + # Step 4: complete_task (simulating what _watch_task does) + tm.complete_task(session_id, task_id) + + # Step 5: inbox shows completed + raw = await ms.coda_inbox() + inbox = _parse(raw) + assert len(inbox["tasks"]) == 1 + assert inbox["tasks"][0]["status"] == "completed" + assert inbox["tasks"][0]["summary"] == "Created sales pipeline with 3 stages" + assert inbox["counts"]["completed"] == 1 + + # Step 6: get full result + raw = await ms.coda_get_result(task_id=task_id, session_id=session_id) + result = _parse(raw) + assert result["task_id"] == task_id + assert result["summary"] == "Created sales pipeline with 3 stages" + assert result["files_changed"] == ["pipeline.py", "config.yaml"] + + # Step 7: session was auto-closed + session = tm._read_session(session_id) + assert session["status"] == "closed" + + +# ── 2. Task chaining with previous_session_id ─────────────────────── + + +class TestTaskChaining: + @pytest.mark.asyncio + async def test_chained_task_references_prior_session(self, isolated_env): + """A chained task includes prior session context in prompt.""" + from coda_mcp import mcp_server as ms + from coda_mcp import task_manager as tm + + # First task + raw = await ms.coda_run( + prompt="build pipeline", + email="bob@test.com", + ) + first = _parse(raw) + first_sid = first["session_id"] + first_tid = first["task_id"] + + # Complete first task + tdir = tm._task_dir(first_sid, first_tid) + with open(os.path.join(tdir, "result.json"), "w") as f: + json.dump({ + "status": "completed", + "summary": "Built pipeline.py", + "files_changed": ["pipeline.py"], + }, f) + tm.complete_task(first_sid, first_tid) + + # Second task chained to first + raw = await ms.coda_run( + prompt="add tests for the pipeline", + email="bob@test.com", + previous_session_id=first_sid, + ) + second = _parse(raw) + second_sid = second["session_id"] + second_tid = second["task_id"] + + # Verify prompt references prior session + prompt_path = os.path.join( + tm._task_dir(second_sid, second_tid), "prompt.txt" + ) + with open(prompt_path) as f: + prompt_text = f.read() + assert f"PRIOR SESSION: {first_sid}" in prompt_text + + # Verify meta.json has previous_session_id + meta_path = os.path.join( + tm._task_dir(second_sid, second_tid), "meta.json" + ) + with open(meta_path) as f: + meta = json.load(f) + assert meta["previous_session_id"] == first_sid + + # Verify inbox shows chaining + raw = await ms.coda_inbox() + inbox = _parse(raw) + running_tasks = [t for t in inbox["tasks"] if t["status"] == "running"] + assert len(running_tasks) == 1 + assert running_tasks[0]["previous_session_id"] == first_sid + + +# ── 3. Concurrency limit ──────────────────────────────────────────── + + +class TestConcurrencyLimit: + @pytest.mark.asyncio + async def test_exceeding_limit_returns_error(self, isolated_env): + """Exceeding MAX_CONCURRENT_TASKS returns a clear error.""" + from coda_mcp import mcp_server as ms + from unittest.mock import patch + + with patch("coda_mcp.task_manager.MAX_CONCURRENT_TASKS", 1): + r1 = await ms.coda_run(prompt="task1", email="a@b.com") + assert _parse(r1)["status"] == "running" + + r2 = await ms.coda_run(prompt="task2", email="a@b.com") + d2 = _parse(r2) + assert d2["status"] == "error" + assert "concurrency" in d2["error"].lower() + + +# ── 4. Yolo permissions → --yolo flag ─────────────────────────────── + + +class TestYoloPermissions: + @pytest.mark.asyncio + async def test_yolo_permissions(self, isolated_env): + """permissions='yolo' causes the PTY command to include --yolo.""" + from coda_mcp import mcp_server as ms + + mock_send = isolated_env["mock_send"] + + with MagicMock() as mock_thread: + from coda_mcp import mcp_server + with pytest.MonkeyPatch.context() as mp: + mp.setattr("coda_mcp.mcp_server.threading", mock_thread) + await ms.coda_run( + prompt="deploy everything", + email="dave@test.com", + permissions="yolo", + ) + + mock_send.assert_called_once() + cmd = mock_send.call_args[0][1] + assert "--yolo" in cmd + + +# ── 5. Session auto-close on completion ────────────────────────────── + + +class TestAutoClose: + @pytest.mark.asyncio + async def test_session_auto_closes(self, isolated_env): + """Session is auto-closed when task completes.""" + from coda_mcp import mcp_server as ms + from coda_mcp import task_manager as tm + + raw = await ms.coda_run(prompt="quick job", email="a@b.com") + d = _parse(raw) + + # Session should be busy + session = tm._read_session(d["session_id"]) + assert session["status"] == "busy" + + # Complete the task + tdir = tm._task_dir(d["session_id"], d["task_id"]) + with open(os.path.join(tdir, "result.json"), "w") as f: + json.dump({"status": "completed", "summary": "done"}, f) + tm.complete_task(d["session_id"], d["task_id"]) + + # Session should now be closed + session = tm._read_session(d["session_id"]) + assert session["status"] == "closed" + assert "closed_at" in session + + +# ── 6. Cleanup expired tasks ──────────────────────────────────────── + + +class TestCleanup: + @pytest.mark.asyncio + async def test_cleanup_removes_expired(self, isolated_env): + """cleanup_expired_tasks removes old closed sessions.""" + from coda_mcp import mcp_server as ms + from coda_mcp import task_manager as tm + from unittest.mock import patch + + raw = await ms.coda_run(prompt="old task", email="a@b.com") + d = _parse(raw) + + # Complete and close + tdir = tm._task_dir(d["session_id"], d["task_id"]) + with open(os.path.join(tdir, "result.json"), "w") as f: + json.dump({"status": "completed", "summary": "done"}, f) + tm.complete_task(d["session_id"], d["task_id"]) + + # Backdate closed_at to expire it + session = tm._read_session(d["session_id"]) + session["closed_at"] = time.time() - 90000 # 25 hours ago + tm._write_json(tm._session_file(d["session_id"]), session) + + # Cleanup should remove it + removed = tm.cleanup_expired_tasks() + assert removed == 1 + + # Inbox should be empty now + raw = await ms.coda_inbox() + inbox = _parse(raw) + assert len(inbox["tasks"]) == 0 diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py new file mode 100644 index 0000000..4b20a8e --- /dev/null +++ b/tests/test_mcp_server.py @@ -0,0 +1,342 @@ +"""Tests for mcp_server — v2 background execution + inbox API.""" + +import json +import os +from unittest import mock + +import pytest + + +# ── helpers ────────────────────────────────────────────────────────── + + +@pytest.fixture(autouse=True) +def _reset_hooks(): + """Clear app hooks before/after each test.""" + from coda_mcp import mcp_server + + mcp_server._app_create_session = None + mcp_server._app_send_input = None + mcp_server._app_close_session = None + yield + mcp_server._app_create_session = None + mcp_server._app_send_input = None + mcp_server._app_close_session = None + + +@pytest.fixture(autouse=True) +def _isolated_sessions(tmp_path): + """Point task_manager.SESSIONS_DIR at a temp dir.""" + sessions_dir = str(tmp_path / ".coda" / "sessions") + with mock.patch("coda_mcp.task_manager.SESSIONS_DIR", sessions_dir): + yield sessions_dir + + +def _parse(result: str) -> dict: + """Parse JSON string returned by MCP tools.""" + return json.loads(result) + + +# ── Tool registration ──────────────────────────────────────────────── + + +class TestToolRegistration: + def test_three_tools_registered(self): + from coda_mcp import mcp_server + + tool_mgr = mcp_server.mcp._tool_manager + tool_names = set(tool_mgr._tools.keys()) + expected = {"coda_run", "coda_inbox", "coda_get_result"} + assert expected == tool_names, f"Expected {expected}, got {tool_names}" + + def test_tool_count_is_three(self): + from coda_mcp import mcp_server + + tool_mgr = mcp_server.mcp._tool_manager + assert len(tool_mgr._tools) == 3 + + +# ── coda_run ───────────────────────────────────────────────────────── + + +class TestCodaRun: + @pytest.mark.asyncio + async def test_creates_task_disk_only(self): + """Without app hooks, creates session+task on disk, returns immediately.""" + from coda_mcp import mcp_server + + result = await mcp_server.coda_run( + prompt="fix the bug", + email="a@b.com", + ) + data = _parse(result) + assert data["status"] == "running" + assert data["task_id"].startswith("task-") + assert data["session_id"].startswith("sess-") + + @pytest.mark.asyncio + async def test_auto_creates_session(self): + """coda_run auto-creates a session — no separate create_session needed.""" + from coda_mcp import mcp_server + from coda_mcp import task_manager + + result = await mcp_server.coda_run( + prompt="build pipeline", + email="a@b.com", + ) + data = _parse(result) + session = task_manager._read_session(data["session_id"]) + assert session["email"] == "a@b.com" + assert session["status"] == "busy" # task is running + + @pytest.mark.asyncio + async def test_sends_to_pty_when_hooks_set(self): + """With hooks, creates PTY and sends hermes command.""" + from coda_mcp import mcp_server + + mock_create = mock.Mock(return_value="pty-xyz") + mock_send = mock.Mock() + mcp_server.set_app_hooks( + create_session_fn=mock_create, + send_input_fn=mock_send, + close_session_fn=mock.Mock(), + ) + + with mock.patch("coda_mcp.mcp_server.threading"): + result = await mcp_server.coda_run( + prompt="fix the bug", + email="a@b.com", + ) + + data = _parse(result) + assert data["status"] == "running" + mock_create.assert_called_once_with(label="hermes-mcp") + mock_send.assert_called_once() + assert "hermes" in mock_send.call_args[0][1] + + @pytest.mark.asyncio + async def test_yolo_permission(self): + """permissions='yolo' produces --yolo flag in PTY command.""" + from coda_mcp import mcp_server + + mock_send = mock.Mock() + mcp_server.set_app_hooks( + create_session_fn=mock.Mock(return_value="pty-1"), + send_input_fn=mock_send, + close_session_fn=mock.Mock(), + ) + + with mock.patch("coda_mcp.mcp_server.threading"): + await mcp_server.coda_run( + prompt="go fast", + email="a@b.com", + permissions="yolo", + ) + + cmd = mock_send.call_args[0][1] + assert "--yolo" in cmd + + @pytest.mark.asyncio + async def test_previous_session_id_in_prompt(self): + """previous_session_id appears in the wrapped prompt.""" + from coda_mcp import mcp_server + from coda_mcp import task_manager + + # Create a "prior" session with a completed task + prior = task_manager.create_session("a@b.com", "u1") + prior_sid = prior["session_id"] + + result = await mcp_server.coda_run( + prompt="add tests", + email="a@b.com", + previous_session_id=prior_sid, + ) + data = _parse(result) + + # Read the prompt.txt and verify prior session reference + tdir = task_manager._task_dir(data["session_id"], data["task_id"]) + with open(os.path.join(tdir, "prompt.txt")) as f: + prompt_text = f.read() + + assert f"PRIOR SESSION: {prior_sid}" in prompt_text + + @pytest.mark.asyncio + async def test_meta_json_written(self): + """coda_run writes meta.json with task metadata.""" + from coda_mcp import mcp_server + from coda_mcp import task_manager + + result = await mcp_server.coda_run( + prompt="build a dashboard for sales", + email="alice@test.com", + previous_session_id="sess-old", + ) + data = _parse(result) + + meta_path = os.path.join( + task_manager._task_dir(data["session_id"], data["task_id"]), + "meta.json", + ) + with open(meta_path) as f: + meta = json.load(f) + + assert meta["email"] == "alice@test.com" + assert meta["previous_session_id"] == "sess-old" + assert meta["prompt_summary"] == "build a dashboard for sales" + assert "created_at" in meta + + @pytest.mark.asyncio + async def test_concurrency_limit(self): + """Exceeding MAX_CONCURRENT_TASKS returns an error.""" + from coda_mcp import mcp_server + + with mock.patch("coda_mcp.task_manager.MAX_CONCURRENT_TASKS", 1): + # First task succeeds + r1 = await mcp_server.coda_run(prompt="task1", email="a@b.com") + assert _parse(r1)["status"] == "running" + + # Second task should fail (1 already running) + r2 = await mcp_server.coda_run(prompt="task2", email="a@b.com") + d2 = _parse(r2) + assert d2["status"] == "error" + assert "concurrency" in d2["error"].lower() + + +# ── coda_inbox ─────────────────────────────────────────────────────── + + +class TestCodaInbox: + @pytest.mark.asyncio + async def test_empty_inbox(self): + """No tasks → empty inbox.""" + from coda_mcp import mcp_server + + result = await mcp_server.coda_inbox() + data = _parse(result) + assert data["tasks"] == [] + assert data["counts"] == {"running": 0, "completed": 0, "failed": 0} + + @pytest.mark.asyncio + async def test_running_task_in_inbox(self): + """A running task shows up in the inbox.""" + from coda_mcp import mcp_server + + await mcp_server.coda_run(prompt="build pipeline", email="a@b.com") + + result = await mcp_server.coda_inbox() + data = _parse(result) + assert len(data["tasks"]) == 1 + assert data["tasks"][0]["status"] == "running" + assert data["tasks"][0]["prompt_summary"] == "build pipeline" + assert data["counts"]["running"] == 1 + + @pytest.mark.asyncio + async def test_completed_task_in_inbox(self): + """A completed task shows summary in inbox.""" + from coda_mcp import mcp_server + from coda_mcp import task_manager + + r = await mcp_server.coda_run(prompt="fix bug", email="a@b.com") + d = _parse(r) + + # Simulate agent writing result.json + tdir = task_manager._task_dir(d["session_id"], d["task_id"]) + result_path = os.path.join(tdir, "result.json") + with open(result_path, "w") as f: + json.dump({ + "status": "completed", + "summary": "Fixed the login bug", + "files_changed": ["auth.py"], + "artifacts": [], + "errors": [], + }, f) + + result = await mcp_server.coda_inbox() + data = _parse(result) + assert len(data["tasks"]) == 1 + assert data["tasks"][0]["status"] == "completed" + assert data["tasks"][0]["summary"] == "Fixed the login bug" + + @pytest.mark.asyncio + async def test_status_filter(self): + """Filtering inbox by status works.""" + from coda_mcp import mcp_server + from coda_mcp import task_manager + + # Create two tasks — one running, one completed + r1 = await mcp_server.coda_run(prompt="task1", email="a@b.com") + d1 = _parse(r1) + + r2 = await mcp_server.coda_run(prompt="task2", email="a@b.com") + d2 = _parse(r2) + + # Complete task2 + tdir = task_manager._task_dir(d2["session_id"], d2["task_id"]) + with open(os.path.join(tdir, "result.json"), "w") as f: + json.dump({"status": "completed", "summary": "done"}, f) + + # Filter running only + result = await mcp_server.coda_inbox(status="running") + data = _parse(result) + assert len(data["tasks"]) == 1 + assert data["tasks"][0]["task_id"] == d1["task_id"] + + @pytest.mark.asyncio + async def test_multiple_tasks_sorted_recent_first(self): + """Inbox returns tasks sorted most recent first.""" + from coda_mcp import mcp_server + + r1 = await mcp_server.coda_run(prompt="first", email="a@b.com") + r2 = await mcp_server.coda_run(prompt="second", email="a@b.com") + + result = await mcp_server.coda_inbox() + data = _parse(result) + assert len(data["tasks"]) == 2 + # Most recent first + assert data["tasks"][0]["prompt_summary"] == "second" + assert data["tasks"][1]["prompt_summary"] == "first" + + +# ── coda_get_result ────────────────────────────────────────────────── + + +class TestCodaGetResult: + @pytest.mark.asyncio + async def test_returns_result(self): + from coda_mcp import mcp_server + from coda_mcp import task_manager + + r = await mcp_server.coda_run(prompt="go", email="a@b.com") + d = _parse(r) + + # Simulate agent writing result.json + tdir = task_manager._task_dir(d["session_id"], d["task_id"]) + with open(os.path.join(tdir, "result.json"), "w") as f: + json.dump({ + "summary": "Fixed the bug", + "files_changed": ["app.py"], + "artifacts": [], + "errors": [], + }, f) + + result = await mcp_server.coda_get_result( + task_id=d["task_id"], session_id=d["session_id"] + ) + data = _parse(result) + assert data["task_id"] == d["task_id"] + assert data["session_id"] == d["session_id"] + assert data["summary"] == "Fixed the bug" + + @pytest.mark.asyncio + async def test_no_result_yet(self): + from coda_mcp import mcp_server + + r = await mcp_server.coda_run(prompt="go", email="a@b.com") + d = _parse(r) + + result = await mcp_server.coda_get_result( + task_id=d["task_id"], session_id=d["session_id"] + ) + data = _parse(result) + assert data["status"] == "running" + assert "not yet available" in data["message"] diff --git a/tests/test_mlflow_tracing.py b/tests/test_mlflow_tracing.py index 02a6eb1..59e4ed0 100644 --- a/tests/test_mlflow_tracing.py +++ b/tests/test_mlflow_tracing.py @@ -14,7 +14,7 @@ # Helpers # --------------------------------------------------------------------------- -SETUP_MLFLOW = Path(__file__).parent.parent / "setup_mlflow.py" +SETUP_MLFLOW = Path(__file__).parent.parent / "setup" / "setup_mlflow.py" def run_setup_mlflow(tmp_path, env_overrides=None): diff --git a/tests/test_npm_version_pinning.py b/tests/test_npm_version_pinning.py index 1024242..d156588 100644 --- a/tests/test_npm_version_pinning.py +++ b/tests/test_npm_version_pinning.py @@ -139,8 +139,12 @@ class TestNpmVersionLive: """Run against real npm registry to verify the function works end-to-end.""" @pytest.mark.skipif( - not __import__("shutil").which("npm"), - reason="npm not installed" + not __import__("shutil").which("npm") or + __import__("subprocess").run( + ["npm", "view", "npm", "version"], + capture_output=True, timeout=15 + ).returncode != 0, + reason="npm not installed or not functional" ) def test_resolves_real_package(self): get_npm_version = _get_npm_version() diff --git a/tests/test_run_step.py b/tests/test_run_step.py new file mode 100644 index 0000000..af09733 --- /dev/null +++ b/tests/test_run_step.py @@ -0,0 +1,170 @@ +"""Tests for _run_step and _configure_all_cli_auth — env setup for subprocesses.""" + +import os +import subprocess +from unittest import mock + +import pytest + + +# We need to test _run_step from app.py. It calls subprocess.run, so we mock that. +# The function also updates setup_state, so we mock that too. + + +@pytest.fixture +def patch_app_globals(): + """Patch app.py globals needed by _run_step.""" + with mock.patch("app._update_step"): + yield + + +class TestRunStepEnvStripping: + """Verify _run_step strips OAuth credentials from subprocess env.""" + + def test_strips_databricks_client_id(self, patch_app_globals): + from app import _run_step + with mock.patch.dict(os.environ, { + "DATABRICKS_CLIENT_ID": "sp-client-id", + "DATABRICKS_CLIENT_SECRET": "sp-client-secret", + "HOME": "/tmp/test-home", + }), mock.patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedResult = mock.MagicMock( + returncode=0, stdout="ok", stderr="" + ) + _run_step("test-step", "echo hello") + + call_env = mock_run.call_args.kwargs.get("env", {}) + assert "DATABRICKS_CLIENT_ID" not in call_env + assert "DATABRICKS_CLIENT_SECRET" not in call_env + + def test_preserves_other_env_vars(self, patch_app_globals): + from app import _run_step + with mock.patch.dict(os.environ, { + "HOME": "/tmp/test-home", + "MY_CUSTOM_VAR": "keep-this", + "DATABRICKS_CLIENT_ID": "remove-this", + }), mock.patch("subprocess.run") as mock_run: + mock_run.return_value = mock.MagicMock(returncode=0, stdout="ok", stderr="") + _run_step("test-step", "echo hello") + + call_env = mock_run.call_args.kwargs.get("env", {}) + assert call_env.get("MY_CUSTOM_VAR") == "keep-this" + + +class TestRunStepPythonpath: + """Verify _run_step injects PYTHONPATH for setup script imports.""" + + def test_sets_pythonpath_to_app_dir(self, patch_app_globals): + from app import _run_step + with mock.patch.dict(os.environ, {"HOME": "/tmp/test-home"}), \ + mock.patch("subprocess.run") as mock_run: + mock_run.return_value = mock.MagicMock(returncode=0, stdout="ok", stderr="") + _run_step("test-step", "echo hello") + + call_env = mock_run.call_args.kwargs.get("env", {}) + # PYTHONPATH should contain the app directory (dirname of app.py) + assert "PYTHONPATH" in call_env + assert call_env["PYTHONPATH"] # non-empty + + def test_prepends_to_existing_pythonpath(self, patch_app_globals): + from app import _run_step + with mock.patch.dict(os.environ, { + "HOME": "/tmp/test-home", + "PYTHONPATH": "/existing/path", + }), mock.patch("subprocess.run") as mock_run: + mock_run.return_value = mock.MagicMock(returncode=0, stdout="ok", stderr="") + _run_step("test-step", "echo hello") + + call_env = mock_run.call_args.kwargs.get("env", {}) + assert "/existing/path" in call_env["PYTHONPATH"] + + +class TestRunStepPath: + """Verify _run_step adds ~/.local/bin to PATH.""" + + def test_adds_local_bin_to_path(self, patch_app_globals): + from app import _run_step + with mock.patch.dict(os.environ, { + "HOME": "/tmp/test-home", + "PATH": "/usr/bin", + }), mock.patch("subprocess.run") as mock_run: + mock_run.return_value = mock.MagicMock(returncode=0, stdout="ok", stderr="") + _run_step("test-step", "echo hello") + + call_env = mock_run.call_args.kwargs.get("env", {}) + assert "/tmp/test-home/.local/bin" in call_env["PATH"] + + def test_skips_if_already_in_path(self, patch_app_globals): + from app import _run_step + with mock.patch.dict(os.environ, { + "HOME": "/tmp/test-home", + "PATH": "/tmp/test-home/.local/bin:/usr/bin", + }), mock.patch("subprocess.run") as mock_run: + mock_run.return_value = mock.MagicMock(returncode=0, stdout="ok", stderr="") + _run_step("test-step", "echo hello") + + call_env = mock_run.call_args.kwargs.get("env", {}) + # Should not duplicate + assert call_env["PATH"].count(".local/bin") == 1 + + def test_defaults_home_when_empty(self, patch_app_globals): + """When HOME is empty or '/', should default to /app/python/source_code.""" + from app import _run_step + with mock.patch.dict(os.environ, {"HOME": ""}, clear=False), \ + mock.patch("subprocess.run") as mock_run: + mock_run.return_value = mock.MagicMock(returncode=0, stdout="ok", stderr="") + _run_step("test-step", "echo hello") + + call_env = mock_run.call_args.kwargs.get("env", {}) + assert "/app/python/source_code" in call_env.get("HOME", "") + + +# --------------------------------------------------------------------------- +# _configure_all_cli_auth — PAT reconfiguration path +# --------------------------------------------------------------------------- + +class TestConfigureAllCliAuth: + """Verify _configure_all_cli_auth injects PYTHONPATH for setup script imports. + + This is a separate code path from _run_step — it runs setup scripts via + subprocess.run after PAT rotation. Without PYTHONPATH, the scripts can't + `from utils import ...` since they live in setup/ subdirectory. + """ + + def _call_configure(self, mock_run, tmp_path, token="dapi_test"): + """Helper to call _configure_all_cli_auth with all dependencies mocked.""" + from app import _configure_all_cli_auth + # Create .claude dir so settings.json write succeeds + (tmp_path / ".claude").mkdir(exist_ok=True) + with mock.patch("utils.resolve_and_cache_gateway"), \ + mock.patch("app.get_gateway_host", return_value=None), \ + mock.patch("app.ensure_https", return_value="https://test.databricks.com"), \ + mock.patch("app.pat_rotator"), \ + mock.patch.dict(os.environ, {"HOME": str(tmp_path)}): + _configure_all_cli_auth(token) + + def test_injects_pythonpath(self, tmp_path): + with mock.patch("subprocess.run") as mock_run: + mock_run.return_value = mock.MagicMock(returncode=0, stdout="", stderr="") + self._call_configure(mock_run, tmp_path) + + # Find a subprocess call that runs a setup script + setup_calls = [c for c in mock_run.call_args_list + if any("setup/" in str(a) for a in c[0][0])] + assert len(setup_calls) > 0, "Expected subprocess calls for setup scripts" + + for call in setup_calls: + call_env = call.kwargs.get("env") or call[1].get("env", {}) + assert "PYTHONPATH" in call_env, f"PYTHONPATH missing from env for {call[0][0]}" + assert call_env["PYTHONPATH"], "PYTHONPATH should not be empty" + + def test_passes_token_in_env(self, tmp_path): + with mock.patch("subprocess.run") as mock_run: + mock_run.return_value = mock.MagicMock(returncode=0, stdout="", stderr="") + self._call_configure(mock_run, tmp_path, token="dapi_mytoken") + + setup_calls = [c for c in mock_run.call_args_list + if any("setup/" in str(a) for a in c[0][0])] + for call in setup_calls: + call_env = call.kwargs.get("env") or call[1].get("env", {}) + assert call_env.get("DATABRICKS_TOKEN") == "dapi_mytoken" diff --git a/tests/test_session_detach.py b/tests/test_session_detach.py index c381a40..6e3b60f 100644 --- a/tests/test_session_detach.py +++ b/tests/test_session_detach.py @@ -7,7 +7,6 @@ import os import subprocess -import sys import threading import time from collections import deque @@ -40,42 +39,23 @@ def test_detects_child_process_name(self): """When a shell has a child process, return the child's name.""" app_mod = _get_app() - # Launch a shell (bash) with a child process (sleep) - shell = subprocess.Popen( - ["bash", "-c", "sleep 300"], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - # Give the child time to spawn - time.sleep(0.5) - - try: - result = app_mod._get_session_process(shell.pid) - assert result == "sleep", f"Expected 'sleep', got '{result}'" - finally: - shell.kill() - shell.wait() + # Mock pgrep returning a child PID, then ps resolving it to "sleep" + pgrep_result = mock.Mock(returncode=0, stdout="12345\n") + ps_result = mock.Mock(returncode=0, stdout="sleep\n") + with mock.patch("subprocess.run", side_effect=[pgrep_result, ps_result]): + result = app_mod._get_session_process(100) + assert result == "sleep", f"Expected 'sleep', got '{result}'" def test_returns_parent_process_name_when_no_children(self): """When a shell has no foreground children, return the shell name.""" app_mod = _get_app() - # Launch a bare shell that just sleeps via bash built-in wait - # Use cat which will block on stdin with no children of its own - proc = subprocess.Popen( - ["cat"], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - - try: - result = app_mod._get_session_process(proc.pid) - assert result == "cat", f"Expected 'cat', got '{result}'" - finally: - proc.kill() - proc.wait() + # Mock pgrep finding no children (exit 1), then ps resolving the process itself + pgrep_result = mock.Mock(returncode=1, stdout="") + ps_result = mock.Mock(returncode=0, stdout="cat\n") + with mock.patch("subprocess.run", side_effect=[pgrep_result, ps_result]): + result = app_mod._get_session_process(100) + assert result == "cat", f"Expected 'cat', got '{result}'" def test_returns_unknown_for_dead_pid(self): """Return 'unknown' when the PID does not exist.""" @@ -230,28 +210,31 @@ def setup_app(self): app_module.sessions.clear() def test_exited_session_removed_from_dict(self): - import pty - master_fd, slave_fd = pty.openpty() + fake_master = 50 + # Use a completed process so waitpid works proc = subprocess.Popen( - ["bash", "-c", "echo hello && exit 0"], - stdin=slave_fd, stdout=slave_fd, stderr=slave_fd, - preexec_fn=os.setsid + ["bash", "-c", "exit 0"], + stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) - os.close(slave_fd) + proc.wait() session_id = "sess-eof-test" with self.app_module.sessions_lock: self.app_module.sessions[session_id] = { "pid": proc.pid, - "master_fd": master_fd, + "master_fd": fake_master, "output_buffer": deque(maxlen=1000), "lock": threading.Lock(), "last_poll_time": time.time(), "created_at": time.time(), } - # read_pty_output should detect EOF and call terminate_session - self.app_module.read_pty_output(session_id, master_fd) + # Simulate EOF: select says readable, os.read returns empty bytes + with mock.patch("select.select", return_value=([fake_master], [], [])), \ + mock.patch("os.read", return_value=b""), \ + mock.patch("os.close"), \ + mock.patch("os.kill"): + self.app_module.read_pty_output(session_id, fake_master) with self.app_module.sessions_lock: assert session_id not in self.app_module.sessions diff --git a/tests/test_sync_to_workspace.py b/tests/test_sync_to_workspace.py new file mode 100644 index 0000000..6faedf4 --- /dev/null +++ b/tests/test_sync_to_workspace.py @@ -0,0 +1,181 @@ +"""Tests for sync_to_workspace — path-escape guard and workspace sync.""" + +import subprocess +from pathlib import Path +from unittest import mock + +import pytest + + +# --------------------------------------------------------------------------- +# _read_databrickscfg +# --------------------------------------------------------------------------- + +class TestReadDatabrickscfg: + def test_reads_host_and_token(self, tmp_path): + cfg = tmp_path / ".databrickscfg" + cfg.write_text("[DEFAULT]\nhost = https://test.cloud.databricks.com\ntoken = dapi_abc123\n") + with mock.patch("sync_to_workspace.Path.home", return_value=tmp_path): + from sync_to_workspace import _read_databrickscfg + host, token = _read_databrickscfg() + assert host == "https://test.cloud.databricks.com" + assert token == "dapi_abc123" + + def test_returns_none_when_missing(self, tmp_path): + with mock.patch("sync_to_workspace.Path.home", return_value=tmp_path): + from sync_to_workspace import _read_databrickscfg + host, token = _read_databrickscfg() + assert host is None + assert token is None + + def test_returns_none_for_missing_keys(self, tmp_path): + cfg = tmp_path / ".databrickscfg" + cfg.write_text("[DEFAULT]\n# empty section\n") + with mock.patch("sync_to_workspace.Path.home", return_value=tmp_path): + from sync_to_workspace import _read_databrickscfg + host, token = _read_databrickscfg() + assert host is None + assert token is None + + +# --------------------------------------------------------------------------- +# get_user_email +# --------------------------------------------------------------------------- + +class TestGetUserEmail: + def test_raises_when_no_config(self, tmp_path): + from sync_to_workspace import get_user_email + with mock.patch("sync_to_workspace._read_databrickscfg", return_value=(None, None)): + with pytest.raises(RuntimeError, match="missing host or token"): + get_user_email() + + def test_raises_when_no_token(self): + from sync_to_workspace import get_user_email + with mock.patch("sync_to_workspace._read_databrickscfg", return_value=("https://host", None)): + with pytest.raises(RuntimeError, match="missing host or token"): + get_user_email() + + def test_returns_email(self): + from sync_to_workspace import get_user_email + mock_user = mock.MagicMock() + mock_user.user_name = "test@example.com" + mock_client = mock.MagicMock() + mock_client.current_user.me.return_value = mock_user + with mock.patch("sync_to_workspace._read_databrickscfg", return_value=("https://host", "tok")): + with mock.patch("sync_to_workspace.WorkspaceClient", return_value=mock_client): + email = get_user_email() + assert email == "test@example.com" + + +# --------------------------------------------------------------------------- +# sync_project — path-escape guard +# --------------------------------------------------------------------------- + +class TestSyncProject: + def test_rejects_path_outside_projects_dir(self, tmp_path, capsys): + from sync_to_workspace import sync_project + # Create a path outside ~/projects/ + outside = tmp_path / "evil-repo" + outside.mkdir() + with mock.patch("sync_to_workspace.Path.home", return_value=tmp_path): + sync_project(outside) + captured = capsys.readouterr() + assert "SKIP" in captured.err + assert "outside" in captured.err + + def test_accepts_path_inside_projects_dir(self, tmp_path): + from sync_to_workspace import sync_project + projects = tmp_path / "projects" + projects.mkdir() + repo = projects / "my-repo" + repo.mkdir() + + mock_user = mock.MagicMock() + mock_user.user_name = "test@example.com" + mock_client = mock.MagicMock() + mock_client.current_user.me.return_value = mock_user + + with mock.patch("sync_to_workspace.Path.home", return_value=tmp_path), \ + mock.patch("sync_to_workspace._read_databrickscfg", return_value=("https://host", "tok")), \ + mock.patch("sync_to_workspace.WorkspaceClient", return_value=mock_client), \ + mock.patch("sync_to_workspace.subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess([], 0, stdout="", stderr="") + sync_project(repo) + + mock_run.assert_called_once() + args = mock_run.call_args + assert "databricks" in args[0][0][0] + assert "sync" in args[0][0][1] + + def test_strips_oauth_env_from_subprocess(self, tmp_path): + """Verify OAuth credentials are stripped so CLI falls through to ~/.databrickscfg.""" + from sync_to_workspace import sync_project + projects = tmp_path / "projects" + projects.mkdir() + repo = projects / "my-repo" + repo.mkdir() + + mock_user = mock.MagicMock() + mock_user.user_name = "test@example.com" + mock_client = mock.MagicMock() + mock_client.current_user.me.return_value = mock_user + + with mock.patch("sync_to_workspace.Path.home", return_value=tmp_path), \ + mock.patch("sync_to_workspace._read_databrickscfg", return_value=("https://host", "tok")), \ + mock.patch("sync_to_workspace.WorkspaceClient", return_value=mock_client), \ + mock.patch("sync_to_workspace.subprocess.run") as mock_run, \ + mock.patch.dict("os.environ", { + "DATABRICKS_CLIENT_ID": "sp-id", + "DATABRICKS_CLIENT_SECRET": "sp-secret", + "DATABRICKS_HOST": "https://host", + "DATABRICKS_TOKEN": "dapi_tok", + }): + mock_run.return_value = subprocess.CompletedProcess([], 0, stdout="", stderr="") + sync_project(repo) + + call_env = mock_run.call_args[1].get("env") or mock_run.call_args.kwargs.get("env", {}) + assert "DATABRICKS_CLIENT_ID" not in call_env + assert "DATABRICKS_CLIENT_SECRET" not in call_env + assert "DATABRICKS_HOST" not in call_env + assert "DATABRICKS_TOKEN" not in call_env + + def test_logs_error_on_failure(self, tmp_path, capsys): + from sync_to_workspace import sync_project + projects = tmp_path / "projects" + projects.mkdir() + repo = projects / "my-repo" + repo.mkdir() + + with mock.patch("sync_to_workspace.Path.home", return_value=tmp_path), \ + mock.patch("sync_to_workspace.get_user_email", side_effect=Exception("auth failed")): + sync_project(repo) + + captured = capsys.readouterr() + assert "Sync failed" in captured.err + # Error should be logged to file + error_log = tmp_path / ".sync-errors.log" + assert error_log.exists() + assert "auth failed" in error_log.read_text() + + def test_sync_failure_warns(self, tmp_path, capsys): + """Non-zero return code from databricks sync should print warning.""" + from sync_to_workspace import sync_project + projects = tmp_path / "projects" + projects.mkdir() + repo = projects / "my-repo" + repo.mkdir() + + mock_user = mock.MagicMock() + mock_user.user_name = "test@example.com" + mock_client = mock.MagicMock() + mock_client.current_user.me.return_value = mock_user + + with mock.patch("sync_to_workspace.Path.home", return_value=tmp_path), \ + mock.patch("sync_to_workspace._read_databrickscfg", return_value=("https://host", "tok")), \ + mock.patch("sync_to_workspace.WorkspaceClient", return_value=mock_client), \ + mock.patch("sync_to_workspace.subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess([], 1, stdout="", stderr="permission denied") + sync_project(repo) + + captured = capsys.readouterr() + assert "Sync warning" in captured.err diff --git a/tests/test_task_manager.py b/tests/test_task_manager.py new file mode 100644 index 0000000..b9717c2 --- /dev/null +++ b/tests/test_task_manager.py @@ -0,0 +1,448 @@ +"""Tests for task_manager — disk-based MCP session/task state.""" + +import json +import os +import time +from unittest import mock + +import pytest + + +@pytest.fixture(autouse=True) +def isolated_sessions(tmp_path): + """Point task_manager.SESSIONS_DIR at a temp dir.""" + sessions_dir = str(tmp_path / ".coda" / "sessions") + with mock.patch("coda_mcp.task_manager.SESSIONS_DIR", sessions_dir): + yield sessions_dir + + +# ── helpers ────────────────────────────────────────────────────────── + + +def _read_json(path): + with open(path) as f: + return json.load(f) + + +def _read_text(path): + with open(path) as f: + return f.read() + + +def _read_jsonl(path): + lines = [] + with open(path) as f: + for line in f: + line = line.strip() + if line: + lines.append(json.loads(line)) + return lines + + +# ── Session lifecycle ──────────────────────────────────────────────── + + +class TestCreateSession: + def test_returns_session_id_and_status(self): + from coda_mcp import task_manager + + result = task_manager.create_session("a@b.com", "u1", "my-label") + assert result["status"] == "ready" + assert result["session_id"].startswith("sess-") + assert len(result["session_id"]) == 5 + 12 # "sess-" + 12 hex + + def test_creates_session_json_on_disk(self, isolated_sessions): + from coda_mcp import task_manager + + result = task_manager.create_session("a@b.com", "u1", "my-label") + sid = result["session_id"] + path = os.path.join(isolated_sessions, sid, "session.json") + assert os.path.isfile(path) + data = _read_json(path) + assert data["email"] == "a@b.com" + assert data["user_id"] == "u1" + assert data["label"] == "my-label" + assert data["status"] == "ready" + assert data["current_task"] is None + assert data["completed_tasks"] == [] + assert "created_at" in data + + def test_unique_ids(self): + from coda_mcp import task_manager + + ids = {task_manager.create_session("a@b.com", "u1")["session_id"] for _ in range(20)} + assert len(ids) == 20 + + +class TestCloseSession: + def test_marks_session_closed(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + task_manager.close_session(sid) + data = _read_json(os.path.join(isolated_sessions, sid, "session.json")) + assert data["status"] == "closed" + + def test_close_nonexistent_raises(self): + from coda_mcp import task_manager + + with pytest.raises(task_manager.SessionNotFoundError): + task_manager.close_session("sess-doesnotexist") + + +class TestReadSession: + def test_read_existing(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1", "lbl")["session_id"] + data = task_manager._read_session(sid) + assert data["email"] == "a@b.com" + + def test_read_nonexistent_raises(self): + from coda_mcp import task_manager + + with pytest.raises(task_manager.SessionNotFoundError): + task_manager._read_session("sess-000000000000") + + +class TestUpdateSessionField: + def test_updates_single_field(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + task_manager._update_session_field(sid, "status", "busy") + data = task_manager._read_session(sid) + assert data["status"] == "busy" + + def test_preserves_other_fields(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1", "lbl")["session_id"] + task_manager._update_session_field(sid, "status", "busy") + data = task_manager._read_session(sid) + assert data["email"] == "a@b.com" + assert data["label"] == "lbl" + + +# ── Task lifecycle ─────────────────────────────────────────────────── + + +class TestCreateTask: + def test_returns_task_id_and_running(self): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + result = task_manager.create_task(sid, "do something", "a@b.com") + assert result["status"] == "running" + assert result["task_id"].startswith("task-") + assert len(result["task_id"]) == 5 + 8 # "task-" + 8 hex + + def test_creates_task_directory_with_files(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task(sid, "do something", "a@b.com")["task_id"] + task_dir = task_manager._task_dir(sid, tid) + assert os.path.isdir(task_dir) + assert os.path.isfile(os.path.join(task_dir, "prompt.txt")) + assert os.path.isfile(os.path.join(task_dir, "status.jsonl")) + + def test_prompt_txt_contains_wrapped_prompt(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task(sid, "fix the bug", "a@b.com")["task_id"] + prompt = _read_text(os.path.join(task_manager._task_dir(sid, tid), "prompt.txt")) + assert "---CODA-TASK---" in prompt + assert "fix the bug" in prompt + + def test_session_marked_busy(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + task_manager.create_task(sid, "do it", "a@b.com") + data = task_manager._read_session(sid) + assert data["status"] == "busy" + + def test_session_current_task_set(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task(sid, "do it", "a@b.com")["task_id"] + data = task_manager._read_session(sid) + assert data["current_task"] == tid + + def test_busy_session_raises(self): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + task_manager.create_task(sid, "first", "a@b.com") + with pytest.raises(task_manager.SessionBusyError): + task_manager.create_task(sid, "second", "a@b.com") + + def test_nonexistent_session_raises(self): + from coda_mcp import task_manager + + with pytest.raises(task_manager.SessionNotFoundError): + task_manager.create_task("sess-doesnotexist", "p", "e@x.com") + + def test_status_jsonl_has_initial_entry(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task(sid, "go", "a@b.com")["task_id"] + entries = _read_jsonl( + os.path.join(task_manager._task_dir(sid, tid), "status.jsonl") + ) + assert len(entries) == 1 + assert entries[0]["status"] == "running" + + def test_optional_params_stored(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task( + sid, "go", "a@b.com", + context={"repo": "myrepo"}, + context_hint="look at utils.py", + timeout_s=120, + permissions=["read", "write"], + )["task_id"] + prompt = _read_text(os.path.join(task_manager._task_dir(sid, tid), "prompt.txt")) + assert "myrepo" in prompt + assert "utils.py" in prompt + + +class TestTaskDir: + def test_returns_correct_path(self, isolated_sessions): + from coda_mcp import task_manager + + path = task_manager._task_dir("sess-aabbccddee01", "task-11223344") + expected = os.path.join( + isolated_sessions, "sess-aabbccddee01", "tasks", "task-11223344" + ) + assert path == expected + + +# ── Task status / result ───────────────────────────────────────────── + + +class TestGetTaskStatus: + def test_returns_latest_status(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task(sid, "go", "a@b.com")["task_id"] + status = task_manager.get_task_status(tid, sid) + assert status["status"] == "running" + + def test_reads_appended_lines(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task(sid, "go", "a@b.com")["task_id"] + # simulate agent appending progress + status_path = os.path.join(task_manager._task_dir(sid, tid), "status.jsonl") + with open(status_path, "a") as f: + f.write(json.dumps({"status": "progress", "pct": 50, "ts": time.time()}) + "\n") + status = task_manager.get_task_status(tid, sid) + assert status["status"] == "progress" + assert status["pct"] == 50 + + def test_missing_task_returns_not_found(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + status = task_manager.get_task_status("task-nonexist", sid) + assert status["status"] == "not_found" + + +class TestGetTaskResult: + def test_returns_result_when_present(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task(sid, "go", "a@b.com")["task_id"] + # simulate agent writing result + result_path = os.path.join(task_manager._task_dir(sid, tid), "result.json") + with open(result_path, "w") as f: + json.dump({"answer": 42}, f) + result = task_manager.get_task_result(tid, sid) + assert result["answer"] == 42 + + def test_returns_none_when_absent(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task(sid, "go", "a@b.com")["task_id"] + result = task_manager.get_task_result(tid, sid) + assert result is None + + def test_missing_task_returns_none(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + result = task_manager.get_task_result("task-nonexist", sid) + assert result is None + + +# ── Complete task ───────────────────────────────────────────────────── + + +class TestCompleteTask: + def test_marks_session_closed(self, isolated_sessions): + """v2: sessions are ephemeral — complete_task auto-closes the session.""" + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task(sid, "go", "a@b.com")["task_id"] + task_manager.complete_task(sid, tid) + data = task_manager._read_session(sid) + assert data["status"] == "closed" + assert "closed_at" in data + + def test_appends_to_completed_tasks(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task(sid, "go", "a@b.com")["task_id"] + task_manager.complete_task(sid, tid) + data = task_manager._read_session(sid) + assert tid in data["completed_tasks"] + + def test_closed_session_rejects_new_task(self, isolated_sessions): + """v2: ephemeral sessions — new tasks need new sessions.""" + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid1 = task_manager.create_task(sid, "first", "a@b.com")["task_id"] + task_manager.complete_task(sid, tid1) + with pytest.raises(task_manager.SessionNotFoundError): + task_manager.create_task(sid, "second", "a@b.com") + + def test_appends_done_to_status_jsonl(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task(sid, "go", "a@b.com")["task_id"] + task_manager.complete_task(sid, tid) + entries = _read_jsonl( + os.path.join(task_manager._task_dir(sid, tid), "status.jsonl") + ) + assert entries[-1]["status"] == "done" + + def test_nonexistent_session_raises(self): + from coda_mcp import task_manager + + with pytest.raises(task_manager.SessionNotFoundError): + task_manager.complete_task("sess-doesnotexist", "task-00000000") + + +# ── Prompt wrapping ────────────────────────────────────────────────── + + +class TestWrapPrompt: + def test_contains_marker(self): + from coda_mcp import task_manager + + wrapped = task_manager.wrap_prompt( + task_id="task-aabbccdd", + session_id="sess-112233445566", + email="a@b.com", + prompt="fix the bug", + context=None, + results_dir="/tmp/r", + context_hint=None, + ) + assert "---CODA-TASK---" in wrapped + assert "fix the bug" in wrapped + assert "task-aabbccdd" in wrapped + assert "sess-112233445566" in wrapped + assert "a@b.com" in wrapped + assert "/tmp/r" in wrapped + + def test_includes_context_when_provided(self): + from coda_mcp import task_manager + + wrapped = task_manager.wrap_prompt( + task_id="task-aabbccdd", + session_id="sess-112233445566", + email="a@b.com", + prompt="go", + context={"repo": "myrepo", "branch": "main"}, + results_dir="/tmp/r", + context_hint=None, + ) + assert "myrepo" in wrapped + assert "main" in wrapped + + def test_includes_context_hint(self): + from coda_mcp import task_manager + + wrapped = task_manager.wrap_prompt( + task_id="task-aabbccdd", + session_id="sess-112233445566", + email="a@b.com", + prompt="go", + context=None, + results_dir="/tmp/r", + context_hint="look at utils.py first", + ) + assert "look at utils.py first" in wrapped + + def test_no_context_still_valid(self): + from coda_mcp import task_manager + + wrapped = task_manager.wrap_prompt( + task_id="task-aabbccdd", + session_id="sess-112233445566", + email="a@b.com", + prompt="hello", + context=None, + results_dir="/tmp/r", + context_hint=None, + ) + assert "---CODA-TASK---" in wrapped + assert "hello" in wrapped + + +# ── Edge cases ──────────────────────────────────────────────────────── + + +class TestEdgeCases: + def test_closed_session_rejects_task(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + task_manager.close_session(sid) + with pytest.raises(task_manager.SessionNotFoundError): + task_manager.create_task(sid, "go", "a@b.com") + + def test_multiple_tasks_across_sessions(self, isolated_sessions): + """v2: each task gets its own ephemeral session; all appear in list_all_tasks.""" + from coda_mcp import task_manager + + tids = [] + for i in range(3): + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + tid = task_manager.create_task(sid, f"task {i}", "a@b.com")["task_id"] + task_manager.complete_task(sid, tid) + tids.append(tid) + # Each session auto-closes + data = task_manager._read_session(sid) + assert data["status"] == "closed" + + all_tasks = task_manager.list_all_tasks() + all_tids = [t["task_id"] for t in all_tasks] + for tid in tids: + assert tid in all_tids + + def test_corrupt_session_json_raises(self, isolated_sessions): + from coda_mcp import task_manager + + sid = task_manager.create_session("a@b.com", "u1")["session_id"] + path = os.path.join(isolated_sessions, sid, "session.json") + with open(path, "w") as f: + f.write("{bad json") + with pytest.raises(task_manager.SessionNotFoundError): + task_manager._read_session(sid) diff --git a/tools/coda-bridge.py b/tools/coda-bridge.py new file mode 100644 index 0000000..c67b54c --- /dev/null +++ b/tools/coda-bridge.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +"""Stdio-to-HTTP MCP bridge with Databricks OAuth token injection. + +Proxies MCP JSON-RPC (stdio) to a Databricks App (Streamable HTTP), +injecting fresh OAuth tokens via `databricks auth token`. + +Config via environment variables (set in Claude Code settings.json): + + CODA_MCP_URL — App MCP endpoint URL + DATABRICKS_PROFILE — Databricks CLI profile for auth +""" + +import json +import os +import subprocess +import sys +import time +import urllib.request +import urllib.error + +APP_URL = os.environ.get("CODA_MCP_URL", "") +PROFILE = os.environ.get("DATABRICKS_PROFILE", "DEFAULT") +TOKEN_TTL = 1800 # cache 30 min (tokens last 60) + +_cache = {"token": None, "expires_at": 0.0} +_session_id = None + + +def _log(msg): + print(f"[coda-bridge] {msg}", file=sys.stderr, flush=True) + + +def _get_token(force=False): + now = time.time() + if not force and _cache["token"] and now < _cache["expires_at"]: + return _cache["token"] + result = subprocess.run( + ["databricks", "auth", "token", "-p", PROFILE], + capture_output=True, text=True, timeout=15, + ) + if result.returncode != 0: + raise RuntimeError(f"databricks auth token failed: {result.stderr.strip()}") + data = json.loads(result.stdout) + _cache["token"] = data["access_token"] + _cache["expires_at"] = now + TOKEN_TTL + _log("OAuth token refreshed") + return _cache["token"] + + +def _forward(line): + global _session_id + token = _get_token() + + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "Authorization": f"Bearer {token}", + } + if _session_id: + headers["Mcp-Session-Id"] = _session_id + + req = urllib.request.Request(APP_URL, data=line.encode(), headers=headers, method="POST") + try: + with urllib.request.urlopen(req, timeout=300) as resp: + sid = resp.headers.get("Mcp-Session-Id") + if sid: + _session_id = sid + body = resp.read().decode() + if body.strip(): + sys.stdout.write(body.rstrip("\n") + "\n") + sys.stdout.flush() + except urllib.error.HTTPError as e: + if e.code in (302, 401, 403): + _log(f"Auth failed ({e.code}), forcing token refresh") + token = _get_token(force=True) + headers["Authorization"] = f"Bearer {token}" + retry = urllib.request.Request(APP_URL, data=line.encode(), headers=headers, method="POST") + with urllib.request.urlopen(retry, timeout=300) as resp: + sid = resp.headers.get("Mcp-Session-Id") + if sid: + _session_id = sid + body = resp.read().decode() + if body.strip(): + sys.stdout.write(body.rstrip("\n") + "\n") + sys.stdout.flush() + else: + raise + + +def main(): + if not APP_URL: + _log("FATAL: CODA_MCP_URL not set") + sys.exit(1) + _log(f"Proxying to {APP_URL} (profile={PROFILE})") + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + _forward(line) + except Exception as e: + _log(f"Error: {e}") + try: + msg_id = json.loads(line).get("id") + except Exception: + msg_id = None + if msg_id is not None: + err = json.dumps({ + "jsonrpc": "2.0", + "id": msg_id, + "error": {"code": -32000, "message": str(e)}, + }) + sys.stdout.write(err + "\n") + sys.stdout.flush() + + +if __name__ == "__main__": + main()