diff --git a/README.md b/README.md index 8ca45e6b..cceeb0aa 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ --- ### Connect -Supports all major databases: SQL Server, PostgreSQL, MySQL, SQLite, MariaDB, FirebirdSQL, Oracle, DuckDB, CockroachDB, ClickHouse, Snowflake, Supabase, CloudFlare D1, Turso, Athena, BigQuery, Spanner, RedShift, IBM Db2, SAP HANA, Teradata, Trino, Presto, Apache Flight SQL, Apache Impala, SurrealDB and osquery. +Supports all major databases: SQL Server, PostgreSQL, MySQL, SQLite, MariaDB, FirebirdSQL, Oracle, DuckDB, CockroachDB, ClickHouse, Snowflake, Databricks, Supabase, CloudFlare D1, Turso, Athena, BigQuery, Spanner, RedShift, IBM Db2, SAP HANA, Teradata, Trino, Presto, Apache Flight SQL, Apache Impala, SurrealDB and osquery. ![Database Providers](docs/demos/demo-providers.gif) @@ -289,6 +289,7 @@ Most of the time you can just run `sqlit` and connect. If a Python driver is mis | Turso | `libsql` | `pipx inject sqlit-tui libsql` | `python -m pip install libsql` | | Cloudflare D1 | `requests` | `pipx inject sqlit-tui requests` | `python -m pip install requests` | | Snowflake | `snowflake-connector-python` | `pipx inject sqlit-tui snowflake-connector-python` | `python -m pip install snowflake-connector-python` | +| Databricks | `databricks-sql-connector` | `pipx inject sqlit-tui databricks-sql-connector` | `python -m pip install databricks-sql-connector` | | Firebird | `firebirdsql` | `pipx inject sqlit-tui firebirdsql` | `python -m pip install firebirdsql` | | Athena | `pyathena` | `pipx inject sqlit-tui pyathena` | `python -m pip install pyathena` | | BigQuery | `google-cloud-bigquery` | `pipx inject sqlit-tui google-cloud-bigquery` | `python -m pip install google-cloud-bigquery` | diff --git a/pyproject.toml b/pyproject.toml index 7940b1ef..3d894366 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ all = [ "impyla>=0.18.0", "osquery>=3.0.0", "surrealdb>=1.0.0", + "databricks-sql-connector>=3.0.0", ] postgres = ["psycopg2-binary>=2.9.0"] cockroachdb = ["psycopg2-binary>=2.9.0"] @@ -88,6 +89,7 @@ flight = ["adbc-driver-flightsql>=1.0.0"] impala = ["impyla>=0.18.0"] osquery = ["osquery>=3.0.0"] surrealdb = ["surrealdb>=1.0.0"] +databricks = ["databricks-sql-connector>=3.0.0"] ssh = [ "sshtunnel>=0.4.0", "paramiko>=2.0.0,<4.0.0", @@ -253,6 +255,10 @@ module = [ "impala.dbapi", "osquery", "surrealdb", + "databricks", + "databricks.sql", + "databricks.sdk", + "databricks.sdk.core", "google.cloud", "google.cloud.bigquery", "google.cloud.bigquery.dbapi", diff --git a/sqlit/domains/connections/domain/config.py b/sqlit/domains/connections/domain/config.py index 8ad839de..6c37675f 100644 --- a/sqlit/domains/connections/domain/config.py +++ b/sqlit/domains/connections/domain/config.py @@ -14,6 +14,7 @@ class DatabaseType(str, Enum): CLICKHOUSE = "clickhouse" COCKROACHDB = "cockroachdb" D1 = "d1" + DATABRICKS = "databricks" DUCKDB = "duckdb" DB2 = "db2" FIREBIRD = "firebird" @@ -53,6 +54,7 @@ class DatabaseType(str, Enum): DatabaseType.HANA, DatabaseType.TERADATA, DatabaseType.SNOWFLAKE, + DatabaseType.DATABRICKS, DatabaseType.BIGQUERY, DatabaseType.SPANNER, DatabaseType.TRINO, diff --git a/sqlit/domains/connections/providers/databricks/__init__.py b/sqlit/domains/connections/providers/databricks/__init__.py new file mode 100644 index 00000000..3bbc4837 --- /dev/null +++ b/sqlit/domains/connections/providers/databricks/__init__.py @@ -0,0 +1 @@ +"""Provider package.""" diff --git a/sqlit/domains/connections/providers/databricks/adapter.py b/sqlit/domains/connections/providers/databricks/adapter.py new file mode 100644 index 00000000..2332a797 --- /dev/null +++ b/sqlit/domains/connections/providers/databricks/adapter.py @@ -0,0 +1,252 @@ +"""Databricks adapter using databricks-sql-connector. + +Databricks SQL uses a three-level namespace via Unity Catalog: + catalog.schema.table + +We map Databricks' "catalog" to the generic `database` slot in +sqlit's connection model, mirroring how Trino is handled. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from sqlit.domains.connections.providers.adapters.base import ( + ColumnInfo, + CursorBasedAdapter, + IndexInfo, + SequenceInfo, + TableInfo, + TriggerInfo, +) + +if TYPE_CHECKING: + from sqlit.domains.connections.domain.config import ConnectionConfig + + +class DatabricksAdapter(CursorBasedAdapter): + """Adapter for Databricks SQL warehouses and clusters.""" + + @property + def name(self) -> str: + return "Databricks" + + @property + def install_extra(self) -> str: + return "databricks" + + @property + def install_package(self) -> str: + return "databricks-sql-connector" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("databricks.sql",) + + @property + def supports_multiple_databases(self) -> bool: + return True + + @property + def supports_cross_database_queries(self) -> bool: + return True + + @property + def supports_stored_procedures(self) -> bool: + return False + + @property + def supports_indexes(self) -> bool: + return False + + @property + def supports_triggers(self) -> bool: + return False + + @property + def supports_sequences(self) -> bool: + return False + + @property + def default_schema(self) -> str: + return "default" + + @property + def system_databases(self) -> frozenset[str]: + # Built-in Databricks catalogs we usually want to hide from the + # primary picker. `samples` is the public demo catalog and + # `system` holds Unity Catalog telemetry. + return frozenset({"system"}) + + def apply_database_override(self, config: ConnectionConfig, database: str) -> ConnectionConfig: + """Apply a default catalog for unqualified queries.""" + if not database: + return config + return config.with_endpoint(database=database) + + def connect(self, config: ConnectionConfig) -> Any: + sql_module = self._import_driver_module( + "databricks.sql", + driver_name=self.name, + extra_name=self.install_extra, + package_name=self.install_package, + ) + + endpoint = config.tcp_endpoint + if endpoint is None: + raise ValueError("Databricks connections require a TCP-style endpoint.") + + extras = config.options + http_path = extras.get("http_path") or config.extra_options.get("http_path") + if not http_path: + raise ValueError("Databricks requires an HTTP Path (SQL warehouse or cluster).") + + connect_args: dict[str, Any] = { + "server_hostname": endpoint.host, + "http_path": http_path, + } + + catalog = endpoint.database + if catalog: + connect_args["catalog"] = catalog + schema = extras.get("schema") + if schema: + connect_args["schema"] = schema + + auth_type = extras.get("auth_type", "pat") + if auth_type == "pat": + token = extras.get("access_token") or endpoint.password + if not token: + raise ValueError("Databricks PAT authentication requires an access token.") + connect_args["access_token"] = token + elif auth_type == "oauth-u2m": + connect_args["auth_type"] = "databricks-oauth" + elif auth_type == "oauth-m2m": + client_id = extras.get("client_id") + client_secret = extras.get("client_secret") + if not client_id or not client_secret: + raise ValueError( + "Databricks OAuth (Service Principal) requires client_id and client_secret." + ) + connect_args["credentials_provider"] = _build_m2m_credentials_provider( + endpoint.host, client_id, client_secret + ) + else: + raise ValueError(f"Unknown Databricks auth_type: {auth_type}") + + connect_args.update(config.extra_options) + # http_path may have been passed via extra_options; drop the legacy key + # so it isn't sent twice if both schemes were used. + connect_args.pop("http_path", None) + connect_args["http_path"] = http_path + + return sql_module.connect(**connect_args) + + def get_databases(self, conn: Any) -> list[str]: + """List Unity Catalog catalogs.""" + cursor = conn.cursor() + # SHOW CATALOGS is universally supported and avoids needing + # SELECT privilege on system.information_schema. + cursor.execute("SHOW CATALOGS") + return [row[0] for row in cursor.fetchall()] + + def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: + cursor = conn.cursor() + if database: + cursor.execute( + "SELECT table_schema, table_name FROM " + f"{self.quote_identifier(database)}.information_schema.tables " + "WHERE table_type IN ('MANAGED', 'EXTERNAL', 'BASE TABLE') " + "ORDER BY table_schema, table_name" + ) + return [(row[0], row[1]) for row in cursor.fetchall()] + + cursor.execute("SHOW TABLES") + return [(row[0], row[1]) for row in cursor.fetchall()] + + def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: + cursor = conn.cursor() + if database: + cursor.execute( + "SELECT table_schema, table_name FROM " + f"{self.quote_identifier(database)}.information_schema.views " + "ORDER BY table_schema, table_name" + ) + return [(row[0], row[1]) for row in cursor.fetchall()] + + cursor.execute("SHOW VIEWS") + # SHOW VIEWS columns: database, viewName, isTemporary + return [(row[0], row[1]) for row in cursor.fetchall()] + + def get_columns( + self, conn: Any, table: str, database: str | None = None, schema: str | None = None + ) -> list[ColumnInfo]: + cursor = conn.cursor() + schema_name = schema or self.default_schema + if database: + cursor.execute( + "SELECT column_name, data_type FROM " + f"{self.quote_identifier(database)}.information_schema.columns " + "WHERE table_schema = ? AND table_name = ? " + "ORDER BY ordinal_position", + (schema_name, table), + ) + else: + cursor.execute( + "SELECT column_name, data_type FROM information_schema.columns " + "WHERE table_schema = ? AND table_name = ? " + "ORDER BY ordinal_position", + (schema_name, table), + ) + return [ColumnInfo(name=row[0], data_type=row[1]) for row in cursor.fetchall()] + + def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: + return [] + + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: + return [] + + def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]: + return [] + + def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]: + return [] + + def quote_identifier(self, name: str) -> str: + """Quote identifier using backticks (Databricks/Spark SQL standard).""" + escaped = name.replace("`", "``") + return f"`{escaped}`" + + def build_select_query( + self, table: str, limit: int, database: str | None = None, schema: str | None = None + ) -> str: + schema_name = schema or self.default_schema + if database and schema_name: + return ( + f"SELECT * FROM {self.quote_identifier(database)}." + f"{self.quote_identifier(schema_name)}." + f"{self.quote_identifier(table)} LIMIT {limit}" + ) + if schema_name: + return f"SELECT * FROM {self.quote_identifier(schema_name)}.{self.quote_identifier(table)} LIMIT {limit}" + return f"SELECT * FROM {self.quote_identifier(table)} LIMIT {limit}" + + +def _build_m2m_credentials_provider(host: str, client_id: str, client_secret: str) -> Any: + """Return a credentials_provider callable for Databricks OAuth M2M. + + Imported lazily so the databricks-sdk dependency is only required + when the user actually selects service-principal auth. + """ + + def _factory() -> Any: + from databricks.sdk.core import Config, oauth_service_principal + + cfg = Config( + host=host if host.startswith(("http://", "https://")) else f"https://{host}", + client_id=client_id, + client_secret=client_secret, + ) + return oauth_service_principal(cfg) + + return _factory diff --git a/sqlit/domains/connections/providers/databricks/provider.py b/sqlit/domains/connections/providers/databricks/provider.py new file mode 100644 index 00000000..05cff445 --- /dev/null +++ b/sqlit/domains/connections/providers/databricks/provider.py @@ -0,0 +1,29 @@ +"""Provider registration.""" + +from sqlit.domains.connections.providers.adapter_provider import build_adapter_provider +from sqlit.domains.connections.providers.catalog import register_provider +from sqlit.domains.connections.providers.databricks.schema import SCHEMA +from sqlit.domains.connections.providers.model import DatabaseProvider, ProviderSpec + + +def _provider_factory(spec: ProviderSpec) -> DatabaseProvider: + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + return build_adapter_provider(spec, SCHEMA, DatabricksAdapter()) + + +SPEC = ProviderSpec( + db_type="databricks", + display_name="Databricks", + schema_path=("sqlit.domains.connections.providers.databricks.schema", "SCHEMA"), + supports_ssh=False, + is_file_based=False, + has_advanced_auth=True, + default_port="", + requires_auth=True, + badge_label="DBRX", + url_schemes=("databricks",), + provider_factory=_provider_factory, +) + +register_provider(SPEC) diff --git a/sqlit/domains/connections/providers/databricks/schema.py b/sqlit/domains/connections/providers/databricks/schema.py new file mode 100644 index 00000000..413acda0 --- /dev/null +++ b/sqlit/domains/connections/providers/databricks/schema.py @@ -0,0 +1,91 @@ +"""Connection schema for Databricks SQL.""" + +from sqlit.domains.connections.providers.schema_helpers import ( + ConnectionSchema, + FieldType, + SchemaField, + SelectOption, +) + + +def _get_databricks_auth_options() -> tuple[SelectOption, ...]: + return ( + SelectOption("pat", "Personal Access Token"), + SelectOption("oauth-u2m", "OAuth (Browser)"), + SelectOption("oauth-m2m", "OAuth (Service Principal)"), + ) + + +_AUTH_NEEDS_TOKEN = {"pat"} +_AUTH_NEEDS_SP = {"oauth-m2m"} + + +SCHEMA = ConnectionSchema( + db_type="databricks", + display_name="Databricks", + fields=( + SchemaField( + name="server", + label="Server Hostname", + placeholder="dbc-a1b2cd34-e5f6.cloud.databricks.com", + required=True, + description="Databricks SQL warehouse server hostname (no protocol)", + ), + SchemaField( + name="http_path", + label="HTTP Path", + placeholder="/sql/1.0/warehouses/abcdef1234567890", + required=True, + description="HTTP path of the SQL warehouse or cluster", + ), + SchemaField( + name="auth_type", + label="Authentication", + field_type=FieldType.DROPDOWN, + options=_get_databricks_auth_options(), + default="pat", + ), + SchemaField( + name="access_token", + label="Access Token", + field_type=FieldType.PASSWORD, + placeholder="dapi...", + group="credentials", + description="Personal Access Token (PAT)", + visible_when=lambda v: v.get("auth_type", "pat") in _AUTH_NEEDS_TOKEN, + ), + SchemaField( + name="client_id", + label="Client ID", + placeholder="service-principal-client-id", + required=False, + group="credentials", + visible_when=lambda v: v.get("auth_type") in _AUTH_NEEDS_SP, + ), + SchemaField( + name="client_secret", + label="Client Secret", + field_type=FieldType.PASSWORD, + placeholder="(secret)", + required=False, + group="credentials", + visible_when=lambda v: v.get("auth_type") in _AUTH_NEEDS_SP, + ), + SchemaField( + name="database", + label="Catalog", + placeholder="main", + required=False, + description="Unity Catalog name (top-level namespace)", + ), + SchemaField( + name="schema", + label="Schema", + placeholder="default", + required=False, + description="Default schema within the catalog", + ), + ), + supports_ssh=False, + has_advanced_auth=True, +) diff --git a/tests/unit/test_databricks_adapter.py b/tests/unit/test_databricks_adapter.py new file mode 100644 index 00000000..28420a26 --- /dev/null +++ b/tests/unit/test_databricks_adapter.py @@ -0,0 +1,228 @@ +"""Unit tests for Databricks adapter.""" + +from __future__ import annotations + +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + +from tests.helpers import ConnectionConfig + + +def _fake_databricks_sql() -> tuple[MagicMock, dict[str, types.ModuleType]]: + """Build a fake `databricks.sql` module hierarchy.""" + databricks_pkg = types.ModuleType("databricks") + databricks_sql = types.ModuleType("databricks.sql") + connect = MagicMock(name="databricks.sql.connect") + databricks_sql.connect = connect # type: ignore[attr-defined] + databricks_pkg.sql = databricks_sql # type: ignore[attr-defined] + modules = {"databricks": databricks_pkg, "databricks.sql": databricks_sql} + return connect, modules + + +class TestDatabricksAdapter: + def test_connect_pat_default(self): + connect, modules = _fake_databricks_sql() + with patch.dict(sys.modules, modules): + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + adapter = DatabricksAdapter() + config = ConnectionConfig( + name="test", + db_type="databricks", + server="dbc-xyz.cloud.databricks.com", + database="main", + options={ + "http_path": "/sql/1.0/warehouses/abcdef", + "auth_type": "pat", + "access_token": "dapi-xxx", + "schema": "default", + }, + ) + adapter.connect(config) + connect.assert_called_once_with( + server_hostname="dbc-xyz.cloud.databricks.com", + http_path="/sql/1.0/warehouses/abcdef", + catalog="main", + schema="default", + access_token="dapi-xxx", + ) + + def test_connect_pat_token_from_password_field(self): + """If access_token isn't in options, the endpoint password is used.""" + connect, modules = _fake_databricks_sql() + with patch.dict(sys.modules, modules): + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + adapter = DatabricksAdapter() + config = ConnectionConfig( + name="test", + db_type="databricks", + server="host", + password="legacy-pat", + options={"http_path": "/sql/1.0/warehouses/x"}, + ) + adapter.connect(config) + args = connect.call_args.kwargs + assert args["access_token"] == "legacy-pat" + + def test_connect_oauth_u2m(self): + connect, modules = _fake_databricks_sql() + with patch.dict(sys.modules, modules): + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + adapter = DatabricksAdapter() + config = ConnectionConfig( + name="test", + db_type="databricks", + server="host", + options={ + "http_path": "/sql/1.0/warehouses/x", + "auth_type": "oauth-u2m", + }, + ) + adapter.connect(config) + args = connect.call_args.kwargs + assert args["auth_type"] == "databricks-oauth" + assert "access_token" not in args + + def test_connect_missing_http_path_raises(self): + _, modules = _fake_databricks_sql() + with patch.dict(sys.modules, modules): + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + adapter = DatabricksAdapter() + config = ConnectionConfig(name="t", db_type="databricks", server="host", password="x") + with pytest.raises(ValueError, match="HTTP Path"): + adapter.connect(config) + + def test_connect_missing_pat_raises(self): + _, modules = _fake_databricks_sql() + with patch.dict(sys.modules, modules): + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + adapter = DatabricksAdapter() + config = ConnectionConfig( + name="t", + db_type="databricks", + server="host", + options={"http_path": "/sql/1.0/warehouses/x", "auth_type": "pat"}, + ) + with pytest.raises(ValueError, match="access token"): + adapter.connect(config) + + def test_connect_m2m_requires_client_credentials(self): + _, modules = _fake_databricks_sql() + with patch.dict(sys.modules, modules): + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + adapter = DatabricksAdapter() + config = ConnectionConfig( + name="t", + db_type="databricks", + server="host", + options={"http_path": "/sql/1.0/warehouses/x", "auth_type": "oauth-m2m"}, + ) + with pytest.raises(ValueError, match="client_id"): + adapter.connect(config) + + def test_get_databases_uses_show_catalogs(self): + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + adapter = DatabricksAdapter() + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchall.return_value = [("main",), ("samples",), ("hive_metastore",)] + + result = adapter.get_databases(mock_conn) + + mock_cursor.execute.assert_called_with("SHOW CATALOGS") + assert result == ["main", "samples", "hive_metastore"] + + def test_get_tables_with_catalog_uses_info_schema(self): + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + adapter = DatabricksAdapter() + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchall.return_value = [ + ("default", "trips"), + ("analytics", "events"), + ] + + tables = adapter.get_tables(mock_conn, database="main") + sql = mock_cursor.execute.call_args[0][0] + assert "`main`.information_schema.tables" in sql + assert tables == [("default", "trips"), ("analytics", "events")] + + def test_get_columns_uses_info_schema(self): + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + adapter = DatabricksAdapter() + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchall.return_value = [ + ("id", "BIGINT"), + ("name", "STRING"), + ] + + cols = adapter.get_columns( + mock_conn, "trips", database="main", schema="default" + ) + + args = mock_cursor.execute.call_args[0] + assert "`main`.information_schema.columns" in args[0] + assert args[1] == ("default", "trips") + assert [c.name for c in cols] == ["id", "name"] + assert cols[0].data_type == "BIGINT" + + def test_quote_identifier_uses_backticks(self): + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + adapter = DatabricksAdapter() + assert adapter.quote_identifier("foo") == "`foo`" + assert adapter.quote_identifier("a`b") == "`a``b`" + + def test_build_select_query(self): + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + adapter = DatabricksAdapter() + assert ( + adapter.build_select_query("trips", 10, database="main", schema="default") + == "SELECT * FROM `main`.`default`.`trips` LIMIT 10" + ) + assert ( + adapter.build_select_query("trips", 10, schema="default") + == "SELECT * FROM `default`.`trips` LIMIT 10" + ) + + def test_capabilities(self): + from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter + + adapter = DatabricksAdapter() + assert adapter.supports_multiple_databases is True + assert adapter.supports_cross_database_queries is True + assert adapter.supports_stored_procedures is False + assert adapter.supports_indexes is False + assert adapter.supports_triggers is False + assert adapter.supports_sequences is False + assert adapter.default_schema == "default" + + def test_provider_registration(self): + from sqlit.domains.connections.providers.catalog import ( + get_provider_schema, + get_supported_db_types, + ) + + assert "databricks" in get_supported_db_types() + schema = get_provider_schema("databricks") + field_names = [f.name for f in schema.fields] + assert "server" in field_names + assert "http_path" in field_names + assert "auth_type" in field_names + assert "access_token" in field_names