-
Notifications
You must be signed in to change notification settings - Fork 74
Expand file tree
/
Copy pathroles.py
More file actions
242 lines (197 loc) · 8.38 KB
/
roles.py
File metadata and controls
242 lines (197 loc) · 8.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import os
import random
import string
try:
from typing import List, NamedTuple, Optional, Sequence, TypedDict, TypeVar
except ImportError:
from typing import List, NamedTuple, Optional, Sequence, TypeVar
from typing_extensions import TypedDict
import boto3
from boto3 import Session, client
from botocore.client import BaseClient
from model_engine_server.core.loggers import logger_name, make_logger
logger = make_logger(logger_name())
__all__: Sequence[str] = (
"AwsCredentialsDict",
"AwsCredentials",
"assume_role",
"ArnData",
"parse_arn_string",
"session",
)
SessionT = TypeVar("SessionT", bound=Session)
class ArnData(NamedTuple):
"""An AWS ARN string, parsed into a structured object. Able to re-create ARN string."""
role: str
account: int
user: Optional[str]
is_assumed: bool
def as_arn_string(self) -> str:
if self.is_assumed:
kind = "sts"
source = "assumed-role"
else:
kind = "iam"
source = "role"
maybe_user = f"/{self.user}" if self.user is not None else ""
arn = f"arn:aws:{kind}::{self.account}:{source}/{self.role}{maybe_user}"
return arn
class AwsCredentialsDict(TypedDict):
"""Dictionary form of an :class:`AwsCredentials` instance.
Produced by that class's :func:`as_dict` method.
"""
aws_access_key_id: str
aws_secret_access_key: str
aws_session_token: str
class AwsCredentials(NamedTuple):
"""A complete set of authorized AWS credentials for a particular role.
Produced by the `assume_role` function.
"""
aws_access_key_id: str
aws_secret_access_key: str
aws_session_token: str
def client(self, client_type: str, region_name: str = "us-west-2") -> BaseClient:
"""Creates the specified Boto3 :param:`client_type` using the AWS credentials.
The :param:`client_type` parameter is any valid value for `boto3.client` (e.g. `"s3"`).
"""
return boto3.client(
client_type,
region_name=region_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
)
def as_dict(self) -> AwsCredentialsDict:
return dict(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
)
def assume_role(role_arn: str, role_session_name: Optional[str] = None) -> AwsCredentials:
"""Uses the currently active AWS profile to assume the role specified by :param:`role_arn`.
If :param:`role_session_name` is not specified, this function will create a unique identifier
by prefixing with "ml-infra-services"`, using the current active `USER` env var value, and a
random 10-character long identifier.
"""
if role_session_name is None:
random_10_letters = "".join(random.choices(string.ascii_letters, k=10))
username = os.environ.get("USER", "no_user")
role_session_name = f"ml-infra-services--{username}--{random_10_letters}"
sts_client = boto3.client("sts")
response = sts_client.assume_role(
RoleArn=role_arn,
RoleSessionName=role_session_name,
)
credentials = response["Credentials"]
return AwsCredentials(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
def session(role: Optional[str], session_type: SessionT = Session) -> SessionT:
"""Obtain an AWS session using an arbitrary caller-specified role.
:param:`session_type` defines the type of session to return. Most users will use
the default boto3 type. Some users required a special type (e.g aioboto3 session).
For on-prem deployments without AWS profiles, pass role=None or role=""
to use default credentials from environment variables (AWS_ACCESS_KEY_ID, etc).
"""
# Do not assume roles in CIRCLECI
if os.getenv("CIRCLECI"):
logger.warning(f"In circleci, not assuming role (ignoring: {role})")
role = None
# Use profile-based auth only if role is specified
# For on-prem with MinIO, role will be None or empty - use env var credentials
if role:
sesh: SessionT = session_type(profile_name=role)
else:
sesh: SessionT = session_type() # Uses default credential chain (env vars)
return sesh
def _session_aws_okta(
session_type: SessionT,
arn: str,
) -> Session:
current_arn: Optional[str] = boto3.client("sts").get_caller_identity().get("Arn")
if current_arn is None:
logger.error(
"Could not get current identity from STS to check. This is unexpected! "
"Is aws configuration setup correctly?"
)
creds = assume_role(arn)
else:
current_role = parse_arn_string(current_arn)
desired_role = parse_arn_string(arn)
if current_role.account == desired_role.account:
logger.warning(
f"Current user {current_role} is the same as desired {desired_role} -- "
f"**NOT** assuming desired role with STS as this will be an error! "
f"Using environment variables to create {session_type}"
)
try:
creds = AwsCredentials(
aws_session_token=os.environ["AWS_SESSION_TOKEN"],
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
)
except KeyError as err:
raise EnvironmentError(
"Cannot find all 3 environment variables required for AWS authentication. "
"Did you run aws-okta to get credentials?"
) from err
else:
creds = assume_role(arn)
sesh = session_type(**creds.as_dict())
return sesh
def get_current_user() -> str:
"""Uses AWS sts to obtain the profile name of the currently authenticated AWS account."""
arn = client("sts").get_caller_identity().get("Arn")
if arn is None:
raise ValueError("Failed to get identity from STS")
user = parse_arn_string(arn).user
if user is None:
raise ValueError(f"No user identified from STS! arn={arn}")
return user.split("@")[0]
def parse_arn_string(arn: str) -> ArnData:
"""Parses an AWS ARN string and converts it to structured data in the form of an `ArnData` class."""
bits: List[str] = arn.split("/")
if not 2 <= len(bits) <= 3:
raise ValueError(
f"Invalid format for AWS ARN string: {arn} -- "
f"Expecting either 2 or 3 parts seperated by '/'"
)
account_and_source: List[str] = bits[0].split("::")
if len(account_and_source) != 2:
raise ValueError(
f"Expecting ARN string to have 2 parts in the first '/' part, "
f"seperated by '::'. Instead found {account_and_source} from "
f"arn={arn}"
)
account_bits: List[str] = account_and_source[1].split(":")
if not 1 <= len(account_bits) <= 2:
raise ValueError(
f"Expecting ARN string to have 1 or 2 parts in the first '/' part "
f"of the second '::' part. Instead found {len(account_bits)}: "
f"{account_bits} for arn={arn}"
)
account_str: str = account_bits[0]
if len(account_bits) == 1:
first_bits: List[str] = account_and_source[0].split(":")
if len(first_bits) != 3:
raise ValueError(
f"Expecting to find 3 parts in the first part '/' and first part "
f"of '::'. Instead found {len(first_bits)}: {first_bits} in arn={arn}"
)
is_assumed: bool = first_bits[2] == "sts"
else:
is_assumed = account_bits[1] == "assumed-role"
try:
account: int = int(account_str)
except ValueError as err:
raise ValueError(
"ARN format invalid: expecting account ID to appear as 2nd to last "
"value seperated by ':' within the first value seperated by '/' and "
"second value seperated by '::' -- "
f"arn={arn} and expecting {account_str} to be account ID"
) from err
role: str = bits[1]
user: Optional[str] = None if len(bits) == 2 else bits[2]
return ArnData(role, account, user, is_assumed)