Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
---

### Connect
Supports all major databases: SQL Server, PostgreSQL, MySQL, SQLite, MariaDB, FirebirdSQL, Oracle, DuckDB, CockroachDB, ClickHouse, Snowflake, Supabase, CloudFlare D1, Turso, Athena, BigQuery, Spanner, RedShift, IBM Db2, SAP HANA, Teradata, Trino, Presto, Apache Flight SQL, Apache Impala, SurrealDB and osquery.
Supports all major databases: SQL Server, PostgreSQL, MySQL, SQLite, MariaDB, FirebirdSQL, Oracle, DuckDB, CockroachDB, ClickHouse, Snowflake, Databricks, Supabase, CloudFlare D1, Turso, Athena, BigQuery, Spanner, RedShift, IBM Db2, SAP HANA, Teradata, Trino, Presto, Apache Flight SQL, Apache Impala, SurrealDB and osquery.

![Database Providers](docs/demos/demo-providers.gif)

Expand Down Expand Up @@ -289,6 +289,7 @@ Most of the time you can just run `sqlit` and connect. If a Python driver is mis
| Turso | `libsql` | `pipx inject sqlit-tui libsql` | `python -m pip install libsql` |
| Cloudflare D1 | `requests` | `pipx inject sqlit-tui requests` | `python -m pip install requests` |
| Snowflake | `snowflake-connector-python` | `pipx inject sqlit-tui snowflake-connector-python` | `python -m pip install snowflake-connector-python` |
| Databricks | `databricks-sql-connector` | `pipx inject sqlit-tui databricks-sql-connector` | `python -m pip install databricks-sql-connector` |
| Firebird | `firebirdsql` | `pipx inject sqlit-tui firebirdsql` | `python -m pip install firebirdsql` |
| Athena | `pyathena` | `pipx inject sqlit-tui pyathena` | `python -m pip install pyathena` |
| BigQuery | `google-cloud-bigquery` | `pipx inject sqlit-tui google-cloud-bigquery` | `python -m pip install google-cloud-bigquery` |
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ all = [
"impyla>=0.18.0",
"osquery>=3.0.0",
"surrealdb>=1.0.0",
"databricks-sql-connector>=3.0.0",
]
postgres = ["psycopg2-binary>=2.9.0"]
cockroachdb = ["psycopg2-binary>=2.9.0"]
Expand All @@ -88,6 +89,7 @@ flight = ["adbc-driver-flightsql>=1.0.0"]
impala = ["impyla>=0.18.0"]
osquery = ["osquery>=3.0.0"]
surrealdb = ["surrealdb>=1.0.0"]
databricks = ["databricks-sql-connector>=3.0.0"]
ssh = [
"sshtunnel>=0.4.0",
"paramiko>=2.0.0,<4.0.0",
Expand Down Expand Up @@ -253,6 +255,10 @@ module = [
"impala.dbapi",
"osquery",
"surrealdb",
"databricks",
"databricks.sql",
"databricks.sdk",
"databricks.sdk.core",
"google.cloud",
"google.cloud.bigquery",
"google.cloud.bigquery.dbapi",
Expand Down
2 changes: 2 additions & 0 deletions sqlit/domains/connections/domain/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class DatabaseType(str, Enum):
CLICKHOUSE = "clickhouse"
COCKROACHDB = "cockroachdb"
D1 = "d1"
DATABRICKS = "databricks"
DUCKDB = "duckdb"
DB2 = "db2"
FIREBIRD = "firebird"
Expand Down Expand Up @@ -53,6 +54,7 @@ class DatabaseType(str, Enum):
DatabaseType.HANA,
DatabaseType.TERADATA,
DatabaseType.SNOWFLAKE,
DatabaseType.DATABRICKS,
DatabaseType.BIGQUERY,
DatabaseType.SPANNER,
DatabaseType.TRINO,
Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/databricks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Provider package."""
252 changes: 252 additions & 0 deletions sqlit/domains/connections/providers/databricks/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
"""Databricks adapter using databricks-sql-connector.

Databricks SQL uses a three-level namespace via Unity Catalog:
catalog.schema.table

We map Databricks' "catalog" to the generic `database` slot in
sqlit's connection model, mirroring how Trino is handled.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from sqlit.domains.connections.providers.adapters.base import (
ColumnInfo,
CursorBasedAdapter,
IndexInfo,
SequenceInfo,
TableInfo,
TriggerInfo,
)

if TYPE_CHECKING:
from sqlit.domains.connections.domain.config import ConnectionConfig


class DatabricksAdapter(CursorBasedAdapter):
"""Adapter for Databricks SQL warehouses and clusters."""

@property
def name(self) -> str:
return "Databricks"

@property
def install_extra(self) -> str:
return "databricks"

@property
def install_package(self) -> str:
return "databricks-sql-connector"

@property
def driver_import_names(self) -> tuple[str, ...]:
return ("databricks.sql",)

@property
def supports_multiple_databases(self) -> bool:
return True

@property
def supports_cross_database_queries(self) -> bool:
return True

@property
def supports_stored_procedures(self) -> bool:
return False

@property
def supports_indexes(self) -> bool:
return False

@property
def supports_triggers(self) -> bool:
return False

@property
def supports_sequences(self) -> bool:
return False

@property
def default_schema(self) -> str:
return "default"

@property
def system_databases(self) -> frozenset[str]:
# Built-in Databricks catalogs we usually want to hide from the
# primary picker. `samples` is the public demo catalog and
# `system` holds Unity Catalog telemetry.
return frozenset({"system"})

def apply_database_override(self, config: ConnectionConfig, database: str) -> ConnectionConfig:
"""Apply a default catalog for unqualified queries."""
if not database:
return config
return config.with_endpoint(database=database)

def connect(self, config: ConnectionConfig) -> Any:
sql_module = self._import_driver_module(
"databricks.sql",
driver_name=self.name,
extra_name=self.install_extra,
package_name=self.install_package,
)

endpoint = config.tcp_endpoint
if endpoint is None:
raise ValueError("Databricks connections require a TCP-style endpoint.")

extras = config.options
http_path = extras.get("http_path") or config.extra_options.get("http_path")
if not http_path:
raise ValueError("Databricks requires an HTTP Path (SQL warehouse or cluster).")

connect_args: dict[str, Any] = {
"server_hostname": endpoint.host,
"http_path": http_path,
}

catalog = endpoint.database
if catalog:
connect_args["catalog"] = catalog
schema = extras.get("schema")
if schema:
connect_args["schema"] = schema

auth_type = extras.get("auth_type", "pat")
if auth_type == "pat":
token = extras.get("access_token") or endpoint.password
if not token:
raise ValueError("Databricks PAT authentication requires an access token.")
connect_args["access_token"] = token
elif auth_type == "oauth-u2m":
connect_args["auth_type"] = "databricks-oauth"
elif auth_type == "oauth-m2m":
client_id = extras.get("client_id")
client_secret = extras.get("client_secret")
if not client_id or not client_secret:
raise ValueError(
"Databricks OAuth (Service Principal) requires client_id and client_secret."
)
connect_args["credentials_provider"] = _build_m2m_credentials_provider(
endpoint.host, client_id, client_secret
)
else:
raise ValueError(f"Unknown Databricks auth_type: {auth_type}")

connect_args.update(config.extra_options)
# http_path may have been passed via extra_options; drop the legacy key
# so it isn't sent twice if both schemes were used.
connect_args.pop("http_path", None)
connect_args["http_path"] = http_path

return sql_module.connect(**connect_args)

def get_databases(self, conn: Any) -> list[str]:
"""List Unity Catalog catalogs."""
cursor = conn.cursor()
# SHOW CATALOGS is universally supported and avoids needing
# SELECT privilege on system.information_schema.
cursor.execute("SHOW CATALOGS")
return [row[0] for row in cursor.fetchall()]

def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]:
cursor = conn.cursor()
if database:
cursor.execute(
"SELECT table_schema, table_name FROM "
f"{self.quote_identifier(database)}.information_schema.tables "
"WHERE table_type IN ('MANAGED', 'EXTERNAL', 'BASE TABLE') "
"ORDER BY table_schema, table_name"
)
return [(row[0], row[1]) for row in cursor.fetchall()]

