forked from SearchSavior/OpenArc
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvlm.py
More file actions
282 lines (236 loc) · 11 KB
/
vlm.py
File metadata and controls
282 lines (236 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import asyncio
import base64
import gc
import logging
from io import BytesIO
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union
import numpy as np
import openvino as ov
from openvino_genai import (
GenerationConfig,
VLMPipeline,
)
from PIL import Image
from transformers import AutoTokenizer
from src.server.models.ov_genai import OVGenAI_GenConfig, VLM_VISION_TOKENS
from src.server.utils.chat import flatten_message_content
from src.server.model_registry import ModelRegistry
from src.server.models.registration import ModelLoadConfig
from src.engine.ov_genai.streamers import ChunkStreamer
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class OVGenAI_VLM:
def __init__(self, load_config: ModelLoadConfig):
self.model_path = None
self.tokenizer = None
self.vision_token = None
self.load_config = load_config
def _vision_token_for_index(self, index: int) -> str:
"""
Return the correctly formatted vision token for the given image index.
Handles templates that may contain an index placeholder like '{i}'.
"""
token_template = self.vision_token if self.vision_token is not None else ""
if "{i}" in token_template:
return token_template.replace("{i}", str(index))
return token_template
def prepare_inputs(self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None
) -> Tuple[str, List[ov.Tensor]]:
"""
Parse a messages list and prepare text prompt + image tensors for VLM inference.
Args:
messages: list of messages, optionally containing multimodal content
vision_token: VisionToken enum defining the model's image tag syntax
Returns:
(tokenized_messages, ov_images)
"""
images: List[Image.Image] = []
text_messages: List[Dict[str, Any]] = []
# Step 1: Extract text and images
for idx, message in enumerate(messages):
# Multimodal message (list of dict content items)
if isinstance(message.get("content", ""), list):
text_parts: List[str] = []
for content_item in message["content"]:
if (
isinstance(content_item, dict)
and content_item.get("type") == "image_url"
):
image_url = content_item.get("image_url", {})
# Check for embedded base64 data
if (
isinstance(image_url, dict)
and isinstance(image_url.get("url", ""), str)
and image_url["url"].startswith("data:image/")
):
base64_data = image_url["url"].split(",", 1)
if len(base64_data) > 1:
image_data = base64.b64decode(base64_data[1])
image = Image.open(BytesIO(image_data)).convert("RGB")
images.append(image)
# Insert model-specific image token where this image appears
token_str = self._vision_token_for_index(len(images) - 1)
text_parts.append(f" {token_str} ")
# Handle text segments
elif isinstance(content_item, dict) and content_item.get("type") == "text":
text_parts.append(content_item.get("text", ""))
# Combine extracted text back into a unified string
text_message = message.copy()
text_message["content"] = flatten_message_content(
" ".join([t for t in text_parts if isinstance(t, str)]) if text_parts else ""
)
text_messages.append(text_message)
# Simple text-only message
else:
text_messages.append(
{**message, "content": flatten_message_content(message.get("content"))}
)
# Step 2: Build the chat template prompt using cached tokenizer
tokenizer = self.tokenizer
tokenized_messages: str = tokenizer.apply_chat_template(
text_messages,
tokenize=False,
tools=tools,
add_generation_prompt=True
)
# Step 3: Convert images to OpenVINO Tensors
ov_images: List[ov.Tensor] = []
for img in images:
arr = np.array(img, dtype=np.uint8)
tensor = ov.Tensor(arr)
ov_images.append(tensor)
return tokenized_messages, ov_images
def generate_type(self, gen_config: OVGenAI_GenConfig):
"""
Unified generation method that routes to streaming or non-streaming
based on the stream flag in gen_config. Both paths return an async iterator.
"""
if gen_config.stream:
return self.generate_stream(gen_config)
else:
return self.generate_text(gen_config)
async def generate_text(self, gen_config: OVGenAI_GenConfig) -> AsyncIterator[Union[Dict[str, Any], str]]:
"""
Async non-streaming generation for VLM.
Yields in order: metrics (dict), new_text (str).
"""
try:
generation_kwargs = GenerationConfig(
max_new_tokens=gen_config.max_tokens,
temperature=gen_config.temperature,
top_k=gen_config.top_k,
top_p=gen_config.top_p,
repetition_penalty=gen_config.repetition_penalty,
)
prompt, ov_images = self.prepare_inputs(gen_config.messages, gen_config.tools)
result = await asyncio.to_thread(
self.model_path.generate,
prompt=prompt,
**({'images': ov_images} if len(ov_images) > 0 else {}),
generation_config=generation_kwargs,
)
perf_metrics = result.perf_metrics
text = result.texts[0] if getattr(result, "texts", None) else ""
logger.info(f"[{self.load_config.model_name}] Generation completed, generated {len(text)} characters")
metrics_dict = self.collect_metrics(gen_config, perf_metrics)
yield metrics_dict
yield text
except Exception as e:
logger.error(f"[{self.load_config.model_name}] Error during non-streaming generation: {e}", exc_info=True)
raise
async def generate_stream(self,
gen_config: OVGenAI_GenConfig) -> AsyncIterator[Union[str, Dict[str, Any]]]:
"""
Async streaming generation for VLM.
Yields token chunks (str) as they arrive, then metrics (dict).
"""
generation_kwargs = GenerationConfig(
max_new_tokens=gen_config.max_tokens,
temperature=gen_config.temperature,
top_k=gen_config.top_k,
top_p=gen_config.top_p,
repetition_penalty=gen_config.repetition_penalty,
)
decoder_tokenizer = self.model_path.get_tokenizer()
streamer = ChunkStreamer(decoder_tokenizer, gen_config)
prompt, ov_images = self.prepare_inputs(gen_config.messages, gen_config.tools)
async def _run_generation():
return await asyncio.to_thread(
self.model_path.generate,
prompt=prompt,
**({'images': ov_images} if len(ov_images) > 0 else {}),
generation_config=generation_kwargs,
streamer=streamer,
)
gen_task = asyncio.create_task(_run_generation())
try:
while True:
chunk = await streamer.text_queue.get()
if chunk is None:
break
yield chunk
finally:
result = await gen_task
perf_metrics = result.perf_metrics
metrics = self.collect_metrics(gen_config, perf_metrics)
yield metrics
def collect_metrics(self, gen_config: OVGenAI_GenConfig, perf_metrics) -> Dict[str, Any]:
"""
Collect and format performance metrics into a dictionary.
"""
ttft_seconds = perf_metrics.get_ttft().mean / 1000
input_tokens = perf_metrics.get_num_input_tokens()
prefill_throughput = round(input_tokens / ttft_seconds, 2) if ttft_seconds > 0 else 0
metrics: Dict[str, Any] = {
"load_time (s)": round(perf_metrics.get_load_time() / 1000, 2),
"ttft (s)": round(perf_metrics.get_ttft().mean / 1000, 2),
"tpot (ms)": round(perf_metrics.get_tpot().mean, 5),
"prefill_throughput (tokens/s)": prefill_throughput,
"decode_throughput (tokens/s)": round(perf_metrics.get_throughput().mean, 5),
"decode_duration (s)": round(perf_metrics.get_generate_duration().mean / 1000, 5),
"input_token": input_tokens,
"new_token": perf_metrics.get_num_generated_tokens(),
"total_token": input_tokens + perf_metrics.get_num_generated_tokens(),
"stream": gen_config.stream,
}
if gen_config.stream and hasattr(gen_config, "stream_chunk_tokens"):
metrics["stream_chunk_tokens"] = gen_config.stream_chunk_tokens
return metrics
def load_model(self, loader: ModelLoadConfig):
"""
Load the VLMPipeline and cache the tokenizer and vision token.
"""
try:
logger.info(f"{loader.model_type} on {loader.device} with {loader.runtime_config}")
self.model_path = VLMPipeline(
loader.model_path,
loader.device,
**(loader.runtime_config or {})
)
self.tokenizer = AutoTokenizer.from_pretrained(loader.model_path)
# Get vision token from the mapping using vlm_type as key
self.vision_token = VLM_VISION_TOKENS.get(loader.vlm_type)
if self.vision_token is None:
raise ValueError(f"Unknown VLM type: {loader.vlm_type}. Supported: {list(VLM_VISION_TOKENS.keys())}")
logger.info(f"{loader.model_name} loaded successfully")
except Exception as e:
logger.error(f"[{loader.model_name}] Failed to initialize VLMPipeline: {e}", exc_info=True)
async def unload_model(self, registry: ModelRegistry, model_name: str) -> bool:
"""
Unregister model from registry and free memory resources.
"""
removed = await registry.register_unload(model_name)
if self.model_path is not None:
del self.model_path
self.model_path = None
if self.tokenizer is not None:
del self.tokenizer
self.tokenizer = None
if self.vision_token is not None:
del self.vision_token
self.vision_token = None
gc.collect()
logger.info(f"[{self.load_config.model_name}] unloaded successfully")
return removed