Skip to content

Commit 24de915

Browse files
authored
feat: nvimgcodec + threadpool for image processor (#10)
* fix * ups Signed-off-by: Qidong Su <soodoshll@gmail.com> * upd * update * fix * upd * fix * upd * upd * upd * upd * add limit * simplify * fix * upd * fix * fix * clean * format * clean * upd --------- Signed-off-by: Qidong Su <soodoshll@gmail.com>
1 parent 876e6d6 commit 24de915

4 files changed

Lines changed: 173 additions & 23 deletions

File tree

components/src/dynamo/vllm/handlers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,6 +1363,11 @@ async def _generate_token_mode(self, request, context, request_id):
13631363
logger.warning("Initiating Dynamo Runtime shutdown.")
13641364
self.runtime.shutdown()
13651365
os._exit(1)
1366+
finally:
1367+
if multi_modal_data is not None:
1368+
images = multi_modal_data.get("image")
1369+
count = len(images) if isinstance(images, list) else 1
1370+
self.image_loader.mark_consumed(count)
13661371

13671372
async def _generate_text_mode(self, request, context, request_id):
13681373
"""Generate text using OpenAI-compatible format (text-in-text-out)."""
@@ -1455,6 +1460,11 @@ async def _generate_text_mode(self, request, context, request_id):
14551460
logger.warning("Initiating Dynamo Runtime shutdown.")
14561461
self.runtime.shutdown()
14571462
os._exit(1)
1463+
finally:
1464+
if multi_modal_data is not None:
1465+
images = multi_modal_data.get("image")
1466+
count = len(images) if isinstance(images, list) else 1
1467+
self.image_loader.mark_consumed(count)
14581468

14591469

14601470
class PrefillWorkerHandler(BaseWorkerHandler):
@@ -1608,3 +1618,8 @@ async def _generate_token_mode(self, request, context, request_id):
16081618
raise GeneratorExit(
16091619
"Prefill engine was shut down during token generation"
16101620
) from None
1621+
finally:
1622+
if multi_modal_data is not None:
1623+
images = multi_modal_data.get("image")
1624+
count = len(images) if isinstance(images, list) else 1
1625+
self.image_loader.mark_consumed(count)

components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
self.engine_args = engine_args
5959
self.model = self.engine_args.model
6060

61-
self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM)
61+
self.image_loader = ImageLoader()
6262
self.image_processor = AutoImageProcessor.from_pretrained(
6363
self.model, trust_remote_code=True
6464
)

components/src/dynamo/vllm/multimodal_handlers/worker_handler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ async def generate(self, request: vLLMMultimodalRequest, context):
170170
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
171171

172172
multi_modal_data = defaultdict(list)
173+
num_loaded_images = 0
173174
for mi in request.multimodal_inputs:
174175
# ECConnector consumer mode: vLLM loads embeddings automatically from disk
175176
# We need to pass multimodal_input so vLLM can generate mm_hash and look up cache
@@ -273,6 +274,8 @@ async def generate(self, request: vLLMMultimodalRequest, context):
273274
await self.image_loader.load_image(mi.multimodal_input.image_url)
274275
)
275276

277+
num_loaded_images += 1
278+
276279
# Remove the image features from the request as they are not required
277280
request.multimodal_inputs = None
278281

@@ -362,3 +365,5 @@ async def generate(self, request: vLLMMultimodalRequest, context):
362365
metrics=response.metrics,
363366
kv_transfer_params=response.kv_transfer_params,
364367
).model_dump_json()
368+
369+
self.image_loader.mark_consumed(num_loaded_images)

components/src/dynamo/vllm/multimodal_utils/image_loader.py

Lines changed: 152 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,144 @@
1717
import base64
1818
import binascii
1919
import logging
20+
import os
21+
import threading
22+
from concurrent.futures import ThreadPoolExecutor
2023
from io import BytesIO
24+
from typing import TypeAlias, Union
2125
from urllib.parse import urlparse
2226

2327
import httpx
28+
import torch
2429
from PIL import Image
2530

2631
from .http_client import get_http_client
2732

2833
logger = logging.getLogger(__name__)
2934

