Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/ucode/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def _extract_connection_page(payload: object) -> tuple[list[dict], str | None]:
return [item for item in raw_connections if isinstance(item, dict)], next_page_token


def list_databricks_connections(workspace: str) -> list[dict]:
def list_databricks_connections(workspace: str, profile: str | None = None) -> list[dict]:
env = build_databricks_cli_env(workspace)
connections: list[dict] = []
page_token: str | None = None
Expand All @@ -640,6 +640,7 @@ def list_databricks_connections(workspace: str) -> list[dict]:
"databricks",
"connections",
"list",
*_profile_args(profile),
"--max-results",
"0",
"--output",
Expand Down Expand Up @@ -690,7 +691,7 @@ def _extract_genie_spaces_page(payload: object) -> tuple[list[dict], str | None]
return [item for item in raw_spaces if isinstance(item, dict)], next_page_token


def list_genie_spaces(workspace: str) -> list[dict]:
def list_genie_spaces(workspace: str, profile: str | None = None) -> list[dict]:
env = build_databricks_cli_env(workspace)
spaces: list[dict] = []
page_token: str | None = None
Expand All @@ -702,6 +703,7 @@ def list_genie_spaces(workspace: str) -> list[dict]:
"databricks",
"genie",
"list-spaces",
*_profile_args(profile),
"--page-size",
"100",
"--output",
Expand Down Expand Up @@ -749,14 +751,15 @@ def _extract_apps_payload(payload: object) -> list[dict]:
raise RuntimeError("Databricks apps listing returned invalid JSON.")


def list_databricks_apps(workspace: str) -> list[dict]:
def list_databricks_apps(workspace: str, profile: str | None = None) -> list[dict]:
env = build_databricks_cli_env(workspace)
try:
result = run(
[
"databricks",
"apps",
"list",
*_profile_args(profile),
"--limit",
"1000",
"--output",
Expand Down
53 changes: 29 additions & 24 deletions src/ucode/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ def external_mcp_connection_names(connections: list[dict]) -> list[str]:
return sorted(names)


def discover_external_mcp_connection_names(workspace: str) -> list[str]:
return external_mcp_connection_names(list_databricks_connections(workspace))
def discover_external_mcp_connection_names(workspace: str, profile: str | None = None) -> list[str]:
return external_mcp_connection_names(list_databricks_connections(workspace, profile))


def genie_mcp_servers(spaces: list[dict], workspace: str) -> list[dict]:
Expand All @@ -372,8 +372,8 @@ def genie_mcp_servers(spaces: list[dict], workspace: str) -> list[dict]:
return sorted(servers, key=lambda server: str(server["title"]).lower())


def discover_genie_mcp_servers(workspace: str) -> list[dict]:
return genie_mcp_servers(list_genie_spaces(workspace), workspace)
def discover_genie_mcp_servers(workspace: str, profile: str | None = None) -> list[dict]:
return genie_mcp_servers(list_genie_spaces(workspace, profile), workspace)


def app_mcp_servers(apps: list[dict]) -> list[dict]:
Expand Down Expand Up @@ -403,8 +403,8 @@ def app_mcp_servers(apps: list[dict]) -> list[dict]:
return sorted(servers, key=lambda server: str(server["title"]).lower())


def discover_app_mcp_servers(workspace: str) -> list[dict]:
return app_mcp_servers(list_databricks_apps(workspace))
def discover_app_mcp_servers(workspace: str, profile: str | None = None) -> list[dict]:
return app_mcp_servers(list_databricks_apps(workspace, profile))


def _picker_style() -> questionary.Style:
Expand Down Expand Up @@ -674,35 +674,35 @@ def _resolve_mcp_selection(
selection: str,
workspace: str,
available_app_servers: list[dict] | None = None,
) -> tuple[str, str] | None:
) -> tuple[str, str]:
if selection.startswith(APP_MCP_SELECTION_PREFIX):
app_name = selection.removeprefix(APP_MCP_SELECTION_PREFIX)
if not app_name:
return None
raise RuntimeError("missing Databricks app name")
server = _servers_by_name(available_app_servers or []).get(f"databricks-app-{app_name}")
if not server:
return None
raise RuntimeError(f"Databricks app `{app_name}` was not in the discovered app list")
url = server.get("url")
if not isinstance(url, str) or not url:
return None
raise RuntimeError(f"Databricks app `{app_name}` has no MCP URL")
return f"databricks-app-{app_name}", url

if selection.startswith(GENIE_SPACE_SELECTION_PREFIX):
space_id = selection.removeprefix(GENIE_SPACE_SELECTION_PREFIX)
if not space_id:
return None
raise RuntimeError("missing Genie space id")
return f"databricks-genie-{space_id}", f"{workspace}/api/2.0/mcp/genie/{space_id}"

if selection.startswith(EXTERNAL_MCP_SELECTION_PREFIX):
server_name = selection.removeprefix(EXTERNAL_MCP_SELECTION_PREFIX)
if not server_name:
return None
raise RuntimeError("missing external connection name")
return server_name, f"{workspace}/api/2.0/mcp/external/{server_name}"

if selection == SQL_MCP_VALUE:
return "databricks-sql", f"{workspace}/api/2.0/mcp/sql"

return None
raise RuntimeError(f"unrecognized selection prefix in `{selection}`")


def _discover_mcp_source(label: str, discover: Callable[[], list[Any]]) -> list[Any]:
Expand Down Expand Up @@ -766,7 +766,8 @@ def configure_mcp_command() -> int:
client for client in MCP_CLIENTS if client in configured_tools and client not in clients
]

