Skip to content

Commit cce8ca4

Browse files
committed
support hf artifact URIs for API inference runtime
1 parent 5760ef4 commit cce8ca4

6 files changed

Lines changed: 181 additions & 29 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ quote-style = "double"
5656

5757
[dependency-groups]
5858
api = [
59+
"huggingface-hub>=1.4.1",
5960
"fastapi>=0.121.1",
6061
"python-jose[cryptography]>=3.5.0",
6162
"python-dotenv>=1.2.1",

src/api/deps/inference.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from fastapi import HTTPException, status
66

77
from api.config import Settings, get_settings
8+
from api.inference_artifacts import resolve_artifact_uri
89
from inference.service import InferenceService
910

1011

@@ -30,9 +31,14 @@ def _build_inference_service(
3031
def get_inference_service_dep() -> InferenceService:
3132
settings: Settings = get_settings()
3233
try:
34+
resolved_checkpoint = resolve_artifact_uri(settings.model_checkpoint_path)
35+
resolved_onnx = resolve_artifact_uri(settings.model_onnx_path)
36+
37+
# InferenceService expects a checkpoint path value, but ONNX-only runtime is valid.
38+
checkpoint_arg = resolved_checkpoint or "__missing_checkpoint__.ckpt"
3339
return _build_inference_service(
34-
checkpoint_path=settings.model_checkpoint_path,
35-
onnx_path=settings.model_onnx_path,
40+
checkpoint_path=checkpoint_arg,
41+
onnx_path=resolved_onnx or "",
3642
device=settings.inference_device,
3743
mcts_sims=settings.inference_mcts_sims,
3844
c_puct=settings.inference_c_puct,

src/api/inference_artifacts.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import annotations
2+
3+
import importlib
4+
import os
5+
from functools import lru_cache
6+
from urllib.parse import parse_qs, unquote, urlparse
7+
8+
9+
def _normalize_local_artifact_path(uri: str) -> str:
10+
parsed = urlparse(uri)
11+
if parsed.scheme == "file":
12+
if parsed.netloc not in {"", "localhost"}:
13+
raise ValueError(f"Unsupported artifact URI host: {uri}")
14+
path = unquote(parsed.path)
15+
if path == "":
16+
raise ValueError(f"Invalid artifact URI (empty path): {uri}")
17+
return path
18+
if parsed.scheme == "":
19+
return uri
20+
raise ValueError(
21+
f"Unsupported artifact URI scheme '{parsed.scheme}'. Use local paths, file://, or hf:// URIs."
22+
)
23+
24+
25+
def _parse_hf_uri(
26+
uri: str,
27+
*,
28+
default_repo_id: str | None = None,
29+
default_revision: str | None = None,
30+
) -> tuple[str, str, str | None]:
31+
parsed = urlparse(uri)
32+
if parsed.scheme != "hf":
33+
raise ValueError(f"Invalid HF URI: {uri}")
34+
35+
# Expected format: hf://<owner>/<repo>/<path/to/file>[?revision=<rev>]
36+
payload = f"{parsed.netloc}{parsed.path}".lstrip("/")
37+
segments = [seg for seg in payload.split("/") if seg]
38+
if len(segments) >= 3:
39+
repo_id = f"{segments[0]}/{segments[1]}"
40+
filename = "/".join(segments[2:])
41+
elif len(segments) >= 1 and default_repo_id:
42+
repo_id = default_repo_id
43+
filename = "/".join(segments)
44+
else:
45+
raise ValueError(
46+
f"Invalid HF URI '{uri}'. Expected hf://<owner>/<repo>/<artifact_path>."
47+
)
48+
49+
if filename == "":
50+
raise ValueError(f"Invalid HF URI '{uri}': missing artifact path.")
51+
52+
query = parse_qs(parsed.query)
53+
revision = query.get("revision", [None])[0] or default_revision
54+
return repo_id, filename, revision
55+
56+
57+
@lru_cache(maxsize=128)
58+
def _download_hf_artifact(
59+
*,
60+
repo_id: str,
61+
filename: str,
62+
revision: str | None,
63+
) -> str:
64+
try:
65+
hf_module = importlib.import_module("huggingface_hub")
66+
except ModuleNotFoundError as exc:
67+
raise ValueError(
68+
"huggingface-hub is required for hf:// artifacts. "
69+
"Install it in the API runtime dependencies."
70+
) from exc
71+
72+
hf_hub_download = getattr(hf_module, "hf_hub_download", None)
73+
if hf_hub_download is None:
74+
raise ValueError("huggingface_hub.hf_hub_download is unavailable in this runtime.")
75+
76+
token = os.getenv("HF_TOKEN")
77+
return str(
78+
hf_hub_download(
79+
repo_id=repo_id,
80+
filename=filename,
81+
revision=revision,
82+
token=token if token else None,
83+
)
84+
)
85+
86+
87+
def resolve_artifact_uri(
88+
uri: str | None,
89+
*,
90+
default_repo_id: str | None = None,
91+
default_revision: str | None = None,
92+
) -> str | None:
93+
if uri is None:
94+
return None
95+
cleaned = uri.strip()
96+
if cleaned == "":
97+
return None
98+
99+
parsed = urlparse(cleaned)
100+
if parsed.scheme in {"", "file"}:
101+
return _normalize_local_artifact_path(cleaned)
102+
if parsed.scheme == "hf":
103+
repo_id, filename, revision = _parse_hf_uri(
104+
cleaned,
105+
default_repo_id=default_repo_id,
106+
default_revision=default_revision,
107+
)
108+
return _download_hf_artifact(
109+
repo_id=repo_id,
110+
filename=filename,
111+
revision=revision,
112+
)
113+
raise ValueError(
114+
f"Unsupported artifact URI scheme '{parsed.scheme}'. Use local paths, file://, or hf:// URIs."
115+
)

src/api/modules/matches/model_runtime.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,12 @@
11
from __future__ import annotations
22

33
from functools import lru_cache
4-
from urllib.parse import unquote, urlparse
54

65
from api.db.models import ModelVersion
6+
from api.inference_artifacts import resolve_artifact_uri
77
from inference.service import InferenceService
88

99

10-
def _normalize_local_artifact_path(uri: str | None) -> str | None:
11-
if uri is None:
12-
return None
13-
cleaned = uri.strip()
14-
if cleaned == "":
15-
return None
16-
17-
parsed = urlparse(cleaned)
18-
if parsed.scheme in {"", "file"}:
19-
if parsed.scheme == "file":
20-
if parsed.netloc not in {"", "localhost"}:
21-
raise ValueError(f"Unsupported artifact URI host: {cleaned}")
22-
path = unquote(parsed.path)
23-
if path == "":
24-
raise ValueError(f"Invalid artifact URI (empty path): {cleaned}")
25-
return path
26-
return cleaned
27-
28-
raise ValueError(
29-
f"Unsupported artifact URI scheme '{parsed.scheme}'. Use local paths or file:// URIs."
30-
)
31-
32-
3310
def _runtime_config_from_base(base_service: InferenceService | None) -> tuple[str, int, float, bool]:
3411
if base_service is None:
3512
return "auto", 160, 1.5, True
@@ -65,11 +42,19 @@ def resolve_model_inference_service(
6542
version: ModelVersion,
6643
base_service: InferenceService | None,
6744
) -> InferenceService:
68-
checkpoint_path = _normalize_local_artifact_path(version.checkpoint_uri)
69-
onnx_path = _normalize_local_artifact_path(version.onnx_uri)
45+
checkpoint_path = resolve_artifact_uri(
46+
version.checkpoint_uri,
47+
default_repo_id=version.hf_repo_id,
48+
default_revision=version.hf_revision,
49+
)
50+
onnx_path = resolve_artifact_uri(
51+
version.onnx_uri,
52+
default_repo_id=version.hf_repo_id,
53+
default_revision=version.hf_revision,
54+
)
7055
if checkpoint_path is None and onnx_path is None:
7156
raise ValueError(
72-
f"Model version '{version.name}' has no local checkpoint_uri/onnx_uri configured."
57+
f"Model version '{version.name}' has no usable checkpoint_uri/onnx_uri configured."
7358
)
7459

7560
# Keep runtime knobs aligned with the API default inference service so
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
import unittest
5+
from pathlib import Path
6+
from types import SimpleNamespace
7+
from unittest.mock import MagicMock, patch
8+
9+
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
10+
11+
from api.inference_artifacts import resolve_artifact_uri
12+
13+
14+
class TestApiInferenceArtifacts(unittest.TestCase):
15+
def test_resolve_local_path_returns_same_value(self) -> None:
16+
path = resolve_artifact_uri("checkpoints/model_iter_039.pt")
17+
self.assertEqual(path, "checkpoints/model_iter_039.pt")
18+
19+
def test_resolve_file_uri_returns_local_path(self) -> None:
20+
path = resolve_artifact_uri("file:///var/lib/ataxx/model.ckpt")
21+
self.assertEqual(path, "/var/lib/ataxx/model.ckpt")
22+
23+
@patch("api.inference_artifacts.importlib.import_module")
24+
def test_resolve_hf_uri_downloads_artifact(self, mock_download: MagicMock) -> None:
25+
hf_download = MagicMock(return_value="/var/lib/ataxx/model_iter_039.pt")
26+
mock_download.return_value = SimpleNamespace(hf_hub_download=hf_download)
27+
28+
resolved = resolve_artifact_uri("hf://dieg0code/ataxx-zero/model_iter_039.pt")
29+
self.assertEqual(resolved, "/var/lib/ataxx/model_iter_039.pt")
30+
hf_download.assert_called_once_with(
31+
repo_id="dieg0code/ataxx-zero",
32+
filename="model_iter_039.pt",
33+
revision=None,
34+
token=None,
35+
)
36+
37+
def test_rejects_unknown_scheme(self) -> None:
38+
with self.assertRaises(ValueError):
39+
resolve_artifact_uri("s3://bucket/model.pt")
40+
41+
42+
if __name__ == "__main__":
43+
unittest.main()

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)