35+
# Image output can be either PIL Image or Tensor (from nvimgcodec)
36+
ImageOutput: TypeAlias = Union[Image.Image, torch.Tensor]
37+
38+
# Thread-local storage for nvimgcodec decoders
39+
_thread_local = threading.local()
40+
41+
# Lazy import for nvimgcodec
42+
_nvimgcodec = None
43+
_nvimgcodec_available: bool | None = None # None = not yet probed
44+
45+
# Global thread pool for nvimgcodec decoding operations
46+
# Default to 8 workers, configurable via DYN_IMAGE_DECODE_WORKERS env var
47+
_IMAGE_DECODE_WORKERS = int(os.environ.get("DYN_IMAGE_DECODE_WORKERS", 8))
48+
_decode_thread_pool = ThreadPoolExecutor(
49+
max_workers=_IMAGE_DECODE_WORKERS,
50+
thread_name_prefix="image_decode_",
51+
)
52+
53+
54+
def _is_nvimgcodec_available() -> bool:
55+
"""Check whether nvimgcodec can be imported. Result is cached."""
56+
global _nvimgcodec_available
57+
if _nvimgcodec_available is None:
58+
try:
59+
_get_nvimgcodec()
60+
_nvimgcodec_available = True
61+
except (ImportError, ModuleNotFoundError):
62+
_nvimgcodec_available = False
63+
return _nvimgcodec_available
64+
65+
66+
def _get_nvimgcodec():
67+
"""Lazy import nvimgcodec. Raises ImportError if not installed."""
68+
global _nvimgcodec
69+
if _nvimgcodec is None:
70+
from nvidia import nvimgcodec
71+
72+
_nvimgcodec = nvimgcodec
73+
return _nvimgcodec
74+
75+
76+
def get_decoder():
77+
"""Get or create a thread-local nvimgcodec decoder instance."""
78+
if not hasattr(_thread_local, "decoder"):
79+
nvimgcodec = _get_nvimgcodec()
80+
_thread_local.decoder = nvimgcodec.Decoder()
81+
logger.info("nvimgcodec decoder initialized for thread")
82+
return _thread_local.decoder
83+
3084

3185
class ImageLoader:
3286
CACHE_SIZE_MAXIMUM = 8
87+
DEFAULT_MAX_PENDING = 64
3388

3489
def __init__(
35-
self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0
90+
self,
91+
cache_size: int = CACHE_SIZE_MAXIMUM,
92+
http_timeout: float = 30.0,
93+
use_nvimgcodec: bool = True,
94+
max_pending: int | None = None,
3695
):
3796
self._http_timeout = http_timeout
3897
self._image_cache: dict[str, Image.Image] = {}
3998
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
4099

41-
async def load_image(self, image_url: str) -> Image.Image:
100+
# Fall back to PIL if nvimgcodec was requested but is not installed
101+
if use_nvimgcodec and not _is_nvimgcodec_available():
102+
logger.warning(
103+
"nvimgcodec requested but not installed — "
104+
"falling back to PIL for image decoding"
105+
)
106+
use_nvimgcodec = False
107+
self._use_nvimgcodec = use_nvimgcodec
108+
109+
if max_pending is None:
110+
max_pending = int(
111+
os.environ.get("DYN_IMAGE_MAX_PENDING", self.DEFAULT_MAX_PENDING)
112+
)
113+
self._pending_semaphore = asyncio.Semaphore(max_pending)
114+
self._max_pending = max_pending
115+
116+
def mark_consumed(self, count: int = 1):
117+
"""
118+
Signal that decoded images have been consumed by the vLLM prefill batch.
119+
Call this after the prefill batch completes to allow more images to be decoded.
120+
121+
Args:
122+
count: Number of images consumed (default: 1)
123+
"""
124+
for _ in range(count):
125+
self._pending_semaphore.release()
126+
127+
def _decode_with_nvimgcodec(self, data: bytes) -> torch.Tensor:
128+
"""
129+
Decode image bytes using nvimgcodec for GPU-accelerated decoding.
130+
131+
Returns:
132+
torch.Tensor in NCHW format (4D) on CUDA device.
133+
Shape: (1, C, H, W) - batch dimension added so vLLM treats it as
134+
a batch of images, not as embeddings.
135+
"""
136+
nvimgcodec = _get_nvimgcodec()
137+
decoder = get_decoder()
138+
code_stream = nvimgcodec.CodeStream(data)
139+
decoded = decoder.decode(code_stream)
140+
141+
device = torch.device("cuda", torch.cuda.current_device())
142+
tensor = torch.as_tensor(decoded, device=device)
143+
# HWC -> CHW
144+
tensor = tensor.permute(2, 0, 1)
145+
# Add batch dimension: CHW -> NCHW (1, C, H, W)
146+
# This is critical: 3D tensors are interpreted as embeddings by vLLM,
147+
# but 4D tensors are interpreted as a batch of images.
148+
tensor = tensor.unsqueeze(0)
149+
150+
return tensor
151+
152+
async def load_image(self, image_url: str) -> ImageOutput:
153+
"""Load an image from a URL or data URI."""
42154
parsed_url = urlparse(image_url)
43155

