Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
UpdateLLMModelEndpointV1UseCase,
)
from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase
from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager
from pydantic import RootModel
from sse_starlette.sse import EventSourceResponse

Expand Down Expand Up @@ -168,11 +169,15 @@ async def create_model_endpoint(
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
docker_repository=external_interfaces.docker_repository,
)
model_weights_manager = ModelWeightsManager(
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
)
use_case = CreateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=create_llm_model_bundle_use_case,
model_endpoint_service=external_interfaces.model_endpoint_service,
docker_repository=external_interfaces.docker_repository,
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
model_weights_manager=model_weights_manager,
)
return await use_case.execute(user=auth, request=request)
except ObjectAlreadyExistsException as exc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[
"""
pass

@abstractmethod
def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
"""
Upload all files from a local directory to a remote path.

Args:
local_path (str): local directory containing files to upload
remote_path (str): remote destination path (s3://, gs://, or https://)
"""
pass

@abstractmethod
def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1322,12 +1322,14 @@ def __init__(
model_endpoint_service: ModelEndpointService,
docker_repository: DockerRepository,
llm_artifact_gateway: LLMArtifactGateway,
model_weights_manager=None,
):
self.authz_module = LiveAuthorizationModule()
self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case
self.model_endpoint_service = model_endpoint_service
self.docker_repository = docker_repository
self.llm_artifact_gateway = llm_artifact_gateway
self.model_weights_manager = model_weights_manager

async def execute(
self, user: User, request: CreateLLMModelEndpointV1Request
Expand Down Expand Up @@ -1387,6 +1389,19 @@ async def execute(
"Multinode endpoints are only supported for VLLM models."
)

# Resolve checkpoint path: auto-download from HF Hub to remote storage if not cached
checkpoint_path = request.checkpoint_path
if (
checkpoint_path is None
and request.source == LLMSource.HUGGING_FACE
and self.model_weights_manager is not None
):
models_info = SUPPORTED_MODELS_INFO.get(request.model_name)
if models_info and models_info.hf_repo:
checkpoint_path = await self.model_weights_manager.ensure_model_weights_available(
hf_repo=models_info.hf_repo
)

bundle = await self.create_llm_model_bundle_use_case.execute(
user,
endpoint_name=request.name,
Expand All @@ -1397,7 +1412,7 @@ async def execute(
endpoint_type=request.endpoint_type,
num_shards=request.num_shards,
quantize=request.quantize,
checkpoint_path=request.checkpoint_path,
checkpoint_path=checkpoint_path,
chat_template_override=request.chat_template_override,
nodes_per_worker=request.nodes_per_worker,
additional_args=request.model_dump(exclude_none=True),
Expand Down Expand Up @@ -1430,7 +1445,7 @@ async def execute(
inference_framework_image_tag=request.inference_framework_image_tag,
num_shards=request.num_shards,
quantize=request.quantize,
checkpoint_path=request.checkpoint_path,
checkpoint_path=checkpoint_path,
chat_template_override=request.chat_template_override,
)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import asyncio
import functools
import tempfile
from typing import List

from huggingface_hub import snapshot_download
from model_engine_server.common.config import hmi_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway

logger = make_logger(logger_name())

# Match the internal sync_model_weights.py inclusion/exclusion patterns
HF_IGNORE_PATTERNS: List[str] = [
"optimizer*",
"*.msgpack",
"*.h5",
"flax_model*",
"tf_model*",
"rust_model*",
]


class ModelWeightsManager:
def __init__(self, llm_artifact_gateway: LLMArtifactGateway):
self.llm_artifact_gateway = llm_artifact_gateway

def _get_remote_path(self, hf_repo: str) -> str:
prefix = hmi_config.hf_user_fine_tuned_weights_prefix.rstrip("/")
return f"{prefix}/{hf_repo}"

async def ensure_model_weights_available(self, hf_repo: str) -> str:
"""
Ensures model weights for ``hf_repo`` are available at the configured remote path.

If the weights are already cached (remote path is non-empty), returns immediately.
Otherwise downloads from HuggingFace Hub and uploads to the remote path.

Args:
hf_repo: HuggingFace repository ID, e.g. ``"meta-llama/Meta-Llama-3-8B"``.

Returns:
The remote path (s3://, gs://, or https://) where the weights are stored.
"""
remote_path = self._get_remote_path(hf_repo)
files = self.llm_artifact_gateway.list_files(remote_path)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list_files() blocks the event loop

snapshot_download and upload_files are correctly offloaded via run_in_executor, but list_files() on line 46 is a synchronous I/O call (S3 ListObjects / GCS list_blobs / ABS list_blob_names) that runs directly on the async event loop. For consistency with the other two calls, this should also be wrapped in run_in_executor:

Suggested change
files = self.llm_artifact_gateway.list_files(remote_path)
files = await loop.run_in_executor(
None,
functools.partial(self.llm_artifact_gateway.list_files, remote_path),
)

Note: loop would need to be obtained before this line — move the loop = asyncio.get_event_loop() line above this call.

Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/domain/use_cases/model_weights_manager.py
Line: 46

Comment:
**`list_files()` blocks the event loop**

`snapshot_download` and `upload_files` are correctly offloaded via `run_in_executor`, but `list_files()` on line 46 is a synchronous I/O call (S3 `ListObjects` / GCS `list_blobs` / ABS `list_blob_names`) that runs directly on the async event loop. For consistency with the other two calls, this should also be wrapped in `run_in_executor`:

```suggestion
        files = await loop.run_in_executor(
            None,
            functools.partial(self.llm_artifact_gateway.list_files, remote_path),
        )
```

Note: `loop` would need to be obtained before this line — move the `loop = asyncio.get_event_loop()` line above this call.

How can I resolve this? If you propose a fix, please make it concise.

if files:
logger.info(f"Cache hit: {len(files)} files at {remote_path}")
return remote_path

logger.info(f"Cache miss for {hf_repo}. Downloading from HuggingFace Hub...")
loop = asyncio.get_event_loop()
with tempfile.TemporaryDirectory() as tmp_dir:
await loop.run_in_executor(
None,
functools.partial(
snapshot_download,
repo_id=hf_repo,
local_dir=tmp_dir,
ignore_patterns=HF_IGNORE_PATTERNS,
),
)
await loop.run_in_executor(
None,
functools.partial(
self.llm_artifact_gateway.upload_files,
tmp_dir,
remote_path,
),
)
Comment on lines +79 to +98
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unhandled errors from snapshot_download will crash endpoint creation

If snapshot_download fails (e.g., gated model requiring HF auth token, rate limiting, network timeout), the exception propagates uncaught and returns a 500 to the caller. Many models in SUPPORTED_MODELS_INFO (like meta-llama/*) are gated and require authentication. Consider wrapping this in a try/except that logs the error and either raises a user-friendly error or falls back to checkpoint_path = None (allowing downstream logic to handle it):

try:
    await loop.run_in_executor(...)
    await loop.run_in_executor(...)
except Exception as e:
    logger.error(f"Failed to download/upload weights for {hf_repo}: {e}")
    raise ObjectHasInvalidValueException(
        f"Could not download model weights for {hf_repo}. "
        "Ensure the model is accessible and try again, or provide a checkpoint_path explicitly."
    )
Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/domain/use_cases/model_weights_manager.py
Line: 51-70

Comment:
**Unhandled errors from `snapshot_download` will crash endpoint creation**

If `snapshot_download` fails (e.g., gated model requiring HF auth token, rate limiting, network timeout), the exception propagates uncaught and returns a 500 to the caller. Many models in `SUPPORTED_MODELS_INFO` (like `meta-llama/*`) are gated and require authentication. Consider wrapping this in a try/except that logs the error and either raises a user-friendly error or falls back to `checkpoint_path = None` (allowing downstream logic to handle it):

```python
try:
    await loop.run_in_executor(...)
    await loop.run_in_executor(...)
except Exception as e:
    logger.error(f"Failed to download/upload weights for {hf_repo}: {e}")
    raise ObjectHasInvalidValueException(
        f"Could not download model weights for {hf_repo}. "
        "Ensure the model is accessible and try again, or provide a checkpoint_path explicitly."
    )
```

How can I resolve this? If you propose a fix, please make it concise.


logger.info(f"Weights for {hf_repo} uploaded to {remote_path}")
return remote_path
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs)
downloaded_files.append(local_path)
return downloaded_files

def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
parsed = parse_attachment_url(remote_path, clean_key=False)
container_client = _get_abs_container_client(parsed.bucket)
for root, _, files in os.walk(local_path):
for file in files:
local_file = os.path.join(root, file)
blob_name = os.path.join(parsed.key, os.path.relpath(local_file, local_path))
with open(local_file, "rb") as f:
container_client.upload_blob(name=blob_name, data=f, overwrite=True)

def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
parsed_remote = parse_attachment_url(
hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs)
downloaded_files.append(local_path)
return downloaded_files

def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
parsed = parse_attachment_url(remote_path, clean_key=False)
client = get_gcs_sync_client()
bucket = client.bucket(parsed.bucket)
for root, _, files in os.walk(local_path):
for file in files:
local_file = os.path.join(root, file)
blob_name = os.path.join(parsed.key, os.path.relpath(local_file, local_path))
bucket.blob(blob_name).upload_from_filename(local_file)

def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
parsed_remote = parse_attachment_url(
hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs)
logger.info(f"Downloaded {len(downloaded_files)} files to {target_path}")
return downloaded_files

def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
s3 = get_s3_resource(kwargs)
parsed = parse_attachment_url(remote_path, clean_key=False)
bucket = s3.Bucket(parsed.bucket)
for root, _, files in os.walk(local_path):
for file in files:
local_file = os.path.join(root, file)
s3_key = os.path.join(parsed.key, os.path.relpath(local_file, local_path))
logger.info(f"Uploading {local_file} → s3://{parsed.bucket}/{s3_key}")
bucket.upload_file(local_file, s3_key)

def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
s3 = get_s3_resource(kwargs)
parsed_remote = parse_attachment_url(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from typing import Dict, NamedTuple, Optional

from huggingface_hub import list_repo_refs
from huggingface_hub.utils._errors import RepositoryNotFoundError

try:
from huggingface_hub.utils._errors import RepositoryNotFoundError
except ImportError:
from huggingface_hub.errors import RepositoryNotFoundError
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.exceptions import ObjectNotFoundException
from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway
Expand Down
3 changes: 3 additions & 0 deletions model-engine/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,9 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs)
if path in self.s3_bucket:
return self.s3_bucket[path]

def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
pass

def get_model_weights_urls(self, owner: str, model_name: str):
if (owner, model_name) in self.existing_models:
return self.urls
Expand Down
Loading