Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 82 additions & 9 deletions deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import base64
import contextlib
import json
import logging
import re
import uuid
import warnings
from typing import Any
from urllib.parse import quote

import google.oauth2.credentials
Expand All @@ -14,6 +16,7 @@
from google.api_core.client_info import ClientInfo
from google.cloud import bigquery
from packaging.version import parse as parse_version
from pydantic import BaseModel, ValidationError
from sqlalchemy.engine import URL, create_engine, make_url
from sqlalchemy.exc import ResourceClosedError

Expand All @@ -33,6 +36,18 @@
from deepnote_toolkit.sql.sql_utils import is_single_select_query
from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url

logger = logging.getLogger(__name__)


class IntegrationFederatedAuthParams(BaseModel):
integrationId: str
userPodAuthContextToken: str


class FederatedAuthResponseData(BaseModel):
integrationType: str
accessToken: str

Comment thread
tkislan marked this conversation as resolved.

def compile_sql_query(
skip_jinja_template_render,
Expand Down Expand Up @@ -247,6 +262,71 @@ def _generate_temporary_credentials(integration_id):
return quote(data["username"]), quote(data["password"])


def _get_federated_auth_credentials(integration_id: str, user_pod_auth_context_token: str) -> FederatedAuthResponseData:
url = get_absolute_userpod_api_url(
f"integrations/federated-auth-token/{integration_id}"
)

# Add project credentials in detached mode
headers = get_project_auth_headers()
headers["UserPodAuthContextToken"] = user_pod_auth_context_token

response = requests.post(url, timeout=10, headers=headers)

data = FederatedAuthResponseData.model_validate(response.json())

return data
Comment thread
tkislan marked this conversation as resolved.
Outdated

Comment thread
tkislan marked this conversation as resolved.

def _handle_iam_params(sql_alchemy_dict: dict[str, Any]) -> None:
"""Apply IAM credentials to the connection URL in-place."""

if "iamParams" not in sql_alchemy_dict:
return

integration_id = sql_alchemy_dict["iamParams"]["integrationId"]

temporary_username, temporary_password = _generate_temporary_credentials(
integration_id
)

sql_alchemy_dict["url"] = replace_user_pass_in_pg_url(
sql_alchemy_dict["url"], temporary_username, temporary_password
)

Comment thread
coderabbitai[bot] marked this conversation as resolved.

def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None:
"""Fetch and apply federated auth credentials to connection params in-place."""

if "federatedAuthParams" not in sql_alchemy_dict:
return

try:
federated_auth_params = IntegrationFederatedAuthParams.model_validate(
sql_alchemy_dict["federatedAuthParams"]
)
except ValidationError as e:
logger.error(
"Invalid federated auth params, try updating toolkit version:", exc_info=e
)
return

federated_auth = _get_federated_auth_credentials(
federated_auth_params.integrationId, federated_auth_params.userPodAuthContextToken
)

if federated_auth.integrationType == "trino":
sql_alchemy_dict["params"]["connect_args"]["http_headers"][
"Authorization"
] = f"Bearer {federated_auth.access_token}"
elif federated_auth.integrationType == "big-query":
sql_alchemy_dict["params"]["access_token"] = federated_auth.access_token
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
else:
logger.error(
"Unsupported integration type: %s, try updating toolkit version", federated_auth.integrationType
)
Comment thread
tkislan marked this conversation as resolved.

Comment thread
coderabbitai[bot] marked this conversation as resolved.

@contextlib.contextmanager
def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict):
server = None
Expand Down Expand Up @@ -346,16 +426,9 @@ def _query_data_source(
):
sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False)

if "iamParams" in sql_alchemy_dict:
integration_id = sql_alchemy_dict["iamParams"]["integrationId"]

temporaryUsername, temporaryPassword = _generate_temporary_credentials(
integration_id
)
_handle_iam_params(sql_alchemy_dict)

sql_alchemy_dict["url"] = replace_user_pass_in_pg_url(
sql_alchemy_dict["url"], temporaryUsername, temporaryPassword
)
_handle_federated_auth_params(sql_alchemy_dict)

with _create_sql_ssh_uri(sshEnabled, sql_alchemy_dict) as url:
if url is None:
Expand Down
Loading