Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from google.cloud import bigquery
from packaging.version import parse as parse_version
from pydantic import BaseModel
from requests.adapters import HTTPAdapter, Retry
from sqlalchemy.engine import URL, Connection, create_engine, make_url
from sqlalchemy.exc import ResourceClosedError

Expand Down Expand Up @@ -263,13 +264,27 @@ class ExecuteSqlError(Exception):
)


def _generate_temporary_credentials(integration_id):
def _create_retry_session() -> requests.Session:
"""Create a requests session with retry on 5xx for POST requests."""
session = requests.Session()
retries = Retry(
total=3,
backoff_factor=0.5,
status_forcelist=[500, 502, 503, 504],
)
Comment thread
tkislan marked this conversation as resolved.
session.mount("http://", HTTPAdapter(max_retries=retries))
session.mount("https://", HTTPAdapter(max_retries=retries))
return session
Comment thread
tkislan marked this conversation as resolved.


def _generate_temporary_credentials(integration_id) -> tuple[str, str]:
url = get_absolute_userpod_api_url(f"integrations/credentials/{integration_id}")

# Add project credentials in detached mode
headers = get_project_auth_headers()

response = requests.post(url, timeout=10, headers=headers)
session = _create_retry_session()
response = session.post(url, timeout=10, headers=headers)

response.raise_for_status()

Expand All @@ -291,7 +306,8 @@ def _get_federated_auth_credentials(
headers = get_project_auth_headers()
headers["UserPodAuthContextToken"] = user_pod_auth_context_token

response = requests.post(url, timeout=10, headers=headers)
session = _create_retry_session()
response = session.post(url, timeout=10, headers=headers)

response.raise_for_status()

Expand Down
94 changes: 90 additions & 4 deletions tests/unit/test_sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,9 @@ def test_all_dataframes_serialize_to_parquet(self, key, df):
class TestFederatedAuth(unittest.TestCase):
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
@mock.patch("deepnote_toolkit.sql.sql_execution.requests.post")
@mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session")
def test_get_federated_auth_credentials_returns_validated_response(
self, mock_post, mock_get_url, mock_get_headers
self, mock_create_session, mock_get_url, mock_get_headers
):
"""Test that _get_federated_auth_credentials properly validates and returns response data."""
from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials
Expand All @@ -603,12 +603,14 @@ def test_get_federated_auth_credentials_returns_validated_response(
mock_get_url.return_value = "https://api.example.com/integrations/federated-auth-token/test-integration-id"
mock_get_headers.return_value = {"Authorization": "Bearer project-token"}

mock_session = mock.Mock()
mock_response = mock.Mock()
mock_response.json.return_value = {
"integrationType": "trino",
"accessToken": "test-access-token-123",
}
mock_post.return_value = mock_response
mock_session.post.return_value = mock_response
mock_create_session.return_value = mock_session

# Call the function
result = _get_federated_auth_credentials(
Expand All @@ -621,7 +623,7 @@ def test_get_federated_auth_credentials_returns_validated_response(
)

# Verify headers include both project auth and user pod auth context token
mock_post.assert_called_once_with(
mock_session.post.assert_called_once_with(
"https://api.example.com/integrations/federated-auth-token/test-integration-id",
timeout=10,
headers={
Expand Down Expand Up @@ -1019,3 +1021,87 @@ def test_databricks_connector_dialect_alias_is_registered(self):

self.assertEqual(url.drivername, "databricks+connector")
self.assertIsNotNone(dialect_cls)


class TestCreateRetrySession(unittest.TestCase):
def test_retry_session_has_correct_config(self):
"""Test that _create_retry_session configures retries correctly."""
from deepnote_toolkit.sql.sql_execution import _create_retry_session

session = _create_retry_session()

# Check that both http and https adapters are mounted with retry config
for prefix in ("http://", "https://"):
adapter = session.get_adapter(prefix)
retries = adapter.max_retries
self.assertEqual(retries.total, 3)
self.assertEqual(retries.backoff_factor, 0.5)
self.assertEqual(list(retries.status_forcelist), [500, 502, 503, 504])
self.assertIn("POST", retries.allowed_methods)
Comment thread
tkislan marked this conversation as resolved.
Outdated

@mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_generate_temporary_credentials_uses_retry_session(
self, mock_get_url, mock_get_headers, mock_create_session
):
"""Test that _generate_temporary_credentials uses a retry session."""
from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials

mock_get_url.return_value = (
"https://api.example.com/integrations/credentials/test-id"
)
mock_get_headers.return_value = {"Authorization": "Bearer token"}

mock_session = mock.Mock()
mock_response = mock.Mock()
mock_response.json.return_value = {
"username": "user",
"password": "pass",
}
mock_session.post.return_value = mock_response
mock_create_session.return_value = mock_session

_generate_temporary_credentials("test-id")

mock_create_session.assert_called_once()
mock_session.post.assert_called_once_with(
"https://api.example.com/integrations/credentials/test-id",
timeout=10,
headers={"Authorization": "Bearer token"},
)

@mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_get_federated_auth_credentials_uses_retry_session(
self, mock_get_url, mock_get_headers, mock_create_session
):
"""Test that _get_federated_auth_credentials uses a retry session."""
from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials

mock_get_url.return_value = (
"https://api.example.com/integrations/federated-auth-token/test-id"
)
mock_get_headers.return_value = {"Authorization": "Bearer token"}

mock_session = mock.Mock()
mock_response = mock.Mock()
mock_response.json.return_value = {
"integrationType": "trino",
"accessToken": "test-token",
}
mock_session.post.return_value = mock_response
mock_create_session.return_value = mock_session

_get_federated_auth_credentials("test-id", "auth-context-token")

mock_create_session.assert_called_once()
mock_session.post.assert_called_once_with(
"https://api.example.com/integrations/federated-auth-token/test-id",
timeout=10,
headers={
"Authorization": "Bearer token",
"UserPodAuthContextToken": "auth-context-token",
},
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Loading