Skip to content

Commit 5671a27

Browse files
committed
fix progress of initial models download
Nowadays huggingface mostly uses xet protocol for downloads, but we only monkey-patched the pure http-based flow. Progress updates are driven by the patch, hence there were none for xet based downloads. In addition, increase granularity so we report downloaded bytes before reaching the first Mb.
1 parent 049cb1f commit 5671a27

7 files changed

Lines changed: 79 additions & 44 deletions

File tree

backend/api_types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,14 @@ class ModelsStatusResponse(BaseModel):
149149
class DownloadProgressResponse(BaseModel):
150150
status: str
151151
current_downloading_file: ModelFileType | None
152-
current_file_progress: int
153-
total_progress: int
152+
current_file_progress: float
153+
total_progress: float
154154
total_downloaded_bytes: int
155155
expected_total_bytes: int
156156
completed_files: set[ModelFileType]
157157
all_files: set[ModelFileType]
158158
error: str | None
159-
speed_mbps: int
159+
speed_bytes_per_sec: float
160160

161161

162162
class SuggestGapPromptResponse(BaseModel):

backend/handlers/download_handler.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def start_file(self, file_type: ModelFileType, target: str) -> None:
7272
file_type=file_type,
7373
target_path=target,
7474
downloaded_bytes=0,
75-
speed_mbps=0.0,
75+
speed_bytes_per_sec=0.0,
7676
)
7777

7878
@with_state_lock
@@ -87,15 +87,15 @@ def finish_download(self) -> None:
8787
self.state.downloading_session = None
8888

8989
@with_state_lock
90-
def update_file_progress(self, file_type: ModelFileType, downloaded: int, speed_mbps: float) -> None:
90+
def update_file_progress(self, file_type: ModelFileType, downloaded: int, speed_bytes_per_sec: float) -> None:
9191
session = self.state.downloading_session
9292
if session is None:
9393
return
9494
rf = session.current_running_file
9595
if rf is None or rf.file_type != file_type:
9696
return
9797
rf.downloaded_bytes = downloaded
98-
rf.speed_mbps = speed_mbps
98+
rf.speed_bytes_per_sec = speed_bytes_per_sec
9999

100100
@with_state_lock
101101
def fail_download(self, error: str) -> None:
@@ -106,12 +106,25 @@ def fail_download(self, error: str) -> None:
106106
self.state.downloading_session = None
107107

108108
def _make_progress_callback(self, file_type: ModelFileType) -> Callable[[int], None]:
109-
start_time = time.monotonic()
109+
last_sample_time = time.monotonic()
110+
last_sample_bytes = 0
111+
smoothed_speed = 0.0
110112

111113
def on_progress(downloaded: int) -> None:
112-
elapsed = time.monotonic() - start_time
113-
speed_mbps = (downloaded / elapsed / (1024 * 1024)) if elapsed > 0 else 0.0
114-
self.update_file_progress(file_type, downloaded, speed_mbps)
114+
nonlocal last_sample_time, last_sample_bytes, smoothed_speed
115+
now = time.monotonic()
116+
elapsed = now - last_sample_time
117+
if elapsed >= 1.0:
118+
instant_speed = (downloaded - last_sample_bytes) / elapsed
119+
# EWMA: weight new sample at 30%, keep 70% of previous.
120+
# On first sample (smoothed_speed == 0) use instant value.
121+
if smoothed_speed == 0.0:
122+
smoothed_speed = instant_speed
123+
else:
124+
smoothed_speed = 0.3 * instant_speed + 0.7 * smoothed_speed
125+
last_sample_time = now
126+
last_sample_bytes = downloaded
127+
self.update_file_progress(file_type, downloaded, smoothed_speed)
115128

116129
return on_progress
117130

@@ -130,15 +143,15 @@ def get_download_progress(self, session_id: str) -> DownloadProgressResponse:
130143
self.config.spec_for(ft).expected_size_bytes for ft in session.files_to_download
131144
)
132145

133-
current_file_progress = 0
146+
current_file_progress = 0.0
134147
if rf is not None:
135148
spec = self.config.spec_for(rf.file_type)
136149
if spec.expected_size_bytes > 0:
137-
current_file_progress = min(99, int(rf.downloaded_bytes / spec.expected_size_bytes * 100))
150+
current_file_progress = min(99.0, rf.downloaded_bytes / spec.expected_size_bytes * 100)
138151

139-
total_progress = 0
152+
total_progress = 0.0
140153
if expected_total_bytes > 0:
141-
total_progress = min(99, int(total_downloaded / expected_total_bytes * 100))
154+
total_progress = min(99.0, total_downloaded / expected_total_bytes * 100)
142155

143156
return DownloadProgressResponse(
144157
status="downloading",
@@ -149,7 +162,7 @@ def get_download_progress(self, session_id: str) -> DownloadProgressResponse:
149162
expected_total_bytes=expected_total_bytes,
150163
completed_files=set(session.completed_files),
151164
all_files=set(session.files_to_download),
152-
speed_mbps=int(rf.speed_mbps) if rf else 0,
165+
speed_bytes_per_sec=rf.speed_bytes_per_sec if rf else 0.0,
153166
error=None,
154167
)
155168

