Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 103 additions & 9 deletions deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import base64
import contextlib
import json
import logging
import re
import uuid
import warnings
from typing import Any
from urllib.parse import quote

import google.oauth2.credentials
Expand All @@ -14,6 +16,7 @@
from google.api_core.client_info import ClientInfo
from google.cloud import bigquery
from packaging.version import parse as parse_version
from pydantic import BaseModel, ValidationError
from sqlalchemy.engine import URL, create_engine, make_url
from sqlalchemy.exc import ResourceClosedError

Expand All @@ -33,6 +36,18 @@
from deepnote_toolkit.sql.sql_utils import is_single_select_query
from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url

logger = logging.getLogger(__name__)


class IntegrationFederatedAuthParams(BaseModel):
integrationId: str
authContextToken: str


class FederatedAuthResponseData(BaseModel):
integrationType: str
accessToken: str

Comment thread
tkislan marked this conversation as resolved.

def compile_sql_query(
skip_jinja_template_render,
Expand Down Expand Up @@ -242,11 +257,97 @@ def _generate_temporary_credentials(integration_id):

response = requests.post(url, timeout=10, headers=headers)

response.raise_for_status()

data = response.json()

return quote(data["username"]), quote(data["password"])


def _get_federated_auth_credentials(
integration_id: str, user_pod_auth_context_token: str
) -> FederatedAuthResponseData:
"""Get federated auth credentials for the given integration ID and user pod auth context token."""

url = get_absolute_userpod_api_url(
f"integrations/federated-auth-token/{integration_id}"
)

# Add project credentials in detached mode
headers = get_project_auth_headers()
headers["UserPodAuthContextToken"] = user_pod_auth_context_token

response = requests.post(url, timeout=10, headers=headers)

response.raise_for_status()

data = FederatedAuthResponseData.model_validate(response.json())

return data

Comment thread
tkislan marked this conversation as resolved.

def _handle_iam_params(sql_alchemy_dict: dict[str, Any]) -> None:
"""Apply IAM credentials to the connection URL in-place."""

if "iamParams" not in sql_alchemy_dict:
return

integration_id = sql_alchemy_dict["iamParams"]["integrationId"]

temporary_username, temporary_password = _generate_temporary_credentials(
integration_id
)

sql_alchemy_dict["url"] = replace_user_pass_in_pg_url(
sql_alchemy_dict["url"], temporary_username, temporary_password
)

Comment thread
coderabbitai[bot] marked this conversation as resolved.

def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None:
"""Fetch and apply federated auth credentials to connection params in-place."""

if "federatedAuthParams" not in sql_alchemy_dict:
return

try:
federated_auth_params = IntegrationFederatedAuthParams.model_validate(
sql_alchemy_dict["federatedAuthParams"]
)
except ValidationError:
logger.exception("Invalid federated auth params, try updating toolkit version")
return

federated_auth = _get_federated_auth_credentials(
federated_auth_params.integrationId, federated_auth_params.authContextToken
)

if federated_auth.integrationType == "trino":
try:
sql_alchemy_dict["params"]["connect_args"]["http_headers"][
"Authorization"
] = f"Bearer {federated_auth.accessToken}"
except KeyError:
logger.exception(
"Invalid federated auth params, try updating toolkit version"
)
Comment thread
tkislan marked this conversation as resolved.
elif federated_auth.integrationType == "big-query":
try:
sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken
except KeyError:
logger.exception(
"Invalid federated auth params, try updating toolkit version"
)
elif federated_auth.integrationType == "snowflake":
logger.warning(
"Snowflake federated auth is not supported yet, using the original connection URL"
)
else:
logger.error(
"Unsupported integration type: %s, try updating toolkit version",
federated_auth.integrationType,
)
Comment thread
tkislan marked this conversation as resolved.

Comment thread
coderabbitai[bot] marked this conversation as resolved.

@contextlib.contextmanager
def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict):
server = None
Expand Down Expand Up @@ -346,16 +447,9 @@ def _query_data_source(
):
sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False)

