diff --git a/setup_claude.py b/setup_claude.py index 9815ef5..ab2ed87 100644 --- a/setup_claude.py +++ b/setup_claude.py @@ -4,7 +4,7 @@ import subprocess from pathlib import Path -from utils import ensure_https, get_gateway_host +from utils import discover_serving_endpoints, ensure_https, get_gateway_host, pick_in_geo_model # Set HOME if not properly set if not os.environ.get("HOME") or os.environ["HOME"] == "/": @@ -40,13 +40,46 @@ else: settings = {} + # Discover models actually served at this workspace. The direct serving- + # endpoints list reflects Databricks Geo Designated Services policy — a + # workspace in AU only sees in-geo models, etc. Validating env-set defaults + # against this list avoids configuring Claude Code with a model the gateway + # claims to serve but the user's geo can't access. + available = discover_serving_endpoints(databricks_host, token) + if available: + print(f"Discovered {len(available)} READY serving endpoints at workspace") + + requested_model = os.environ.get("ANTHROPIC_MODEL", "databricks-claude-opus-4-7") + active_model = pick_in_geo_model( + [requested_model, "databricks-claude-opus-4-6", "databricks-claude-sonnet-4-6"], + available, + fallback=requested_model, + ) + opus_model = pick_in_geo_model( + ["databricks-claude-opus-4-7", "databricks-claude-opus-4-6"], + available, + fallback="databricks-claude-opus-4-7", + ) + sonnet_model = pick_in_geo_model( + ["databricks-claude-sonnet-4-6", "databricks-claude-sonnet-4-5"], + available, + fallback="databricks-claude-sonnet-4-6", + ) + haiku_model = pick_in_geo_model( + ["databricks-claude-haiku-4-5"], + available, + fallback="databricks-claude-haiku-4-5", + ) + if available and active_model != requested_model: + print(f"ANTHROPIC_MODEL={requested_model} not served at this workspace, using {active_model}") + settings.setdefault("env", {}) - settings["env"]["ANTHROPIC_MODEL"] = os.environ.get("ANTHROPIC_MODEL", "databricks-claude-opus-4-7") + settings["env"]["ANTHROPIC_MODEL"] = active_model settings["env"]["ANTHROPIC_BASE_URL"] = anthropic_base_url settings["env"]["ANTHROPIC_AUTH_TOKEN"] = token - settings["env"]["ANTHROPIC_DEFAULT_OPUS_MODEL"] = "databricks-claude-opus-4-7" - settings["env"]["ANTHROPIC_DEFAULT_SONNET_MODEL"] = "databricks-claude-sonnet-4-6" - settings["env"]["ANTHROPIC_DEFAULT_HAIKU_MODEL"] = "databricks-claude-haiku-4-5" + settings["env"]["ANTHROPIC_DEFAULT_OPUS_MODEL"] = opus_model + settings["env"]["ANTHROPIC_DEFAULT_SONNET_MODEL"] = sonnet_model + settings["env"]["ANTHROPIC_DEFAULT_HAIKU_MODEL"] = haiku_model settings["env"]["ANTHROPIC_CUSTOM_HEADERS"] = "x-databricks-use-coding-agent-mode: true" settings["env"]["CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS"] = "1" diff --git a/utils.py b/utils.py index 94237bf..ad5305a 100644 --- a/utils.py +++ b/utils.py @@ -2,11 +2,64 @@ from __future__ import annotations +import logging import os import re import subprocess from pathlib import Path +import requests + +logger = logging.getLogger(__name__) + + +def discover_serving_endpoints(host: str, token: str, timeout: float = 5.0) -> set[str]: + """Return the set of READY serving-endpoint names at the workspace. + + The workspace's direct serving-endpoints list naturally reflects in-geo + model availability — Databricks Geo Designated Services restricts which + models are deployed to each region. Validating an env-set model against + this list is therefore equivalent to "is this model in the workspace's + geo / data-residency policy", without parsing GDS rules ourselves. + + Returns an empty set on any failure (auth error, network blip, JSON parse, + etc.) — caller should treat empty as "discovery unavailable, keep defaults". + """ + if not host or not token: + return set() + try: + resp = requests.get( + f"{host}/api/2.0/serving-endpoints", + headers={"Authorization": f"Bearer {token}"}, + timeout=timeout, + ) + resp.raise_for_status() + endpoints = resp.json().get("endpoints", []) + return { + ep["name"] + for ep in endpoints + if ep.get("name") and ep.get("state", {}).get("ready") == "READY" + } + except Exception as e: + logger.warning("Could not discover serving endpoints at %s: %s", host, e) + return set() + + +def pick_in_geo_model(preferred: list[str], available: set[str], fallback: str) -> str: + """Pick the highest-priority preferred model that's actually served here. + + `preferred` is the caller's degradation chain (e.g. opus-4-7 → opus-4-6). + Returns the first entry that's in `available`. If none match (or `available` + is empty because discovery failed), returns `fallback` — typically the + original env-set default. The user will see a clean ENDPOINT_NOT_FOUND + later if they actually try to use a missing model, rather than getting + silently downgraded to a different model tier. + """ + for m in preferred: + if m in available: + return m + return fallback + def get_npm_version(package_name): """Resolve the latest stable version of an npm package.