44-
# For HTTP(S) URLs, check cache first
45-
if parsed_url.scheme in ("http", "https"):
156+
# For HTTP(S) URLs, check cache first (PIL path only)
157+
if not self._use_nvimgcodec and parsed_url.scheme in ("http", "https"):
46158
image_url_lower = image_url.lower()
47159
if image_url_lower in self._image_cache:
48160
logger.debug(f"Image found in cache for URL: {image_url}")
@@ -61,7 +173,6 @@ async def load_image(self, image_url: str) -> Image.Image:
61173

62174
try:
63175
image_bytes = base64.b64decode(data)
64-
image_data = BytesIO(image_bytes)
65176
except binascii.Error as e:
66177
raise ValueError(f"Invalid base64 encoding: {e}")
67178
elif parsed_url.scheme in ("http", "https"):
@@ -73,31 +184,50 @@ async def load_image(self, image_url: str) -> Image.Image:
73184
if not response.content:
74185
raise ValueError("Empty response content from image URL")
75186

76-
image_data = BytesIO(response.content)
187+
image_bytes = response.content
77188
else:
78189
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
79190

80-
# PIL is sync, so offload to a thread to avoid blocking the event loop
81-
image = await asyncio.to_thread(Image.open, image_data)
191+
# Wait if too many decoded images are pending in the vLLM scheduler.
192+
# Released when the caller invokes mark_consumed() after prefill.
193+
await self._pending_semaphore.acquire()
194+
195+
try:
196+
if self._use_nvimgcodec:
197+
# nvimgcodec decoding (GPU-accelerated, returns 4D tensor)
198+
loop = asyncio.get_running_loop()
199+
return await loop.run_in_executor(
200+
_decode_thread_pool,
201+
self._decode_with_nvimgcodec,
202+
image_bytes,
203+
)
204+
else:
205+
# Original PIL path
206+
image_data = BytesIO(image_bytes)
207+
image = await asyncio.to_thread(Image.open, image_data)
208+
209+
# Validate image format and convert to RGB
210+
if image.format not in ("JPEG", "PNG", "WEBP"):
211+
raise ValueError(f"Unsupported image format: {image.format}")
82212

83-
# Validate image format and convert to RGB
84-
if image.format not in ("JPEG", "PNG", "WEBP"):
85-
raise ValueError(f"Unsupported image format: {image.format}")
213+
image_converted = image.convert("RGB")
86214

87-
image_converted = image.convert("RGB")
215+
# Cache HTTP(S) URLs
216+
if parsed_url.scheme in ("http", "https"):
217+
image_url_lower = image_url.lower()
218+
if self._cache_queue.full():
219+
oldest_image_url = await self._cache_queue.get()
220+
del self._image_cache[oldest_image_url]
88221

89-
# Cache HTTP(S) URLs
90-
if parsed_url.scheme in ("http", "https"):
91-
image_url_lower = image_url.lower()
92-
# Cache the image for future use, and evict the oldest image if the cache is full
93-
if self._cache_queue.full():
94-
oldest_image_url = await self._cache_queue.get()
95-
del self._image_cache[oldest_image_url]
222+
self._image_cache[image_url_lower] = image_converted
223+
await self._cache_queue.put(image_url_lower)
96224

97-
self._image_cache[image_url_lower] = image_converted
98-
await self._cache_queue.put(image_url_lower)
225+
return image_converted
99226

100-
return image_converted
227+
except Exception:
228+
# Release semaphore on decode failure to prevent leak
229+
self._pending_semaphore.release()
230+
raise
101231

102232
except httpx.HTTPError as e:
103233
logger.error(f"HTTP error loading image: {e}")

0 commit comments

Comments
 (0)