forked from google/adk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathauth_handler.py
More file actions
206 lines (174 loc) · 6.87 KB
/
auth_handler.py
File metadata and controls
206 lines (174 loc) · 6.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi.openapi.models import SecurityBase
from .auth_credential import AuthCredential
from .auth_schemes import AuthSchemeType
from .auth_schemes import OpenIdConnectWithConfig
from .auth_tool import AuthConfig
from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger
if TYPE_CHECKING:
from ..sessions.state import State
try:
from authlib.integrations.requests_client import OAuth2Session
AUTHLIB_AVAILABLE = True
except ImportError:
AUTHLIB_AVAILABLE = False
class AuthHandler:
"""A handler that handles the auth flow in Agent Development Kit to help
orchestrate the credential request and response flow (e.g. OAuth flow)
This class should only be used by Agent Development Kit.
"""
def __init__(self, auth_config: AuthConfig):
self.auth_config = auth_config
async def exchange_auth_token(
self,
) -> AuthCredential:
exchanger = OAuth2CredentialExchanger()
exchanged_credential, _ = await exchanger.exchange(
self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme
)
return exchanged_credential
async def parse_and_store_auth_response(self, state: State) -> None:
credential_key = "temp:" + self.auth_config.credential_key
state[credential_key] = self.auth_config.exchanged_auth_credential
if not isinstance(
self.auth_config.auth_scheme, SecurityBase
) or self.auth_config.auth_scheme.type_ not in (
AuthSchemeType.oauth2,
AuthSchemeType.openIdConnect,
):
return
state[credential_key] = await self.exchange_auth_token()
def _validate(self) -> None:
if not self.auth_scheme:
raise ValueError("auth_scheme is empty.")
def get_auth_response(self, state: State) -> AuthCredential:
credential_key = "temp:" + self.auth_config.credential_key
return state.get(credential_key, None)
def generate_auth_request(self) -> AuthConfig:
if not isinstance(
self.auth_config.auth_scheme, SecurityBase
) or self.auth_config.auth_scheme.type_ not in (
AuthSchemeType.oauth2,
AuthSchemeType.openIdConnect,
):
return self.auth_config.model_copy(deep=True)
# auth_uri already in exchanged credential
if (
self.auth_config.exchanged_auth_credential
and self.auth_config.exchanged_auth_credential.oauth2
and self.auth_config.exchanged_auth_credential.oauth2.auth_uri
):
return self.auth_config.model_copy(deep=True)
# Check if raw_auth_credential exists
if not self.auth_config.raw_auth_credential:
raise ValueError(
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires"
" auth_credential."
)
# Check if oauth2 exists in raw_auth_credential
if not self.auth_config.raw_auth_credential.oauth2:
raise ValueError(
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires oauth2 in"
" auth_credential."
)
# auth_uri in raw credential
if self.auth_config.raw_auth_credential.oauth2.auth_uri:
return AuthConfig(
auth_scheme=self.auth_config.auth_scheme,
raw_auth_credential=self.auth_config.raw_auth_credential,
exchanged_auth_credential=self.auth_config.raw_auth_credential.model_copy(
deep=True
),
)
# Check for client_id and client_secret
if (
not self.auth_config.raw_auth_credential.oauth2.client_id
or not self.auth_config.raw_auth_credential.oauth2.client_secret
):
raise ValueError(
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires both"
" client_id and client_secret in auth_credential.oauth2."
)
# Generate new auth URI
exchanged_credential = self.generate_auth_uri()
return AuthConfig(
auth_scheme=self.auth_config.auth_scheme,
raw_auth_credential=self.auth_config.raw_auth_credential,
exchanged_auth_credential=exchanged_credential,
)
def generate_auth_uri(
self,
) -> AuthCredential:
"""Generates a response containing the auth uri for user to sign in.
Returns:
An AuthCredential object containing the auth URI and state.
Raises:
ValueError: If the authorization endpoint is not configured in the auth
scheme.
"""
if not AUTHLIB_AVAILABLE:
return (
self.auth_config.raw_auth_credential.model_copy(deep=True)
if self.auth_config.raw_auth_credential
else None
)
auth_scheme = self.auth_config.auth_scheme
auth_credential = self.auth_config.raw_auth_credential
if isinstance(auth_scheme, OpenIdConnectWithConfig):
authorization_endpoint = auth_scheme.authorization_endpoint
scopes = auth_scheme.scopes
else:
authorization_endpoint = (
auth_scheme.flows.implicit
and auth_scheme.flows.implicit.authorizationUrl
or auth_scheme.flows.authorizationCode
and auth_scheme.flows.authorizationCode.authorizationUrl
or auth_scheme.flows.clientCredentials
and auth_scheme.flows.clientCredentials.tokenUrl
or auth_scheme.flows.password
and auth_scheme.flows.password.tokenUrl
)
scopes = (
auth_scheme.flows.implicit
and auth_scheme.flows.implicit.scopes
or auth_scheme.flows.authorizationCode
and auth_scheme.flows.authorizationCode.scopes
or auth_scheme.flows.clientCredentials
and auth_scheme.flows.clientCredentials.scopes
or auth_scheme.flows.password
and auth_scheme.flows.password.scopes
)
scopes = list(scopes.keys())
client = OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
)
params = {
"access_type": "offline",
"prompt": "consent",
}
if auth_credential.oauth2.audience:
params["audience"] = auth_credential.oauth2.audience
uri, state = client.create_authorization_url(
url=authorization_endpoint, **params
)
exchanged_auth_credential = auth_credential.model_copy(deep=True)
exchanged_auth_credential.oauth2.auth_uri = uri
exchanged_auth_credential.oauth2.state = state
return exchanged_auth_credential