1717import base64
1818import binascii
1919import logging
20+ import os
21+ import threading
22+ from concurrent .futures import ThreadPoolExecutor
2023from io import BytesIO
24+ from typing import TypeAlias , Union
2125from urllib .parse import urlparse
2226
2327import httpx
28+ import torch
2429from PIL import Image
2530
2631from .http_client import get_http_client
2732
2833logger = logging .getLogger (__name__ )
2934
35+ # Image output can be either PIL Image or Tensor (from nvimgcodec)
36+ ImageOutput : TypeAlias = Union [Image .Image , torch .Tensor ]
37+
38+ # Thread-local storage for nvimgcodec decoders
39+ _thread_local = threading .local ()
40+
41+ # Lazy import for nvimgcodec
42+ _nvimgcodec = None
43+ _nvimgcodec_available : bool | None = None # None = not yet probed
44+
45+ # Global thread pool for nvimgcodec decoding operations
46+ # Default to 8 workers, configurable via DYN_IMAGE_DECODE_WORKERS env var
47+ _IMAGE_DECODE_WORKERS = int (os .environ .get ("DYN_IMAGE_DECODE_WORKERS" , 8 ))
48+ _decode_thread_pool = ThreadPoolExecutor (
49+ max_workers = _IMAGE_DECODE_WORKERS ,
50+ thread_name_prefix = "image_decode_" ,
51+ )
52+
53+
54+ def _is_nvimgcodec_available () -> bool :
55+ """Check whether nvimgcodec can be imported. Result is cached."""
56+ global _nvimgcodec_available
57+ if _nvimgcodec_available is None :
58+ try :
59+ _get_nvimgcodec ()
60+ _nvimgcodec_available = True
61+ except (ImportError , ModuleNotFoundError ):
62+ _nvimgcodec_available = False
63+ return _nvimgcodec_available
64+
65+
66+ def _get_nvimgcodec ():
67+ """Lazy import nvimgcodec. Raises ImportError if not installed."""
68+ global _nvimgcodec
69+ if _nvimgcodec is None :
70+ from nvidia import nvimgcodec
71+
72+ _nvimgcodec = nvimgcodec
73+ return _nvimgcodec
74+
75+
76+ def get_decoder ():
77+ """Get or create a thread-local nvimgcodec decoder instance."""
78+ if not hasattr (_thread_local , "decoder" ):
79+ nvimgcodec = _get_nvimgcodec ()
80+ _thread_local .decoder = nvimgcodec .Decoder ()
81+ logger .info ("nvimgcodec decoder initialized for thread" )
82+ return _thread_local .decoder
83+
3084
3185class ImageLoader :
3286 CACHE_SIZE_MAXIMUM = 8
87+ DEFAULT_MAX_PENDING = 64
3388
3489 def __init__ (
35- self , cache_size : int = CACHE_SIZE_MAXIMUM , http_timeout : float = 30.0
90+ self ,
91+ cache_size : int = CACHE_SIZE_MAXIMUM ,
92+ http_timeout : float = 30.0 ,
93+ use_nvimgcodec : bool = True ,
94+ max_pending : int | None = None ,
3695 ):
3796 self ._http_timeout = http_timeout
3897 self ._image_cache : dict [str , Image .Image ] = {}
3998 self ._cache_queue : asyncio .Queue [str ] = asyncio .Queue (maxsize = cache_size )
4099
41- async def load_image (self , image_url : str ) -> Image .Image :
100+ # Fall back to PIL if nvimgcodec was requested but is not installed
101+ if use_nvimgcodec and not _is_nvimgcodec_available ():
102+ logger .warning (
103+ "nvimgcodec requested but not installed — "
104+ "falling back to PIL for image decoding"
105+ )
106+ use_nvimgcodec = False
107+ self ._use_nvimgcodec = use_nvimgcodec
108+
109+ if max_pending is None :
110+ max_pending = int (
111+ os .environ .get ("DYN_IMAGE_MAX_PENDING" , self .DEFAULT_MAX_PENDING )
112+ )
113+ self ._pending_semaphore = asyncio .Semaphore (max_pending )
114+ self ._max_pending = max_pending
115+
116+ def mark_consumed (self , count : int = 1 ):
117+ """
118+ Signal that decoded images have been consumed by the vLLM prefill batch.
119+ Call this after the prefill batch completes to allow more images to be decoded.
120+
121+ Args:
122+ count: Number of images consumed (default: 1)
123+ """
124+ for _ in range (count ):
125+ self ._pending_semaphore .release ()
126+
127+ def _decode_with_nvimgcodec (self , data : bytes ) -> torch .Tensor :
128+ """
129+ Decode image bytes using nvimgcodec for GPU-accelerated decoding.
130+
131+ Returns:
132+ torch.Tensor in NCHW format (4D) on CUDA device.
133+ Shape: (1, C, H, W) - batch dimension added so vLLM treats it as
134+ a batch of images, not as embeddings.
135+ """
136+ nvimgcodec = _get_nvimgcodec ()
137+ decoder = get_decoder ()
138+ code_stream = nvimgcodec .CodeStream (data )
139+ decoded = decoder .decode (code_stream )
140+
141+ device = torch .device ("cuda" , torch .cuda .current_device ())
142+ tensor = torch .as_tensor (decoded , device = device )
143+ # HWC -> CHW
144+ tensor = tensor .permute (2 , 0 , 1 )
145+ # Add batch dimension: CHW -> NCHW (1, C, H, W)
146+ # This is critical: 3D tensors are interpreted as embeddings by vLLM,
147+ # but 4D tensors are interpreted as a batch of images.
148+ tensor = tensor .unsqueeze (0 )
149+
150+ return tensor
151+
152+ async def load_image (self , image_url : str ) -> ImageOutput :
153+ """Load an image from a URL or data URI."""
42154 parsed_url = urlparse (image_url )
43155
44- # For HTTP(S) URLs, check cache first
45- if parsed_url .scheme in ("http" , "https" ):
156+ # For HTTP(S) URLs, check cache first (PIL path only)
157+ if not self . _use_nvimgcodec and parsed_url .scheme in ("http" , "https" ):
46158 image_url_lower = image_url .lower ()
47159 if image_url_lower in self ._image_cache :
48160 logger .debug (f"Image found in cache for URL: { image_url } " )
@@ -61,7 +173,6 @@ async def load_image(self, image_url: str) -> Image.Image:
61173
62174 try :
63175 image_bytes = base64 .b64decode (data )
64- image_data = BytesIO (image_bytes )
65176 except binascii .Error as e :
66177 raise ValueError (f"Invalid base64 encoding: { e } " )
67178 elif parsed_url .scheme in ("http" , "https" ):
@@ -73,31 +184,50 @@ async def load_image(self, image_url: str) -> Image.Image:
73184 if not response .content :
74185 raise ValueError ("Empty response content from image URL" )
75186
76- image_data = BytesIO ( response .content )
187+ image_bytes = response .content
77188 else :
78189 raise ValueError (f"Invalid image source scheme: { parsed_url .scheme } " )
79190
80- # PIL is sync, so offload to a thread to avoid blocking the event loop
81- image = await asyncio .to_thread (Image .open , image_data )
191+ # Wait if too many decoded images are pending in the vLLM scheduler.
192+ # Released when the caller invokes mark_consumed() after prefill.
193+ await self ._pending_semaphore .acquire ()
194+
195+ try :
196+ if self ._use_nvimgcodec :
197+ # nvimgcodec decoding (GPU-accelerated, returns 4D tensor)
198+ loop = asyncio .get_running_loop ()
199+ return await loop .run_in_executor (
200+ _decode_thread_pool ,
201+ self ._decode_with_nvimgcodec ,
202+ image_bytes ,
203+ )
204+ else :
205+ # Original PIL path
206+ image_data = BytesIO (image_bytes )
207+ image = await asyncio .to_thread (Image .open , image_data )
208+
209+ # Validate image format and convert to RGB
210+ if image .format not in ("JPEG" , "PNG" , "WEBP" ):
211+ raise ValueError (f"Unsupported image format: { image .format } " )
82212
83- # Validate image format and convert to RGB
84- if image .format not in ("JPEG" , "PNG" , "WEBP" ):
85- raise ValueError (f"Unsupported image format: { image .format } " )
213+ image_converted = image .convert ("RGB" )
86214
87- image_converted = image .convert ("RGB" )
215+ # Cache HTTP(S) URLs
216+ if parsed_url .scheme in ("http" , "https" ):
217+ image_url_lower = image_url .lower ()
218+ if self ._cache_queue .full ():
219+ oldest_image_url = await self ._cache_queue .get ()
220+ del self ._image_cache [oldest_image_url ]
88221
89- # Cache HTTP(S) URLs
90- if parsed_url .scheme in ("http" , "https" ):
91- image_url_lower = image_url .lower ()
92- # Cache the image for future use, and evict the oldest image if the cache is full
93- if self ._cache_queue .full ():
94- oldest_image_url = await self ._cache_queue .get ()
95- del self ._image_cache [oldest_image_url ]
222+ self ._image_cache [image_url_lower ] = image_converted
223+ await self ._cache_queue .put (image_url_lower )
96224
97- self ._image_cache [image_url_lower ] = image_converted
98- await self ._cache_queue .put (image_url_lower )
225+ return image_converted
99226
100- return image_converted
227+ except Exception :
228+ # Release semaphore on decode failure to prevent leak
229+ self ._pending_semaphore .release ()
230+ raise
101231
102232 except httpx .HTTPError as e :
103233 logger .error (f"HTTP error loading image: { e } " )
0 commit comments