Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ucode/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def validate_tool(tool: str) -> tuple[bool, str]:
def validate_all_tools(state: dict) -> None:
from rich.panel import Panel # local to avoid bumping module-level deps

from ucode.agents.pi import PI_SETTINGS_BACKUP_PATH, PI_SETTINGS_PATH
from ucode.config_io import restore_file

console.print()
Expand All @@ -411,6 +412,9 @@ def validate_all_tools(state: dict) -> None:
print_err(f"{spec['display']}: {err}")
managed = bool(state.get("managed_configs", {}).get(tool))
restore_file(spec["config_path"], spec["backup_path"], managed)
# Rollback settings.json for Pi
if tool == "pi":
restore_file(PI_SETTINGS_PATH, PI_SETTINGS_BACKUP_PATH, managed)
available_tools.remove(tool)
state["available_tools"] = available_tools
save_state(state)
Expand Down
16 changes: 16 additions & 0 deletions src/ucode/agents/pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
PI_UCODE_HOME = APP_DIR / "pi-home"
PI_CONFIG_DIR = PI_UCODE_HOME / ".pi" / "agent"
PI_CONFIG_PATH = PI_CONFIG_DIR / "models.json"
PI_SETTINGS_PATH = PI_CONFIG_DIR / "settings.json"
PI_BACKUP_PATH = APP_DIR / "pi-models.backup.json"
PI_SETTINGS_BACKUP_PATH = APP_DIR / "pi-settings.backup.json"