if "iamParams" in sql_alchemy_dict:
integration_id = sql_alchemy_dict["iamParams"]["integrationId"]
_handle_iam_params(sql_alchemy_dict)

temporaryUsername, temporaryPassword = _generate_temporary_credentials(
integration_id
)

sql_alchemy_dict["url"] = replace_user_pass_in_pg_url(
sql_alchemy_dict["url"], temporaryUsername, temporaryPassword
)
_handle_federated_auth_params(sql_alchemy_dict)

with _create_sql_ssh_uri(sshEnabled, sql_alchemy_dict) as url:
if url is None:
Expand Down
219 changes: 219 additions & 0 deletions tests/unit/test_sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,3 +585,222 @@ def test_all_dataframes_serialize_to_parquet(self, key, df):
df_cleaned.to_parquet(in_memory_file)
except: # noqa: E722
self.fail(f"serializing to parquet failed for {key}")


class TestFederatedAuth(unittest.TestCase):
@mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials")
def test_federated_auth_params_trino(self, mock_get_credentials):
"""Test that Trino federated auth updates the Authorization header with Bearer token."""
from deepnote_toolkit.sql.sql_execution import (
FederatedAuthResponseData,
_handle_federated_auth_params,
)

# Setup mock to return Trino credentials
mock_get_credentials.return_value = FederatedAuthResponseData(
integrationType="trino",
accessToken="test-trino-access-token",
)

# Create a sql_alchemy_dict with federatedAuthParams and the expected structure
sql_alchemy_dict = {
"url": "trino://user@localhost:8080/catalog",
"params": {
"connect_args": {
"http_headers": {
"Authorization": "Bearer old-token",
}
}
},
"federatedAuthParams": {
"integrationId": "test-integration-id",
"authContextToken": "test-auth-context-token",
},
}

# Call the function
_handle_federated_auth_params(sql_alchemy_dict)

# Verify the API was called with correct params
mock_get_credentials.assert_called_once_with(
"test-integration-id", "test-auth-context-token"
)

# Verify the Authorization header was updated with the new token
self.assertEqual(
sql_alchemy_dict["params"]["connect_args"]["http_headers"]["Authorization"],
"Bearer test-trino-access-token",
)

@mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials")
def test_federated_auth_params_bigquery(self, mock_get_credentials):
"""Test that BigQuery federated auth updates the access_token in params."""
from deepnote_toolkit.sql.sql_execution import (
FederatedAuthResponseData,
_handle_federated_auth_params,
)

# Setup mock to return BigQuery credentials
mock_get_credentials.return_value = FederatedAuthResponseData(
integrationType="big-query",
accessToken="test-bigquery-access-token",
)

# Create a sql_alchemy_dict with federatedAuthParams
sql_alchemy_dict = {
"url": "bigquery://?user_supplied_client=true",
"params": {
"access_token": "old-access-token",
"project": "test-project",
},
"federatedAuthParams": {
"integrationId": "test-bigquery-integration-id",
"authContextToken": "test-bigquery-auth-context-token",
},
}

# Call the function
_handle_federated_auth_params(sql_alchemy_dict)

# Verify the API was called with correct params
mock_get_credentials.assert_called_once_with(
"test-bigquery-integration-id", "test-bigquery-auth-context-token"
)

# Verify the access_token was updated with the new token
self.assertEqual(
sql_alchemy_dict["params"]["access_token"],
"test-bigquery-access-token",
)

@mock.patch("deepnote_toolkit.sql.sql_execution.logger")
@mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials")
def test_federated_auth_params_snowflake(self, mock_get_credentials, mock_logger):
"""Test that Snowflake federated auth logs a warning since it's not supported yet."""
from deepnote_toolkit.sql.sql_execution import (
FederatedAuthResponseData,
_handle_federated_auth_params,
)

