Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/google/adk/auth/auth_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional

from pydantic import alias_generators
Expand Down Expand Up @@ -80,6 +81,14 @@ class OAuth2Auth(BaseModelWithConfig):
expires_at: Optional[int] = None
expires_in: Optional[int] = None
audience: Optional[str] = None
token_endpoint_auth_method: Optional[
Literal[
"client_secret_basic",
"client_secret_post",
"client_secret_jwt",
"private_key_jwt",
]
] = "client_secret_basic"


class ServiceAccountCredential(BaseModelWithConfig):
Expand Down
1 change: 1 addition & 0 deletions src/google/adk/auth/oauth2_credential_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def create_oauth2_session(
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
state=auth_credential.oauth2.state,
token_endpoint_auth_method=auth_credential.oauth2.token_endpoint_auth_method,
),
token_endpoint,
)
Expand Down
92 changes: 92 additions & 0 deletions tests/unittests/auth/test_oauth2_credential_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,36 @@
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
from google.adk.auth.oauth2_credential_util import create_oauth2_session
from google.adk.auth.oauth2_credential_util import update_credential_with_tokens
import pytest


@pytest.fixture
def openid_connect_scheme():
"""Fixture providing a standard OpenIdConnectWithConfig scheme."""
return OpenIdConnectWithConfig(
type_="openIdConnect",
openId_connect_url="https://example.com/.well-known/openid_configuration",
authorization_endpoint="https://example.com/auth",
token_endpoint="https://example.com/token",
scopes=["openid", "profile"],
)


def create_oauth2_auth_credential(token_endpoint_auth_method=None):
"""Helper function to create OAuth2Auth credential with optional token_endpoint_auth_method."""
oauth2_auth = OAuth2Auth(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uri="https://example.com/callback",
state="test_state",
)
if token_endpoint_auth_method is not None:
oauth2_auth.token_endpoint_auth_method = token_endpoint_auth_method

return AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
oauth2=oauth2_auth,
)


class TestOAuth2CredentialUtil:
Expand Down Expand Up @@ -122,6 +152,68 @@ def test_create_oauth2_session_missing_credentials(self):
assert client is None
assert token_endpoint is None

def test_create_oauth2_session_with_token_endpoint_auth_method(
self, openid_connect_scheme
):
"""Test create_oauth2_session with token_endpoint_auth_method specified."""
credential = create_oauth2_auth_credential(
token_endpoint_auth_method="client_secret_post"
)

client, token_endpoint = create_oauth2_session(
openid_connect_scheme, credential
)

assert client is not None
assert token_endpoint == "https://example.com/token"
assert client.client_id == "test_client_id"
assert client.client_secret == "test_client_secret"
assert client.token_endpoint_auth_method == "client_secret_post"

def test_create_oauth2_session_with_default_token_endpoint_auth_method(
self, openid_connect_scheme
):
"""Test create_oauth2_session with default token_endpoint_auth_method."""
credential = create_oauth2_auth_credential()

client, token_endpoint = create_oauth2_session(
openid_connect_scheme, credential
)

assert client is not None
assert token_endpoint == "https://example.com/token"
assert client.client_id == "test_client_id"
assert client.client_secret == "test_client_secret"
assert client.token_endpoint_auth_method == "client_secret_basic"

def test_create_oauth2_session_oauth2_scheme_with_token_endpoint_auth_method(
self,
):
"""Test create_oauth2_session with OAuth2 scheme and token_endpoint_auth_method."""
flows = OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl="https://example.com/auth",
tokenUrl="https://example.com/token",
scopes={"read": "Read access", "write": "Write access"},
)
)
scheme = OAuth2(type_="oauth2", flows=flows)
credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uri="https://example.com/callback",
token_endpoint_auth_method="client_secret_jwt",
),
)

client, token_endpoint = create_oauth2_session(scheme, credential)

assert client is not None
assert token_endpoint == "https://example.com/token"
assert client.token_endpoint_auth_method == "client_secret_jwt"

def test_update_credential_with_tokens(self):
"""Test update_credential_with_tokens function."""
credential = AuthCredential(
Expand Down
Loading