SPEC: ToolSpec = {
"binary": "pi",
Expand Down Expand Up @@ -184,11 +186,25 @@ def write_tool_config(
providers.pop(stale, None)
merged = deep_merge_dict(existing, overlay)
write_json_file(PI_CONFIG_PATH, merged)
_write_settings(overlay["model"])
state = mark_tool_managed(state, "pi", managed_keys)
save_state(state)
return state, token


def _write_settings(model_selector: str) -> None:
# Pin defaultProvider/defaultModel in settings.json so Pi doesn't fall
# through to an env-key-backed provider (e.g. HF_TOKEN exposing
# huggingface) in `findInitialModel` when no --model is passed.
provider, _, model_id = model_selector.partition("/")
if not model_id:
return
backup_existing_file(PI_SETTINGS_PATH, PI_SETTINGS_BACKUP_PATH)
existing = read_json_safe(PI_SETTINGS_PATH)
merged = deep_merge_dict(existing, {"defaultProvider": provider, "defaultModel": model_id})
write_json_file(PI_SETTINGS_PATH, merged)


def default_model(state: dict) -> str | None:
"""Prefer Claude opus → sonnet → haiku; fall back to codex, gemini."""
claude_models = state.get("claude_models") or {}
Expand Down
5 changes: 5 additions & 0 deletions src/ucode/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ucode.agents import (
launch as launch_agent,
)
from ucode.agents.pi import PI_SETTINGS_BACKUP_PATH, PI_SETTINGS_PATH
from ucode.config_io import restore_file, set_dry_run
from ucode.databricks import (
build_shared_base_urls,
Expand Down Expand Up @@ -398,12 +399,16 @@ def revert() -> int:
)
for tool, spec in TOOL_SPECS.items()
}
pi_settings_restored = restore_file(
PI_SETTINGS_PATH, PI_SETTINGS_BACKUP_PATH, bool(managed_configs.get("pi"))
)
clear_state()

print_heading("Revert")
print_kv("Workspace", state.get("workspace") or "none")
for tool, spec in TOOL_SPECS.items():
print_kv(f"{spec['display']} config", "restored" if results[tool] else "unchanged")
print_kv("Pi settings", "restored" if pi_settings_restored else "unchanged")
for client, spec in MCP_CLIENTS.items():
print_kv(
f"{spec['display']} MCP config",
Expand Down
71 changes: 67 additions & 4 deletions tests/test_agent_pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import json
from contextlib import nullcontext
from unittest.mock import patch

from ucode.agents import pi
Expand Down Expand Up @@ -267,9 +268,13 @@ def _setup(self, tmp_path, monkeypatch):
monkeypatch.setattr(config_io_mod, "APP_DIR", tmp_path)
config_file = tmp_path / "models.json"
backup_file = tmp_path / "pi-backup.json"
settings_file = tmp_path / "settings.json"
settings_backup_file = tmp_path / "pi-settings-backup.json"
monkeypatch.setattr(pi_mod, "PI_CONFIG_PATH", config_file)
monkeypatch.setattr(pi_mod, "PI_SETTINGS_PATH", settings_file)
monkeypatch.setattr(pi_mod, "PI_BACKUP_PATH", backup_file)
return pi_mod, config_file
monkeypatch.setattr(pi_mod, "PI_SETTINGS_BACKUP_PATH", settings_backup_file)
return pi_mod, config_file, settings_file, settings_backup_file

def _state(self, **overrides) -> dict:
state = {
Expand All @@ -284,7 +289,7 @@ def _state(self, **overrides) -> dict:
return state

def test_stale_managed_providers_removed_before_merge(self, tmp_path, monkeypatch):
pi_mod, config_file = self._setup(tmp_path, monkeypatch)
pi_mod, config_file, _, _ = self._setup(tmp_path, monkeypatch)

stale = {
"providers": {
Expand Down Expand Up @@ -312,7 +317,7 @@ def test_legacy_providers_removed_on_upgrade(self, tmp_path, monkeypatch):
"""Earlier ucode versions wrote `databricks-anthropic`, `databricks-codex`,
and `databricks-oss` providers. They must be stripped on the next write
so users don't end up with stale entries pointing at routes that 400."""
pi_mod, config_file = self._setup(tmp_path, monkeypatch)
pi_mod, config_file, _, _ = self._setup(tmp_path, monkeypatch)

config_file.write_text(
json.dumps(
Expand All @@ -339,7 +344,7 @@ def test_legacy_providers_removed_on_upgrade(self, tmp_path, monkeypatch):
assert "databricks-claude" in written_providers

def test_config_written_with_correct_model_and_token(self, tmp_path, monkeypatch):
pi_mod, config_file = self._setup(tmp_path, monkeypatch)
pi_mod, config_file, _, _ = self._setup(tmp_path, monkeypatch)

with (
patch("ucode.agents.pi.get_databricks_token", return_value="tok"),
Expand All @@ -350,3 +355,61 @@ def test_config_written_with_correct_model_and_token(self, tmp_path, monkeypatch
written = json.loads(config_file.read_text())
assert written["model"] == "databricks-claude/claude-sonnet"
assert written["providers"]["databricks-claude"]["apiKey"] == "tok"

def test_settings_pins_default_provider_and_model(self, tmp_path, monkeypatch):
# Without this, Pi's `findInitialModel` can fall through to a built-in
# provider when an unrelated env var (e.g. HF_TOKEN) makes one look
# auth-configured. Pinning the default keeps Pi on our provider.
pi_mod, _, settings_file, _ = self._setup(tmp_path, monkeypatch)

with (
patch("ucode.agents.pi.get_databricks_token", return_value="tok"),
patch("ucode.agents.pi.save_state"),
):
pi_mod.write_tool_config(self._state(), "claude-sonnet", token="tok")

settings = json.loads(settings_file.read_text())
assert settings["defaultProvider"] == "databricks-claude"
assert settings["defaultModel"] == "claude-sonnet"

def test_pre_existing_settings_are_backed_up_before_first_write(self, tmp_path, monkeypatch):
pi_mod, _, settings_file, settings_backup_file = self._setup(tmp_path, monkeypatch)

original = '{"theme": "Default Dark", "defaultProvider": "openai"}'
settings_file.parent.mkdir(parents=True, exist_ok=True)
settings_file.write_text(original, encoding="utf-8")

with (
patch("ucode.agents.pi.get_databricks_token", return_value="tok"),
patch("ucode.agents.pi.save_state"),
):
pi_mod.write_tool_config(self._state(), "claude-sonnet", token="tok")

assert settings_backup_file.read_text(encoding="utf-8") == original
# The on-disk settings still get the ucode pin applied via deep_merge.
merged = json.loads(settings_file.read_text())
assert merged["defaultProvider"] == "databricks-claude"
assert merged["theme"] == "Default Dark"


class TestValidateAllToolsPiRollback:
def test_failed_pi_validation_rolls_back_settings(self, tmp_path, monkeypatch):
import ucode.agents as agents_mod
import ucode.agents.pi as pi_mod

settings_file = tmp_path / "settings.json"
settings_file.write_text("{}", encoding="utf-8")
monkeypatch.setattr(pi_mod, "PI_SETTINGS_PATH", settings_file)
monkeypatch.setattr(pi_mod, "PI_SETTINGS_BACKUP_PATH", tmp_path / "settings.backup.json")
# Keep the generic models.json rollback off the user's real config dir.
monkeypatch.setitem(agents_mod.TOOL_SPECS["pi"], "config_path", tmp_path / "models.json")
monkeypatch.setitem(
agents_mod.TOOL_SPECS["pi"], "backup_path", tmp_path / "models.backup.json"
)
monkeypatch.setattr(agents_mod, "validate_tool", lambda tool: (False, "boom"))
monkeypatch.setattr(agents_mod, "save_state", lambda s: None)
monkeypatch.setattr(agents_mod, "spinner", lambda *_a, **_kw: nullcontext())

agents_mod.validate_all_tools({"available_tools": ["pi"], "managed_configs": {"pi": True}})

assert not settings_file.exists()
Loading