@@ -159,26 +172,26 @@ def get_download_progress(self, session_id: str) -> DownloadProgressResponse:
159172
return DownloadProgressResponse(
160173
status="complete",
161174
current_downloading_file=None,
162-
current_file_progress=100,
163-
total_progress=100,
175+
current_file_progress=100.0,
176+
total_progress=100.0,
164177
total_downloaded_bytes=0,
165178
expected_total_bytes=0,
166179
completed_files=set(),
167180
all_files=set(),
168-
speed_mbps=0,
181+
speed_bytes_per_sec=0.0,
169182
error=None,
170183
)
171184
else:
172185
return DownloadProgressResponse(
173186
status="error",
174187
current_downloading_file=None,
175-
current_file_progress=0,
176-
total_progress=0,
188+
current_file_progress=0.0,
189+
total_progress=0.0,
177190
total_downloaded_bytes=0,
178191
expected_total_bytes=0,
179192
completed_files=set(),
180193
all_files=set(),
181-
speed_mbps=0,
194+
speed_bytes_per_sec=0.0,
182195
error=result,
183196
)
184197

backend/services/model_downloader/hugging_face_downloader.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,15 @@ def update(self, n: float | int | None = 1) -> bool | None: # type: ignore[repo
4040

4141

4242
@contextlib.contextmanager
43-
def _patch_http_get_progress(callback: Callable[[int], None]) -> Iterator[None]:
43+
def _patch_download_progress(callback: Callable[[int], None]) -> Iterator[None]:
4444
"""Temporarily monkey-patch ``huggingface_hub.file_download.http_get``
45-
to inject a custom tqdm bar that forwards progress to *callback*.
45+
and ``xet_get`` to inject a custom tqdm bar that forwards progress to
46+
*callback*.
4647
4748
``hf_hub_download`` does not expose a ``tqdm_class`` parameter (unlike
48-
``snapshot_download``), but its internal ``http_get`` accepts a private
49-
``_tqdm_bar`` kwarg. We wrap ``http_get`` to inject our own bar when
50-
the caller hasn't already provided one.
49+
``snapshot_download``), but its internal ``http_get`` and ``xet_get``
50+
both accept a private ``_tqdm_bar`` kwarg. We wrap them to inject our
51+
own bar when the caller hasn't already provided one.
5152
5253
See ``test_http_get_accepts_tqdm_bar`` — if that test breaks after a
5354
huggingface_hub upgrade, this patch needs to be revisited.
@@ -60,8 +61,20 @@ def _wrapped_http_get(*args: Any, **kwargs: Any) -> None:
6061
kwargs["_tqdm_bar"] = tqdm_cls(disable=True)
6162
return original_http_get(*args, **kwargs)
6263

64+
xet_get_fn: Callable[..., Any] | None = getattr(file_download, "xet_get", None)
65+
66+
def _wrapped_xet_get(*args: Any, **kwargs: Any) -> None:
67+
if kwargs.get("_tqdm_bar") is None:
68+
kwargs["_tqdm_bar"] = tqdm_cls(disable=True)
69+
assert xet_get_fn is not None
70+
return xet_get_fn(*args, **kwargs)
71+
6372
with patch.object(file_download, "http_get", _wrapped_http_get):
64-
yield
73+
if xet_get_fn is not None:
74+
with patch.object(file_download, "xet_get", _wrapped_xet_get):
75+
yield
76+
else:
77+
yield
6578

6679

6780
class HuggingFaceDownloader:
@@ -74,7 +87,7 @@ def download_file(
7487
local_dir: str,
7588
on_progress: Callable[[int], None] | None = None,
7689
) -> Path:
77-
ctx = _patch_http_get_progress(on_progress) if on_progress is not None else contextlib.nullcontext()
90+
ctx = _patch_download_progress(on_progress) if on_progress is not None else contextlib.nullcontext()
7891
with ctx:
7992
path: str = hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
8093
return Path(path)
@@ -85,7 +98,7 @@ def download_snapshot(
8598
local_dir: str,
8699
on_progress: Callable[[int], None] | None = None,
87100
) -> Path:
88-
ctx = _patch_http_get_progress(on_progress) if on_progress is not None else contextlib.nullcontext()
101+
ctx = _patch_download_progress(on_progress) if on_progress is not None else contextlib.nullcontext()
89102
with ctx:
90103
path: str = snapshot_download(repo_id=repo_id, local_dir=local_dir)
91104
return Path(path)

backend/state/app_state_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class FileDownloadRunning:
4343
file_type: ModelFileType
4444
target_path: str
4545
downloaded_bytes: int
46-
speed_mbps: float
46+
speed_bytes_per_sec: float
4747

4848

4949
@dataclass

backend/tests/test_models.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_active(self, client, test_state):
105105
file_type="checkpoint",
106106
target_path="checkpoint",
107107
downloaded_bytes=5_000_000_000,
108-
speed_mbps=50,
108+
speed_bytes_per_sec=50_000_000.0,
109109
),
110110
files_to_download={"checkpoint"},
111111
completed_files=set(),
@@ -344,9 +344,10 @@ def test_failed_download_cleans_up_downloading_dir(self, test_state):
344344
class TestHuggingFaceInternals:
345345
"""Guard tests for huggingface_hub internals we rely on.
346346
347-
We monkey-patch ``file_download.http_get`` to inject a custom tqdm bar
348-
for progress tracking during ``hf_hub_download`` (which has no public
349-
``tqdm_class`` parameter, unlike ``snapshot_download``).
347+
We monkey-patch ``file_download.http_get`` and ``file_download.xet_get``
348+
to inject a custom tqdm bar for progress tracking during
349+
``hf_hub_download`` (which has no public ``tqdm_class`` parameter,
350+
unlike ``snapshot_download``).
350351
351352
If these tests break after a huggingface_hub upgrade, the internal API
352353
has changed. Find an alternative approach and raise to a developer.
@@ -363,3 +364,12 @@ def test_http_get_accepts_tqdm_bar(self):
363364
assert "_tqdm_bar" in sig.parameters, (
364365
"file_download.http_get no longer accepts _tqdm_bar — progress patch for hf_hub_download is broken"
365366
)
367+
368+
def test_xet_get_accepts_tqdm_bar(self):
369+
xet_get = getattr(file_download, "xet_get", None)
370+
if xet_get is None:
371+
return # xet_get not present in this version; patch skips it gracefully
372+
sig = inspect.signature(xet_get)
373+
assert "_tqdm_bar" in sig.parameters, (
374+
"file_download.xet_get no longer accepts _tqdm_bar — progress patch for xet downloads is broken"
375+
)

