forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm_utils.py
More file actions
158 lines (140 loc) · 5.69 KB
/
llm_utils.py
File metadata and controls
158 lines (140 loc) · 5.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
from __future__ import absolute_import
import os
from typing import Optional
import importlib.util
import urllib.request
from urllib.error import HTTPError, URLError
import json
from json import JSONDecodeError
import logging
from sagemaker import image_uris
from sagemaker.session import Session
logger = logging.getLogger(__name__)
def get_huggingface_llm_image_uri(
backend: str,
session: Optional[Session] = None,
region: Optional[str] = None,
version: Optional[str] = None,
) -> str:
"""Retrieves the image URI for inference.
Args:
backend (str): The backend to use. Valid values include "huggingface" and "lmi".
session (Session): The SageMaker Session to use. (Default: None).
region (str): The AWS region to use for image URI. (default: None).
version (str): The framework version for which to retrieve an
image URI. If no version is set, defaults to latest version. (default: None).
Returns:
str: The image URI string.
"""
if region is None:
if session is None:
region = Session().boto_session.region_name
else:
region = session.boto_session.region_name
if backend == "huggingface":
return image_uris.retrieve(
"huggingface-llm",
region=region,
version=version,
image_scope="inference",
)
if backend == "huggingface-neuronx":
return image_uris.retrieve(
"huggingface-llm-neuronx",
region=region,
version=version,
image_scope="inference",
inference_tool="neuronx",
)
if backend == "huggingface-vllm-neuronx":
return image_uris.retrieve(
"huggingface-vllm-neuronx",
region=region,
version=version,
image_scope="inference",
inference_tool="neuronx",
)
if backend == "huggingface-tei":
return image_uris.retrieve(
"huggingface-tei",
region=region,
version=version,
image_scope="inference",
)
if backend == "huggingface-tei-cpu":
return image_uris.retrieve(
"huggingface-tei-cpu",
region=region,
version=version,
image_scope="inference",
)
if backend == "lmi":
version = version or "0.24.0"
return image_uris.retrieve(framework="djl-deepspeed", region=region, version=version)
raise ValueError("Unsupported backend: %s" % backend)
def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = None) -> dict:
"""Retrieves the json metadata of the HuggingFace Model via HuggingFace API.
Args:
model_id (str): The HuggingFace Model ID
hf_hub_token (str): The HuggingFace Hub Token needed for Private/Gated HuggingFace Models
Returns:
dict: The model metadata retrieved with the HuggingFace API
"""
if not model_id:
raise ValueError("Model ID is empty. Please provide a valid Model ID.")
hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}"
hf_model_metadata_json = None
try:
if hf_hub_token:
hf_model_metadata_url = urllib.request.Request(
hf_model_metadata_url, None, {"Authorization": "Bearer " + hf_hub_token}
)
with urllib.request.urlopen(hf_model_metadata_url) as response:
hf_model_metadata_json = json.load(response)
except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e:
if "HTTP Error 401: Unauthorized" in str(e):
raise ValueError(
"Trying to access a gated/private HuggingFace model without valid credentials. "
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
)
logger.warning(
"Exception encountered while trying to retrieve HuggingFace model metadata %s. "
"Details: %s",
hf_model_metadata_url,
e,
)
if not hf_model_metadata_json:
raise ValueError(
"Did not find model metadata for the following HuggingFace Model ID %s" % model_id
)
return hf_model_metadata_json
def download_huggingface_model_metadata(
model_id: str, model_local_path: str, hf_hub_token: Optional[str] = None
) -> None:
"""Downloads the HuggingFace Model snapshot via HuggingFace API.
Args:
model_id (str): The HuggingFace Model ID
model_local_path (str): The local path to save the HuggingFace Model snapshot.
hf_hub_token (str): The HuggingFace Hub Token
Raises:
ImportError: If huggingface_hub is not installed.
"""
if not importlib.util.find_spec("huggingface_hub"):
raise ImportError("Unable to import huggingface_hub, check if huggingface_hub is installed")
from huggingface_hub import snapshot_download
os.makedirs(model_local_path, exist_ok=True)
logger.info("Downloading model %s from Hugging Face Hub to %s", model_id, model_local_path)
snapshot_download(repo_id=model_id, local_dir=model_local_path, token=hf_hub_token)