-
Notifications
You must be signed in to change notification settings - Fork 87
Expand file tree
/
Copy pathsaml_credentials_provider.py
More file actions
290 lines (258 loc) · 13.2 KB
/
saml_credentials_provider.py
File metadata and controls
290 lines (258 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import base64
import logging
import random
import re
import typing
from abc import abstractmethod
from redshift_connector.credentials_holder import CredentialsHolder
from redshift_connector.error import InterfaceError
from redshift_connector.idp_auth_helper import IdpAuthHelper
from redshift_connector.plugin.credential_provider_constants import SAML_RESP_NAMESPACES
from redshift_connector.plugin.idp_credentials_provider import IdpCredentialsProvider
from redshift_connector.redshift_property import RedshiftProperty
_logger: logging.Logger = logging.getLogger(__name__)
class SamlCredentialsProvider(IdpCredentialsProvider):
"""
Generic Identity Provider Plugin providing single sign-on access to an Amazon Redshift cluster using an identity provider of your choice.
"""
def __init__(self: "SamlCredentialsProvider") -> None:
super().__init__()
self.user_name: typing.Optional[str] = None
self.password: typing.Optional[str] = None
self.idp_host: typing.Optional[str] = None
self.idpPort: int = 443
self.duration: typing.Optional[int] = None
self.preferred_role: typing.Optional[str] = None
self.ssl_insecure: typing.Optional[bool] = None
self.db_user: typing.Optional[str] = None
self.db_groups: typing.List[str] = list()
self.force_lowercase: typing.Optional[bool] = None
self.auto_create: typing.Optional[bool] = None
self.region: typing.Optional[str] = None
self.principal: typing.Optional[str] = None
self.group_federation: bool = False
self.cache: dict = {}
def add_parameter(self: "SamlCredentialsProvider", info: RedshiftProperty) -> None:
self.user_name = info.user_name
self.password = info.password
self.idp_host = info.idp_host
self.idpPort = info.idpPort
self.duration = info.duration
self.preferred_role = info.preferred_role
self.ssl_insecure = info.ssl_insecure
self.db_user = info.db_user
self.db_groups = info.db_groups
self.force_lowercase = info.force_lowercase
self.auto_create = info.auto_create
self.region = info.region
self.principal = info.principal
def set_group_federation(self: "SamlCredentialsProvider", group_federation: bool):
self.group_federation = group_federation
def get_sub_type(self) -> int:
return IdpAuthHelper.SAML_PLUGIN
def do_verify_ssl_cert(self: "SamlCredentialsProvider") -> bool:
return not self.ssl_insecure
def get_credentials(self: "SamlCredentialsProvider") -> CredentialsHolder:
_logger.debug("SamlCredentialsProvider.get_credentials")
key: str = self.get_cache_key()
if key not in self.cache or self.cache[key].is_expired():
try:
self.refresh()
_logger.debug("Successfully refreshed credentials")
except Exception as e:
_logger.debug("Refreshing IdP credentials failed")
raise InterfaceError(e)
# if the SAML response has db_user argument, it will be picked up at this point.
credentials: CredentialsHolder = self.cache[key]
if credentials is None:
exec_msg = "Unable to load AWS credentials from IdP"
_logger.debug(exec_msg)
raise InterfaceError(exec_msg)
# if db_user argument has been passed in the connection string, add it to metadata.
if self.db_user:
_logger.debug("adding db_user to metadata")
credentials.metadata.set_db_user(self.db_user)
return credentials
def refresh(self: "SamlCredentialsProvider") -> None:
_logger.debug("SamlCredentialsProvider.refresh")
import boto3 # type: ignore
import bs4 # type: ignore
try:
# get SAML assertion from specific identity provider
saml_assertion = self.get_saml_assertion()
_logger.debug("Successfully retrieved SAML assertion")
except Exception as e:
exec_msg = "Failed to get SAML assertion"
_logger.debug(exec_msg)
raise InterfaceError(exec_msg) from e
# decode SAML assertion into xml format
doc: bytes = base64.b64decode(saml_assertion)
_logger.debug("decoded SAML assertion into xml format")
soup = bs4.BeautifulSoup(doc, "xml")
attrs = soup.findAll("Attribute")
# extract RoleArn and PrincipleArn from SAML assertion
role_pattern = re.compile(r"arn:aws:iam::\d*:role/\S+")
provider_pattern = re.compile(r"arn:aws:iam::\d*:saml-provider/\S+")
roles: typing.Dict[str, str] = {}
_logger.debug("searching SAML assertion for values matching patterns for RoleArn and PrincipalArn")
for attr in attrs:
name: str = attr.attrs["Name"]
values: typing.Any = attr.findAll("AttributeValue")
if name == "https://aws.amazon.com/SAML/Attributes/Role":
_logger.debug("Attribute with name %s found. Checking if pattern match occurs", name)
for value in values:
arns = value.contents[0].split(",")
role: str = ""
provider: str = ""
for arn in arns:
arn = arn.strip() # remove trailing or leading whitespace
if role_pattern.match(arn):
_logger.debug("RoleArn pattern matched")
role = arn
if provider_pattern.match(arn):
_logger.debug("PrincipleArn pattern matched")
provider = arn
if role != "" and provider != "":
roles[role] = provider
_logger.debug("Done reading SAML assertion attributes")
_logger.debug("%s roles identified in SAML assertion", len(roles))
if len(roles) == 0:
exec_msg = "No roles were found in SAML assertion. Please verify IdP configuration provides ARNs in the SAML https://aws.amazon.com/SAML/Attributes/Role Attribute."
_logger.debug(exec_msg)
raise InterfaceError(exec_msg)
role_arn: str = ""
principle: str = ""
if self.preferred_role:
_logger.debug("User provided preferred_role, trying to use...")
role_arn = self.preferred_role
if role_arn not in roles:
exec_msg = "User specified preferred_role was not found in SAML assertion https://aws.amazon.com/SAML/Attributes/Role Attribute"
_logger.debug(exec_msg)
raise InterfaceError(exec_msg)
principle = roles[role_arn]
else:
_logger.debug(
"User did not specify a preferred_role. A randomly selected role from the SAML assertion https://aws.amazon.com/SAML/Attributes/Role Attribute will be used."
)
role_arn = random.choice(list(roles))
principle = roles[role_arn]
client = boto3.client("sts")
try:
_logger.debug(
"Attempting to retrieve temporary AWS credentials using the SAML assertion, principal ARN, and role ARN."
)
response = client.assume_role_with_saml(
RoleArn=role_arn, # self.preferred_role,
PrincipalArn=principle, # self.principal,
SAMLAssertion=saml_assertion,
)
_logger.debug("Extracting temporary AWS credentials from assume_role_with_saml response")
stscred: typing.Dict[str, typing.Any] = response["Credentials"]
credentials: CredentialsHolder = CredentialsHolder(stscred)
# get metadata from SAML assertion
credentials.set_metadata(self.read_metadata(doc))
key: str = self.get_cache_key()
self.cache[key] = credentials
except AttributeError as e:
_logger.debug("AttributeError: %s", e)
raise e
except KeyError as e:
_logger.debug("KeyError: %s", e)
raise e
except client.exceptions.MalformedPolicyDocumentException as e:
_logger.debug("MalformedPolicyDocumentException: %s", e)
raise e
except client.exceptions.PackedPolicyTooLargeException as e:
_logger.debug("PackedPolicyTooLargeException: %s", e)
raise e
except client.exceptions.IDPRejectedClaimException as e:
_logger.debug("IDPRejectedClaimException: %s", e)
raise e
except client.exceptions.InvalidIdentityTokenException as e:
_logger.debug("InvalidIdentityTokenException: %s", e)
raise e
except client.exceptions.ExpiredTokenException as e:
_logger.debug("ExpiredTokenException: %s", e)
raise e
except client.exceptions.RegionDisabledException as e:
_logger.debug("RegionDisabledException: %s", e)
raise e
except Exception as e:
_logger.debug("Other Exception: %s", e)
raise e
def get_cache_key(self: "SamlCredentialsProvider") -> str:
return "{username}{password}{idp_host}{idp_port}{duration}{preferred_role}".format(
username=self.user_name,
password=self.password,
idp_host=self.idp_host,
idp_port=self.idpPort,
duration=self.duration,
preferred_role=self.preferred_role,
)
@abstractmethod
def get_saml_assertion(self: "SamlCredentialsProvider"):
pass
def check_required_parameters(self: "SamlCredentialsProvider") -> None:
_logger.debug("SamlCredentialsProvider.check_required_parameters")
if self.user_name == "" or self.user_name is None:
SamlCredentialsProvider.handle_missing_required_property("user_name")
if self.password == "" or self.password is None:
SamlCredentialsProvider.handle_missing_required_property("password")
if self.idp_host == "" or self.idp_host is None:
SamlCredentialsProvider.handle_missing_required_property("idp_host")
def read_metadata(self: "SamlCredentialsProvider", doc: bytes) -> CredentialsHolder.IamMetadata:
_logger.debug("SamlCredentialsProvider.read_metadata")
import bs4 # type: ignore
try:
soup = bs4.BeautifulSoup(doc, "xml")
attrs: typing.Any = []
namespace_used_idx: int = 0
# prefer using Attributes in saml-compliant namespace
for idx, namespace in enumerate(SAML_RESP_NAMESPACES):
_logger.debug("Looking for attributes under %s namespace", namespace)
attrs = soup.find_all("{}Attribute".format(namespace))
if len(attrs) > 0:
_logger.debug("Attributes found under SAML response namespace %s", namespace)
namespace_used_idx = idx
break
metadata: CredentialsHolder.IamMetadata = CredentialsHolder.IamMetadata()
for attr in attrs:
name: str = attr.attrs["Name"]
_logger.debug("Searching SAML attribute %s for attribute values", name)
attribute_name: str = "{}AttributeValue".format(SAML_RESP_NAMESPACES[namespace_used_idx])
values: typing.Any = attr.findAll(attribute_name)
if len(values) == 0 or not values[0].contents:
_logger.debug("No SAML attribute %s found. Continuing to search", attribute_name)
# Ignore empty-valued attributes.
continue
value: str = values[0].contents[0]
if name == "https://redshift.amazon.com/SAML/Attributes/AllowDbUserOverride":
metadata.set_allow_db_user_override(value)
elif name == "https://redshift.amazon.com/SAML/Attributes/DbUser":
metadata.set_saml_db_user(value)
elif name == "https://aws.amazon.com/SAML/Attributes/RoleSessionName":
if metadata.get_saml_db_user() is None:
metadata.set_saml_db_user(value)
elif name == "https://redshift.amazon.com/SAML/Attributes/AutoCreate":
metadata.set_auto_create(value)
elif name == "https://redshift.amazon.com/SAML/Attributes/DbGroups":
metadata.set_db_groups([value.contents[0].lower() for value in values])
elif name == "https://redshift.amazon.com/SAML/Attributes/ForceLowercase":
metadata.set_force_lowercase(value)
return metadata
except AttributeError as e:
_logger.debug("AttributeError: %s", e)
raise e
except KeyError as e:
_logger.debug("KeyError: %s", e)
raise e
def get_form_action(self: "SamlCredentialsProvider", soup) -> typing.Optional[str]:
for inputtag in soup.find_all(re.compile("(FORM|form)")):
action: str = inputtag.get("action")
if action:
return action
return None
def is_text(self: "SamlCredentialsProvider", inputtag) -> bool:
return typing.cast(bool, "text" == inputtag.get("type"))
def is_password(self: "SamlCredentialsProvider", inputtag) -> bool:
return typing.cast(bool, "password" == inputtag.get("type"))