backend/tests/test_response_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_snake_case_keys(self, client, test_state):
3939
file_type="checkpoint",
4040
target_path="checkpoint",
4141
downloaded_bytes=5_000_000_000,
42-
speed_mbps=50,
42+
speed_bytes_per_sec=50_000_000.0,
4343
),
4444
files_to_download={"checkpoint"},
4545
completed_files=set(),
@@ -60,7 +60,7 @@ def test_snake_case_keys(self, client, test_state):
6060
"completed_files",
6161
"all_files",
6262
"error",
63-
"speed_mbps",
63+
"speed_bytes_per_sec",
6464
}
6565
assert set(data.keys()) == expected_keys
6666

frontend/components/FirstRunSetup.tsx

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ interface DownloadProgress {
2222
completed_files: string[]
2323
all_files: string[]
2424
error: string | null
25-
speed_mbps: number
25+
speed_bytes_per_sec: number
2626
}
2727

2828
interface RequiredModelsResponse {
@@ -82,11 +82,10 @@ export function LaunchGate({
8282

8383
// Calculate ETA based on speed and remaining bytes
8484
const getTimeRemaining = (): string => {
85-
if (!downloadProgress || downloadProgress.speed_mbps <= 0) return '--'
85+
if (!downloadProgress || downloadProgress.speed_bytes_per_sec <= 0) return '--'
8686
const remainingBytes = downloadProgress.expected_total_bytes - downloadProgress.total_downloaded_bytes
8787
if (remainingBytes <= 0) return '--'
88-
const speedBytesPerSec = downloadProgress.speed_mbps * 1024 * 1024
89-
const secondsRemaining = remainingBytes / speedBytesPerSec
88+
const secondsRemaining = remainingBytes / downloadProgress.speed_bytes_per_sec
9089
return formatTimeRemaining(secondsRemaining)
9190
}
9291

@@ -701,7 +700,7 @@ export function LaunchGate({
701700
{(downloadProgress?.total_progress || 0) > 85 ? 'Installing...' : 'Downloading...'}
702701
</span>
703702
<span style={{ fontSize: 13, color: '#A98BD9', fontWeight: 600 }}>
704-
{downloadProgress?.total_progress || 0}%
703+
{Math.round(downloadProgress?.total_progress || 0)}%
705704
</span>
706705
</div>
707706

@@ -739,17 +738,17 @@ export function LaunchGate({
739738

740739
{/* Speed and ETA */}
741740
<div style={{ display: 'flex', gap: 16, marginLeft: 16, flexShrink: 0 }}>
742-
{downloadProgress && downloadProgress.speed_mbps > 0 && (
741+
{downloadProgress && downloadProgress.speed_bytes_per_sec > 0 && (
743742
<span style={{ color: '#6D28D9', fontWeight: 500 }}>
744-
{downloadProgress.speed_mbps.toFixed(1)} MB/s
743+
{(downloadProgress.speed_bytes_per_sec / (1024 * 1024)).toFixed(1)} MB/s
745744
</span>
746745
)}
747746
{downloadProgress && downloadProgress.expected_total_bytes > 0 && (
748747
<span>
749748
{formatBytes(downloadProgress.total_downloaded_bytes)} / {formatBytes(downloadProgress.expected_total_bytes)}
750749
</span>
751750
)}
752-
{downloadProgress && downloadProgress.speed_mbps > 0 && (
751+
{downloadProgress && downloadProgress.speed_bytes_per_sec > 0 && (
753752
<span>
754753
ETA: {getTimeRemaining()}
755754
</span>

0 commit comments

Comments
 (0)