diff --git a/app.py b/app.py index 005b997..46e35b6 100644 --- a/app.py +++ b/app.py @@ -23,6 +23,7 @@ import app_state from utils import ensure_https, get_gateway_host from pat_rotator import PATRotator +from telemetry import log_telemetry, set_product_info # Sanitize DATABRICKS_TOKEN early — the platform sometimes injects trailing # newlines / whitespace which causes auth failures. Cleaning it here prevents @@ -175,6 +176,7 @@ def _setup_git_config(): db_token = os.environ.get("DATABRICKS_TOKEN") if db_host and db_token: w = WorkspaceClient(host=db_host, token=db_token, auth_type="pat") + set_product_info(w) me = w.current_user.me() user_email = me.user_name display_name = me.display_name or user_email.split("@")[0] @@ -412,6 +414,7 @@ def get_token_owner(): if app_name: try: w = WorkspaceClient() # auto-detects SP credentials + set_product_info(w) app = w.apps.get(name=app_name) owner = (app.creator or "").lower() logger.info(f"Owner resolved from app.creator: {owner}") @@ -426,6 +429,7 @@ def get_token_owner(): if not host or not token: return None w = WorkspaceClient(host=host, token=token, auth_type="pat") + set_product_info(w) username = w.current_user.me().user_name return username.lower() if username else username except Exception as e: @@ -1060,6 +1064,9 @@ def create_session(): thread = threading.Thread(target=read_pty_output, args=(session_id, master_fd), daemon=True) thread.start() + # Telemetry: track session creation with agent type + log_telemetry("agent", label or "shell") + return jsonify({"session_id": session_id}) except Exception as e: return jsonify({"error": str(e)}), 500 @@ -1112,6 +1119,10 @@ def upload_file(): file_size = os.path.getsize(file_path) if os.path.exists(file_path) else 0 logger.info(f"Upload saved: {file_path} ({file_size} bytes)") + + # Telemetry: track file uploads + log_telemetry("event", "file_upload") + return jsonify({"path": file_path}) @@ -1269,6 +1280,9 @@ def initialize_app(local_dev=False): os.environ.pop("DATABRICKS_CLIENT_SECRET", None) logger.info("SP credentials stripped — PAT-only auth from this point") + # Telemetry: app startup ping (fire-and-forget in background thread) + log_telemetry("event", "app_startup") + # Start background cleanup thread cleanup_thread = threading.Thread(target=cleanup_stale_sessions, daemon=True) cleanup_thread.start() diff --git a/pat_rotator.py b/pat_rotator.py index 28e0319..cc8734f 100644 --- a/pat_rotator.py +++ b/pat_rotator.py @@ -159,6 +159,13 @@ def _rotate_once(self): logger.info(f"INFO: PAT rotation complete — new token (id={new_token_id}, " f"expires in {self._token_lifetime}s). First rotation — no old token to revoke.") + # Telemetry: track PAT rotation events (import here to avoid circular deps) + try: + from telemetry import log_telemetry + log_telemetry("event", "pat_rotation") + except Exception: + pass # Telemetry must never break rotation + return True def revoke_bootstrap_token(self): diff --git a/pyproject.toml b/pyproject.toml index 4eed3e4..fd3579e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "coda" -version = "0.17.2" +version = "0.17.3" description = "CoDA - Coding Agents on Databricks Apps" requires-python = ">=3.10" dependencies = [ diff --git a/sync_to_workspace.py b/sync_to_workspace.py index 0134925..1dcff4d 100644 --- a/sync_to_workspace.py +++ b/sync_to_workspace.py @@ -35,6 +35,11 @@ def get_user_email(): if not host or not token: raise RuntimeError("~/.databrickscfg missing host or token") w = WorkspaceClient(host=host, token=token, auth_type="pat") + try: + from telemetry import set_product_info + set_product_info(w) + except Exception: + pass return w.current_user.me().user_name @@ -68,6 +73,12 @@ def sync_project(project_path: Path): if result.returncode == 0: print(f"✓ Synced to {workspace_dest}") + # Telemetry: track workspace sync events + try: + from telemetry import log_telemetry + log_telemetry("event", "workspace_sync") + except Exception: + pass # Telemetry must never break sync else: print(f"⚠ Sync warning: {result.stderr}", file=sys.stderr) diff --git a/telemetry.py b/telemetry.py new file mode 100644 index 0000000..a94c2c6 --- /dev/null +++ b/telemetry.py @@ -0,0 +1,98 @@ +"""Databricks Labs telemetry for CoDA. + +Follows the DQX pattern: piggybacks telemetry on the Databricks SDK's +User-Agent header. Each log_telemetry() call creates a throwaway +WorkspaceClient, augments the User-Agent with key-value data, and fires +clusters.select_spark_version() to transmit the header to Databricks +servers where it's recorded. + +All telemetry runs in background daemon threads -- never blocks the +Flask request path or terminal I/O. + +Reference: https://github.com/databrickslabs/dqx/blob/main/src/databricks/labs/dqx/telemetry.py +""" + +import functools +import logging +import os +import threading + +import tomllib + +logger = logging.getLogger(__name__) + +_version_cache = None + + +def _get_version(): + """Get CoDA version from pyproject.toml (cached after first call).""" + global _version_cache + if _version_cache is not None: + return _version_cache + try: + pyproject = os.path.join(os.path.dirname(__file__), "pyproject.toml") + with open(pyproject, "rb") as f: + _version_cache = tomllib.load(f)["project"]["version"] + except Exception: + _version_cache = "0.0.0" + return _version_cache + + +def set_product_info(ws): + """Set CoDA product info on a WorkspaceClient for telemetry attribution. + + Call this on any WorkspaceClient so all SDK API calls carry the 'coda' + product identifier in the User-Agent header. + """ + product_info = getattr(ws.config, "_product_info", None) + if product_info is None or product_info[0] != "coda": + setattr(ws.config, "_product_info", ("coda", _get_version())) + + +def log_telemetry(key, value): + """Send a telemetry key-value pair via the Databricks SDK User-Agent header. + + Creates a throwaway WorkspaceClient from ~/.databrickscfg, adds the + key-value to the User-Agent, and fires clusters.select_spark_version() + to transmit. Runs in a background daemon thread. Errors are caught and + logged, never raised. + """ + + def _send(): + try: + from databricks.sdk import WorkspaceClient + from databricks.sdk.errors import DatabricksError + + ws = WorkspaceClient() + set_product_info(ws) + new_config = ws.config.copy().with_user_agent_extra(key, value) + temp_ws = WorkspaceClient(config=new_config) + try: + temp_ws.clusters.select_spark_version() + except DatabricksError as e: + logger.debug(f"Telemetry transmit failed: {e}") + except Exception as e: + logger.debug(f"Telemetry error ({key}={value}): {e}") + + threading.Thread(target=_send, daemon=True, name=f"telemetry-{key}").start() + + +def telemetry_logger(key, value): + """Decorator that fires telemetry before executing the wrapped function. + + Works on standalone functions and class methods alike. Creates its own + WorkspaceClient from ~/.databrickscfg -- no self.ws required. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + log_telemetry(key, value) + except Exception: + pass # Telemetry must never break the wrapped function + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py new file mode 100644 index 0000000..94ecbac --- /dev/null +++ b/tests/test_telemetry.py @@ -0,0 +1,297 @@ +"""Tests for Databricks Labs telemetry — telemetry.py module.""" + +import threading +import time +from unittest import mock + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _wait_for_telemetry_threads(timeout=2.0): + """Wait for any background telemetry threads to complete.""" + deadline = time.monotonic() + timeout + for t in threading.enumerate(): + if t.name.startswith("telemetry-"): + remaining = deadline - time.monotonic() + if remaining > 0: + t.join(timeout=remaining) + + +# --------------------------------------------------------------------------- +# _get_version +# --------------------------------------------------------------------------- + + +class TestGetVersion: + def test_reads_from_pyproject(self): + import telemetry + + # Reset cache so it re-reads + telemetry._version_cache = None + version = telemetry._get_version() + assert version != "0.0.0" + assert "." in version # semver-ish + + def test_caches_result(self): + import telemetry + + telemetry._version_cache = None + v1 = telemetry._get_version() + v2 = telemetry._get_version() + assert v1 == v2 + assert telemetry._version_cache == v1 + + def test_falls_back_on_missing_file(self, tmp_path): + import telemetry + + telemetry._version_cache = None + with mock.patch("telemetry.os.path.dirname", return_value=str(tmp_path)): + version = telemetry._get_version() + assert version == "0.0.0" + # Reset for other tests + telemetry._version_cache = None + + +# --------------------------------------------------------------------------- +# set_product_info +# --------------------------------------------------------------------------- + + +class TestSetProductInfo: + def test_sets_product_info_on_ws(self): + from telemetry import set_product_info + + ws = mock.MagicMock() + ws.config._product_info = None + + set_product_info(ws) + + assert ws.config._product_info == ("coda", mock.ANY) + assert ws.config._product_info[0] == "coda" + + def test_idempotent_when_already_set(self): + from telemetry import set_product_info + + ws = mock.MagicMock() + ws.config._product_info = ("coda", "0.17.2") + + set_product_info(ws) + + # Should not overwrite + assert ws.config._product_info == ("coda", "0.17.2") + + def test_overwrites_different_product(self): + from telemetry import set_product_info + + ws = mock.MagicMock() + ws.config._product_info = ("other-project", "1.0.0") + + set_product_info(ws) + + assert ws.config._product_info[0] == "coda" + + +# --------------------------------------------------------------------------- +# log_telemetry +# --------------------------------------------------------------------------- + + +class TestLogTelemetry: + @mock.patch("databricks.sdk.WorkspaceClient") + def test_fires_in_background_thread(self, mock_ws_cls): + from telemetry import log_telemetry + + mock_ws = mock.MagicMock() + mock_ws.config._product_info = None + mock_ws.config.copy.return_value.with_user_agent_extra.return_value = ( + mock.MagicMock() + ) + mock_ws_cls.return_value = mock_ws + + log_telemetry("event", "test_event") + _wait_for_telemetry_threads() + + # WorkspaceClient() called twice: once for initial ws, once for temp_ws with config + assert mock_ws_cls.call_count == 2 + + @mock.patch("databricks.sdk.WorkspaceClient") + def test_calls_select_spark_version(self, mock_ws_cls): + from telemetry import log_telemetry + + mock_ws = mock.MagicMock() + mock_ws.config._product_info = None + + mock_config_copy = mock.MagicMock() + mock_ws.config.copy.return_value.with_user_agent_extra.return_value = ( + mock_config_copy + ) + + mock_temp_ws = mock.MagicMock() + # First call: WorkspaceClient() -> mock_ws, Second call: WorkspaceClient(config=...) -> mock_temp_ws + mock_ws_cls.side_effect = [mock_ws, mock_temp_ws] + + log_telemetry("agent", "claude") + _wait_for_telemetry_threads() + + # The temp WS should have clusters.select_spark_version() called + assert mock_ws_cls.call_count == 2 + mock_temp_ws.clusters.select_spark_version.assert_called_once() + + @mock.patch("databricks.sdk.WorkspaceClient") + def test_adds_user_agent_extra(self, mock_ws_cls): + from telemetry import log_telemetry + + mock_ws = mock.MagicMock() + mock_ws.config._product_info = None + mock_ws_cls.return_value = mock_ws + + log_telemetry("event", "file_upload") + _wait_for_telemetry_threads() + + mock_ws.config.copy.return_value.with_user_agent_extra.assert_called_once_with( + "event", "file_upload" + ) + + @mock.patch("databricks.sdk.WorkspaceClient") + def test_fire_and_forget_on_ws_error(self, mock_ws_cls): + """Telemetry errors must never propagate to caller.""" + from telemetry import log_telemetry + + mock_ws_cls.side_effect = Exception("No databrickscfg") + + # Should not raise + log_telemetry("event", "startup") + _wait_for_telemetry_threads() + + @mock.patch("databricks.sdk.WorkspaceClient") + def test_fire_and_forget_on_api_error(self, mock_ws_cls): + """DatabricksError during transmit must be swallowed.""" + from databricks.sdk.errors import DatabricksError + from telemetry import log_telemetry + + mock_ws = mock.MagicMock() + mock_ws.config._product_info = None + mock_config_copy = mock.MagicMock() + mock_ws.config.copy.return_value.with_user_agent_extra.return_value = ( + mock_config_copy + ) + mock_temp_ws = mock.MagicMock() + mock_temp_ws.clusters.select_spark_version.side_effect = DatabricksError( + "Forbidden" + ) + mock_ws_cls.side_effect = [mock_ws, mock_temp_ws] + + # Should not raise + log_telemetry("event", "test") + _wait_for_telemetry_threads() + + def test_runs_in_daemon_thread(self): + """Telemetry threads must be daemons so they don't block shutdown.""" + from telemetry import log_telemetry + + with mock.patch("databricks.sdk.WorkspaceClient") as mock_ws_cls: + # Make the WS constructor block so we can inspect the thread + barrier = threading.Event() + + def slow_init(*args, **kwargs): + barrier.wait(timeout=5) + return mock.MagicMock() + + mock_ws_cls.side_effect = slow_init + + log_telemetry("event", "test") + + # Find the telemetry thread + telemetry_threads = [ + t for t in threading.enumerate() if t.name.startswith("telemetry-") + ] + assert len(telemetry_threads) >= 1 + assert telemetry_threads[0].daemon is True + + barrier.set() # unblock + _wait_for_telemetry_threads() + + +# --------------------------------------------------------------------------- +# telemetry_logger decorator +# --------------------------------------------------------------------------- + + +class TestTelemetryLogger: + @mock.patch("telemetry.log_telemetry") + def test_decorator_fires_telemetry(self, mock_log): + from telemetry import telemetry_logger + + @telemetry_logger("event", "decorated_fn") + def my_function(x, y): + return x + y + + result = my_function(1, 2) + + assert result == 3 + mock_log.assert_called_once_with("event", "decorated_fn") + + @mock.patch("telemetry.log_telemetry") + def test_preserves_function_metadata(self, mock_log): + from telemetry import telemetry_logger + + @telemetry_logger("event", "test") + def documented_function(): + """This is the docstring.""" + pass + + assert documented_function.__name__ == "documented_function" + assert documented_function.__doc__ == "This is the docstring." + + @mock.patch("telemetry.log_telemetry") + def test_passes_args_and_kwargs(self, mock_log): + from telemetry import telemetry_logger + + @telemetry_logger("event", "test") + def func_with_args(a, b, c=None): + return (a, b, c) + + result = func_with_args("x", "y", c="z") + assert result == ("x", "y", "z") + + @mock.patch("telemetry.log_telemetry", side_effect=Exception("boom")) + def test_telemetry_failure_doesnt_break_function(self, mock_log): + """If telemetry itself fails, the wrapped function must still execute.""" + from telemetry import telemetry_logger + + @telemetry_logger("event", "test") + def important_function(): + return "success" + + # The decorator catches the exception from log_telemetry + # but log_telemetry itself is fire-and-forget, so this tests + # that the function still returns correctly + result = important_function() + assert result == "success" + + +# --------------------------------------------------------------------------- +# Integration: product_info + telemetry together +# --------------------------------------------------------------------------- + + +class TestProductInfoIntegration: + @mock.patch("databricks.sdk.WorkspaceClient") + def test_product_info_set_during_telemetry(self, mock_ws_cls): + """log_telemetry should set product_info before transmitting.""" + from telemetry import log_telemetry + + mock_ws = mock.MagicMock() + mock_ws.config._product_info = None + mock_ws_cls.return_value = mock_ws + + log_telemetry("event", "startup") + _wait_for_telemetry_threads() + + # product_info should have been set to ('coda', version) + assert mock_ws.config._product_info[0] == "coda"