Skip to content

Commit 01a4099

Browse files
committed
Use HuggingFace Hub for checkpoint download
HuggingFace has no bandwidth limits for public models, unlike GitHub LFS (1 GB/month free). huggingface_hub is already a transitive dep of transformers so this adds no new dependencies. Upload checkpoint with: huggingface-cli upload ActiveVisionLab/CrossScore ckpt/CrossScore-v1.0.0.ckpt CrossScore-v1.0.0.ckpt https://claude.ai/code/session_0114iFoswRfTkMai4JTrgMrB
1 parent bc61416 commit 01a4099

File tree

1 file changed

+9
-47
lines changed

1 file changed

+9
-47
lines changed

crossscore/_download.py

Lines changed: 9 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,18 @@
11
"""Utilities for downloading CrossScore model checkpoints."""
22

33
import os
4-
import urllib.request
5-
import shutil
64
from pathlib import Path
75

8-
# Download directly from GitHub (served via Git LFS)
9-
CHECKPOINT_URL = (
10-
"https://github.com/ActiveVisionLab/CrossScore/raw/main/ckpt/CrossScore-v1.0.0.ckpt"
11-
)
6+
HF_REPO_ID = "ActiveVisionLab/CrossScore"
127
CHECKPOINT_FILENAME = "CrossScore-v1.0.0.ckpt"
138

149

15-
def get_cache_dir() -> Path:
16-
"""Return the cache directory for CrossScore model checkpoints."""
17-
cache_dir = Path(os.environ.get("CROSSSCORE_CACHE_DIR", Path.home() / ".cache" / "crossscore"))
18-
cache_dir.mkdir(parents=True, exist_ok=True)
19-
return cache_dir
20-
21-
2210
def get_checkpoint_path() -> str:
2311
"""Get path to the CrossScore checkpoint, downloading it if necessary.
2412
25-
Downloads from GitHub (Git LFS) on first use and caches locally at
26-
~/.cache/crossscore/. Set environment variables to customize:
13+
Downloads from HuggingFace Hub on first use and caches locally.
14+
Set environment variables to customize:
2715
CROSSSCORE_CKPT_PATH - use a specific local checkpoint file
28-
CROSSSCORE_CACHE_DIR - custom cache directory
2916
3017
Returns:
3118
Path to the checkpoint file.
@@ -37,35 +24,10 @@ def get_checkpoint_path() -> str:
3724
raise FileNotFoundError(f"Checkpoint not found at CROSSSCORE_CKPT_PATH={custom_path}")
3825
return custom_path
3926

40-
cache_dir = get_cache_dir()
41-
ckpt_path = cache_dir / CHECKPOINT_FILENAME
42-
43-
if ckpt_path.exists():
44-
return str(ckpt_path)
45-
46-
print(f"Downloading CrossScore checkpoint (~129MB)...")
47-
print(f" From: {CHECKPOINT_URL}")
48-
print(f" To: {ckpt_path}")
49-
print(" (Set CROSSSCORE_CKPT_PATH to skip download and use a local file)")
50-
51-
tmp_path = str(ckpt_path) + ".tmp"
52-
try:
53-
urllib.request.urlretrieve(CHECKPOINT_URL, tmp_path, _download_progress)
54-
os.rename(tmp_path, str(ckpt_path))
55-
except Exception:
56-
if os.path.exists(tmp_path):
57-
os.remove(tmp_path)
58-
raise
59-
60-
print(f"\n Download complete.")
61-
return str(ckpt_path)
62-
27+
from huggingface_hub import hf_hub_download
6328

64-
def _download_progress(block_count, block_size, total_size):
65-
"""Progress callback for urlretrieve."""
66-
downloaded = block_count * block_size
67-
if total_size > 0:
68-
pct = min(100, downloaded * 100 // total_size)
69-
mb_done = downloaded / (1024 * 1024)
70-
mb_total = total_size / (1024 * 1024)
71-
print(f"\r {mb_done:.1f}/{mb_total:.1f} MB ({pct}%)", end="", flush=True)
29+
path = hf_hub_download(
30+
repo_id=HF_REPO_ID,
31+
filename=CHECKPOINT_FILENAME,
32+
)
33+
return path

0 commit comments

Comments
 (0)