Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,78 @@
# Transaction isolation level constants (extension to PEP 249)
TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ"

# All supported **kwargs for Connection.__init__. Used to warn on unknown params.
KNOWN_KWARGS = frozenset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add a test test_known_kwargs_is_complete

[
# Authentication
"access_token",
"auth_type",
"username",
"password",
"credentials_provider",
"_use_cert_as_auth",
"oauth_client_id",
"oauth_redirect_port",
"oauth_scopes",
"experimental_oauth_persistence",
"identity_federation_client_id",
"azure_client_id",
"azure_client_secret",
"azure_tenant_id",
"azure_workspace_resource_id",
# TLS / SSL
"_enable_ssl",
"_tls_no_verify",
"_tls_verify_hostname",
"_tls_trusted_ca_file",
"_tls_client_cert_file",
"_tls_client_cert_key_file",
"_tls_client_cert_key_password",
# Connection
"_port",
"_connection_uri",
"_skip_routing_headers",
# Retry policy
"_retry_delay_min",
"_retry_delay_max",
"_retry_delay_default",
"_retry_stop_after_attempts_count",
"_retry_stop_after_attempts_duration",
"_retry_dangerous_codes",
"_retry_max_redirects",
"_enable_v3_retries",
# Socket / network
"_socket_timeout",
"_proxy_auth_method",
"_pool_connections",
"_pool_maxsize",
"pool_maxsize",
# User agent & telemetry
"user_agent_entry",
"_user_agent_entry",
"enable_telemetry",
"force_enable_telemetry",
"telemetry_batch_size",
"_telemetry_circuit_breaker_enabled",
# Query / result
"use_inline_params",
"enable_query_result_lz4_compression",
"use_cloud_fetch",
"max_download_threads",
"use_hybrid_disposition",
"staging_allowed_local_path",
# Data type / format
"_use_arrow_native_decimals",
"_use_arrow_native_timestamps",
"_disable_pandas",
# Behavior
"enable_metric_view_metadata",
"fetch_autocommit_from_server",
# Backend selection
"use_sea",
]
)


class Connection:
def __init__(
Expand Down Expand Up @@ -269,6 +341,14 @@ def read(self) -> Optional[OAuthToken]:
http_path,
)

unknown_params = set(kwargs.keys()) - KNOWN_KWARGS
if unknown_params:
logger.warning(
"Unsupported connection parameter(s) will be ignored: %s. "
"Check the Connection documentation for supported parameters.",
", ".join(sorted(unknown_params)),
)

if access_token:
access_token_kv = {"access_token": access_token}
kwargs = {**kwargs, **access_token_kv}
Expand Down
55 changes: 55 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,61 @@ def mock_close_normal():
cursors_closed, [1, 2], "Both cursors should have close called"
)

@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
@patch("databricks.sql.client.logger")
def test_unknown_connection_param_issues_warning(
self, mock_logger, mock_client_class
):
"""Unknown kwargs passed to Connection should trigger a warning."""
databricks.sql.connect(
**self.DUMMY_CONNECTION_ARGS, totally_unknown_param="value"
)
mock_logger.warning.assert_called_once()
warning_msg = mock_logger.warning.call_args[0][0]
self.assertIn("Unsupported connection parameter", warning_msg)

@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
@patch("databricks.sql.client.logger")
def test_unknown_connection_param_warning_names_the_param(
self, mock_logger, mock_client_class
):
"""Warning message should include the name of the unknown parameter."""
databricks.sql.connect(
**self.DUMMY_CONNECTION_ARGS, totally_unknown_param="value"
)
# The unknown param name should appear in the warning args
call_args = mock_logger.warning.call_args
formatted_msg = call_args[0][0] % call_args[0][1:]
self.assertIn("totally_unknown_param", formatted_msg)

@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
@patch("databricks.sql.client.logger")
def test_known_connection_params_do_not_issue_warning(
self, mock_logger, mock_client_class
):
"""Known kwargs passed to Connection should not trigger an unknown-param warning."""
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, use_cloud_fetch=False)
for call in mock_logger.warning.call_args_list:
msg = call[0][0]
self.assertNotIn("Unsupported connection parameter", msg)

@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
@patch("databricks.sql.client.logger")
def test_multiple_unknown_params_all_appear_in_warning(
self, mock_logger, mock_client_class
):
"""All unknown param names should appear in the warning message."""
databricks.sql.connect(
**self.DUMMY_CONNECTION_ARGS,
bad_param_one="a",
bad_param_two="b",
)
mock_logger.warning.assert_called_once()
call_args = mock_logger.warning.call_args
formatted_msg = call_args[0][0] % call_args[0][1:]
self.assertIn("bad_param_one", formatted_msg)
self.assertIn("bad_param_two", formatted_msg)


class TransactionTestSuite(unittest.TestCase):
"""
Expand Down
Loading