Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions databricks-mcp-server/databricks_mcp_server/tools/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def execute_sql(
catalog: str = None,
schema: str = None,
timeout: int = 180,
query_tags: str = None,
) -> List[Dict[str, Any]]:
"""
Execute a SQL query on a Databricks SQL Warehouse.
Expand All @@ -37,6 +38,8 @@ def execute_sql(
catalog: Optional catalog context for unqualified table names.
schema: Optional schema context for unqualified table names.
timeout: Timeout in seconds (default: 180)
query_tags: Optional query tags for cost attribution (e.g., "team:eng,cost_center:701").
Appears in system.query.history and Query History UI.

Returns:
List of dictionaries, each representing a row with column names as keys.
Expand All @@ -47,6 +50,7 @@ def execute_sql(
catalog=catalog,
schema=schema,
timeout=timeout,
query_tags=query_tags,
)


Expand All @@ -58,6 +62,7 @@ def execute_sql_multi(
schema: str = None,
timeout: int = 180,
max_workers: int = 4,
query_tags: str = None,
) -> Dict[str, Any]:
"""
Execute multiple SQL statements with dependency-aware parallelism.
Expand All @@ -76,6 +81,7 @@ def execute_sql_multi(
schema: Optional schema context for unqualified table names.
timeout: Timeout per query in seconds (default: 180)
max_workers: Maximum parallel queries per group (default: 4)
query_tags: Optional query tags for cost attribution (e.g., "team:eng,cost_center:701").

Returns:
Dictionary with results per query and execution summary.
Expand All @@ -87,6 +93,7 @@ def execute_sql_multi(
schema=schema,
timeout=timeout,
max_workers=max_workers,
query_tags=query_tags,
)


Expand Down
8 changes: 8 additions & 0 deletions databricks-tools-core/databricks_tools_core/sql/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def execute_sql(
catalog: Optional[str] = None,
schema: Optional[str] = None,
timeout: int = 180,
query_tags: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Execute a SQL query on a Databricks SQL Warehouse.
Expand All @@ -32,6 +33,9 @@ def execute_sql(
catalog: Optional catalog context. If not provided, use fully qualified names.
schema: Optional schema context. If not provided, use fully qualified names.
timeout: Timeout in seconds (default: 180)
query_tags: Optional query tags for cost attribution and filtering.
Format: "key:value,key2:value2" (e.g., "team:eng,cost_center:701").
Appears in system.query.history and Query History UI.

Returns:
List of dictionaries, each representing a row with column names as keys.
Expand Down Expand Up @@ -64,6 +68,7 @@ def execute_sql(
catalog=catalog,
schema=schema,
timeout=timeout,
query_tags=query_tags,
)


Expand All @@ -74,6 +79,7 @@ def execute_sql_multi(
schema: Optional[str] = None,
timeout: int = 180,
max_workers: int = 4,
query_tags: Optional[str] = None,
) -> Dict[str, Any]:
"""
Execute multiple SQL statements with dependency-aware parallelism.
Expand All @@ -92,6 +98,7 @@ def execute_sql_multi(
schema: Optional schema context. If not provided, use fully qualified names.
timeout: Timeout per query in seconds (default: 180)
max_workers: Maximum parallel queries per group (default: 4)
query_tags: Optional query tags for cost attribution (e.g., "team:eng,cost_center:701").

Returns:
Dictionary with:
Expand Down Expand Up @@ -148,4 +155,5 @@ def execute_sql_multi(
catalog=catalog,
schema=schema,
timeout=timeout,
query_tags=query_tags,
)
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def execute(
schema: Optional[str] = None,
row_limit: Optional[int] = None,
timeout: int = 180,
query_tags: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Execute a SQL query and return results as a list of dictionaries.
Expand All @@ -60,6 +61,9 @@ def execute(
schema: Optional schema context for the query
row_limit: Optional maximum number of rows to return
timeout: Timeout in seconds (default: 180)
query_tags: Optional query tags for cost attribution and filtering.
Format: "key:value,key2:value2" (e.g., "team:eng,cost_center:701").
Appears in system.query.history and Query History UI.

Returns:
List of dictionaries, each representing a row with column names as keys
Expand All @@ -81,6 +85,8 @@ def execute(
exec_params["schema"] = schema
if row_limit is not None:
exec_params["row_limit"] = row_limit
if query_tags:
exec_params["query_tags"] = query_tags

# Submit the statement
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def execute(
catalog: Optional[str] = None,
schema: Optional[str] = None,
timeout: int = 180,
query_tags: Optional[str] = None,
) -> Dict[str, Any]:
"""
Execute multiple SQL statements with dependency-aware parallelism.
Expand All @@ -63,6 +64,7 @@ def execute(
catalog: Optional catalog context for queries
schema: Optional schema context for queries
timeout: Timeout per query in seconds (default: 180)
query_tags: Optional query tags for cost attribution (e.g., "team:eng,cost_center:701")

Returns:
Dictionary with:
Expand Down Expand Up @@ -109,6 +111,7 @@ def execute(
catalog=catalog,
schema=schema,
timeout=timeout,
query_tags=query_tags,
)

# Store results and check for errors
Expand Down Expand Up @@ -149,6 +152,7 @@ def _execute_group(
catalog: Optional[str],
schema: Optional[str],
timeout: int,
query_tags: Optional[str] = None,
) -> Dict[int, Dict[str, Any]]:
"""Execute a group of queries in parallel using ThreadPoolExecutor."""
results: Dict[int, Dict[str, Any]] = {}
Expand All @@ -168,6 +172,7 @@ def execute_single(query_idx: int) -> Dict[str, Any]:
catalog=catalog,
schema=schema,
timeout=timeout,
query_tags=query_tags,
)

dt = round(time.time() - t0, 2)
Expand Down
120 changes: 120 additions & 0 deletions databricks-tools-core/tests/unit/test_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Unit tests for SQL execution functions."""

from unittest import mock

import pytest
from databricks.sdk.service.sql import StatementState

from databricks_tools_core.sql import execute_sql, execute_sql_multi
from databricks_tools_core.sql.sql_utils import SQLExecutor


class TestExecuteSQLQueryTags:
"""Tests for query_tags parameter passthrough."""

@mock.patch("databricks_tools_core.sql.sql.get_best_warehouse", return_value="wh-123")
@mock.patch("databricks_tools_core.sql.sql.SQLExecutor")
def test_execute_sql_passes_query_tags_to_executor(self, mock_executor_cls, mock_warehouse):
"""query_tags should be passed through to SQLExecutor.execute()."""
mock_executor = mock.Mock()
mock_executor.execute.return_value = [{"num": 1}]
mock_executor_cls.return_value = mock_executor

execute_sql(
sql_query="SELECT 1",
warehouse_id="wh-123",
query_tags="team:eng,cost_center:701",
)

mock_executor.execute.assert_called_once()
call_kwargs = mock_executor.execute.call_args.kwargs
assert call_kwargs["query_tags"] == "team:eng,cost_center:701"

@mock.patch("databricks_tools_core.sql.sql.get_best_warehouse", return_value="wh-123")
@mock.patch("databricks_tools_core.sql.sql.SQLExecutor")
def test_execute_sql_without_query_tags(self, mock_executor_cls, mock_warehouse):
"""When query_tags not provided, executor should not receive it (or receive None)."""
mock_executor = mock.Mock()
mock_executor.execute.return_value = [{"num": 1}]
mock_executor_cls.return_value = mock_executor

execute_sql(sql_query="SELECT 1", warehouse_id="wh-123")

mock_executor.execute.assert_called_once()
call_kwargs = mock_executor.execute.call_args.kwargs
assert call_kwargs.get("query_tags") is None

@mock.patch("databricks_tools_core.sql.sql.get_best_warehouse", return_value="wh-123")
@mock.patch("databricks_tools_core.sql.sql.SQLParallelExecutor")
def test_execute_sql_multi_passes_query_tags(self, mock_parallel_cls, mock_warehouse):
"""query_tags should be passed through to SQLParallelExecutor.execute()."""
mock_executor = mock.Mock()
mock_executor.execute.return_value = {
"results": {0: {"status": "success", "query_index": 0}},
"execution_summary": {"total_queries": 1, "total_groups": 1},
}
mock_parallel_cls.return_value = mock_executor

execute_sql_multi(
sql_content="SELECT 1;",
warehouse_id="wh-123",
query_tags="app:agent,env:dev",
)

mock_executor.execute.assert_called_once()
call_kwargs = mock_executor.execute.call_args.kwargs
assert call_kwargs["query_tags"] == "app:agent,env:dev"


class TestSQLExecutorQueryTags:
"""Tests for SQLExecutor passing query_tags to the API."""

@mock.patch("databricks_tools_core.sql.sql_utils.executor.get_workspace_client")
def test_executor_passes_query_tags_to_api(self, mock_get_client):
"""SQLExecutor.execute() should include query_tags in execute_statement call."""
mock_client = mock.Mock()
mock_response = mock.Mock()
mock_response.statement_id = "stmt-1"
mock_client.statement_execution.execute_statement.return_value = mock_response

# Simulate SUCCEEDED state on get_statement
mock_status = mock.Mock()
mock_status.status.state = StatementState.SUCCEEDED
mock_status.result = mock.Mock()
mock_status.result.data_array = []
mock_status.manifest = None
mock_client.statement_execution.get_statement.return_value = mock_status

mock_get_client.return_value = mock_client

executor = SQLExecutor(warehouse_id="wh-123", client=mock_client)
executor.execute(
sql_query="SELECT 1",
query_tags="team:eng,cost_center:701",
)

call_kwargs = mock_client.statement_execution.execute_statement.call_args.kwargs
assert call_kwargs.get("query_tags") == "team:eng,cost_center:701"

@mock.patch("databricks_tools_core.sql.sql_utils.executor.get_workspace_client")
def test_executor_without_query_tags_omits_from_api(self, mock_get_client):
"""When query_tags not provided, it should not be in the API call."""
mock_client = mock.Mock()
mock_response = mock.Mock()
mock_response.statement_id = "stmt-1"
mock_client.statement_execution.execute_statement.return_value = mock_response

mock_status = mock.Mock()
mock_status.status.state = StatementState.SUCCEEDED
mock_status.result = mock.Mock()
mock_status.result.data_array = []
mock_status.manifest = None
mock_client.statement_execution.get_statement.return_value = mock_status

mock_get_client.return_value = mock_client

executor = SQLExecutor(warehouse_id="wh-123", client=mock_client)
executor.execute(sql_query="SELECT 1")

call_kwargs = mock_client.statement_execution.execute_statement.call_args.kwargs
assert "query_tags" not in call_kwargs