-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathutils.py
More file actions
249 lines (202 loc) · 8.06 KB
/
utils.py
File metadata and controls
249 lines (202 loc) · 8.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import
import os
import logging
import shutil
import subprocess
import json
import re
import errno
from sagemaker.core import s3
from six.moves.urllib.parse import urlparse
logger = logging.getLogger(__name__)
STUDIO_APP_TYPES = ["KernelGateway", "CodeEditor", "JupyterLab"]
def copy_directory_structure(destination_directory, relative_path):
"""Creates intermediate directory structure for relative_path.
Create all the intermediate directories required for relative_path to
exist within destination_directory. This assumes that relative_path is a
directory located within root_dir.
Examples:
destination_directory: /tmp/destination relative_path: test/unit/
will create: /tmp/destination/test/unit
Args:
destination_directory (str): root of the destination directory where the
directory structure will be created.
relative_path (str): relative path that will be created within
destination_directory
"""
full_path = os.path.join(destination_directory, relative_path)
if os.path.exists(full_path):
return
os.makedirs(destination_directory, relative_path)
def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""):
"""Move source to destination.
Can handle uploading to S3.
Args:
source (str): root directory to move
destination (str): file:// or s3:// URI that source will be moved to.
job_name (str): SageMaker job name.
sagemaker_session (sagemaker.Session): a sagemaker_session to interact
with S3 if needed
prefix (str, optional): the directory on S3 used to save files, default
to the root of ``destination``
Returns:
(str): destination URI
"""
parsed_uri = urlparse(destination)
if parsed_uri.scheme == "file":
dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
recursive_copy(source, dir_path)
final_uri = destination
elif parsed_uri.scheme == "s3":
bucket = parsed_uri.netloc
path = s3.s3_path_join(parsed_uri.path, job_name, prefix)
final_uri = s3.s3_path_join("s3://", bucket, path)
sagemaker_session.upload_data(source, bucket, path)
else:
raise ValueError("Invalid destination URI, must be s3:// or file://, got: %s" % destination)
try:
shutil.rmtree(source)
except OSError as exc:
# on Linux, when docker writes to any mounted volume, it uses the container's user. In most
# cases this is root. When the container exits and we try to delete them we can't because
# root owns those files. We expect this to happen, so we handle EACCESS. Any other error
# we will raise the exception up.
if exc.errno == errno.EACCES:
logger.warning("Failed to delete: %s Please remove it manually.", source)
else:
logger.error("Failed to delete: %s", source)
raise
return final_uri
def recursive_copy(source, destination):
"""A wrapper around shutil.copy_tree.
This won't throw any exception when the source directory does not exist.
Args:
source (str): source path
destination (str): destination path
"""
if os.path.isdir(source):
shutil.copytree(source, destination, dirs_exist_ok=True)
def kill_child_processes(pid):
"""Kill child processes
Kills all nested child process ids for a specific pid
Args:
pid (int): process id
"""
child_pids = get_child_process_ids(pid)
for child_pid in child_pids:
os.kill(child_pid, 15)
def get_child_process_ids(pid):
"""Retrieve all child pids for a certain pid
Recursively scan each childs process tree and add it to the output
Args:
pid (int): process id
Returns:
(List[int]): Child process ids
"""
if not str(pid).isdigit():
raise ValueError("Invalid PID")
cmd = ["pgrep", "-P", str(pid)]
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = process.communicate()
if err:
return []
pids = [int(pid) for pid in output.decode("utf-8").split()]
if pids:
for child_pid in pids:
return pids + get_child_process_ids(child_pid)
else:
return []
def get_docker_host():
"""Discover remote docker host address (if applicable) or use "localhost"
Use "docker context inspect" to read current docker host endpoint url,
url must start with "tcp://"
Args:
Returns:
docker_host (str): Docker host DNS or IP address
"""
cmd = "docker context inspect".split()
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = process.communicate()
if err:
return "localhost"
docker_context_string = output.decode("utf-8")
docker_context_host_url = json.loads(docker_context_string)[0]["Endpoints"]["docker"]["Host"]
parsed_url = urlparse(docker_context_host_url)
if parsed_url.hostname and parsed_url.scheme == "tcp":
return parsed_url.hostname
return "localhost"
def get_using_dot_notation(dictionary, keys):
"""Extract `keys` from dictionary where keys is a string in dot notation.
Args:
dictionary (Dict)
keys (str)
Returns:
Nested object within dictionary as defined by "keys"
Raises:
ValueError if the provided key does not exist in input dictionary
"""
try:
if keys is None:
return dictionary
split_keys = keys.split(".", 1)
key = split_keys[0]
rest = None
if len(split_keys) > 1:
rest = split_keys[1]
bracket_accessors = re.findall(r"\[(.+?)]", key)
if bracket_accessors:
pre_bracket_key = key.split("[", 1)[0]
inner_dict = dictionary[pre_bracket_key]
else:
inner_dict = dictionary[key]
for bracket_accessor in bracket_accessors:
if (
bracket_accessor.startswith("'")
and bracket_accessor.endswith("'")
or bracket_accessor.startswith('"')
and bracket_accessor.endswith('"')
):
# key accessor
inner_key = bracket_accessor[1:-1]
else:
# list accessor
inner_key = int(bracket_accessor)
inner_dict = inner_dict[inner_key]
return get_using_dot_notation(inner_dict, rest)
except (KeyError, IndexError, TypeError):
raise ValueError(f"{keys} does not exist in input dictionary.")
def check_for_studio():
"""Helper function to determine if the run environment is studio.
Returns (bool): Returns True if valid Studio request.
Raises:
NotImplementedError:
if run environment = Studio and AppType not in STUDIO_APP_TYPES
"""
is_studio = False
if os.path.exists("/opt/ml/metadata/resource-metadata.json"):
with open("/opt/ml/metadata/resource-metadata.json", "r") as handle:
metadata = json.load(handle)
app_type = metadata.get("AppType")
if app_type:
# check if the execution is triggered from Studio KernelGateway App
if app_type in STUDIO_APP_TYPES:
is_studio = True
else:
raise NotImplementedError(
f"AppType {app_type} in Studio does not support Local Mode."
)
# if no apptype, case of classic notebooks
return is_studio