Skip to content

Commit d268c0f

Browse files
committed
Adjust progress bars
1 parent 8b40c6e commit d268c0f

1 file changed

Lines changed: 93 additions & 51 deletions

File tree

cloudnet_api_client/dl.py

Lines changed: 93 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,50 @@
11
import asyncio
22
import logging
33
from collections.abc import Iterable
4+
from dataclasses import dataclass
45
from pathlib import Path
56

67
import aiohttp
78
from tqdm import tqdm
8-
from tqdm.asyncio import tqdm_asyncio
99

1010
from cloudnet_api_client import utils
1111
from cloudnet_api_client.containers import Metadata, ProductMetadata
1212

1313

14+
class BarConfig:
15+
def __init__(self, quiet: bool | None, max_workers: int, total_bytes: int) -> None:
16+
self.quiet = quiet
17+
self.position_queue = self._init_position_queue(max_workers)
18+
self.total_amount = tqdm(
19+
total=total_bytes,
20+
desc="Total amount",
21+
unit="iB",
22+
unit_scale=True,
23+
unit_divisor=1024,
24+
disable=self.quiet,
25+
position=0,
26+
leave=False,
27+
colour="green",
28+
)
29+
self.lock = asyncio.Lock()
30+
31+
def _init_position_queue(self, max_workers: int) -> asyncio.Queue:
32+
queue: asyncio.Queue = asyncio.Queue()
33+
for i in range(1, max_workers + 1):
34+
queue.put_nowait(i)
35+
return queue
36+
37+
38+
@dataclass
39+
class DlParams:
40+
url: str
41+
destination: Path
42+
session: aiohttp.ClientSession
43+
semaphore: asyncio.Semaphore
44+
bar_config: BarConfig
45+
quiet: bool | None
46+
47+
1448
async def download_files(
1549
base_url: str,
1650
metadata: Iterable[Metadata],
@@ -19,81 +53,89 @@ async def download_files(
1953
disable_progress: bool | None,
2054
validate_checksum: bool = False,
2155
) -> list[Path]:
56+
metas = list(metadata)
2257
file_exists = _checksum_matches if validate_checksum else _size_and_name_matches
2358
semaphore = asyncio.Semaphore(concurrency_limit)
59+
total_bytes = sum(meta.size for meta in metas)
60+
bar_config = BarConfig(disable_progress, concurrency_limit, total_bytes)
2461
full_paths = []
2562
async with aiohttp.ClientSession() as session:
2663
tasks = []
27-
for meta in metadata:
64+
for meta in metas:
2865
download_url = f"{base_url}{meta.download_url.split('/api/')[-1]}"
2966
destination = output_path / meta.download_url.split("/")[-1]
3067
full_paths.append(destination)
3168
if destination.exists() and file_exists(meta, destination):
3269
logging.debug(f"Already downloaded: {destination}")
3370
continue
34-
task = asyncio.create_task(
35-
_download_file_with_retries(
36-
session, download_url, destination, semaphore, disable_progress
37-
)
71+
dl_params = DlParams(
72+
url=download_url,
73+
destination=destination,
74+
session=session,
75+
semaphore=semaphore,
76+
bar_config=bar_config,
77+
quiet=disable_progress,
3878
)
79+
task = asyncio.create_task(_download_file_with_retries(dl_params))
3980
tasks.append(task)
40-
await tqdm_asyncio.gather(
41-
*tasks, desc="Completed files", disable=disable_progress
42-
)
81+
await asyncio.gather(*tasks)
82+
bar_config.total_amount.close()
83+
bar_config.total_amount.clear()
4384
return full_paths
4485

4586

4687
async def _download_file_with_retries(
47-
session: aiohttp.ClientSession,
48-
url: str,
49-
destination: Path,
50-
semaphore: asyncio.Semaphore,
51-
disable_progress: bool | None,
88+
params: DlParams,
5289
max_retries: int = 3,
5390
) -> None:
5491
"""Attempt to download a file, retrying up to max_retries times if needed."""
55-
for attempt in range(1, max_retries + 1):
56-
try:
57-
await _download_file(session, url, destination, semaphore, disable_progress)
58-
return
59-
except aiohttp.ClientError as e:
60-
logging.warning(f"Attempt {attempt} failed for {url}: {e}")
61-
if attempt == max_retries:
62-
logging.error(f"Giving up on {url} after {max_retries} attempts.")
63-
raise e
64-
else:
65-
# Exponential backoff before retrying
66-
await asyncio.sleep(2**attempt)
92+
position = await params.bar_config.position_queue.get()
93+
try:
94+
for attempt in range(1, max_retries + 1):
95+
try:
96+
await _download_file(params, position)
97+
return
98+
except aiohttp.ClientError as e:
99+
logging.warning(f"Attempt {attempt} failed for {params.url}: {e}")
100+
if attempt == max_retries:
101+
logging.error(
102+
f"Giving up on {params.url} after {max_retries} attempts."
103+
)
104+
raise e
105+
else:
106+
# Exponential backoff before retrying
107+
await asyncio.sleep(2**attempt)
108+
finally:
109+
params.bar_config.position_queue.put_nowait(position)
110+
raise RuntimeError("Unreachable code reached.")
67111

68112

69113
async def _download_file(
70-
session: aiohttp.ClientSession,
71-
url: str,
72-
destination: Path,
73-
semaphore: asyncio.Semaphore,
74-
disable_progress: bool | None,
114+
params: DlParams,
115+
position: int,
75116
) -> None:
76-
async with semaphore:
77-
async with session.get(url) as response:
78-
response.raise_for_status()
79-
with (
80-
destination.open("wb") as file_out,
81-
tqdm(
82-
desc=destination.name,
83-
total=response.content_length,
84-
unit="iB",
85-
unit_scale=True,
86-
unit_divisor=1024,
87-
disable=disable_progress,
88-
) as bar,
89-
):
90-
while True:
91-
chunk = await response.content.read(8192)
92-
if not chunk:
93-
break
94-
file_out.write(chunk)
117+
async with params.semaphore, params.session.get(params.url) as response:
118+
response.raise_for_status()
119+
bar = tqdm(
120+
desc=params.destination.name,
121+
total=response.content_length,
122+
unit="iB",
123+
unit_scale=True,
124+
unit_divisor=1024,
125+
disable=params.bar_config.quiet,
126+
position=position,
127+
leave=False,
128+
colour="cyan",
129+
)
130+
try:
131+
with params.destination.open("wb") as f:
132+
while chunk := await response.content.read(8192):
133+
f.write(chunk)
95134
bar.update(len(chunk))
96-
logging.debug(f"Downloaded: {destination}")
135+
params.bar_config.total_amount.update(len(chunk))
136+
finally:
137+
bar.close()
138+
bar.clear()
97139

98140

99141
def _checksum_matches(meta: Metadata, destination: Path) -> bool:

0 commit comments

Comments
 (0)