From 9e7906899c76f466556ac6b0042241072594b8aa Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Mon, 11 May 2026 06:48:14 +0000 Subject: [PATCH] issue/344 add prefix hashing for mm_data --- README.md | 5 + python/infinilm/llm/cache_manager.py | 174 +++++++--- python/infinilm/llm/llm.py | 18 +- python/infinilm/llm/request.py | 5 + python/infinilm/llm/scheduler.py | 4 +- python/infinilm/multimodal/multimodal.py | 37 +- python/infinilm/processors/__init__.py | 3 + .../processors/basic_llm_processor.py | 11 + .../infinilm/processors/minicpmv_processor.py | 317 ++++++++++++++++++ python/infinilm/processors/processor.py | 19 +- python/infinilm/server/inference_server.py | 3 +- test/service/request.py | 122 +++++++ 12 files changed, 656 insertions(+), 62 deletions(-) create mode 100644 python/infinilm/processors/minicpmv_processor.py create mode 100644 test/service/request.py diff --git a/README.md b/README.md index fde56ce7..9dc528ec 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,11 @@ python scripts/test_perf.py --verbose ``` + - 单请求推理服务测试 + ```bash + python test/service/request.py --content="text:Image 1:" --content="image_url:xxx.jpg" --content="text:Image 2:" --content="image_url:xxxx.jpg" --content="text:Compare the 2 images." + ``` + - 运行推理基准测试(C-Eval/MMLU) ```bash diff --git a/python/infinilm/llm/cache_manager.py b/python/infinilm/llm/cache_manager.py index 44ca1376..ea857b02 100644 --- a/python/infinilm/llm/cache_manager.py +++ b/python/infinilm/llm/cache_manager.py @@ -3,6 +3,7 @@ """ from collections import deque +import queue from typing import List, Dict, Set import xxhash import numpy as np @@ -45,9 +46,9 @@ class BlockManager: """ def __init__(self, num_blocks: int, block_size: int): - assert ( - num_blocks > 0 and block_size > 0 - ), "num_blocks and block_size must be positive" + assert num_blocks > 0 and block_size > 0, ( + "num_blocks and block_size must be positive" + ) self.num_blocks = num_blocks self.block_size = block_size @@ -67,12 +68,20 @@ def reset_req_blocks(self) -> None: self.req_block_ids.clear() @classmethod - def compute_hash(cls, token_ids: List[int], prefix_hash: int = -1) -> int: + def compute_hash( + cls, + token_ids: List[int], + prefix_hash: int = -1, + mm_data_identifiers: List[str] = None, + ) -> int: """Compute hash for token sequence with optional prefix chaining.""" h = xxhash.xxh64() if prefix_hash != -1: h.update(prefix_hash.to_bytes(8, "little")) h.update(np.array(token_ids, dtype=np.int32).tobytes()) + if mm_data_identifiers is not None: + for identifier in mm_data_identifiers: + h.update(identifier.encode("utf-8")) return h.intdigest() def _allocate_partial_block(self, block_id: int) -> Block: @@ -100,9 +109,9 @@ def _allocate_full_block(self, block_id: int) -> Block: def _deallocate_block(self, block_id: int): """Deallocate a block and return it to free list.""" block = self.blocks[block_id] - assert ( - block.ref_count == 0 - ), f"Block {block_id} ref_count not zero, cannot deallocate" + assert block.ref_count == 0, ( + f"Block {block_id} ref_count not zero, cannot deallocate" + ) if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id: del self.hash_to_block_id[block.hash] @@ -115,73 +124,142 @@ def can_allocate(self, num_required_blocks: int) -> bool: return len(self.free_block_ids) >= num_required_blocks def allocate_blocks( - self, token_ids: List[int], block_table: List[int] = None + self, + token_ids: List[int], + block_table: List[int] = None, + mm_token_index_mappings: List[dict] = None, ) -> tuple[List[int], List[int], int]: """Allocate cache blocks for new request with prefix caching support. Args: token_ids: Input token sequence block_table: Existing block_table (for decode phase) - + mm_token_index_mappings: List of multimodal token index mappings Returns: Tuple of (block_table, slot_mapping, num_cached_tokens) """ if block_table is None: block_table = [] + # Static args num_tokens = len(token_ids) num_blocks = (num_tokens + self.block_size - 1) // self.block_size + num_full_blocks = num_tokens // self.block_size + remain_tokens = num_tokens % self.block_size + num_mm_inputs = ( + 0 if not mm_token_index_mappings else len(mm_token_index_mappings) + ) + + # Variables slot_mapping = [] num_cached_tokens = 0 prefix_hash = -1 cache_miss = False + mm_start_counter = 0 + mm_caching_queue = queue.Queue(maxsize=len(mm_token_index_mappings)) + blocks_blueprint = [] # [{"prefix_hash": int or -1 if not a full block, "block_id": int or -1 if not cached}, ...] + max_blocks_to_reuse = num_full_blocks for block_idx in range(num_blocks): start_idx = block_idx * self.block_size end_idx = min(start_idx + self.block_size, num_tokens) block_tokens = token_ids[start_idx:end_idx] - # Only full blocks can be hashed for reuse - if len(block_tokens) == self.block_size: - prefix_hash = self.compute_hash(block_tokens, prefix_hash) - - # Try to reuse existing block - if not cache_miss: - cached_block_id = self.hash_to_block_id.get(prefix_hash, -1) - if ( - cached_block_id != -1 - and self.blocks[cached_block_id].token_ids == block_tokens - ): - # Check if all tokens are cached - if num_cached_tokens + self.block_size == len(token_ids): - cache_miss = True - else: - # Reuse successful - block = self.blocks[cached_block_id] - block.ref_count += 1 - block_table.append(cached_block_id) - num_cached_tokens += self.block_size - continue - else: - cache_miss = True - else: - prefix_hash = -1 - - # Cannot reuse, allocate new block - if not self.free_block_ids: - raise RuntimeError("No available cache blocks") + # Process multimodal token index mappings for this block + mm_data_identifiers = [] + while ( + mm_start_counter < num_mm_inputs + and mm_token_index_mappings[mm_start_counter]["start_index"] < end_idx + and mm_token_index_mappings[mm_start_counter]["start_index"] + >= start_idx + ): + # for all mm_data whose start_index is within this block's token range, add its identifier to the list + mm_data_identifiers.append( + mm_token_index_mappings[mm_start_counter]["identifier"] + ) + mm_caching_queue.put((mm_start_counter)) + mm_start_counter += 1 + + prefix_hash = ( + self.compute_hash(block_tokens, prefix_hash, mm_data_identifiers) + if len(block_tokens) == self.block_size + else -1 + ) + + # Try to reuse existing block if no previous cache miss yet + cached_block_id = ( + self.hash_to_block_id.get(prefix_hash, -1) if not cache_miss else -1 + ) + if ( + cached_block_id != -1 + and self.blocks[cached_block_id].token_ids != block_tokens + ): + cached_block_id = -1 + if end_idx == num_tokens and remain_tokens == 0: + # Spicial case, when the last block is fully packed, we cannot reuse it because we need to leave at least one uncached token for forward + cached_block_id = -1 + + # Deal with the first cache miss + if not cache_miss and cached_block_id == -1: + max_blocks_to_reuse = min(max_blocks_to_reuse, block_idx) + cache_miss = True + + if not cache_miss: + # pop fully cached mm_data + while ( + not mm_caching_queue.empty() + and mm_token_index_mappings[mm_caching_queue.queue[0]]["end_index"] + < end_idx + ): + mm_caching_queue.get() + + blocks_blueprint.append( + {"prefix_hash": prefix_hash, "block_id": cached_block_id} + ) + + # If there is one incomplete mm_data, tailing blocks need to fall back until all included mm_data are complete + if not mm_caching_queue.empty(): + incomplete_mm = mm_token_index_mappings[mm_caching_queue.get()] + incomplete_mm_start = incomplete_mm[ + "start_index" + ] # Fall back until this index is no longer included in the block + max_blocks_to_reuse = min( + max_blocks_to_reuse, incomplete_mm_start // self.block_size + ) + + num_cached_tokens = max_blocks_to_reuse * self.block_size + + for block_id in range(num_blocks): + n_block_tokens = self.block_size + + if block_id < max_blocks_to_reuse: + # Reuse block + block = self.blocks[blocks_blueprint[block_id]["block_id"]] + block.ref_count += 1 - new_block_id = self.free_block_ids[0] - if prefix_hash != -1: - block = self._allocate_full_block(new_block_id) - block.update(prefix_hash, block_tokens) else: - block = self._allocate_partial_block(new_block_id) - block_table.append(new_block_id) - - # Generate slot_mapping - for i in range(len(block_tokens)): - slot_mapping.append(new_block_id * self.block_size + i) + new_block_id = self.free_block_ids[0] + if blocks_blueprint[block_id]["prefix_hash"] != -1: + start_idx = block_id * self.block_size + end_idx = start_idx + self.block_size + block_tokens = token_ids[start_idx:end_idx] + block = self._allocate_full_block(new_block_id) + block.update( + blocks_blueprint[block_id]["prefix_hash"], block_tokens + ) + else: + block = self._allocate_partial_block(new_block_id) + n_block_tokens = remain_tokens + slot_mapping.extend( + list( + range( + block.block_id * self.block_size, + block.block_id * self.block_size + n_block_tokens, + ) + ) + ) + + block_table.append(block.block_id) return block_table, slot_mapping, num_cached_tokens diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index cba3af83..af57844b 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -688,7 +688,7 @@ def add_request( if request_id is None: request_id = f"cmpl-{uuid.uuid4().hex}" - images, videos, audios = None, None, None + mm_index_mappings = None processed_inputs = None if prompt_token_ids is not None: @@ -708,12 +708,23 @@ def add_request( messages, add_generation_prompt=add_generation_prompt ) - images, videos, audios = resolve_multimodal_inputs(messages) + mm_inputs = resolve_multimodal_inputs(messages) + processed_inputs = self.engine.process( - prompt, images, videos, audios, return_tensors="pt" + prompt, + mm_inputs["images"], + mm_inputs["videos"], + mm_inputs["audios"], + return_tensors="pt", ) prompt_token_ids = processed_inputs.get("input_ids").flatten().tolist() + mm_index_mappings = self.engine.processor.get_mm_token_index_list( + prompt_token_ids, + image_ids=mm_inputs["image_urls"], + video_ids=mm_inputs["video_urls"], + audio_ids=mm_inputs["audio_urls"], + ) if sampling_params is None: sampling_params = SamplingParams(max_tokens=self.config.max_tokens) @@ -726,6 +737,7 @@ def add_request( prompt=prompt, prompt_token_ids=prompt_token_ids, processed_inputs=processed_inputs, + mm_token_index_mappings=mm_index_mappings, sampling_params=sampling_params, eos_token_ids=self.engine.eos_token_ids, request_data=request_data, diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py index 15bcf69f..a94c2f68 100644 --- a/python/infinilm/llm/request.py +++ b/python/infinilm/llm/request.py @@ -106,6 +106,7 @@ def __init__( prompt: Optional[str] = None, prompt_token_ids: Optional[List[int]] = None, processed_inputs: Optional[dict] = None, + mm_token_index_mappings: Optional[List[dict]] = None, sampling_params: Optional[SamplingParams] = None, eos_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, @@ -122,6 +123,7 @@ def __init__( self.prompt_token_ids: List[int] = prompt_token_ids or [] self.prompt_length: int = len(self.prompt_token_ids) self.processed_inputs: Optional[dict] = processed_inputs + self.mm_token_index_mappings: Optional[List[dict]] = mm_token_index_mappings # Sampling parameters self.sampling_params: SamplingParams = sampling_params or SamplingParams() @@ -186,6 +188,9 @@ def get_num_blocks_required(self, block_size: int) -> int: def get_max_tokens(self) -> Optional[int]: return self.sampling_params.max_tokens + def get_mm_token_index_mappings(self) -> Optional[List[dict]]: + return self.mm_token_index_mappings + def is_finished(self) -> bool: return self.status in [ RequestStatus.FINISHED, diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index f9c11635..c5f4921a 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -86,7 +86,9 @@ def schedule(self) -> Optional[SchedulerOutput]: # Allocate blocks with automatic prefix caching support req.block_table, req.slot_mapping, req.num_cached_tokens = ( - self.cache_manager.allocate_blocks(req_tokens, req.block_table) + self.cache_manager.allocate_blocks( + req_tokens, req.block_table, req.get_mm_token_index_mappings() + ) ) req.num_blocks = len(req.block_table) diff --git a/python/infinilm/multimodal/multimodal.py b/python/infinilm/multimodal/multimodal.py index 7b067734..7e14a8c5 100644 --- a/python/infinilm/multimodal/multimodal.py +++ b/python/infinilm/multimodal/multimodal.py @@ -1,5 +1,23 @@ -from typing import List, Union +from typing import List, Optional, Union from PIL import Image +import xxhash + + +def has_multimodal_inputs(messages: Union[List[dict], dict]) -> bool: + """Check if the input messages contain any multimodal inputs.""" + if isinstance(messages, dict): + messages = [messages] + + for msg in messages: + content = msg.get("content", []) + if not isinstance(content, list): + return False + + for item in content: + if item.get("type") in ["image_url", "video_url", "audio_url"]: + return True + + return False def resolve_multimodal_inputs(messages: Union[List[dict], dict]): @@ -8,8 +26,11 @@ def resolve_multimodal_inputs(messages: Union[List[dict], dict]): messages = [messages] images = [] + image_urls = [] videos = [] + video_urls = [] audios = [] + audio_urls = [] for msg in messages: content = msg.get("content", []) @@ -19,11 +40,19 @@ def resolve_multimodal_inputs(messages: Union[List[dict], dict]): for item in content: if item.get("type") == "text": pass - elif item.get("type") == "image": + elif item.get("type") == "image_url": # TODO support other image url formats - images.append(Image.open(item["image_url"])) + images.append(Image.open(item["image_url"]["url"])) + image_urls.append(item["image_url"]["url"]) else: # TODO support video/audio raise NotImplementedError("Only image input is supported for now") - return images, videos, audios + return { + "images": images, + "image_urls": image_urls, + "videos": videos, + "video_urls": video_urls, + "audios": audios, + "audio_urls": audio_urls, + } diff --git a/python/infinilm/processors/__init__.py b/python/infinilm/processors/__init__.py index 61adff6d..ac2c42c0 100644 --- a/python/infinilm/processors/__init__.py +++ b/python/infinilm/processors/__init__.py @@ -1,6 +1,7 @@ from .processor import InfinilmProcessor from .basic_llm_processor import BasicLLMProcessor from .llama_processor import LlamaProcessor +from .minicpmv_processor import MiniCPMVProcessor from transformers import AutoConfig @@ -14,5 +15,7 @@ def from_pretrained(cls, model_dir_path: str, **kwargs) -> InfinilmProcessor: if model_type in ["llama"]: return LlamaProcessor(model_dir_path) + elif model_type in ["minicpmv"]: + return MiniCPMVProcessor(model_dir_path) else: return BasicLLMProcessor(model_dir_path) diff --git a/python/infinilm/processors/basic_llm_processor.py b/python/infinilm/processors/basic_llm_processor.py index 070a4062..97ff84c4 100644 --- a/python/infinilm/processors/basic_llm_processor.py +++ b/python/infinilm/processors/basic_llm_processor.py @@ -2,6 +2,7 @@ from transformers import AutoTokenizer from ..llm.static_scheduler import StaticSchedulerOutput from ..llm.scheduler import SchedulerOutput +from typing import override class BasicLLMProcessor(InfinilmProcessor): @@ -10,6 +11,7 @@ def __init__(self, model_dir_path: str): model_dir_path, trust_remote_code=True ) + @override def __call__(self, prompt: str, return_tensors: str = None, **kwargs) -> dict: if return_tensors is None: return self.tokenizer(prompt) @@ -24,6 +26,7 @@ def __call__(self, prompt: str, return_tensors: str = None, **kwargs) -> dict: # "pt" or "np" or "tf". return self.tokenizer(prompt, return_tensors="pt") + @override def apply_chat_template( self, conversation, @@ -49,6 +52,7 @@ def apply_chat_template( **kwargs, ) + @override def build_model_inputs( self, scheduler_output: SchedulerOutput | StaticSchedulerOutput, @@ -236,5 +240,12 @@ def _build_model_input_from_batch_scheduler_output( "top_p": top_p, } + @override def get_tokenizer(self): return self.tokenizer + + @override + def get_mm_token_index_list( + self, prompt_token_ids, image_ids=None, video_ids=None, audio_ids=None, **kwargs + ): + return [] diff --git a/python/infinilm/processors/minicpmv_processor.py b/python/infinilm/processors/minicpmv_processor.py new file mode 100644 index 00000000..dd9d756c --- /dev/null +++ b/python/infinilm/processors/minicpmv_processor.py @@ -0,0 +1,317 @@ +from typing import override + +from transformers import AutoConfig, AutoProcessor + +from .processor import InfinilmProcessor + + +class MiniCPMVProcessor(InfinilmProcessor): + def __init__(self, model_dir_path: str): + """Initialize the processor with the model directory path.""" + self.processor = AutoProcessor.from_pretrained( + model_dir_path, trust_remote_code=True + ) + self.tokenizer = self.processor.tokenizer + self.config = AutoConfig.from_pretrained(model_dir_path, trust_remote_code=True) + self.pixel_values_dtype = self.config.dtype + + @override + def __call__( + self, + prompt, + images=None, + videos=None, + audios=None, + return_tensors: str = None, + **kwargs, + ) -> dict: + """ + Process the input prompt and media into final inputs. + + { + 'input_ids': TensorShape(1, seq_len), + 'attention_mask': TensorShape(1, seq_len), + 'pixel_values': [[TensorShape(patch_channel, patch_height, patch_width * dim) * n_patches]], + 'image_sizes': [[TensorShape(2,) * n_images]], + 'image_bound': [TensorShape(total_patch, 2)], + 'tgt_sizes': [TensorShape(total_patch, 2)], + } + + For text-only input, result only contains 'input_ids' and 'attention_mask'. + """ + if not images and not videos and not audios: + return self.tokenizer(prompt, return_tensors=return_tensors, **kwargs) + + results = self.processor( + prompt, images=images, return_tensors="pt", max_slice_nums=9, **kwargs + ) + + return results + + @override + def apply_chat_template( + self, + conversation, + add_generation_prompt: bool = False, + tokenize: bool = True, + **kwargs, + ): + """Apply chant template given input messages""" + processed_msg = [] + for msg in conversation: + content = msg["content"] + if not isinstance(content, list): + if isinstance(content, str): + processed_msg.append( + {"role": msg.get("role", "user"), "content": content} + ) + else: + raise ValueError("Content must be a list of items or a string") + continue + + processed_content = [] + for item in content: + if item.get("type") == "text": + processed_content.append(item.get("text", "")) + elif item.get("type") == "image_url": + processed_content.append("(./)") + else: + raise NotImplementedError("Only image input is supported for now") + + processed_msg.append( + { + "role": msg.get("role", "user"), + "content": "\n".join(processed_content), + } + ) + + return self.tokenizer.apply_chat_template( + conversation=processed_msg, + add_generation_prompt=add_generation_prompt, + tokenize=tokenize, + **kwargs, + ) + + @override + def build_model_inputs( + self, + scheduler_output, + temperature: float = 1.0, + top_p: float = 0.8, + top_k: int = 1, + **kwargs, + ) -> dict: + """Build batched infinilm model inputs from the scheduler output.""" + import infinicore + + if not scheduler_output.scheduled_requests: + raise RuntimeError( + "build_model_inputs called with empty scheduled_requests" + ) + + tokens = [] + seq_lens = [] + seq_offsets = [0] + block_tables = [] + slot_mapping = [] + cached_lens = [] + position_ids = [] + cu_seqlens = [0] + mm_data = {} + + max_block_table_len = max( + len(req.block_table) for req in scheduler_output.scheduled_requests + ) + current_offset = 0 + + for req in scheduler_output.scheduled_requests: + num_cached = req.num_cached_tokens + if scheduler_output.is_prefill: + # Prefill phase + req_tokens = req.get_input_tokens() + tokens_to_compute = req_tokens[num_cached:] + tokens.extend(tokens_to_compute) + + compute_len = len(tokens_to_compute) + seq_len = len(req_tokens) + seq_lens.append(seq_len) + + current_offset += compute_len + seq_offsets.append(current_offset) + + slot_mapping.extend(req.slot_mapping) + cached_lens.append(num_cached) + position_ids.extend(range(num_cached, num_cached + compute_len)) + + if ( + req.processed_inputs is not None + and "pixel_values" in req.processed_inputs + ): + import torch + + assert len(scheduler_output.scheduled_requests) == 1, ( + "Batching is not supported for image inputs yet" + ) + + num_cached_patch = ( + (req.processed_inputs["image_bound"][0][:, 1] <= num_cached) + .sum() + .item() + ) + + # if all patches are already cached, skip processing multimodal inputs and return text-only inputs for this request + if ( + num_cached_patch + < req.processed_inputs["image_bound"][0].shape[0] + ): + # 1. pixel_values + all_pixel_values = [] + pixel_values = req.processed_inputs["pixel_values"] + tgt_sizes = req.processed_inputs["tgt_sizes"] + image_bound = req.processed_inputs["image_bound"] + for pv in pixel_values: + all_pixel_values.extend( + [ + t.flatten(end_dim=1) + .permute(1, 0) + .to(self.pixel_values_dtype) + for i, t in enumerate(pv) + if i >= num_cached_patch + ] + ) + + pixel_values_tensor = torch.nn.utils.rnn.pad_sequence( + all_pixel_values, batch_first=True, padding_value=0.0 + ) + B, L, _ = pixel_values_tensor.shape + pixel_values_tensor = ( + pixel_values_tensor.permute(0, 2, 1) + .reshape(B, 3, -1, L) + .contiguous() + ) + pixel_values_infini = infinicore.from_torch(pixel_values_tensor) + + # 2. tgt_sizes + all_tgt_sizes = [ + tgt_size + for i, tgt_size in enumerate(tgt_sizes) + if isinstance(tgt_size, torch.Tensor) + and i >= num_cached_patch + ] + + tgt_sizes_tensor = torch.vstack(all_tgt_sizes).to(torch.int64) + + tgt_sizes_infini = infinicore.from_torch(tgt_sizes_tensor) + + # 3. image_bound + batch_size = len(image_bound) + max_ranges = max(len(b) for b in image_bound) + + bound = torch.zeros( + (batch_size, max_ranges, 2), dtype=torch.int64 + ) + + for i, bnd in enumerate(image_bound): + bnd = bnd[num_cached_patch:, :] + if len(bnd) > 0: + bound[i, : len(bnd), :] = bnd + + image_bound_infini = infinicore.from_torch(bound) + mm_data["pixel_values"] = pixel_values_infini + mm_data["tgt_sizes"] = tgt_sizes_infini + mm_data["image_bound"] = image_bound_infini + + else: + # Decode phase + seq_len = req.get_total_length() + last_token = req.generated_token_ids[-1] + tokens.append(last_token) + seq_lens.append(seq_len) + + current_offset += 1 + seq_offsets.append(current_offset) + + slot_mapping.extend(req.slot_mapping) + cached_lens.append(num_cached) + position_ids.append(seq_len - 1) + + # Pad block_table to same length + padded_block_table = req.block_table + [-1] * ( + max_block_table_len - len(req.block_table) + ) + block_tables.append(padded_block_table) + cu_seqlens.append(cu_seqlens[-1] + seq_len) + + return { + "input_ids": infinicore.from_list([tokens], dtype=infinicore.int64), + "position_ids": infinicore.from_list(position_ids, dtype=infinicore.int64), + "past_kv_lengths": infinicore.from_list( + cached_lens, dtype=infinicore.int32 + ), + "total_kv_lengths": infinicore.from_list(seq_lens, dtype=infinicore.int32), + "input_offsets": infinicore.from_list(seq_offsets, dtype=infinicore.int32), + "cu_seqlens": infinicore.from_list(cu_seqlens, dtype=infinicore.int32), + "block_tables": infinicore.from_list(block_tables, dtype=infinicore.int32), + "slot_mapping": infinicore.from_list(slot_mapping, dtype=infinicore.int64), + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + **mm_data, + } + + @override + def get_tokenizer(self): + """Return the text tokenizer associated with this processor.""" + return self.tokenizer + + @override + def get_mm_token_index_list( + self, prompt_token_ids, image_ids=None, video_ids=None, **kwargs + ): + image_idx = -1 + patch_start = [] + patch_end = [] + mm_token_index_list = [] + for i, token_id in enumerate(prompt_token_ids): + if token_id == self.tokenizer.im_id_start_id: + assert len(patch_start) == len(patch_end), ( + "Invalid prompt format: image start token found before previous image end token is closed" + ) + # deal with previous image patches + if patch_start: + for start, end in zip(patch_start, patch_end): + mm_token_index_list.append( + { + "start_index": start, + "end_index": end, + "identifier": image_ids[image_idx], + } + ) + # reset patch start and end for next image + patch_start = [] + patch_end = [] + + # increment image index for next image + image_idx += 1 + patch_start.append(i + 1) + elif token_id == self.tokenizer.slice_start_id: + patch_start.append(i + 1) + elif ( + token_id == self.tokenizer.im_id_end_id + or token_id == self.tokenizer.slice_end_id + ): + patch_end.append(i - 1) + + if patch_start: + for start, end in zip(patch_start, patch_end): + mm_token_index_list.append( + { + "start_index": start, + "end_index": end, + "identifier": image_ids[image_idx], + } + ) + assert image_idx + 1 == len(image_ids), ( + "The number of image tokens does not match the number of images data provided" + ) + return mm_token_index_list diff --git a/python/infinilm/processors/processor.py b/python/infinilm/processors/processor.py index fd8353fe..402e6e97 100644 --- a/python/infinilm/processors/processor.py +++ b/python/infinilm/processors/processor.py @@ -13,22 +13,31 @@ def __call__( **kwargs, ) -> dict: """Process the input prompt and media into final inputs.""" - raise NotImplementedError("InfinilmProcessor is not implemented yet") + raise NotImplementedError("__call__ is not implemented yet") def apply_chat_template( self, - messages, + conversation, add_generation_prompt: bool = False, tokenize: bool = True, **kwargs, ): """Apply chant template given input messages""" - raise NotImplementedError("InfinilmProcessor is not implemented yet") + raise NotImplementedError("apply_chat_template is not implemented yet") def build_model_inputs(self, scheduler_output, **kwargs) -> dict: """Build batched infinilm model inputs from the scheduler output.""" - raise NotImplementedError("InfinilmProcessor is not implemented yet") + raise NotImplementedError("build_model_inputs is not implemented yet") def get_tokenizer(self): """Return the text tokenizer associated with this processor.""" - raise NotImplementedError("InfinilmProcessor is not implemented yet") + raise NotImplementedError("get_tokenizer is not implemented yet") + + def get_mm_token_index_list( + self, prompt_token_ids, image_ids=None, video_ids=None, audio_ids=None, **kwargs + ): + """ + Get the list of starting token index and identifier mapping for multimodal inputs, sorted by index. + Return: [{"start_index": , "identifier": }, ...] + """ + raise NotImplementedError("get_mm_token_index_list is not implemented yet") diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index 71e9c992..3d35941c 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -218,7 +218,8 @@ async def chat_completions(request: Request): data["messages"] = [{"role": "user", "content": data.get("prompt")}] # Normalize messages to handle multimodal content (list format) - data["messages"] = self._normalize_messages(data.get("messages", [])) + # data["messages"] = self._normalize_messages(data.get("messages", [])) + data["messages"] = data.get("messages", []) stream = data.get("stream", False) request_id = f"cmpl-{uuid.uuid4().hex}" diff --git a/test/service/request.py b/test/service/request.py new file mode 100644 index 00000000..d5fa008e --- /dev/null +++ b/test/service/request.py @@ -0,0 +1,122 @@ +import argparse +import asyncio +import time + +from openai import AsyncOpenAI + + +def get_args(): + # 1. 创建参数解析器(支持重复 --message 构建列表) + parser = argparse.ArgumentParser(description="向推理服务发送 OpenAI 格式请求") + + # 核心:重复 --content 自动拼成列表 + parser.add_argument( + "--system", + type=str, + default="", + help="system prompt", + ) + parser.add_argument( + "--content", + action="append", + default=[], + help="start with content type['text', 'image_url'] and colon, e.g. text:hello or image_url:http://example.com/image.jpg", + ) + + # 目标服务地址与端口 + parser.add_argument( + "--port", type=int, default=8000, help="推理服务端口,默认 8000" + ) + parser.add_argument( + "--host", default="127.0.0.1", help="推理服务地址,默认 127.0.0.1" + ) + + # 解析参数 + return parser.parse_args() + + +def build_messages(content_args, system_prompt): + contents = [] + for content in content_args: + if ":" not in content: + raise ValueError( + f"Invalid content format: '{content}'. Expected format is 'type:value'." + ) + ctype, cvalue = content.split(":", 1) + + if ctype == "text": + contents.append({"type": "text", "text": cvalue}) + elif ctype == "image_url": + contents.append({"type": "image_url", "image_url": {"url": cvalue}}) + else: + raise ValueError( + f"Unsupported content type: '{ctype}'. Supported types are 'text' and 'image_url'." + ) + + messages = ( + [] if not system_prompt else [{"role": "system", "content": system_prompt}] + ) + messages.append({"role": "user", "content": contents}) + return messages + + +async def benchmark_user(client, messages): + try: + print(f" ❓ 提问: {messages}") + start_time = time.time() + stream = await client.chat.completions.create( + model="default", + messages=messages, + stream=True, + ) + + first_token_time = None + total_tokens = 0 + answer_chunks = [] + + async for chunk in stream: + if first_token_time is None: + first_token_time = time.time() + delta = chunk.choices[0].delta.content + if delta: + answer_chunks.append(delta) + total_tokens += 1 + if chunk.choices[0].finish_reason is not None: + break + + end_time = time.time() + + ttft = first_token_time - start_time if first_token_time else None + elapsed_time = end_time - start_time if start_time else None + ms_per_token_decode = ( + ((elapsed_time - ttft) / (total_tokens - 1) * 1000) + if total_tokens - 1 > 0 and elapsed_time + else None + ) + + answer = "".join(answer_chunks) + print(f" 💬 回答: {answer}\n") + print(f" 总耗时: {elapsed_time:.3f}s") + print(f" 首字延迟 TTFT: {ttft:.3f}s") + print(f" Token间延迟 ITL: {ms_per_token_decode:.2f} ms") + print( + f" Decode吞吐: {1000 / ms_per_token_decode:.2f} tokens/s" + if ms_per_token_decode + else " Decode吞吐: N/A" + ) + + except Exception as e: + print(f" ❌ Error: {e}\n") + + +def main(): + args = get_args() + if not args.content: + args.content = ["text:山东最高的山是?"] + messages = build_messages(args.content, args.system) + client = AsyncOpenAI(base_url=f"http://{args.host}:{args.port}", api_key="default") + asyncio.run(benchmark_user(client, messages)) + + +if __name__ == "__main__": + main()