diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 1a246b7c1..266c6eff4 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -93,6 +93,78 @@ # Transaction isolation level constants (extension to PEP 249) TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ" +# All supported **kwargs for Connection.__init__. Used to warn on unknown params. +KNOWN_KWARGS = frozenset( + [ + # Authentication + "access_token", + "auth_type", + "username", + "password", + "credentials_provider", + "_use_cert_as_auth", + "oauth_client_id", + "oauth_redirect_port", + "oauth_scopes", + "experimental_oauth_persistence", + "identity_federation_client_id", + "azure_client_id", + "azure_client_secret", + "azure_tenant_id", + "azure_workspace_resource_id", + # TLS / SSL + "_enable_ssl", + "_tls_no_verify", + "_tls_verify_hostname", + "_tls_trusted_ca_file", + "_tls_client_cert_file", + "_tls_client_cert_key_file", + "_tls_client_cert_key_password", + # Connection + "_port", + "_connection_uri", + "_skip_routing_headers", + # Retry policy + "_retry_delay_min", + "_retry_delay_max", + "_retry_delay_default", + "_retry_stop_after_attempts_count", + "_retry_stop_after_attempts_duration", + "_retry_dangerous_codes", + "_retry_max_redirects", + "_enable_v3_retries", + # Socket / network + "_socket_timeout", + "_proxy_auth_method", + "_pool_connections", + "_pool_maxsize", + "pool_maxsize", + # User agent & telemetry + "user_agent_entry", + "_user_agent_entry", + "enable_telemetry", + "force_enable_telemetry", + "telemetry_batch_size", + "_telemetry_circuit_breaker_enabled", + # Query / result + "use_inline_params", + "enable_query_result_lz4_compression", + "use_cloud_fetch", + "max_download_threads", + "use_hybrid_disposition", + "staging_allowed_local_path", + # Data type / format + "_use_arrow_native_decimals", + "_use_arrow_native_timestamps", + "_disable_pandas", + # Behavior + "enable_metric_view_metadata", + "fetch_autocommit_from_server", + # Backend selection + "use_sea", + ] +) + class Connection: def __init__( @@ -269,6 +341,14 @@ def read(self) -> Optional[OAuthToken]: http_path, ) + unknown_params = set(kwargs.keys()) - KNOWN_KWARGS + if unknown_params: + logger.warning( + "Unsupported connection parameter(s) will be ignored: %s. " + "Check the Connection documentation for supported parameters.", + ", ".join(sorted(unknown_params)), + ) + if access_token: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5b6991931..ece617c21 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -633,6 +633,61 @@ def mock_close_normal(): cursors_closed, [1, 2], "Both cursors should have close called" ) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + @patch("databricks.sql.client.logger") + def test_unknown_connection_param_issues_warning( + self, mock_logger, mock_client_class + ): + """Unknown kwargs passed to Connection should trigger a warning.""" + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, totally_unknown_param="value" + ) + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args[0][0] + self.assertIn("Unsupported connection parameter", warning_msg) + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + @patch("databricks.sql.client.logger") + def test_unknown_connection_param_warning_names_the_param( + self, mock_logger, mock_client_class + ): + """Warning message should include the name of the unknown parameter.""" + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, totally_unknown_param="value" + ) + # The unknown param name should appear in the warning args + call_args = mock_logger.warning.call_args + formatted_msg = call_args[0][0] % call_args[0][1:] + self.assertIn("totally_unknown_param", formatted_msg) + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + @patch("databricks.sql.client.logger") + def test_known_connection_params_do_not_issue_warning( + self, mock_logger, mock_client_class + ): + """Known kwargs passed to Connection should not trigger an unknown-param warning.""" + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, use_cloud_fetch=False) + for call in mock_logger.warning.call_args_list: + msg = call[0][0] + self.assertNotIn("Unsupported connection parameter", msg) + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + @patch("databricks.sql.client.logger") + def test_multiple_unknown_params_all_appear_in_warning( + self, mock_logger, mock_client_class + ): + """All unknown param names should appear in the warning message.""" + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, + bad_param_one="a", + bad_param_two="b", + ) + mock_logger.warning.assert_called_once() + call_args = mock_logger.warning.call_args + formatted_msg = call_args[0][0] % call_args[0][1:] + self.assertIn("bad_param_one", formatted_msg) + self.assertIn("bad_param_two", formatted_msg) + class TransactionTestSuite(unittest.TestCase): """