Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions setup_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import subprocess
from pathlib import Path

from utils import ensure_https, get_gateway_host
from utils import discover_serving_endpoints, ensure_https, get_gateway_host, pick_in_geo_model

# Set HOME if not properly set
if not os.environ.get("HOME") or os.environ["HOME"] == "/":
Expand Down Expand Up @@ -40,13 +40,46 @@
else:
settings = {}

# Discover models actually served at this workspace. The direct serving-
# endpoints list reflects Databricks Geo Designated Services policy — a
# workspace in AU only sees in-geo models, etc. Validating env-set defaults
# against this list avoids configuring Claude Code with a model the gateway
# claims to serve but the user's geo can't access.
available = discover_serving_endpoints(databricks_host, token)
if available:
print(f"Discovered {len(available)} READY serving endpoints at workspace")

requested_model = os.environ.get("ANTHROPIC_MODEL", "databricks-claude-opus-4-7")
active_model = pick_in_geo_model(
[requested_model, "databricks-claude-opus-4-6", "databricks-claude-sonnet-4-6"],
available,
fallback=requested_model,
)
opus_model = pick_in_geo_model(
["databricks-claude-opus-4-7", "databricks-claude-opus-4-6"],
available,
fallback="databricks-claude-opus-4-7",
)
sonnet_model = pick_in_geo_model(
["databricks-claude-sonnet-4-6", "databricks-claude-sonnet-4-5"],
available,
fallback="databricks-claude-sonnet-4-6",
)
haiku_model = pick_in_geo_model(
["databricks-claude-haiku-4-5"],
available,
fallback="databricks-claude-haiku-4-5",
)
if available and active_model != requested_model:
print(f"ANTHROPIC_MODEL={requested_model} not served at this workspace, using {active_model}")

settings.setdefault("env", {})
settings["env"]["ANTHROPIC_MODEL"] = os.environ.get("ANTHROPIC_MODEL", "databricks-claude-opus-4-7")
settings["env"]["ANTHROPIC_MODEL"] = active_model
settings["env"]["ANTHROPIC_BASE_URL"] = anthropic_base_url
settings["env"]["ANTHROPIC_AUTH_TOKEN"] = token
settings["env"]["ANTHROPIC_DEFAULT_OPUS_MODEL"] = "databricks-claude-opus-4-7"
settings["env"]["ANTHROPIC_DEFAULT_SONNET_MODEL"] = "databricks-claude-sonnet-4-6"
settings["env"]["ANTHROPIC_DEFAULT_HAIKU_MODEL"] = "databricks-claude-haiku-4-5"
settings["env"]["ANTHROPIC_DEFAULT_OPUS_MODEL"] = opus_model
settings["env"]["ANTHROPIC_DEFAULT_SONNET_MODEL"] = sonnet_model
settings["env"]["ANTHROPIC_DEFAULT_HAIKU_MODEL"] = haiku_model
settings["env"]["ANTHROPIC_CUSTOM_HEADERS"] = "x-databricks-use-coding-agent-mode: true"
settings["env"]["CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS"] = "1"

Expand Down
53 changes: 53 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,64 @@

from __future__ import annotations

import logging
import os
import re
import subprocess
from pathlib import Path

import requests

logger = logging.getLogger(__name__)


def discover_serving_endpoints(host: str, token: str, timeout: float = 5.0) -> set[str]:
"""Return the set of READY serving-endpoint names at the workspace.

The workspace's direct serving-endpoints list naturally reflects in-geo
model availability — Databricks Geo Designated Services restricts which
models are deployed to each region. Validating an env-set model against
this list is therefore equivalent to "is this model in the workspace's
geo / data-residency policy", without parsing GDS rules ourselves.

Returns an empty set on any failure (auth error, network blip, JSON parse,
etc.) — caller should treat empty as "discovery unavailable, keep defaults".
"""
if not host or not token:
return set()
try:
resp = requests.get(
Comment on lines 10 to +31
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger = logging.getLogger(__name__)
def discover_serving_endpoints(host: str, token: str, timeout: float = 5.0) -> set[str]:
"""Return the set of READY serving-endpoint names at the workspace.
The workspace's direct serving-endpoints list naturally reflects in-geo
model availabilityDatabricks Geo Designated Services restricts which
models are deployed to each region. Validating an env-set model against
this list is therefore equivalent to "is this model in the workspace's
geo / data-residency policy", without parsing GDS rules ourselves.
Returns an empty set on any failure (auth error, network blip, JSON parse,
etc.) — caller should treat empty as "discovery unavailable, keep defaults".
"""
if not host or not token:
return set()
try:
import requests
resp = requests.get(
import requests
logger = logging.getLogger(__name__)
def discover_serving_endpoints(host: str, token: str, timeout: float = 5.0) -> set[str]:
"""Return the set of READY serving-endpoint names at the workspace.
The workspace's direct serving-endpoints list naturally reflects in-geo
model availabilityDatabricks Geo Designated Services restricts which
models are deployed to each region. Validating an env-set model against
this list is therefore equivalent to "is this model in the workspace's
geo / data-residency policy", without parsing GDS rules ourselves.
Returns an empty set on any failure (auth error, network blip, JSON parse,
etc.) — caller should treat empty as "discovery unavailable, keep defaults".
"""
if not host or not token:
return set()
try:
resp = requests.get(

f"{host}/api/2.0/serving-endpoints",
headers={"Authorization": f"Bearer {token}"},
timeout=timeout,
)
resp.raise_for_status()
endpoints = resp.json().get("endpoints", [])
return {
ep["name"]
for ep in endpoints
if ep.get("name") and ep.get("state", {}).get("ready") == "READY"
}
except Exception as e:
logger.warning("Could not discover serving endpoints at %s: %s", host, e)
return set()


def pick_in_geo_model(preferred: list[str], available: set[str], fallback: str) -> str:
"""Pick the highest-priority preferred model that's actually served here.

`preferred` is the caller's degradation chain (e.g. opus-4-7 → opus-4-6).
Returns the first entry that's in `available`. If none match (or `available`
is empty because discovery failed), returns `fallback` — typically the
original env-set default. The user will see a clean ENDPOINT_NOT_FOUND
later if they actually try to use a missing model, rather than getting
silently downgraded to a different model tier.
"""
for m in preferred:
if m in available:
return m
return fallback


def get_npm_version(package_name):
"""Resolve the latest stable version of an npm package.
Expand Down