-
Notifications
You must be signed in to change notification settings - Fork 518
Expand file tree
/
Copy path__init__.py
More file actions
75 lines (64 loc) · 2.31 KB
/
__init__.py
File metadata and controls
75 lines (64 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import json
import os
import subprocess
import yaml
class MultiFilter:
def __init__(self, rules, default=False):
self.rules = rules
self.default = default
def __call__(self, x):
try:
x_json = x["json"]
if isinstance(x_json, bytes):
x_json = json.loads(x_json)
validations = []
for k, r in self.rules.items():
v = (
r(*[x_json[kv] for kv in k])
if isinstance(k, tuple)
else r(x_json[k])
)
validations.append(v)
return all(validations)
except Exception:
return False
class MultiGetter:
def __init__(self, rules):
self.rules = rules
def __call__(self, x_json):
if isinstance(x_json, bytes):
x_json = json.loads(x_json)
outputs = []
for k, r in self.rules.items():
v = r(*[x_json[kv] for kv in k]) if isinstance(k, tuple) else r(x_json[k])
outputs.append(v)
if len(outputs) == 1:
outputs = outputs[0]
return outputs
def setup_webdataset_path(paths, cache_path=None):
if cache_path is None or not os.path.exists(cache_path):
tar_paths = []
if isinstance(paths, str):
paths = [paths]
for path in paths:
if path.strip().endswith(".tar"):
# Avoid looking up s3 if we already have a tar file
tar_paths.append(path)
continue
bucket = "/".join(path.split("/")[:3])
result = subprocess.run(
[f"aws s3 ls {path} --recursive | awk '{{print $4}}'"],
stdout=subprocess.PIPE,
shell=True,
check=True,
)
files = result.stdout.decode("utf-8").split()
files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
tar_paths += files
with open(cache_path, "w", encoding="utf-8") as outfile:
yaml.dump(tar_paths, outfile, default_flow_style=False)
else:
with open(cache_path, "r", encoding="utf-8") as file:
tar_paths = yaml.safe_load(file)
tar_paths_str = ",".join([f"{p}" for p in tar_paths])
return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"