cursor.execute("SHOW TABLES")
return [(row[0], row[1]) for row in cursor.fetchall()]

def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]:
cursor = conn.cursor()
if database:
cursor.execute(
"SELECT table_schema, table_name FROM "
f"{self.quote_identifier(database)}.information_schema.views "
"ORDER BY table_schema, table_name"
)
return [(row[0], row[1]) for row in cursor.fetchall()]

cursor.execute("SHOW VIEWS")
# SHOW VIEWS columns: database, viewName, isTemporary
return [(row[0], row[1]) for row in cursor.fetchall()]

def get_columns(
self, conn: Any, table: str, database: str | None = None, schema: str | None = None
) -> list[ColumnInfo]:
cursor = conn.cursor()
schema_name = schema or self.default_schema
if database:
cursor.execute(
"SELECT column_name, data_type FROM "
f"{self.quote_identifier(database)}.information_schema.columns "
"WHERE table_schema = ? AND table_name = ? "
"ORDER BY ordinal_position",
(schema_name, table),
)
else:
cursor.execute(
"SELECT column_name, data_type FROM information_schema.columns "
"WHERE table_schema = ? AND table_name = ? "
"ORDER BY ordinal_position",
(schema_name, table),
)
return [ColumnInfo(name=row[0], data_type=row[1]) for row in cursor.fetchall()]

