-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdl.py
More file actions
169 lines (152 loc) · 5.5 KB
/
dl.py
File metadata and controls
169 lines (152 loc) · 5.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import asyncio
import logging
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
import aiohttp
from tqdm import tqdm
from cloudnet_api_client import utils
from cloudnet_api_client.containers import (
Metadata,
ProductMetadata,
)
class BarConfig:
def __init__(
self, disable: bool | None, max_workers: int, total_bytes: int
) -> None:
self.disable = disable
self.position_queue = self._init_position_queue(max_workers)
self.total_amount = tqdm(
total=total_bytes,
desc="Progress",
unit="iB",
unit_scale=True,
unit_divisor=1024,
disable=self.disable,
position=0,
leave=False,
colour="green",
)
self.lock = asyncio.Lock()
def _init_position_queue(self, max_workers: int) -> asyncio.Queue:
queue: asyncio.Queue = asyncio.Queue()
for i in range(1, max_workers + 1):
queue.put_nowait(i)
return queue
@dataclass
class DlParams:
url: str
destination: Path
session: aiohttp.ClientSession
semaphore: asyncio.Semaphore
bar_config: BarConfig
disable: bool | None
async def download_files(
base_url: str,
metadata: Iterable[Metadata] | Metadata,
output_path: Path,
concurrency_limit: int,
disable_progress: bool | None,
validate_checksum: bool = False,
) -> list[Path]:
metas = list(metadata) if isinstance(metadata, Iterable) else [metadata]
file_exists = _checksum_matches if validate_checksum else _size_and_name_matches
semaphore = asyncio.Semaphore(concurrency_limit)
total_bytes = sum(meta.size for meta in metas)
bar_config = BarConfig(disable_progress, concurrency_limit, total_bytes)
full_paths = []
async with aiohttp.ClientSession() as session:
tasks = []
for meta in metas:
download_url = f"{base_url}{meta.download_url.split('/api/')[-1]}"
destination = output_path / meta.download_url.split("/")[-1]
full_paths.append(destination)
if destination.exists() and file_exists(meta, destination):
logging.debug(f"Already downloaded: {destination}")
continue
dl_params = DlParams(
url=download_url,
destination=destination,
session=session,
semaphore=semaphore,
bar_config=bar_config,
disable=disable_progress,
)
task = asyncio.create_task(_download_file_with_retries(dl_params))
tasks.append(task)
if disable_progress is True:
print(f"Downloading {len(metas)} files...", end="", flush=True)
await asyncio.gather(*tasks)
bar_config.total_amount.close()
bar_config.total_amount.clear()
if disable_progress is True:
print(" done.", flush=True)
return full_paths
async def _download_file_with_retries(
params: DlParams,
max_retries: int = 3,
) -> None:
"""Attempt to download a file, retrying up to max_retries times if needed."""
position = await params.bar_config.position_queue.get()
try:
for attempt in range(1, max_retries + 1):
try:
await _download_file(params, position)
return
except aiohttp.ClientError as e:
logging.warning(f"Attempt {attempt} failed for {params.url}: {e}")
if attempt == max_retries:
logging.error(
f"Giving up on {params.url} after {max_retries} attempts."
)
raise e
else:
# Exponential backoff before retrying
await asyncio.sleep(2**attempt)
finally:
params.bar_config.position_queue.put_nowait(position)
raise RuntimeError("Unreachable code reached.")
async def _download_file(
params: DlParams,
position: int,
) -> None:
tmp_path = params.destination.with_suffix(f"{params.destination.suffix}.part")
async with params.semaphore, params.session.get(params.url) as response:
response.raise_for_status()
bar = tqdm(
desc=params.destination.name,
total=response.content_length,
unit="iB",
unit_scale=True,
unit_divisor=1024,
disable=params.bar_config.disable,
position=position,
leave=False,
colour="cyan",
)
try:
tmp_path.parent.mkdir(parents=True, exist_ok=True)
with tmp_path.open("wb") as f:
async for chunk in response.content.iter_chunked(8192):
f.write(chunk)
bar.update(len(chunk))
params.bar_config.total_amount.update(len(chunk))
tmp_path.replace(params.destination)
except Exception:
try:
if tmp_path.exists():
tmp_path.unlink()
except OSError:
pass
raise
finally:
bar.close()
bar.clear()
def _checksum_matches(meta: Metadata, destination: Path) -> bool:
fun = utils.sha256sum if isinstance(meta, ProductMetadata) else utils.md5sum
return fun(destination) == meta.checksum
def _size_and_name_matches(meta: Metadata, destination: Path) -> bool:
return (
destination.stat().st_size == meta.size
and destination.name == meta.download_url.split("/")[-1]
)