11"""Utilities for downloading CrossScore model checkpoints."""
22
33import os
4- import urllib .request
5- import shutil
64from pathlib import Path
75
8- # Download directly from GitHub (served via Git LFS)
9- CHECKPOINT_URL = (
10- "https://github.com/ActiveVisionLab/CrossScore/raw/main/ckpt/CrossScore-v1.0.0.ckpt"
11- )
6+ HF_REPO_ID = "ActiveVisionLab/CrossScore"
127CHECKPOINT_FILENAME = "CrossScore-v1.0.0.ckpt"
138
149
15- def get_cache_dir () -> Path :
16- """Return the cache directory for CrossScore model checkpoints."""
17- cache_dir = Path (os .environ .get ("CROSSSCORE_CACHE_DIR" , Path .home () / ".cache" / "crossscore" ))
18- cache_dir .mkdir (parents = True , exist_ok = True )
19- return cache_dir
20-
21-
2210def get_checkpoint_path () -> str :
2311 """Get path to the CrossScore checkpoint, downloading it if necessary.
2412
25- Downloads from GitHub (Git LFS) on first use and caches locally at
26- ~/.cache/crossscore/. Set environment variables to customize:
13+ Downloads from HuggingFace Hub on first use and caches locally.
14+ Set environment variables to customize:
2715 CROSSSCORE_CKPT_PATH - use a specific local checkpoint file
28- CROSSSCORE_CACHE_DIR - custom cache directory
2916
3017 Returns:
3118 Path to the checkpoint file.
@@ -37,35 +24,10 @@ def get_checkpoint_path() -> str:
3724 raise FileNotFoundError (f"Checkpoint not found at CROSSSCORE_CKPT_PATH={ custom_path } " )
3825 return custom_path
3926
40- cache_dir = get_cache_dir ()
41- ckpt_path = cache_dir / CHECKPOINT_FILENAME
42-
43- if ckpt_path .exists ():
44- return str (ckpt_path )
45-
46- print (f"Downloading CrossScore checkpoint (~129MB)..." )
47- print (f" From: { CHECKPOINT_URL } " )
48- print (f" To: { ckpt_path } " )
49- print (" (Set CROSSSCORE_CKPT_PATH to skip download and use a local file)" )
50-
51- tmp_path = str (ckpt_path ) + ".tmp"
52- try :
53- urllib .request .urlretrieve (CHECKPOINT_URL , tmp_path , _download_progress )
54- os .rename (tmp_path , str (ckpt_path ))
55- except Exception :
56- if os .path .exists (tmp_path ):
57- os .remove (tmp_path )
58- raise
59-
60- print (f"\n Download complete." )
61- return str (ckpt_path )
62-
27+ from huggingface_hub import hf_hub_download
6328
64- def _download_progress (block_count , block_size , total_size ):
65- """Progress callback for urlretrieve."""
66- downloaded = block_count * block_size
67- if total_size > 0 :
68- pct = min (100 , downloaded * 100 // total_size )
69- mb_done = downloaded / (1024 * 1024 )
70- mb_total = total_size / (1024 * 1024 )
71- print (f"\r { mb_done :.1f} /{ mb_total :.1f} MB ({ pct } %)" , end = "" , flush = True )
29+ path = hf_hub_download (
30+ repo_id = HF_REPO_ID ,
31+ filename = CHECKPOINT_FILENAME ,
32+ )
33+ return path
0 commit comments