def get_procedures(self, conn: Any, database: str | None = None) -> list[str]:
return []

def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]:
return []

def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]:
return []

def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]:
return []

def quote_identifier(self, name: str) -> str:
"""Quote identifier using backticks (Databricks/Spark SQL standard)."""
escaped = name.replace("`", "``")
return f"`{escaped}`"

def build_select_query(
self, table: str, limit: int, database: str | None = None, schema: str | None = None
) -> str:
schema_name = schema or self.default_schema
if database and schema_name:
return (
f"SELECT * FROM {self.quote_identifier(database)}."
f"{self.quote_identifier(schema_name)}."
f"{self.quote_identifier(table)} LIMIT {limit}"
)
if schema_name:
return f"SELECT * FROM {self.quote_identifier(schema_name)}.{self.quote_identifier(table)} LIMIT {limit}"
return f"SELECT * FROM {self.quote_identifier(table)} LIMIT {limit}"


def _build_m2m_credentials_provider(host: str, client_id: str, client_secret: str) -> Any:
"""Return a credentials_provider callable for Databricks OAuth M2M.

Imported lazily so the databricks-sdk dependency is only required
when the user actually selects service-principal auth.
"""

def _factory() -> Any:
from databricks.sdk.core import Config, oauth_service_principal

cfg = Config(
host=host if host.startswith(("http://", "https://")) else f"https://{host}",
client_id=client_id,
client_secret=client_secret,
)
return oauth_service_principal(cfg)

return _factory
29 changes: 29 additions & 0 deletions sqlit/domains/connections/providers/databricks/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Provider registration."""

from sqlit.domains.connections.providers.adapter_provider import build_adapter_provider
from sqlit.domains.connections.providers.catalog import register_provider
from sqlit.domains.connections.providers.databricks.schema import SCHEMA
from sqlit.domains.connections.providers.model import DatabaseProvider, ProviderSpec


def _provider_factory(spec: ProviderSpec) -> DatabaseProvider:
from sqlit.domains.connections.providers.databricks.adapter import DatabricksAdapter

return build_adapter_provider(spec, SCHEMA, DatabricksAdapter())


SPEC = ProviderSpec(
db_type="databricks",
display_name="Databricks",
schema_path=("sqlit.domains.connections.providers.databricks.schema", "SCHEMA"),
supports_ssh=False,
is_file_based=False,
has_advanced_auth=True,
default_port="",
requires_auth=True,
badge_label="DBRX",
url_schemes=("databricks",),
provider_factory=_provider_factory,
)

register_provider(SPEC)
Loading
Loading