ensure_databricks_auth(workspace, state.get("profile"))
profile = state.get("profile")
ensure_databricks_auth(workspace, profile)

print_section("MCP Servers")
client_names = ", ".join(str(MCP_CLIENTS[client]["display"]) for client in clients)
Expand All @@ -779,15 +780,15 @@ def configure_mcp_command() -> int:

available_external_mcp_names = _discover_mcp_source(
"external connections",
lambda: discover_external_mcp_connection_names(workspace),
lambda: discover_external_mcp_connection_names(workspace, profile),
)
available_genie_mcp_servers = _discover_mcp_source(
"Genie spaces",
lambda: discover_genie_mcp_servers(workspace),
lambda: discover_genie_mcp_servers(workspace, profile),
)
available_app_mcp_servers = _discover_mcp_source(
"Databricks apps",
lambda: discover_app_mcp_servers(workspace),
lambda: discover_app_mcp_servers(workspace, profile),
)

original_mcp_servers: list[dict] = list(state.get("mcp_servers") or [])
Expand All @@ -814,14 +815,15 @@ def configure_mcp_command() -> int:
working_names.add(selection)

for selection in add_selections:
resolved = _resolve_mcp_selection(
selection,
workspace,
available_app_mcp_servers,
)
if resolved is None:
try:
entry_name, url = _resolve_mcp_selection(
selection,
workspace,
available_app_mcp_servers,
)
except RuntimeError as exc:
print_warning(f"Skipped MCP selection `{selection}`: {exc}.")
continue
entry_name, url = resolved
if entry_name in working_names:
continue
working_mcp_servers.append(
Expand All @@ -839,4 +841,7 @@ def configure_mcp_command() -> int:
state["mcp_servers"] = working_mcp_servers
save_state(state)
print_success("Saved")
elif not selections and not original_mcp_servers:
# User submitted the picker without toggling anything --> make it clear nothing was selected
print_note("No MCP servers selected. Press space to toggle an item, then enter to save.")
return 0
42 changes: 42 additions & 0 deletions tests/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,20 @@ def fake_run(args, **kwargs):
assert calls[0]["kwargs"]["env"]["DATABRICKS_HOST"] == WS
assert calls[1]["args"][-2:] == ["--page-token", "next-page"]

def test_passes_profile_when_provided(self, monkeypatch):
calls: list[list[str]] = []

def fake_run(args, **kwargs):
calls.append(args)
return subprocess.CompletedProcess(args, 0, stdout=json.dumps({"connections": []}))

monkeypatch.setattr(db_mod, "run", fake_run)

list_databricks_connections(WS, "my-profile")

assert "--profile" in calls[0]
assert calls[0][calls[0].index("--profile") + 1] == "my-profile"

def test_raises_on_invalid_json(self, monkeypatch):
def fake_run(args, **kwargs):
return subprocess.CompletedProcess(args, 0, stdout="not-json")
Expand Down Expand Up @@ -418,6 +432,20 @@ def fake_run(args, **kwargs):
assert calls[0]["kwargs"]["env"]["DATABRICKS_HOST"] == WS
assert calls[1]["args"][-2:] == ["--page-token", "next-page"]

def test_passes_profile_when_provided(self, monkeypatch):
calls: list[list[str]] = []

def fake_run(args, **kwargs):
calls.append(args)
return subprocess.CompletedProcess(args, 0, stdout=json.dumps({"spaces": []}))

monkeypatch.setattr(db_mod, "run", fake_run)

list_genie_spaces(WS, "my-profile")

assert "--profile" in calls[0]
assert calls[0][calls[0].index("--profile") + 1] == "my-profile"

def test_raises_on_invalid_json(self, monkeypatch):
def fake_run(args, **kwargs):
return subprocess.CompletedProcess(args, 0, stdout="not-json")
Expand Down Expand Up @@ -461,6 +489,20 @@ def fake_run(args, **kwargs):
]
assert calls[0]["kwargs"]["env"]["DATABRICKS_HOST"] == WS

def test_passes_profile_when_provided(self, monkeypatch):
calls: list[list[str]] = []

def fake_run(args, **kwargs):
calls.append(args)
return subprocess.CompletedProcess(args, 0, stdout=json.dumps([]))

monkeypatch.setattr(db_mod, "run", fake_run)

list_databricks_apps(WS, "my-profile")

assert "--profile" in calls[0]
assert calls[0][calls[0].index("--profile") + 1] == "my-profile"

def test_accepts_object_wrapped_apps(self, monkeypatch):
def fake_run(args, **kwargs):
return subprocess.CompletedProcess(
Expand Down
Loading
Loading