# Setup mock to return Snowflake credentials
mock_get_credentials.return_value = FederatedAuthResponseData(
integrationType="snowflake",
accessToken="test-snowflake-access-token",
)

# Create a sql_alchemy_dict with federatedAuthParams
sql_alchemy_dict = {
"url": "snowflake://test@test?warehouse=&role=&application=Deepnote_Workspaces",
"params": {},
"federatedAuthParams": {
"integrationId": "test-snowflake-integration-id",
"authContextToken": "test-snowflake-auth-context-token",
},
}

# Store original params to verify they remain unchanged
original_params = sql_alchemy_dict["params"].copy()

# Call the function
_handle_federated_auth_params(sql_alchemy_dict)

# Verify the API was called with correct params
mock_get_credentials.assert_called_once_with(
"test-snowflake-integration-id", "test-snowflake-auth-context-token"
)

# Verify a warning was logged
mock_logger.warning.assert_called_once_with(
"Snowflake federated auth is not supported yet, using the original connection URL"
)

# Verify params were NOT modified (snowflake is not supported yet)
self.assertEqual(sql_alchemy_dict["params"], original_params)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def test_federated_auth_params_not_present(self):
"""Test that no action is taken when federatedAuthParams is not present."""
from deepnote_toolkit.sql.sql_execution import _handle_federated_auth_params

# Create a sql_alchemy_dict without federatedAuthParams
sql_alchemy_dict = {
"url": "trino://user@localhost:8080/catalog",
"params": {
"connect_args": {
"http_headers": {"Authorization": "Bearer original-token"}
}
},
}

original_dict = json.loads(json.dumps(sql_alchemy_dict))
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

# Call the function
_handle_federated_auth_params(sql_alchemy_dict)

# Verify the dict was not modified
self.assertEqual(sql_alchemy_dict, original_dict)

@mock.patch("deepnote_toolkit.sql.sql_execution.logger")
def test_federated_auth_params_invalid_params(self, mock_logger):
"""Test that invalid federated auth params logs an error and returns early."""
from deepnote_toolkit.sql.sql_execution import _handle_federated_auth_params

# Create a sql_alchemy_dict with invalid federatedAuthParams (missing required fields)
sql_alchemy_dict = {
"url": "trino://user@localhost:8080/catalog",
"params": {},
"federatedAuthParams": {
"invalidField": "value",
},
}

original_dict = json.loads(json.dumps(sql_alchemy_dict))

# Call the function
_handle_federated_auth_params(sql_alchemy_dict)

# Verify an exception was logged
mock_logger.exception.assert_called_once()
call_args = mock_logger.exception.call_args
self.assertIn("Invalid federated auth params", call_args[0][0])

self.assertEqual(sql_alchemy_dict, original_dict)

@mock.patch("deepnote_toolkit.sql.sql_execution.logger")
@mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials")
def test_federated_auth_params_unsupported_integration_type(
self, mock_get_credentials, mock_logger
):
"""Test that unsupported integration type logs an error."""
from deepnote_toolkit.sql.sql_execution import (
FederatedAuthResponseData,
_handle_federated_auth_params,
)

# Setup mock to return unknown integration type
mock_get_credentials.return_value = FederatedAuthResponseData(
integrationType="unknown-database",
accessToken="test-token",
)

# Create a sql_alchemy_dict with federatedAuthParams
sql_alchemy_dict = {
"url": "unknown://host/db",
"params": {},
"federatedAuthParams": {
"integrationId": "test-integration-id",
"authContextToken": "test-auth-context-token",
},
}

original_dict = json.loads(json.dumps(sql_alchemy_dict))

# Call the function
_handle_federated_auth_params(sql_alchemy_dict)

# Verify an error was logged for unsupported integration type
mock_logger.error.assert_called_once_with(
"Unsupported integration type: %s, try updating toolkit version",
"unknown-database",
)

self.assertEqual(sql_alchemy_dict, original_dict)
Comment thread
coderabbitai[bot] marked this conversation as resolved.