Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/ucode/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@
normalize_workspace_url,
run_databricks_login,
)
from ucode.mcp import MCP_CLIENTS, configure_mcp_command, revert_mcp_configs
from ucode.mcp import (
MCP_CLIENTS,
configure_mcp_command,
purge_cross_workspace_mcp_residue,
revert_mcp_configs,
)
from ucode.state import STATE_PATH, clear_state, load_state, save_state
from ucode.ui import (
console,
Expand Down Expand Up @@ -151,6 +156,7 @@ def configure_shared_state(
don't error out. If ``None``, we resolve it from the host after login.
"""
workspace = normalize_workspace_url(workspace)
previous_workspace = load_state().get("workspace")
fetch_all = tools is None
if force_login:
run_databricks_login(workspace, profile)
Expand Down Expand Up @@ -210,6 +216,10 @@ def configure_shared_state(
if fetch_all or "opencode" in tools:
state["opencode_models"] = opencode_models
save_state(state)
# Scrub MCP entries that ucode wrote for the previous workspace so the new
# workspace's agent configs aren't stale.
if previous_workspace and previous_workspace != workspace:
purge_cross_workspace_mcp_residue(state, workspace)
# Diagnostic reasons are transient — attach after save_state so they don't
# land on disk but are available to the caller for this run.
state["_discovery_reasons"] = {
Expand Down
119 changes: 118 additions & 1 deletion src/ucode/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import subprocess
from collections.abc import Callable
from typing import Any
from urllib.parse import urlparse

import questionary
from prompt_toolkit.application import Application
Expand All @@ -29,8 +30,9 @@
list_databricks_apps,
list_databricks_connections,
list_genie_spaces,
workspace_hostname,
)
from ucode.state import load_state, save_state
from ucode.state import load_full_state, load_state, save_state
from ucode.ui import (
print_note,
print_section,
Expand Down Expand Up @@ -432,6 +434,61 @@ def _servers_by_name(mcp_servers: list[dict]) -> dict[str, dict]:
return servers


def _mcp_entry_url_host(entry: dict) -> str | None:
"""Return the host of an MCP entry's URL, or ``None`` if missing/malformed."""
url = entry.get("url")
if not isinstance(url, str) or not url:
return None
try:
return urlparse(url).hostname
except ValueError:
return None


def _partition_mcp_entries_by_workspace(
entries: list[dict], workspace: str
) -> tuple[list[dict], list[dict]]:
"""Split MCP entries into ones that belong to ``workspace`` and ones that don't."""
workspace_host = workspace_hostname(workspace)
current: list[dict] = []
foreign: list[dict] = []
for entry in entries:
if _mcp_entry_url_host(entry) == workspace_host:
current.append(entry)
else:
foreign.append(entry)
return current, foreign


def _mcp_entries_only_in_other_workspaces(current_workspace: str) -> dict[str, set[str]]:
"""Return ``{name: {client, ...}}`` for MCPs ucode tracks only in workspaces other than ``current_workspace``."""
full_state = load_full_state()
workspaces = full_state.get("workspaces")
if not isinstance(workspaces, dict):
return {}

current_names: set[str] = set()
current_bucket = workspaces.get(current_workspace)
if isinstance(current_bucket, dict):
for entry in current_bucket.get("mcp_servers") or []:
name = _server_name(entry)
if name:
current_names.add(name)

external_entries: dict[str, set[str]] = {}
for ws, bucket in workspaces.items():
if ws == current_workspace or not isinstance(bucket, dict):
continue
for entry in bucket.get("mcp_servers") or []:
name = _server_name(entry)
if not name or name in current_names:
continue
client_set = external_entries.setdefault(name, set())
for client in entry.get("clients") or []:
client_set.add(client)
return external_entries


def _server_choice(name: str, checked: bool, title: str | None = None) -> questionary.Choice:
return questionary.Choice(
title=title or name,
Expand Down Expand Up @@ -743,12 +800,72 @@ def apply_mcp_server_changes(
return changed


def purge_cross_workspace_mcp_residue(state: dict, workspace: str) -> None:
installed = set(available_mcp_clients())

raw_mcp_servers = list(state.get("mcp_servers") or [])
current_mcp_servers, foreign_mcp_servers = _partition_mcp_entries_by_workspace(
raw_mcp_servers, workspace
)
if foreign_mcp_servers:
foreign_names = ", ".join(
(_server_name(server) or "(unnamed)") for server in foreign_mcp_servers
)
noun = "entry" if len(foreign_mcp_servers) == 1 else "entries"
print_warning(
f"Dropping {len(foreign_mcp_servers)} stale MCP {noun} "
f"not bound to this workspace: {foreign_names}."
)
for server in foreign_mcp_servers:
name = _server_name(server)
if not name:
continue
for client in server.get("clients") or []:
if client not in installed or client not in MCP_CLIENTS:
continue
try:
remove_client_mcp_server(client, name)
except RuntimeError as exc:
print_warning(
f"Failed to remove `{name}` from {MCP_CLIENTS[client]['display']}: {exc}"
)
state["mcp_servers"] = current_mcp_servers
save_state(state)

other_ws_mcps = _mcp_entries_only_in_other_workspaces(workspace)
actually_removed: list[str] = []
for name in sorted(other_ws_mcps):
any_removed = False
for client in other_ws_mcps[name]:
if client not in installed or client not in MCP_CLIENTS:
continue
try:
removed_scopes = remove_client_mcp_server(client, name)
except RuntimeError as exc:
print_warning(
f"Failed to remove `{name}` from {MCP_CLIENTS[client]['display']}: {exc}"
)
continue
if removed_scopes:
any_removed = True
if any_removed:
actually_removed.append(name)
if actually_removed:
noun = "entry" if len(actually_removed) == 1 else "entries"
print_warning(
f"Removed {len(actually_removed)} MCP {noun} left over from "
f"previously-configured workspaces: {', '.join(actually_removed)}."
)


def configure_mcp_command() -> int:
state = load_state()
workspace = state.get("workspace")
if not workspace:
raise RuntimeError("Workspace is not configured. Run `ucode configure` first.")

purge_cross_workspace_mcp_residue(state, workspace)

installed_clients = available_mcp_clients()
if not installed_clients:
raise RuntimeError(
Expand Down
58 changes: 58 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,3 +646,61 @@ def fake_configure_shared_state(workspace, profile=None, tools=None, force_login
]
assert saved == ["https://first.com"]
assert configured_tools == [("https://first.com", ["codex"])]


class TestConfigureSharedStateMcpCleanup:
"""A workspace switch should scrub the previous workspace's MCP entries from
installed client configs. Switching to the same workspace must not."""

@staticmethod
def _stub_external_deps(monkeypatch):
import ucode.cli as cli_mod

monkeypatch.setattr(cli_mod, "normalize_workspace_url", lambda w: w)
monkeypatch.setattr(cli_mod, "run_databricks_login", lambda w, p: None)
monkeypatch.setattr(cli_mod, "ensure_databricks_auth", lambda w, p=None: None)
monkeypatch.setattr(cli_mod, "find_profile_name_for_host", lambda w: None)
monkeypatch.setattr(cli_mod, "get_databricks_token", lambda w, p: "token")
monkeypatch.setattr(cli_mod, "ensure_ai_gateway_v2", lambda w, t: None)
monkeypatch.setattr(cli_mod, "discover_claude_models", lambda w, t: ({}, None))
monkeypatch.setattr(cli_mod, "discover_gemini_models", lambda w, t: ([], None))
monkeypatch.setattr(cli_mod, "discover_codex_models", lambda w, t: ([], None))
monkeypatch.setattr(cli_mod, "build_shared_base_urls", lambda w: {})

def test_purges_residue_when_workspace_changes(self, monkeypatch):
import ucode.cli as cli_mod

self._stub_external_deps(monkeypatch)
monkeypatch.setattr(
cli_mod, "load_state", lambda: {"workspace": "https://old.databricks.com"}
)
purge_calls: list[tuple[dict, str]] = []
monkeypatch.setattr(
cli_mod,
"purge_cross_workspace_mcp_residue",
lambda state, workspace: purge_calls.append((state, workspace)),
)

cli_mod.configure_shared_state("https://new.databricks.com")

assert len(purge_calls) == 1
_, called_workspace = purge_calls[0]
assert called_workspace == "https://new.databricks.com"

def test_skips_purge_when_workspace_unchanged(self, monkeypatch):
import ucode.cli as cli_mod

self._stub_external_deps(monkeypatch)
monkeypatch.setattr(
cli_mod, "load_state", lambda: {"workspace": "https://same.databricks.com"}
)
purge_calls: list = []
monkeypatch.setattr(
cli_mod,
"purge_cross_workspace_mcp_residue",
lambda state, workspace: purge_calls.append((state, workspace)),
)

cli_mod.configure_shared_state("https://same.databricks.com")

assert purge_calls == []
Loading
Loading