-
Notifications
You must be signed in to change notification settings - Fork 74
Expand file tree
/
Copy pathabs_llm_artifact_gateway.py
More file actions
95 lines (75 loc) · 4.06 KB
/
abs_llm_artifact_gateway.py
File metadata and controls
95 lines (75 loc) · 4.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import json
import os
from typing import Any, Dict, List
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient, ContainerClient
from model_engine_server.common.config import get_model_cache_directory_name, hmi_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.core.utils.url import parse_attachment_url
from model_engine_server.domain.gateways import LLMArtifactGateway
logger = make_logger(logger_name())
def _get_abs_container_client(bucket: str) -> ContainerClient:
blob_service_client = BlobServiceClient(
f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net",
DefaultAzureCredential(),
)
return blob_service_client.get_container_client(container=bucket)
class ABSLLMArtifactGateway(LLMArtifactGateway):
"""
Concrete implemention using Azure Blob Storage.
"""
def list_files(self, path: str, **kwargs) -> List[str]:
parsed_remote = parse_attachment_url(path, clean_key=False)
bucket = parsed_remote.bucket
key = parsed_remote.key
container_client = _get_abs_container_client(bucket)
return list(container_client.list_blob_names(name_starts_with=key))
def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]:
parsed_remote = parse_attachment_url(path, clean_key=False)
bucket = parsed_remote.bucket
key = parsed_remote.key
container_client = _get_abs_container_client(bucket)
downloaded_files: List[str] = []
for blob in container_client.list_blobs(name_starts_with=key):
file_path_suffix = blob.name.replace(key, "").lstrip("/")
local_path = os.path.join(target_path, file_path_suffix).rstrip("/")
if not overwrite and os.path.exists(local_path):
downloaded_files.append(local_path)
continue
local_dir = "/".join(local_path.split("/")[:-1])
if not os.path.exists(local_dir):
os.makedirs(local_dir)
logger.info(f"Downloading {blob.name} to {local_path}")
with open(file=local_path, mode="wb") as f:
f.write(container_client.download_blob(blob.name).readall())
downloaded_files.append(local_path)
return downloaded_files
def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
parsed = parse_attachment_url(remote_path, clean_key=False)
container_client = _get_abs_container_client(parsed.bucket)
for root, _, files in os.walk(local_path):
for file in files:
local_file = os.path.join(root, file)
blob_name = os.path.join(parsed.key, os.path.relpath(local_file, local_path))
with open(local_file, "rb") as f:
container_client.upload_blob(name=blob_name, data=f, overwrite=True)
def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
parsed_remote = parse_attachment_url(
hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False
)
account = parsed_remote.account
bucket = parsed_remote.bucket
fine_tuned_weights_prefix = parsed_remote.key
container_client = _get_abs_container_client(bucket)
model_files: List[str] = []
model_cache_name = get_model_cache_directory_name(model_name)
prefix = f"{fine_tuned_weights_prefix}/{owner}/{model_cache_name}"
for blob_name in container_client.list_blob_names(name_starts_with=prefix):
model_files.append(f"https://{account}.blob.core.windows.net/{bucket}/{blob_name}")
return model_files
def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]:
parsed_remote = parse_attachment_url(path, clean_key=False)
bucket = parsed_remote.bucket
key = os.path.join(parsed_remote.key, "config.json")
container_client = _get_abs_container_client(bucket)
return json.loads(container_client.download_blob(blob=key).readall())