-
Notifications
You must be signed in to change notification settings - Fork 74
feat: add ModelWeightsManager to auto-sync HF weights on endpoint creation #761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
3e7a2a4
b14deb6
493d1cb
3a18aee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| import asyncio | ||
| import functools | ||
| import tempfile | ||
| from typing import List | ||
|
|
||
| from huggingface_hub import snapshot_download | ||
| from model_engine_server.common.config import hmi_config | ||
| from model_engine_server.core.loggers import logger_name, make_logger | ||
| from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway | ||
|
|
||
| logger = make_logger(logger_name()) | ||
|
|
||
| # Match the internal sync_model_weights.py inclusion/exclusion patterns | ||
| HF_IGNORE_PATTERNS: List[str] = [ | ||
| "optimizer*", | ||
| "*.msgpack", | ||
| "*.h5", | ||
| "flax_model*", | ||
| "tf_model*", | ||
| "rust_model*", | ||
| ] | ||
|
|
||
|
|
||
| class ModelWeightsManager: | ||
| def __init__(self, llm_artifact_gateway: LLMArtifactGateway): | ||
| self.llm_artifact_gateway = llm_artifact_gateway | ||
|
|
||
| def _get_remote_path(self, hf_repo: str) -> str: | ||
| prefix = hmi_config.hf_user_fine_tuned_weights_prefix.rstrip("/") | ||
| return f"{prefix}/{hf_repo}" | ||
|
|
||
| async def ensure_model_weights_available(self, hf_repo: str) -> str: | ||
| """ | ||
| Ensures model weights for ``hf_repo`` are available at the configured remote path. | ||
|
|
||
| If the weights are already cached (remote path is non-empty), returns immediately. | ||
| Otherwise downloads from HuggingFace Hub and uploads to the remote path. | ||
|
|
||
| Args: | ||
| hf_repo: HuggingFace repository ID, e.g. ``"meta-llama/Meta-Llama-3-8B"``. | ||
|
|
||
| Returns: | ||
| The remote path (s3://, gs://, or https://) where the weights are stored. | ||
| """ | ||
| remote_path = self._get_remote_path(hf_repo) | ||
| files = self.llm_artifact_gateway.list_files(remote_path) | ||
| if files: | ||
| logger.info(f"Cache hit: {len(files)} files at {remote_path}") | ||
| return remote_path | ||
|
|
||
| logger.info(f"Cache miss for {hf_repo}. Downloading from HuggingFace Hub...") | ||
| loop = asyncio.get_event_loop() | ||
| with tempfile.TemporaryDirectory() as tmp_dir: | ||
| await loop.run_in_executor( | ||
| None, | ||
| functools.partial( | ||
| snapshot_download, | ||
| repo_id=hf_repo, | ||
| local_dir=tmp_dir, | ||
| ignore_patterns=HF_IGNORE_PATTERNS, | ||
| ), | ||
| ) | ||
| await loop.run_in_executor( | ||
| None, | ||
| functools.partial( | ||
| self.llm_artifact_gateway.upload_files, | ||
| tmp_dir, | ||
| remote_path, | ||
| ), | ||
| ) | ||
|
Comment on lines
+79
to
+98
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unhandled errors from If try:
await loop.run_in_executor(...)
await loop.run_in_executor(...)
except Exception as e:
logger.error(f"Failed to download/upload weights for {hf_repo}: {e}")
raise ObjectHasInvalidValueException(
f"Could not download model weights for {hf_repo}. "
"Ensure the model is accessible and try again, or provide a checkpoint_path explicitly."
)Prompt To Fix With AIThis is a comment left during a code review.
Path: model-engine/model_engine_server/domain/use_cases/model_weights_manager.py
Line: 51-70
Comment:
**Unhandled errors from `snapshot_download` will crash endpoint creation**
If `snapshot_download` fails (e.g., gated model requiring HF auth token, rate limiting, network timeout), the exception propagates uncaught and returns a 500 to the caller. Many models in `SUPPORTED_MODELS_INFO` (like `meta-llama/*`) are gated and require authentication. Consider wrapping this in a try/except that logs the error and either raises a user-friendly error or falls back to `checkpoint_path = None` (allowing downstream logic to handle it):
```python
try:
await loop.run_in_executor(...)
await loop.run_in_executor(...)
except Exception as e:
logger.error(f"Failed to download/upload weights for {hf_repo}: {e}")
raise ObjectHasInvalidValueException(
f"Could not download model weights for {hf_repo}. "
"Ensure the model is accessible and try again, or provide a checkpoint_path explicitly."
)
```
How can I resolve this? If you propose a fix, please make it concise. |
||
|
|
||
| logger.info(f"Weights for {hf_repo} uploaded to {remote_path}") | ||
| return remote_path | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list_files()blocks the event loopsnapshot_downloadandupload_filesare correctly offloaded viarun_in_executor, butlist_files()on line 46 is a synchronous I/O call (S3ListObjects/ GCSlist_blobs/ ABSlist_blob_names) that runs directly on the async event loop. For consistency with the other two calls, this should also be wrapped inrun_in_executor:Note:
loopwould need to be obtained before this line — move theloop = asyncio.get_event_loop()line above this call.Prompt To Fix With AI